spacr 0.4.15__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -2
- spacr/core.py +52 -10
- spacr/deep_spacr.py +2 -3
- spacr/gui.py +0 -1
- spacr/gui_core.py +247 -41
- spacr/gui_elements.py +133 -2
- spacr/gui_utils.py +22 -17
- spacr/io.py +624 -149
- spacr/ml.py +141 -258
- spacr/plot.py +76 -34
- spacr/resources/MEDIAR/__pycache__/SetupDict.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/evaluate.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/generate_mapping.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/main.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/Predictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/Trainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/EnsemblePredictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Predictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Trainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/BasePredictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/BaseTrainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/measures.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/datasetter.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/transforms.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/CellAware.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/LoadImage.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/NormalizeImage.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/models/__pycache__/MEDIARFormer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/models/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/sequencing.py +73 -38
- spacr/settings.py +161 -135
- spacr/submodules.py +618 -215
- spacr/timelapse.py +197 -29
- spacr/toxo.py +23 -23
- spacr/utils.py +186 -128
- {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/METADATA +5 -2
- {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/RECORD +53 -24
- spacr/stats.py +0 -221
- /spacr/{cellpose.py → spacr_cellpose.py} +0 -0
- {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/LICENSE +0 -0
- {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/WHEEL +0 -0
- {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/top_level.txt +0 -0
spacr/submodules.py
CHANGED
@@ -1,14 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
1
|
import seaborn as sns
|
5
|
-
import os, random, sqlite3, re, shap
|
2
|
+
import os, random, sqlite3, re, shap, string, time
|
6
3
|
import pandas as pd
|
7
4
|
import numpy as np
|
8
|
-
|
5
|
+
|
9
6
|
from skimage.measure import regionprops, label
|
7
|
+
from skimage.transform import resize as sk_resize, rotate
|
8
|
+
from skimage.exposure import rescale_intensity
|
9
|
+
|
10
|
+
import cellpose
|
11
|
+
from cellpose import models as cp_models
|
12
|
+
from cellpose import train as train_cp
|
10
13
|
from cellpose import models as cp_models
|
14
|
+
from cellpose import io as cp_io
|
11
15
|
from cellpose import train as train_cp
|
16
|
+
from cellpose.metrics import aggregated_jaccard_index
|
17
|
+
from cellpose.metrics import average_precision
|
18
|
+
|
12
19
|
from IPython.display import display
|
13
20
|
from sklearn.ensemble import RandomForestClassifier
|
14
21
|
from sklearn.inspection import permutation_importance
|
@@ -17,10 +24,545 @@ from scipy.stats import chi2_contingency, pearsonr
|
|
17
24
|
from scipy.spatial.distance import cosine
|
18
25
|
|
19
26
|
from sklearn.metrics import mean_absolute_error
|
20
|
-
|
27
|
+
from skimage.measure import regionprops, label as sklabel
|
21
28
|
import matplotlib.pyplot as plt
|
22
29
|
from natsort import natsorted
|
23
30
|
|
31
|
+
from torch.utils.data import Dataset
|
32
|
+
|
33
|
+
class CellposeLazyDataset(Dataset):
|
34
|
+
def __init__(self, image_files, label_files, settings, randomize=True, augment=False):
|
35
|
+
combined = list(zip(image_files, label_files))
|
36
|
+
if randomize:
|
37
|
+
random.shuffle(combined)
|
38
|
+
self.image_files, self.label_files = zip(*combined)
|
39
|
+
self.normalize = settings['normalize']
|
40
|
+
self.percentiles = settings.get('percentiles', [2, 99])
|
41
|
+
self.target_size = settings['target_size']
|
42
|
+
self.augment = augment
|
43
|
+
|
44
|
+
def __len__(self):
|
45
|
+
return len(self.image_files) * (8 if self.augment else 1)
|
46
|
+
|
47
|
+
def apply_augmentation(self, image, label, aug_idx):
|
48
|
+
if aug_idx == 1:
|
49
|
+
return rotate(image, 90, resize=False, preserve_range=True), rotate(label, 90, resize=False, preserve_range=True)
|
50
|
+
elif aug_idx == 2:
|
51
|
+
return rotate(image, 180, resize=False, preserve_range=True), rotate(label, 180, resize=False, preserve_range=True)
|
52
|
+
elif aug_idx == 3:
|
53
|
+
return rotate(image, 270, resize=False, preserve_range=True), rotate(label, 270, resize=False, preserve_range=True)
|
54
|
+
elif aug_idx == 4:
|
55
|
+
return np.fliplr(image), np.fliplr(label)
|
56
|
+
elif aug_idx == 5:
|
57
|
+
return np.flipud(image), np.flipud(label)
|
58
|
+
elif aug_idx == 6:
|
59
|
+
return np.fliplr(rotate(image, 90, resize=False, preserve_range=True)), np.fliplr(rotate(label, 90, resize=False, preserve_range=True))
|
60
|
+
elif aug_idx == 7:
|
61
|
+
return np.flipud(rotate(image, 90, resize=False, preserve_range=True)), np.flipud(rotate(label, 90, resize=False, preserve_range=True))
|
62
|
+
return image, label
|
63
|
+
|
64
|
+
def __getitem__(self, idx):
|
65
|
+
base_idx = idx // 8 if self.augment else idx
|
66
|
+
aug_idx = idx % 8 if self.augment else 0
|
67
|
+
|
68
|
+
image = cp_io.imread(self.image_files[base_idx])
|
69
|
+
label = cp_io.imread(self.label_files[base_idx])
|
70
|
+
|
71
|
+
if image.ndim == 3:
|
72
|
+
image = image.mean(axis=-1)
|
73
|
+
|
74
|
+
if image.max() > 1:
|
75
|
+
image = image / image.max()
|
76
|
+
|
77
|
+
if self.normalize:
|
78
|
+
lower_p, upper_p = np.percentile(image, self.percentiles)
|
79
|
+
image = rescale_intensity(image, in_range=(lower_p, upper_p), out_range=(0, 1))
|
80
|
+
|
81
|
+
image, label = self.apply_augmentation(image, label, aug_idx)
|
82
|
+
|
83
|
+
image_shape = (self.target_size, self.target_size)
|
84
|
+
image = sk_resize(image, image_shape, preserve_range=True, anti_aliasing=True).astype(np.float32)
|
85
|
+
label = sk_resize(label, image_shape, order=0, preserve_range=True, anti_aliasing=False).astype(np.uint8)
|
86
|
+
|
87
|
+
return image, label
|
88
|
+
|
89
|
+
def train_cellpose(settings):
|
90
|
+
|
91
|
+
from .settings import get_train_cellpose_default_settings
|
92
|
+
from .utils import save_settings
|
93
|
+
|
94
|
+
settings = get_train_cellpose_default_settings(settings)
|
95
|
+
img_src = os.path.join(settings['src'], 'train', 'images')
|
96
|
+
mask_src = os.path.join(settings['src'], 'train', 'masks')
|
97
|
+
target_size = settings['target_size']
|
98
|
+
|
99
|
+
model_name = f"{settings['model_name']}_cyto_e{settings['n_epochs']}_X{target_size}_Y{target_size}.CP_model"
|
100
|
+
model_save_path = os.path.join(settings['src'], 'models', 'cellpose_model')
|
101
|
+
os.makedirs(model_save_path, exist_ok=True)
|
102
|
+
|
103
|
+
save_settings(settings, name=model_name)
|
104
|
+
|
105
|
+
model = cp_models.CellposeModel(gpu=True, model_type='cyto', diam_mean=30, pretrained_model='cyto')
|
106
|
+
cp_channels = [0, 0]
|
107
|
+
|
108
|
+
#train_image_files = sorted([os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')])
|
109
|
+
#train_label_files = sorted([os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')])
|
110
|
+
|
111
|
+
image_filenames = set(f for f in os.listdir(img_src) if f.endswith('.tif'))
|
112
|
+
label_filenames = set(f for f in os.listdir(mask_src) if f.endswith('.tif'))
|
113
|
+
|
114
|
+
# Only keep files that are present in both folders
|
115
|
+
matched_filenames = sorted(image_filenames & label_filenames)
|
116
|
+
|
117
|
+
train_image_files = [os.path.join(img_src, f) for f in matched_filenames]
|
118
|
+
train_label_files = [os.path.join(mask_src, f) for f in matched_filenames]
|
119
|
+
|
120
|
+
train_dataset = CellposeLazyDataset(train_image_files, train_label_files, settings, randomize=True, augment=settings['augment'])
|
121
|
+
|
122
|
+
n_aug = 8 if settings['augment'] else 1
|
123
|
+
max_base_images = len(train_dataset) // n_aug if settings['augment'] else len(train_dataset)
|
124
|
+
n_base = min(settings['batch_size'], max_base_images)
|
125
|
+
|
126
|
+
unique_base_indices = list(range(max_base_images))
|
127
|
+
random.shuffle(unique_base_indices)
|
128
|
+
selected_indices = unique_base_indices[:n_base]
|
129
|
+
|
130
|
+
images, labels = [], []
|
131
|
+
for idx in selected_indices:
|
132
|
+
for aug_idx in range(n_aug):
|
133
|
+
i = idx * n_aug + aug_idx if settings['augment'] else idx
|
134
|
+
img, lbl = train_dataset[i]
|
135
|
+
images.append(img)
|
136
|
+
labels.append(lbl)
|
137
|
+
try:
|
138
|
+
plot_cellpose_batch(images, labels)
|
139
|
+
except:
|
140
|
+
print(f"could not print batch images")
|
141
|
+
|
142
|
+
print(f"Training model with {len(images)} ber patch for {settings['n_epochs']} Epochs")
|
143
|
+
|
144
|
+
train_cp.train_seg(model.net,
|
145
|
+
train_data=images,
|
146
|
+
train_labels=labels,
|
147
|
+
channels=cp_channels,
|
148
|
+
save_path=model_save_path,
|
149
|
+
n_epochs=settings['n_epochs'],
|
150
|
+
batch_size=settings['batch_size'],
|
151
|
+
learning_rate=settings['learning_rate'],
|
152
|
+
weight_decay=settings['weight_decay'],
|
153
|
+
model_name=model_name,
|
154
|
+
save_every=max(1, (settings['n_epochs'] // 10)),
|
155
|
+
rescale=False)
|
156
|
+
|
157
|
+
print(f"Model saved at: {model_save_path}/{model_name}")
|
158
|
+
|
159
|
+
def test_cellpose_model(settings):
|
160
|
+
|
161
|
+
from .utils import save_settings, print_progress
|
162
|
+
from .settings import get_default_test_cellpose_model_settings
|
163
|
+
|
164
|
+
def plot_cellpose_resilts(i, j, results_dir, img, lbl, pred, flow):
|
165
|
+
from . plot import generate_mask_random_cmap
|
166
|
+
fig, axs = plt.subplots(1, 5, figsize=(16, 4), gridspec_kw={'wspace': 0.1, 'hspace': 0.1})
|
167
|
+
cmap_lbl = generate_mask_random_cmap(lbl)
|
168
|
+
cmap_pred = generate_mask_random_cmap(pred)
|
169
|
+
|
170
|
+
axs[0].imshow(img, cmap='gray')
|
171
|
+
axs[0].set_title('Image')
|
172
|
+
axs[0].axis('off')
|
173
|
+
|
174
|
+
axs[1].imshow(lbl, cmap=cmap_lbl, interpolation='nearest')
|
175
|
+
axs[1].set_title('True Mask')
|
176
|
+
axs[1].axis('off')
|
177
|
+
|
178
|
+
axs[2].imshow(pred, cmap=cmap_pred, interpolation='nearest')
|
179
|
+
axs[2].set_title('Predicted Mask')
|
180
|
+
axs[2].axis('off')
|
181
|
+
|
182
|
+
axs[3].imshow(flow[2], cmap='gray')
|
183
|
+
axs[3].set_title('Cell Probability')
|
184
|
+
axs[3].axis('off')
|
185
|
+
|
186
|
+
axs[4].imshow(flow[0], cmap='gray')
|
187
|
+
axs[4].set_title('Flows')
|
188
|
+
axs[4].axis('off')
|
189
|
+
|
190
|
+
save_path = os.path.join(results_dir, f"cellpose_result_{i+j:03d}.png")
|
191
|
+
plt.savefig(save_path, dpi=200, bbox_inches='tight')
|
192
|
+
plt.show()
|
193
|
+
plt.close(fig)
|
194
|
+
|
195
|
+
|
196
|
+
settings = get_default_test_cellpose_model_settings(settings)
|
197
|
+
|
198
|
+
save_settings(settings, name='test_cellpose_model')
|
199
|
+
test_image_folder = os.path.join(settings['src'], 'test', 'images')
|
200
|
+
test_label_folder = os.path.join(settings['src'], 'test', 'masks')
|
201
|
+
results_dir = os.path.join(settings['src'], 'results')
|
202
|
+
os.makedirs(results_dir, exist_ok=True)
|
203
|
+
|
204
|
+
print(f"Results will be saved in: {results_dir}")
|
205
|
+
|
206
|
+
image_filenames = set(f for f in os.listdir(test_image_folder) if f.endswith('.tif'))
|
207
|
+
label_filenames = set(f for f in os.listdir(test_label_folder) if f.endswith('.tif'))
|
208
|
+
|
209
|
+
# Only keep files that are present in both folders
|
210
|
+
matched_filenames = sorted(image_filenames & label_filenames)
|
211
|
+
|
212
|
+
test_image_files = [os.path.join(test_image_folder, f) for f in matched_filenames]
|
213
|
+
test_label_files = [os.path.join(test_label_folder, f) for f in matched_filenames]
|
214
|
+
|
215
|
+
print(f"Found {len(test_image_files)} images and {len(test_label_files)} masks")
|
216
|
+
|
217
|
+
test_dataset = CellposeLazyDataset(test_image_files, test_label_files, settings, randomize=False, augment=False)
|
218
|
+
|
219
|
+
model = cp_models.CellposeModel(gpu=True, pretrained_model=settings['model_path'])
|
220
|
+
|
221
|
+
batch_size = settings['batch_size']
|
222
|
+
scores = []
|
223
|
+
names = []
|
224
|
+
time_ls = []
|
225
|
+
|
226
|
+
files_to_process = len(test_image_folder)
|
227
|
+
|
228
|
+
for i in range(0, len(test_dataset), batch_size):
|
229
|
+
start = time.time()
|
230
|
+
batch = [test_dataset[j] for j in range(i, min(i + batch_size, len(test_dataset)))]
|
231
|
+
images, labels = zip(*batch)
|
232
|
+
|
233
|
+
masks_pred, flows, _ = model.eval(x=list(images),
|
234
|
+
channels=[0, 0],
|
235
|
+
normalize=False,
|
236
|
+
diameter=30,
|
237
|
+
flow_threshold=settings['FT'],
|
238
|
+
cellprob_threshold=settings['CP_probability'],
|
239
|
+
rescale=None,
|
240
|
+
resample=True,
|
241
|
+
interp=True,
|
242
|
+
anisotropy=None,
|
243
|
+
min_size=5,
|
244
|
+
augment=True,
|
245
|
+
tile=True,
|
246
|
+
tile_overlap=0.2,
|
247
|
+
bsize=224)
|
248
|
+
|
249
|
+
n_objects_true_ls = []
|
250
|
+
n_objects_pred_ls = []
|
251
|
+
mean_area_true_ls = []
|
252
|
+
mean_area_pred_ls = []
|
253
|
+
tp_ls, fp_ls, fn_ls = [], [], []
|
254
|
+
precision_ls, recall_ls, f1_ls, accuracy_ls = [], [], [], []
|
255
|
+
|
256
|
+
for j, (img, lbl, pred, flow) in enumerate(zip(images, labels, masks_pred, flows)):
|
257
|
+
score = float(aggregated_jaccard_index([lbl], [pred]))
|
258
|
+
fname = os.path.basename(test_label_files[i + j])
|
259
|
+
scores.append(score)
|
260
|
+
names.append(fname)
|
261
|
+
|
262
|
+
# Label masks
|
263
|
+
lbl_lab = label(lbl)
|
264
|
+
pred_lab = label(pred)
|
265
|
+
|
266
|
+
# Count objects
|
267
|
+
n_true = lbl_lab.max()
|
268
|
+
n_pred = pred_lab.max()
|
269
|
+
n_objects_true_ls.append(n_true)
|
270
|
+
n_objects_pred_ls.append(n_pred)
|
271
|
+
|
272
|
+
# Mean object size (area)
|
273
|
+
area_true = [p.area for p in regionprops(lbl_lab)]
|
274
|
+
area_pred = [p.area for p in regionprops(pred_lab)]
|
275
|
+
|
276
|
+
mean_area_true = np.mean(area_true) if area_true else 0
|
277
|
+
mean_area_pred = np.mean(area_pred) if area_pred else 0
|
278
|
+
mean_area_true_ls.append(mean_area_true)
|
279
|
+
mean_area_pred_ls.append(mean_area_pred)
|
280
|
+
|
281
|
+
# Compute object-level TP, FP, FN
|
282
|
+
ap, tp, fp, fn = average_precision([lbl], [pred], threshold=[0.5])
|
283
|
+
tp, fp, fn = int(tp[0, 0]), int(fp[0, 0]), int(fn[0, 0])
|
284
|
+
tp_ls.append(tp)
|
285
|
+
fp_ls.append(fp)
|
286
|
+
fn_ls.append(fn)
|
287
|
+
|
288
|
+
# Precision, Recall, F1, Accuracy
|
289
|
+
prec = tp / (tp + fp) if (tp + fp) > 0 else 0
|
290
|
+
rec = tp / (tp + fn) if (tp + fn) > 0 else 0
|
291
|
+
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
292
|
+
acc = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
|
293
|
+
|
294
|
+
precision_ls.append(prec)
|
295
|
+
recall_ls.append(rec)
|
296
|
+
f1_ls.append(f1)
|
297
|
+
accuracy_ls.append(acc)
|
298
|
+
|
299
|
+
if settings['save']:
|
300
|
+
plot_cellpose_resilts(i, j, results_dir, img, lbl, pred, flow)
|
301
|
+
|
302
|
+
if settings['save']:
|
303
|
+
plot_cellpose_resilts(i,j,results_dir, img, lbl, pred, flow)
|
304
|
+
|
305
|
+
stop = time.time()
|
306
|
+
duration = stop-start
|
307
|
+
files_processed = (i+1) * batch_size
|
308
|
+
time_ls.append(duration)
|
309
|
+
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=batch_size, operation_type="test custom cellpose model")
|
310
|
+
|
311
|
+
df_results = pd.DataFrame({
|
312
|
+
'label_image': names,
|
313
|
+
'Jaccard': scores,
|
314
|
+
'n_objects_true': n_objects_true_ls,
|
315
|
+
'n_objects_pred': n_objects_pred_ls,
|
316
|
+
'mean_area_true': mean_area_true_ls,
|
317
|
+
'mean_area_pred': mean_area_pred_ls,
|
318
|
+
'TP': tp_ls,
|
319
|
+
'FP': fp_ls,
|
320
|
+
'FN': fn_ls,
|
321
|
+
'Precision': precision_ls,
|
322
|
+
'Recall': recall_ls,
|
323
|
+
'F1': f1_ls,
|
324
|
+
'Accuracy': accuracy_ls
|
325
|
+
})
|
326
|
+
|
327
|
+
df_results['n_error'] = abs(df_results['n_objects_pred'] - df_results['n_objects_true'])
|
328
|
+
|
329
|
+
print(f"Average true objects/image: {df_results['n_objects_true'].mean():.2f}")
|
330
|
+
print(f"Average predicted objects/image: {df_results['n_objects_pred'].mean():.2f}")
|
331
|
+
print(f"Mean object area (true): {df_results['mean_area_true'].mean():.2f} px")
|
332
|
+
print(f"Mean object area (pred): {df_results['mean_area_pred'].mean():.2f} px")
|
333
|
+
print(f"Average Jaccard score: {df_results['Jaccard'].mean():.4f}")
|
334
|
+
|
335
|
+
print(f"Average Precision: {df_results['Precision'].mean():.3f}")
|
336
|
+
print(f"Average Recall: {df_results['Recall'].mean():.3f}")
|
337
|
+
print(f"Average F1-score: {df_results['F1'].mean():.3f}")
|
338
|
+
print(f"Average Accuracy: {df_results['Accuracy'].mean():.3f}")
|
339
|
+
|
340
|
+
display(df_results)
|
341
|
+
|
342
|
+
if settings['save']:
|
343
|
+
df_results.to_csv(os.path.join(results_dir, 'test_results.csv'), index=False)
|
344
|
+
|
345
|
+
def apply_cellpose_model(settings):
|
346
|
+
|
347
|
+
from .settings import get_default_apply_cellpose_model_settings
|
348
|
+
from .utils import save_settings, print_progress
|
349
|
+
|
350
|
+
def plot_cellpose_result(i, j, results_dir, img, pred, flow):
|
351
|
+
|
352
|
+
from .plot import generate_mask_random_cmap
|
353
|
+
|
354
|
+
fig, axs = plt.subplots(1, 4, figsize=(16, 4), gridspec_kw={'wspace': 0.1, 'hspace': 0.1})
|
355
|
+
cmap_pred = generate_mask_random_cmap(pred)
|
356
|
+
|
357
|
+
axs[0].imshow(img, cmap='gray')
|
358
|
+
axs[0].set_title('Image')
|
359
|
+
axs[0].axis('off')
|
360
|
+
|
361
|
+
axs[1].imshow(pred, cmap=cmap_pred, interpolation='nearest')
|
362
|
+
axs[1].set_title('Predicted Mask')
|
363
|
+
axs[1].axis('off')
|
364
|
+
|
365
|
+
axs[2].imshow(flow[2], cmap='gray')
|
366
|
+
axs[2].set_title('Cell Probability')
|
367
|
+
axs[2].axis('off')
|
368
|
+
|
369
|
+
axs[3].imshow(flow[0], cmap='gray')
|
370
|
+
axs[3].set_title('Flows')
|
371
|
+
axs[3].axis('off')
|
372
|
+
|
373
|
+
save_path = os.path.join(results_dir, f"cellpose_result_{i + j:03d}.png")
|
374
|
+
plt.savefig(save_path, dpi=200, bbox_inches='tight')
|
375
|
+
plt.show()
|
376
|
+
plt.close(fig)
|
377
|
+
|
378
|
+
|
379
|
+
settings = get_default_apply_cellpose_model_settings(settings)
|
380
|
+
save_settings(settings, name='apply_cellpose_model')
|
381
|
+
|
382
|
+
image_folder = os.path.join(settings['src'])
|
383
|
+
results_dir = os.path.join(settings['src'], 'results')
|
384
|
+
os.makedirs(results_dir, exist_ok=True)
|
385
|
+
print(f"Results will be saved in: {results_dir}")
|
386
|
+
|
387
|
+
image_files = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith('.tif')])
|
388
|
+
print(f"Found {len(image_files)} images")
|
389
|
+
|
390
|
+
dummy_labels = [image_files[0]] * len(image_files)
|
391
|
+
dataset = CellposeLazyDataset(image_files, dummy_labels, settings, randomize=False, augment=False)
|
392
|
+
|
393
|
+
model = cp_models.CellposeModel(gpu=True, pretrained_model=settings['model_path'])
|
394
|
+
batch_size = settings['batch_size']
|
395
|
+
measurements = []
|
396
|
+
|
397
|
+
files_to_process = len(image_files)
|
398
|
+
time_ls = []
|
399
|
+
|
400
|
+
for i in range(0, len(dataset), batch_size):
|
401
|
+
start = time.time()
|
402
|
+
batch = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
|
403
|
+
images, _ = zip(*batch)
|
404
|
+
|
405
|
+
X = list(images)
|
406
|
+
|
407
|
+
print(settings['CP_probability'])
|
408
|
+
masks_pred, flows, _ = model.eval(x=list(images),
|
409
|
+
channels=[0, 0],
|
410
|
+
normalize=False,
|
411
|
+
diameter=30,
|
412
|
+
flow_threshold=settings['FT'],
|
413
|
+
cellprob_threshold=settings['CP_probability'],
|
414
|
+
rescale=None,
|
415
|
+
resample=True,
|
416
|
+
interp=True,
|
417
|
+
anisotropy=None,
|
418
|
+
min_size=5,
|
419
|
+
augment=True,
|
420
|
+
tile=True,
|
421
|
+
tile_overlap=0.2,
|
422
|
+
bsize=224)
|
423
|
+
|
424
|
+
for j, (img, pred, flow) in enumerate(zip(images, masks_pred, flows)):
|
425
|
+
fname = os.path.basename(image_files[i + j])
|
426
|
+
|
427
|
+
if settings.get('circularize', False):
|
428
|
+
h, w = pred.shape
|
429
|
+
Y, X = np.ogrid[:h, :w]
|
430
|
+
center_x, center_y = w / 2, h / 2
|
431
|
+
radius = min(center_x, center_y)
|
432
|
+
circular_mask = (X - center_x)**2 + (Y - center_y)**2 <= radius**2
|
433
|
+
pred = pred * circular_mask
|
434
|
+
|
435
|
+
if settings['save']:
|
436
|
+
plot_cellpose_result(i, j, results_dir, img, pred, flow)
|
437
|
+
|
438
|
+
props = regionprops(sklabel(pred))
|
439
|
+
for k, prop in enumerate(props):
|
440
|
+
measurements.append({
|
441
|
+
'image': fname,
|
442
|
+
'object_id': k + 1,
|
443
|
+
'area': prop.area
|
444
|
+
})
|
445
|
+
|
446
|
+
stop = time.time()
|
447
|
+
duration = stop-start
|
448
|
+
files_processed = (i+1) * batch_size
|
449
|
+
time_ls.append(duration)
|
450
|
+
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=batch_size, operation_type="apply custom cellpose model")
|
451
|
+
|
452
|
+
|
453
|
+
# Write after each batch
|
454
|
+
df_measurements = pd.DataFrame(measurements)
|
455
|
+
df_measurements.to_csv(os.path.join(results_dir, 'measurements.csv'), index=False)
|
456
|
+
print("Saved object counts and areas to measurements.csv")
|
457
|
+
|
458
|
+
df_summary = df_measurements.groupby('image').agg(
|
459
|
+
object_count=('object_id', 'count'),
|
460
|
+
average_area=('area', 'mean')
|
461
|
+
).reset_index()
|
462
|
+
df_summary.to_csv(os.path.join(results_dir, 'summary.csv'), index=False)
|
463
|
+
print("Saved object count and average area to summary.csv")
|
464
|
+
|
465
|
+
def plot_cellpose_batch(images, labels):
|
466
|
+
from .plot import generate_mask_random_cmap
|
467
|
+
|
468
|
+
cmap_lbl = generate_mask_random_cmap(labels)
|
469
|
+
batch_size = len(images)
|
470
|
+
fig, axs = plt.subplots(2, batch_size, figsize=(4 * batch_size, 8))
|
471
|
+
for i in range(batch_size):
|
472
|
+
axs[0, i].imshow(images[i], cmap='gray')
|
473
|
+
axs[0, i].set_title(f'Image {i+1}')
|
474
|
+
axs[0, i].axis('off')
|
475
|
+
axs[1, i].imshow(labels[i], cmap=cmap_lbl, interpolation='nearest')
|
476
|
+
axs[1, i].set_title(f'Label {i+1}')
|
477
|
+
axs[1, i].axis('off')
|
478
|
+
plt.show()
|
479
|
+
|
480
|
+
def analyze_percent_positive(settings):
|
481
|
+
from .io import _read_and_merge_data
|
482
|
+
from .utils import save_settings
|
483
|
+
from .settings import default_settings_analyze_percent_positive
|
484
|
+
|
485
|
+
settings = default_settings_analyze_percent_positive(settings)
|
486
|
+
|
487
|
+
def translate_well_in_df(csv_loc):
|
488
|
+
# Load and extract metadata
|
489
|
+
df = pd.read_csv(csv_loc)
|
490
|
+
df[['plateID', 'well']] = df['Renamed TIFF'].str.replace('.tif', '', regex=False).str.split('_', expand=True)[[0, 1]]
|
491
|
+
df['plate_well'] = df['plateID'] + '_' + df['well']
|
492
|
+
|
493
|
+
# Retain one row per plate_well
|
494
|
+
df_2 = df.drop_duplicates(subset='plate_well').copy()
|
495
|
+
|
496
|
+
# Translate well to row and column
|
497
|
+
df_2['rowID'] = 'r' + df_2['well'].str[0].map(lambda x: str(string.ascii_uppercase.index(x) + 1))
|
498
|
+
df_2['column_name'] = 'c' + df_2['well'].str[1:].astype(int).astype(str)
|
499
|
+
|
500
|
+
# Optional: add prcf ID (plate_row_column_field)
|
501
|
+
df_2['fieldID'] = 'f1' # default or extract from filename if needed
|
502
|
+
df_2['prc'] = 'p' + df_2['plateID'].str.extract(r'(\d+)')[0] + '_' + df_2['rowID'] + '_' + df_2['column_name']
|
503
|
+
|
504
|
+
return df_2
|
505
|
+
|
506
|
+
def annotate_and_summarize(df, value_col, condition_col, well_col, threshold, annotation_col='annotation'):
|
507
|
+
"""
|
508
|
+
Annotate and summarize a DataFrame based on a threshold.
|
509
|
+
|
510
|
+
Parameters:
|
511
|
+
- df: pandas.DataFrame
|
512
|
+
- value_col: str, column name to apply threshold on
|
513
|
+
- condition_col: str, column name for experimental condition
|
514
|
+
- well_col: str, column name for wells
|
515
|
+
- threshold: float, threshold value for annotation
|
516
|
+
- annotation_col: str, name of the new annotation column
|
517
|
+
|
518
|
+
Returns:
|
519
|
+
- df: annotated DataFrame
|
520
|
+
- summary_df: DataFrame with counts and fractions per condition and well
|
521
|
+
"""
|
522
|
+
# Annotate
|
523
|
+
df[annotation_col] = np.where(df[value_col] > threshold, 'above', 'below')
|
524
|
+
|
525
|
+
# Count per condition and well
|
526
|
+
count_df = df.groupby([condition_col, well_col, annotation_col]).size().unstack(fill_value=0)
|
527
|
+
|
528
|
+
# Calculate total and fractions
|
529
|
+
count_df['total'] = count_df.sum(axis=1)
|
530
|
+
count_df['fraction_above'] = count_df.get('above', 0) / count_df['total']
|
531
|
+
count_df['fraction_below'] = count_df.get('below', 0) / count_df['total']
|
532
|
+
|
533
|
+
return df, count_df.reset_index()
|
534
|
+
|
535
|
+
save_settings(settings, name='analyze_percent_positive', show=False)
|
536
|
+
|
537
|
+
df, _ = _read_and_merge_data(locs=[settings['src']+'/measurements/measurements.db'],
|
538
|
+
tables=settings['tables'],
|
539
|
+
verbose=True,
|
540
|
+
nuclei_limit=None,
|
541
|
+
pathogen_limit=None)
|
542
|
+
|
543
|
+
df['condition'] = 'none'
|
544
|
+
|
545
|
+
if not settings['filter_1'] is None:
|
546
|
+
df = df[df[settings['filter_1'][0]]>settings['filter_1'][1]]
|
547
|
+
|
548
|
+
condition_col = 'condition'
|
549
|
+
well_col = 'prc'
|
550
|
+
|
551
|
+
df, count_df = annotate_and_summarize(df, settings['value_col'], condition_col, well_col, settings['threshold'], annotation_col='annotation')
|
552
|
+
count_df[['plateID', 'rowID', 'column_name']] = count_df['prc'].str.split('_', expand=True)
|
553
|
+
|
554
|
+
csv_loc = os.path.join(settings['src'], 'rename_log.csv')
|
555
|
+
csv_out_loc = os.path.join(settings['src'], 'result.csv')
|
556
|
+
translate_df = translate_well_in_df(csv_loc)
|
557
|
+
|
558
|
+
merged = pd.merge(count_df, translate_df, on=['rowID', 'column_name'], how='inner')
|
559
|
+
|
560
|
+
merged = merged[['plate_y', 'well', 'plate_well','fieldID','rowID','column_name','prc_x','Original File','Renamed TIFF','above','below','fraction_above','fraction_below']]
|
561
|
+
merged[[f'part{i}' for i in range(merged['Original File'].str.count('_').max() + 1)]] = merged['Original File'].str.split('_', expand=True)
|
562
|
+
merged.to_csv(csv_out_loc, index=False)
|
563
|
+
display(merged)
|
564
|
+
return merged
|
565
|
+
|
24
566
|
def analyze_recruitment(settings):
|
25
567
|
"""
|
26
568
|
Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
|
@@ -136,7 +678,7 @@ def analyze_recruitment(settings):
|
|
136
678
|
|
137
679
|
def analyze_plaques(settings):
|
138
680
|
|
139
|
-
from .
|
681
|
+
from .spacr_cellpose import identify_masks_finetune
|
140
682
|
from .settings import get_analyze_plaque_settings
|
141
683
|
from .utils import save_settings, download_models
|
142
684
|
from spacr import __file__ as spacr_path
|
@@ -198,147 +740,6 @@ def analyze_plaques(settings):
|
|
198
740
|
|
199
741
|
print(f"Analysis completed and saved to database '{db_name}'.")
|
200
742
|
|
201
|
-
def train_cellpose(settings):
|
202
|
-
|
203
|
-
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
204
|
-
from .settings import get_train_cellpose_default_settings
|
205
|
-
from .utils import save_settings
|
206
|
-
|
207
|
-
settings = get_train_cellpose_default_settings(settings)
|
208
|
-
|
209
|
-
img_src = settings['img_src']
|
210
|
-
mask_src = os.path.join(img_src, 'masks')
|
211
|
-
test_img_src = settings['test_img_src']
|
212
|
-
test_mask_src = settings['test_mask_src']
|
213
|
-
|
214
|
-
if settings['resize']:
|
215
|
-
target_height = settings['width_height'][1]
|
216
|
-
target_width = settings['width_height'][0]
|
217
|
-
|
218
|
-
if settings['test']:
|
219
|
-
test_img_src = os.path.join(os.path.dirname(settings['img_src']), 'test')
|
220
|
-
test_mask_src = os.path.join(settings['test_img_src'], 'mask')
|
221
|
-
|
222
|
-
test_images, test_masks, test_image_names, test_mask_names = None,None,None,None
|
223
|
-
print(settings)
|
224
|
-
|
225
|
-
if settings['from_scratch']:
|
226
|
-
model_name=f"scratch_{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_width}_Y{target_height}.CP_model"
|
227
|
-
else:
|
228
|
-
if settings['resize']:
|
229
|
-
model_name=f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_width}_Y{target_height}.CP_model"
|
230
|
-
else:
|
231
|
-
model_name=f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}.CP_model"
|
232
|
-
|
233
|
-
model_save_path = os.path.join(settings['mask_src'], 'models', 'cellpose_model')
|
234
|
-
print(model_save_path)
|
235
|
-
os.makedirs(model_save_path, exist_ok=True)
|
236
|
-
|
237
|
-
save_settings(settings, name=model_name)
|
238
|
-
|
239
|
-
if settings['from_scratch']:
|
240
|
-
model = cp_models.CellposeModel(gpu=True, model_type=settings['model_type'], diam_mean=settings['diameter'], pretrained_model=None)
|
241
|
-
else:
|
242
|
-
model = cp_models.CellposeModel(gpu=True, model_type=settings['model_type'])
|
243
|
-
|
244
|
-
if settings['normalize']:
|
245
|
-
|
246
|
-
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
247
|
-
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
248
|
-
images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files,
|
249
|
-
label_files,
|
250
|
-
settings['channels'],
|
251
|
-
settings['percentiles'],
|
252
|
-
settings['invert'],
|
253
|
-
settings['verbose'],
|
254
|
-
settings['remove_background'],
|
255
|
-
settings['background'],
|
256
|
-
settings['Signal_to_noise'],
|
257
|
-
settings['target_height'],
|
258
|
-
settings['target_width'])
|
259
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
260
|
-
|
261
|
-
if settings['test']:
|
262
|
-
test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
|
263
|
-
test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
|
264
|
-
test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files,
|
265
|
-
test_label_files,
|
266
|
-
settings['channels'],
|
267
|
-
settings['percentiles'],
|
268
|
-
settings['invert'],
|
269
|
-
settings['verbose'],
|
270
|
-
settings['remove_background'],
|
271
|
-
settings['background'],
|
272
|
-
settings['Signal_to_noise'],
|
273
|
-
settings['target_height'],
|
274
|
-
settings['target_width'])
|
275
|
-
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
276
|
-
|
277
|
-
else:
|
278
|
-
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, settings['invert'])
|
279
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
280
|
-
|
281
|
-
if settings['test']:
|
282
|
-
test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(test_img_src,
|
283
|
-
test_mask_src,
|
284
|
-
settings['invert'])
|
285
|
-
|
286
|
-
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
287
|
-
|
288
|
-
#if resize:
|
289
|
-
# images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
|
290
|
-
|
291
|
-
if settings['model_type'] == 'cyto':
|
292
|
-
cp_channels = [0,1]
|
293
|
-
if settings['model_type'] == 'cyto2':
|
294
|
-
cp_channels = [0,2]
|
295
|
-
if settings['model_type'] == 'nucleus':
|
296
|
-
cp_channels = [0,0]
|
297
|
-
if settings['grayscale']:
|
298
|
-
cp_channels = [0,0]
|
299
|
-
images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
|
300
|
-
|
301
|
-
masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
|
302
|
-
|
303
|
-
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
304
|
-
save_every = int(settings['n_epochs']/10)
|
305
|
-
if save_every < 10:
|
306
|
-
save_every = settings['n_epochs']
|
307
|
-
|
308
|
-
train_cp.train_seg(model.net,
|
309
|
-
train_data=images,
|
310
|
-
train_labels=masks,
|
311
|
-
train_files=image_names,
|
312
|
-
train_labels_files=mask_names,
|
313
|
-
train_probs=None,
|
314
|
-
test_data=test_images,
|
315
|
-
test_labels=test_masks,
|
316
|
-
test_files=test_image_names,
|
317
|
-
test_labels_files=test_mask_names,
|
318
|
-
test_probs=None,
|
319
|
-
load_files=True,
|
320
|
-
batch_size=settings['batch_size'],
|
321
|
-
learning_rate=settings['learning_rate'],
|
322
|
-
n_epochs=settings['n_epochs'],
|
323
|
-
weight_decay=settings['weight_decay'],
|
324
|
-
momentum=0.9,
|
325
|
-
SGD=False,
|
326
|
-
channels=cp_channels,
|
327
|
-
channel_axis=None,
|
328
|
-
normalize=False,
|
329
|
-
compute_flows=False,
|
330
|
-
save_path=model_save_path,
|
331
|
-
save_every=save_every,
|
332
|
-
nimg_per_epoch=None,
|
333
|
-
nimg_test_per_epoch=None,
|
334
|
-
rescale=settings['rescale'],
|
335
|
-
#scale_range=None,
|
336
|
-
#bsize=224,
|
337
|
-
min_train_masks=1,
|
338
|
-
model_name=settings['model_name'])
|
339
|
-
|
340
|
-
return print(f"Model saved at: {model_save_path}/{model_name}")
|
341
|
-
|
342
743
|
def count_phenotypes(settings):
|
343
744
|
from .io import _read_db
|
344
745
|
|
@@ -350,17 +751,17 @@ def count_phenotypes(settings):
|
|
350
751
|
unique_values_count = df[settings['annotation_column']].nunique(dropna=True)
|
351
752
|
print(f"Unique values in {settings['annotation_column']} (excluding NaN): {unique_values_count}")
|
352
753
|
|
353
|
-
# Count unique values in 'value' column, grouped by '
|
354
|
-
grouped_unique_count = df.groupby(['
|
754
|
+
# Count unique values in 'value' column, grouped by 'plateID', 'rowID', 'columnID'
|
755
|
+
grouped_unique_count = df.groupby(['plateID', 'rowID', 'columnID'])[settings['annotation_column']].nunique(dropna=True).reset_index(name='unique_count')
|
355
756
|
display(grouped_unique_count)
|
356
757
|
|
357
758
|
save_path = os.path.join(settings['src'], 'phenotype_counts.csv')
|
358
759
|
|
359
760
|
# Group by plate, row, and column, then count the occurrences of each unique value
|
360
|
-
grouped_counts = df.groupby(['
|
761
|
+
grouped_counts = df.groupby(['plateID', 'rowID', 'columnID', 'value']).size().reset_index(name='count')
|
361
762
|
|
362
763
|
# Pivot the DataFrame so that unique values are columns and their counts are in the rows
|
363
|
-
pivot_df = grouped_counts.pivot_table(index=['
|
764
|
+
pivot_df = grouped_counts.pivot_table(index=['plateID', 'rowID', 'columnID'], columns='value', values='count', fill_value=0)
|
364
765
|
|
365
766
|
# Flatten the multi-level columns
|
366
767
|
pivot_df.columns = [f"value_{int(col)}" for col in pivot_df.columns]
|
@@ -382,20 +783,20 @@ def count_phenotypes(settings):
|
|
382
783
|
def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),'r2':(90,10),'r3':(80,20),'r4':(80,20),'r5':(70,30),'r6':(70,30),'r7':(60,40),'r8':(60,40),'r9':(50,50),'r10':(50,50),'r11':(40,60),'r12':(40,60),'r13':(30,70),'r14':(30,70),'r15':(20,80),'r16':(20,80)},
|
383
784
|
pc_grna='TGGT1_220950_1', nc_grna='TGGT1_233460_4',
|
384
785
|
y_columns=['class_1_fraction', 'TGGT1_220950_1_fraction', 'nc_fraction'],
|
385
|
-
column='
|
786
|
+
column='columnID', value='c3', plate=None, save_paths=None):
|
386
787
|
|
387
788
|
def calculate_well_score_fractions(df, class_columns='cv_predictions'):
|
388
|
-
if all(col in df.columns for col in ['
|
389
|
-
df['prc'] = df['
|
789
|
+
if all(col in df.columns for col in ['plateID', 'rowID', 'columnID']):
|
790
|
+
df['prc'] = df['plateID'] + '_' + df['rowID'] + '_' + df['columnID']
|
390
791
|
else:
|
391
|
-
raise ValueError("Cannot find '
|
392
|
-
prc_summary = df.groupby(['
|
393
|
-
well_counts = (df.groupby(['
|
792
|
+
raise ValueError("Cannot find 'plateID', 'rowID', or 'columnID' in df.columns")
|
793
|
+
prc_summary = df.groupby(['plateID', 'rowID', 'columnID', 'prc']).size().reset_index(name='total_rows')
|
794
|
+
well_counts = (df.groupby(['plateID', 'rowID', 'columnID', 'prc', class_columns])
|
394
795
|
.size()
|
395
796
|
.unstack(fill_value=0)
|
396
797
|
.reset_index()
|
397
798
|
.rename(columns={0: 'class_0', 1: 'class_1'}))
|
398
|
-
summary_df = pd.merge(prc_summary, well_counts, on=['
|
799
|
+
summary_df = pd.merge(prc_summary, well_counts, on=['plateID', 'rowID', 'columnID', 'prc'], how='left')
|
399
800
|
summary_df['class_0_fraction'] = summary_df['class_0'] / summary_df['total_rows']
|
400
801
|
summary_df['class_1_fraction'] = summary_df['class_1'] / summary_df['total_rows']
|
401
802
|
return summary_df
|
@@ -490,8 +891,8 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
490
891
|
return result
|
491
892
|
|
492
893
|
def calculate_well_read_fraction(df, count_column='count'):
|
493
|
-
if all(col in df.columns for col in ['
|
494
|
-
df['prc'] = df['
|
894
|
+
if all(col in df.columns for col in ['plateID', 'rowID', 'columnID']):
|
895
|
+
df['prc'] = df['plateID'] + '_' + df['rowID'] + '_' + df['columnID']
|
495
896
|
else:
|
496
897
|
raise ValueError("Cannot find plate, row or column in df.columns")
|
497
898
|
grouped_df = df.groupby('prc')[count_column].sum().reset_index()
|
@@ -507,21 +908,17 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
507
908
|
for i, reads_csv_temp in enumerate(reads_csv):
|
508
909
|
reads_df_temp = pd.read_csv(reads_csv_temp)
|
509
910
|
scores_df_temp = pd.read_csv(scores_csv[i])
|
510
|
-
reads_df_temp['
|
511
|
-
scores_df_temp['
|
911
|
+
reads_df_temp['plateID'] = f"plate{i+1}"
|
912
|
+
scores_df_temp['plateID'] = f"plate{i+1}"
|
512
913
|
|
914
|
+
if 'column' in reads_df_temp.columns:
|
915
|
+
reads_df_temp = reads_df_temp.rename(columns={'column': 'columnID'})
|
513
916
|
if 'column_name' in reads_df_temp.columns:
|
514
|
-
reads_df_temp = reads_df_temp.rename(columns={'column_name': '
|
515
|
-
if '
|
516
|
-
reads_df_temp = reads_df_temp.rename(columns={'
|
517
|
-
if 'column_name' in scores_df_temp.columns:
|
518
|
-
scores_df_temp = scores_df_temp.rename(columns={'column_name': 'column'})
|
519
|
-
if 'column_name' in scores_df_temp.columns:
|
520
|
-
scores_df_temp = scores_df_temp.rename(columns={'column_name': 'column'})
|
521
|
-
if 'row_name' in reads_df_temp.columns:
|
522
|
-
reads_df_temp = reads_df_temp.rename(columns={'row_name': 'row_name'})
|
917
|
+
reads_df_temp = reads_df_temp.rename(columns={'column_name': 'columnID'})
|
918
|
+
if 'row' in reads_df_temp.columns:
|
919
|
+
reads_df_temp = reads_df_temp.rename(columns={'row_name': 'rowID'})
|
523
920
|
if 'row_name' in scores_df_temp.columns:
|
524
|
-
scores_df_temp = scores_df_temp.rename(columns={'row_name': '
|
921
|
+
scores_df_temp = scores_df_temp.rename(columns={'row_name': 'rowID'})
|
525
922
|
|
526
923
|
reads_ls.append(reads_df_temp)
|
527
924
|
scores_ls.append(scores_df_temp)
|
@@ -535,8 +932,8 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
535
932
|
reads_df = pd.read_csv(reads_csv)
|
536
933
|
scores_df = pd.read_csv(scores_csv)
|
537
934
|
if plate != None:
|
538
|
-
reads_df['
|
539
|
-
scores_df['
|
935
|
+
reads_df['plateID'] = plate
|
936
|
+
scores_df['plateID'] = plate
|
540
937
|
|
541
938
|
reads_df = calculate_well_read_fraction(reads_df)
|
542
939
|
scores_df = calculate_well_score_fractions(scores_df)
|
@@ -548,7 +945,7 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
548
945
|
|
549
946
|
df_emp = pd.DataFrame([(key, val[0], val[1], val[0] / (val[0] + val[1]), val[1] / (val[0] + val[1])) for key, val in empirical_dict.items()],columns=['key', 'value1', 'value2', 'pc_fraction', 'nc_fraction'])
|
550
947
|
|
551
|
-
df = pd.merge(df, df_emp, left_on='
|
948
|
+
df = pd.merge(df, df_emp, left_on='rowID', right_on='key')
|
552
949
|
|
553
950
|
if any in y_columns not in df.columns:
|
554
951
|
print(f"columns in dataframe:")
|
@@ -620,7 +1017,7 @@ def interperate_vision_model(settings={}):
|
|
620
1017
|
else:
|
621
1018
|
return None
|
622
1019
|
|
623
|
-
from
|
1020
|
+
from .plot import spacrGraph
|
624
1021
|
|
625
1022
|
df[name] = df['feature'].apply(lambda x: find_feature_class(x, feature_groups))
|
626
1023
|
|
@@ -698,11 +1095,17 @@ def interperate_vision_model(settings={}):
|
|
698
1095
|
# Clean and align columns for merging
|
699
1096
|
df['object_label'] = df['object_label'].str.replace('o', '')
|
700
1097
|
|
701
|
-
if '
|
702
|
-
|
1098
|
+
if 'rowID' not in scores_df.columns:
|
1099
|
+
if 'row' in scores_df.columns:
|
1100
|
+
scores_df['rowID'] = scores_df['row']
|
1101
|
+
if 'row_name' in scores_df.columns:
|
1102
|
+
scores_df['rowID'] = scores_df['row_name']
|
703
1103
|
|
704
|
-
if '
|
705
|
-
|
1104
|
+
if 'columnID' not in scores_df.columns:
|
1105
|
+
if 'column_name' in scores_df.columns:
|
1106
|
+
scores_df['columnID'] = scores_df['column_name']
|
1107
|
+
if 'column' in scores_df.columns:
|
1108
|
+
scores_df['columnID'] = scores_df['column']
|
706
1109
|
|
707
1110
|
if 'object_label' not in scores_df.columns:
|
708
1111
|
scores_df['object_label'] = scores_df['object']
|
@@ -714,14 +1117,14 @@ def interperate_vision_model(settings={}):
|
|
714
1117
|
scores_df['object_label'] = scores_df['object'].astype(str)
|
715
1118
|
|
716
1119
|
# Ensure all join columns have the same data type in both DataFrames
|
717
|
-
df[['
|
718
|
-
scores_df[['
|
1120
|
+
df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label']] = df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label']].astype(str)
|
1121
|
+
scores_df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label']] = scores_df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label']].astype(str)
|
719
1122
|
|
720
1123
|
# Select only the necessary columns from scores_df for merging
|
721
|
-
scores_df = scores_df[['
|
1124
|
+
scores_df = scores_df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label', settings['score_column']]]
|
722
1125
|
|
723
1126
|
# Now merge DataFrames
|
724
|
-
merged_df = pd.merge(df, scores_df, on=['
|
1127
|
+
merged_df = pd.merge(df, scores_df, on=['plateID', 'rowID', 'column_name', 'fieldID', 'object_label'], how='inner')
|
725
1128
|
|
726
1129
|
# Separate numerical features and the score column
|
727
1130
|
X = merged_df.select_dtypes(include='number').drop(columns=[settings['score_column']])
|
@@ -997,8 +1400,8 @@ def analyze_endodyogeny(settings):
|
|
997
1400
|
output['data'] = df
|
998
1401
|
|
999
1402
|
|
1000
|
-
if settings['level'] == '
|
1001
|
-
prc_column = '
|
1403
|
+
if settings['level'] == 'plateID':
|
1404
|
+
prc_column = 'plateID'
|
1002
1405
|
else:
|
1003
1406
|
prc_column = 'prc'
|
1004
1407
|
|
@@ -1144,28 +1547,28 @@ def generate_score_heatmap(settings):
|
|
1144
1547
|
def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
|
1145
1548
|
|
1146
1549
|
df = pd.read_csv(csv)
|
1147
|
-
if '
|
1148
|
-
df = df[df['
|
1550
|
+
if 'columnID' in df.columns:
|
1551
|
+
df = df[df['columnID']==column]
|
1149
1552
|
elif 'column' in df.columns:
|
1150
|
-
df['
|
1151
|
-
df = df[df['
|
1553
|
+
df['columnID'] = df['column']
|
1554
|
+
df = df[df['columnID']==column]
|
1152
1555
|
if not plate is None:
|
1153
|
-
df['
|
1154
|
-
grouped_df = df.groupby(['
|
1155
|
-
grouped_df['prc'] = grouped_df['
|
1556
|
+
df['plateID'] = f"plate{plate}"
|
1557
|
+
grouped_df = df.groupby(['plateID', 'rowID', 'columnID'])[data_column].mean().reset_index()
|
1558
|
+
grouped_df['prc'] = grouped_df['plateID'].astype(str) + '_' + grouped_df['rowID'].astype(str) + '_' + grouped_df['columnID'].astype(str)
|
1156
1559
|
return grouped_df
|
1157
1560
|
|
1158
1561
|
def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
|
1159
1562
|
df = pd.read_csv(csv)
|
1160
1563
|
df = df[df['column_name']==column]
|
1161
1564
|
if plate not in df.columns:
|
1162
|
-
df['
|
1565
|
+
df['plateID'] = f"plate{plate}"
|
1163
1566
|
df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
|
1164
|
-
grouped_df = df.groupby(['
|
1567
|
+
grouped_df = df.groupby(['plateID', 'rowID', 'columnID'])['count'].sum().reset_index()
|
1165
1568
|
grouped_df = grouped_df.rename(columns={'count': 'total_count'})
|
1166
|
-
merged_df = pd.merge(df, grouped_df, on=['
|
1569
|
+
merged_df = pd.merge(df, grouped_df, on=['plateID', 'rowID', 'column_name'])
|
1167
1570
|
merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
|
1168
|
-
merged_df['prc'] = merged_df['
|
1571
|
+
merged_df['prc'] = merged_df['plateID'].astype(str) + '_' + merged_df['rowID'].astype(str) + '_' + merged_df['column_name'].astype(str)
|
1169
1572
|
return merged_df
|
1170
1573
|
|
1171
1574
|
def plot_multi_channel_heatmap(df, column='c3', cmap='coolwarm'):
|
@@ -1177,17 +1580,17 @@ def generate_score_heatmap(settings):
|
|
1177
1580
|
- column: Column to filter by (default is 'c3').
|
1178
1581
|
"""
|
1179
1582
|
# Extract row number and convert to integer for sorting
|
1180
|
-
df['row_num'] = df['
|
1583
|
+
df['row_num'] = df['rowID'].str.extract(r'(\d+)').astype(int)
|
1181
1584
|
|
1182
1585
|
# Filter and sort by plate, row, and column
|
1183
|
-
df = df[df['
|
1184
|
-
df = df.sort_values(by=['
|
1586
|
+
df = df[df['columnID'] == column]
|
1587
|
+
df = df.sort_values(by=['plateID', 'row_num', 'columnID'])
|
1185
1588
|
|
1186
1589
|
# Drop temporary 'row_num' column after sorting
|
1187
1590
|
df = df.drop('row_num', axis=1)
|
1188
1591
|
|
1189
1592
|
# Create a new column combining plate, row, and column for the index
|
1190
|
-
df['plate_row_col'] = df['
|
1593
|
+
df['plate_row_col'] = df['plateID'] + '-' + df['rowID'] + '-' + df['columnID']
|
1191
1594
|
|
1192
1595
|
# Set 'plate_row_col' as the index
|
1193
1596
|
df.set_index('plate_row_col', inplace=True)
|
@@ -1244,11 +1647,11 @@ def generate_score_heatmap(settings):
|
|
1244
1647
|
# Loop through all collected CSV files and process them
|
1245
1648
|
for csv_file in ls:
|
1246
1649
|
df = pd.read_csv(csv_file) # Read CSV into DataFrame
|
1247
|
-
df = df[df['
|
1650
|
+
df = df[df['columnID']==column]
|
1248
1651
|
if not plate is None:
|
1249
|
-
df['
|
1250
|
-
# Group the data by '
|
1251
|
-
grouped_df = df.groupby(['
|
1652
|
+
df['plateID'] = f"plate{plate}"
|
1653
|
+
# Group the data by 'plateID', 'rowID', and 'columnID'
|
1654
|
+
grouped_df = df.groupby(['plateID', 'rowID', 'columnID'])[data_column].mean().reset_index()
|
1252
1655
|
# Use the CSV filename to create a new column name
|
1253
1656
|
folder_name = os.path.dirname(csv_file).replace(".csv", "")
|
1254
1657
|
new_column_name = os.path.basename(f"{folder_name}_{data_column}")
|
@@ -1259,8 +1662,8 @@ def generate_score_heatmap(settings):
|
|
1259
1662
|
if combined_df is None:
|
1260
1663
|
combined_df = grouped_df
|
1261
1664
|
else:
|
1262
|
-
combined_df = pd.merge(combined_df, grouped_df, on=['
|
1263
|
-
combined_df['prc'] = combined_df['
|
1665
|
+
combined_df = pd.merge(combined_df, grouped_df, on=['plateID', 'rowID', 'columnID'], how='outer')
|
1666
|
+
combined_df['prc'] = combined_df['plateID'].astype(str) + '_' + combined_df['rowID'].astype(str) + '_' + combined_df['columnID'].astype(str)
|
1264
1667
|
return combined_df
|
1265
1668
|
|
1266
1669
|
def calculate_mae(df):
|
@@ -1282,16 +1685,16 @@ def generate_score_heatmap(settings):
|
|
1282
1685
|
mae_df = pd.DataFrame(mae_data)
|
1283
1686
|
return mae_df
|
1284
1687
|
|
1285
|
-
result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['
|
1286
|
-
df = calculate_fraction_mixed_condition(settings['csv'], settings['
|
1688
|
+
result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plateID'], settings['columnID'], )
|
1689
|
+
df = calculate_fraction_mixed_condition(settings['csv'], settings['plateID'], settings['columnID'], settings['control_sgrnas'])
|
1287
1690
|
df = df[df['grna_name']==settings['fraction_grna']]
|
1288
1691
|
fraction_df = df[['fraction', 'prc']]
|
1289
1692
|
merged_df = pd.merge(fraction_df, result_df, on=['prc'])
|
1290
|
-
cv_df = group_cv_score(settings['cv_csv'], settings['
|
1693
|
+
cv_df = group_cv_score(settings['cv_csv'], settings['plateID'], settings['columnID'], settings['data_column_cv'])
|
1291
1694
|
cv_df = cv_df[[settings['data_column_cv'], 'prc']]
|
1292
1695
|
merged_df = pd.merge(merged_df, cv_df, on=['prc'])
|
1293
1696
|
|
1294
|
-
fig = plot_multi_channel_heatmap(merged_df, settings['
|
1697
|
+
fig = plot_multi_channel_heatmap(merged_df, settings['columnID'], settings['cmap'])
|
1295
1698
|
if 'row_number' in merged_df.columns:
|
1296
1699
|
merged_df = merged_df.drop('row_num', axis=1)
|
1297
1700
|
mae_df = calculate_mae(merged_df)
|
@@ -1299,9 +1702,9 @@ def generate_score_heatmap(settings):
|
|
1299
1702
|
mae_df = mae_df.drop('row_num', axis=1)
|
1300
1703
|
|
1301
1704
|
if not settings['dst'] is None:
|
1302
|
-
mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['
|
1303
|
-
merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['
|
1304
|
-
heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['
|
1705
|
+
mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plateID']}.csv")
|
1706
|
+
merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}_data.csv")
|
1707
|
+
heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}.pdf")
|
1305
1708
|
mae_df.to_csv(mae_dst, index=False)
|
1306
1709
|
merged_df.to_csv(merged_dst, index=False)
|
1307
1710
|
fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')
|