spacr 0.4.15__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. spacr/__init__.py +2 -2
  2. spacr/core.py +52 -10
  3. spacr/deep_spacr.py +2 -3
  4. spacr/gui.py +0 -1
  5. spacr/gui_core.py +247 -41
  6. spacr/gui_elements.py +133 -2
  7. spacr/gui_utils.py +22 -17
  8. spacr/io.py +624 -149
  9. spacr/ml.py +141 -258
  10. spacr/plot.py +76 -34
  11. spacr/resources/MEDIAR/__pycache__/SetupDict.cpython-39.pyc +0 -0
  12. spacr/resources/MEDIAR/__pycache__/evaluate.cpython-39.pyc +0 -0
  13. spacr/resources/MEDIAR/__pycache__/generate_mapping.cpython-39.pyc +0 -0
  14. spacr/resources/MEDIAR/__pycache__/main.cpython-39.pyc +0 -0
  15. spacr/resources/MEDIAR/core/Baseline/__pycache__/Predictor.cpython-39.pyc +0 -0
  16. spacr/resources/MEDIAR/core/Baseline/__pycache__/Trainer.cpython-39.pyc +0 -0
  17. spacr/resources/MEDIAR/core/Baseline/__pycache__/__init__.cpython-39.pyc +0 -0
  18. spacr/resources/MEDIAR/core/Baseline/__pycache__/utils.cpython-39.pyc +0 -0
  19. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/EnsemblePredictor.cpython-39.pyc +0 -0
  20. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Predictor.cpython-39.pyc +0 -0
  21. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Trainer.cpython-39.pyc +0 -0
  22. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/__init__.cpython-39.pyc +0 -0
  23. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/utils.cpython-39.pyc +0 -0
  24. spacr/resources/MEDIAR/core/__pycache__/BasePredictor.cpython-39.pyc +0 -0
  25. spacr/resources/MEDIAR/core/__pycache__/BaseTrainer.cpython-39.pyc +0 -0
  26. spacr/resources/MEDIAR/core/__pycache__/__init__.cpython-39.pyc +0 -0
  27. spacr/resources/MEDIAR/core/__pycache__/utils.cpython-39.pyc +0 -0
  28. spacr/resources/MEDIAR/train_tools/__pycache__/__init__.cpython-39.pyc +0 -0
  29. spacr/resources/MEDIAR/train_tools/__pycache__/measures.cpython-39.pyc +0 -0
  30. spacr/resources/MEDIAR/train_tools/__pycache__/utils.cpython-39.pyc +0 -0
  31. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  32. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/datasetter.cpython-39.pyc +0 -0
  33. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/transforms.cpython-39.pyc +0 -0
  34. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/utils.cpython-39.pyc +0 -0
  35. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/CellAware.cpython-39.pyc +0 -0
  36. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/LoadImage.cpython-39.pyc +0 -0
  37. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/NormalizeImage.cpython-39.pyc +0 -0
  38. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/__init__.cpython-39.pyc +0 -0
  39. spacr/resources/MEDIAR/train_tools/models/__pycache__/MEDIARFormer.cpython-39.pyc +0 -0
  40. spacr/resources/MEDIAR/train_tools/models/__pycache__/__init__.cpython-39.pyc +0 -0
  41. spacr/sequencing.py +73 -38
  42. spacr/settings.py +161 -135
  43. spacr/submodules.py +618 -215
  44. spacr/timelapse.py +197 -29
  45. spacr/toxo.py +23 -23
  46. spacr/utils.py +186 -128
  47. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/METADATA +5 -2
  48. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/RECORD +53 -24
  49. spacr/stats.py +0 -221
  50. /spacr/{cellpose.py → spacr_cellpose.py} +0 -0
  51. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/LICENSE +0 -0
  52. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/WHEEL +0 -0
  53. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/entry_points.txt +0 -0
  54. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/top_level.txt +0 -0
spacr/submodules.py CHANGED
@@ -1,14 +1,21 @@
1
-
2
-
3
-
4
1
  import seaborn as sns
5
- import os, random, sqlite3, re, shap
2
+ import os, random, sqlite3, re, shap, string, time
6
3
  import pandas as pd
7
4
  import numpy as np
8
- import cellpose
5
+
9
6
  from skimage.measure import regionprops, label
7
+ from skimage.transform import resize as sk_resize, rotate
8
+ from skimage.exposure import rescale_intensity
9
+
10
+ import cellpose
11
+ from cellpose import models as cp_models
12
+ from cellpose import train as train_cp
10
13
  from cellpose import models as cp_models
14
+ from cellpose import io as cp_io
11
15
  from cellpose import train as train_cp
16
+ from cellpose.metrics import aggregated_jaccard_index
17
+ from cellpose.metrics import average_precision
18
+
12
19
  from IPython.display import display
13
20
  from sklearn.ensemble import RandomForestClassifier
14
21
  from sklearn.inspection import permutation_importance
@@ -17,10 +24,545 @@ from scipy.stats import chi2_contingency, pearsonr
17
24
  from scipy.spatial.distance import cosine
18
25
 
19
26
  from sklearn.metrics import mean_absolute_error
20
-
27
+ from skimage.measure import regionprops, label as sklabel
21
28
  import matplotlib.pyplot as plt
22
29
  from natsort import natsorted
23
30
 
31
+ from torch.utils.data import Dataset
32
+
33
+ class CellposeLazyDataset(Dataset):
34
+ def __init__(self, image_files, label_files, settings, randomize=True, augment=False):
35
+ combined = list(zip(image_files, label_files))
36
+ if randomize:
37
+ random.shuffle(combined)
38
+ self.image_files, self.label_files = zip(*combined)
39
+ self.normalize = settings['normalize']
40
+ self.percentiles = settings.get('percentiles', [2, 99])
41
+ self.target_size = settings['target_size']
42
+ self.augment = augment
43
+
44
+ def __len__(self):
45
+ return len(self.image_files) * (8 if self.augment else 1)
46
+
47
+ def apply_augmentation(self, image, label, aug_idx):
48
+ if aug_idx == 1:
49
+ return rotate(image, 90, resize=False, preserve_range=True), rotate(label, 90, resize=False, preserve_range=True)
50
+ elif aug_idx == 2:
51
+ return rotate(image, 180, resize=False, preserve_range=True), rotate(label, 180, resize=False, preserve_range=True)
52
+ elif aug_idx == 3:
53
+ return rotate(image, 270, resize=False, preserve_range=True), rotate(label, 270, resize=False, preserve_range=True)
54
+ elif aug_idx == 4:
55
+ return np.fliplr(image), np.fliplr(label)
56
+ elif aug_idx == 5:
57
+ return np.flipud(image), np.flipud(label)
58
+ elif aug_idx == 6:
59
+ return np.fliplr(rotate(image, 90, resize=False, preserve_range=True)), np.fliplr(rotate(label, 90, resize=False, preserve_range=True))
60
+ elif aug_idx == 7:
61
+ return np.flipud(rotate(image, 90, resize=False, preserve_range=True)), np.flipud(rotate(label, 90, resize=False, preserve_range=True))
62
+ return image, label
63
+
64
+ def __getitem__(self, idx):
65
+ base_idx = idx // 8 if self.augment else idx
66
+ aug_idx = idx % 8 if self.augment else 0
67
+
68
+ image = cp_io.imread(self.image_files[base_idx])
69
+ label = cp_io.imread(self.label_files[base_idx])
70
+
71
+ if image.ndim == 3:
72
+ image = image.mean(axis=-1)
73
+
74
+ if image.max() > 1:
75
+ image = image / image.max()
76
+
77
+ if self.normalize:
78
+ lower_p, upper_p = np.percentile(image, self.percentiles)
79
+ image = rescale_intensity(image, in_range=(lower_p, upper_p), out_range=(0, 1))
80
+
81
+ image, label = self.apply_augmentation(image, label, aug_idx)
82
+
83
+ image_shape = (self.target_size, self.target_size)
84
+ image = sk_resize(image, image_shape, preserve_range=True, anti_aliasing=True).astype(np.float32)
85
+ label = sk_resize(label, image_shape, order=0, preserve_range=True, anti_aliasing=False).astype(np.uint8)
86
+
87
+ return image, label
88
+
89
+ def train_cellpose(settings):
90
+
91
+ from .settings import get_train_cellpose_default_settings
92
+ from .utils import save_settings
93
+
94
+ settings = get_train_cellpose_default_settings(settings)
95
+ img_src = os.path.join(settings['src'], 'train', 'images')
96
+ mask_src = os.path.join(settings['src'], 'train', 'masks')
97
+ target_size = settings['target_size']
98
+
99
+ model_name = f"{settings['model_name']}_cyto_e{settings['n_epochs']}_X{target_size}_Y{target_size}.CP_model"
100
+ model_save_path = os.path.join(settings['src'], 'models', 'cellpose_model')
101
+ os.makedirs(model_save_path, exist_ok=True)
102
+
103
+ save_settings(settings, name=model_name)
104
+
105
+ model = cp_models.CellposeModel(gpu=True, model_type='cyto', diam_mean=30, pretrained_model='cyto')
106
+ cp_channels = [0, 0]
107
+
108
+ #train_image_files = sorted([os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')])
109
+ #train_label_files = sorted([os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')])
110
+
111
+ image_filenames = set(f for f in os.listdir(img_src) if f.endswith('.tif'))
112
+ label_filenames = set(f for f in os.listdir(mask_src) if f.endswith('.tif'))
113
+
114
+ # Only keep files that are present in both folders
115
+ matched_filenames = sorted(image_filenames & label_filenames)
116
+
117
+ train_image_files = [os.path.join(img_src, f) for f in matched_filenames]
118
+ train_label_files = [os.path.join(mask_src, f) for f in matched_filenames]
119
+
120
+ train_dataset = CellposeLazyDataset(train_image_files, train_label_files, settings, randomize=True, augment=settings['augment'])
121
+
122
+ n_aug = 8 if settings['augment'] else 1
123
+ max_base_images = len(train_dataset) // n_aug if settings['augment'] else len(train_dataset)
124
+ n_base = min(settings['batch_size'], max_base_images)
125
+
126
+ unique_base_indices = list(range(max_base_images))
127
+ random.shuffle(unique_base_indices)
128
+ selected_indices = unique_base_indices[:n_base]
129
+
130
+ images, labels = [], []
131
+ for idx in selected_indices:
132
+ for aug_idx in range(n_aug):
133
+ i = idx * n_aug + aug_idx if settings['augment'] else idx
134
+ img, lbl = train_dataset[i]
135
+ images.append(img)
136
+ labels.append(lbl)
137
+ try:
138
+ plot_cellpose_batch(images, labels)
139
+ except:
140
+ print(f"could not print batch images")
141
+
142
+ print(f"Training model with {len(images)} ber patch for {settings['n_epochs']} Epochs")
143
+
144
+ train_cp.train_seg(model.net,
145
+ train_data=images,
146
+ train_labels=labels,
147
+ channels=cp_channels,
148
+ save_path=model_save_path,
149
+ n_epochs=settings['n_epochs'],
150
+ batch_size=settings['batch_size'],
151
+ learning_rate=settings['learning_rate'],
152
+ weight_decay=settings['weight_decay'],
153
+ model_name=model_name,
154
+ save_every=max(1, (settings['n_epochs'] // 10)),
155
+ rescale=False)
156
+
157
+ print(f"Model saved at: {model_save_path}/{model_name}")
158
+
159
+ def test_cellpose_model(settings):
160
+
161
+ from .utils import save_settings, print_progress
162
+ from .settings import get_default_test_cellpose_model_settings
163
+
164
+ def plot_cellpose_resilts(i, j, results_dir, img, lbl, pred, flow):
165
+ from . plot import generate_mask_random_cmap
166
+ fig, axs = plt.subplots(1, 5, figsize=(16, 4), gridspec_kw={'wspace': 0.1, 'hspace': 0.1})
167
+ cmap_lbl = generate_mask_random_cmap(lbl)
168
+ cmap_pred = generate_mask_random_cmap(pred)
169
+
170
+ axs[0].imshow(img, cmap='gray')
171
+ axs[0].set_title('Image')
172
+ axs[0].axis('off')
173
+
174
+ axs[1].imshow(lbl, cmap=cmap_lbl, interpolation='nearest')
175
+ axs[1].set_title('True Mask')
176
+ axs[1].axis('off')
177
+
178
+ axs[2].imshow(pred, cmap=cmap_pred, interpolation='nearest')
179
+ axs[2].set_title('Predicted Mask')
180
+ axs[2].axis('off')
181
+
182
+ axs[3].imshow(flow[2], cmap='gray')
183
+ axs[3].set_title('Cell Probability')
184
+ axs[3].axis('off')
185
+
186
+ axs[4].imshow(flow[0], cmap='gray')
187
+ axs[4].set_title('Flows')
188
+ axs[4].axis('off')
189
+
190
+ save_path = os.path.join(results_dir, f"cellpose_result_{i+j:03d}.png")
191
+ plt.savefig(save_path, dpi=200, bbox_inches='tight')
192
+ plt.show()
193
+ plt.close(fig)
194
+
195
+
196
+ settings = get_default_test_cellpose_model_settings(settings)
197
+
198
+ save_settings(settings, name='test_cellpose_model')
199
+ test_image_folder = os.path.join(settings['src'], 'test', 'images')
200
+ test_label_folder = os.path.join(settings['src'], 'test', 'masks')
201
+ results_dir = os.path.join(settings['src'], 'results')
202
+ os.makedirs(results_dir, exist_ok=True)
203
+
204
+ print(f"Results will be saved in: {results_dir}")
205
+
206
+ image_filenames = set(f for f in os.listdir(test_image_folder) if f.endswith('.tif'))
207
+ label_filenames = set(f for f in os.listdir(test_label_folder) if f.endswith('.tif'))
208
+
209
+ # Only keep files that are present in both folders
210
+ matched_filenames = sorted(image_filenames & label_filenames)
211
+
212
+ test_image_files = [os.path.join(test_image_folder, f) for f in matched_filenames]
213
+ test_label_files = [os.path.join(test_label_folder, f) for f in matched_filenames]
214
+
215
+ print(f"Found {len(test_image_files)} images and {len(test_label_files)} masks")
216
+
217
+ test_dataset = CellposeLazyDataset(test_image_files, test_label_files, settings, randomize=False, augment=False)
218
+
219
+ model = cp_models.CellposeModel(gpu=True, pretrained_model=settings['model_path'])
220
+
221
+ batch_size = settings['batch_size']
222
+ scores = []
223
+ names = []
224
+ time_ls = []
225
+
226
+ files_to_process = len(test_image_folder)
227
+
228
+ for i in range(0, len(test_dataset), batch_size):
229
+ start = time.time()
230
+ batch = [test_dataset[j] for j in range(i, min(i + batch_size, len(test_dataset)))]
231
+ images, labels = zip(*batch)
232
+
233
+ masks_pred, flows, _ = model.eval(x=list(images),
234
+ channels=[0, 0],
235
+ normalize=False,
236
+ diameter=30,
237
+ flow_threshold=settings['FT'],
238
+ cellprob_threshold=settings['CP_probability'],
239
+ rescale=None,
240
+ resample=True,
241
+ interp=True,
242
+ anisotropy=None,
243
+ min_size=5,
244
+ augment=True,
245
+ tile=True,
246
+ tile_overlap=0.2,
247
+ bsize=224)
248
+
249
+ n_objects_true_ls = []
250
+ n_objects_pred_ls = []
251
+ mean_area_true_ls = []
252
+ mean_area_pred_ls = []
253
+ tp_ls, fp_ls, fn_ls = [], [], []
254
+ precision_ls, recall_ls, f1_ls, accuracy_ls = [], [], [], []
255
+
256
+ for j, (img, lbl, pred, flow) in enumerate(zip(images, labels, masks_pred, flows)):
257
+ score = float(aggregated_jaccard_index([lbl], [pred]))
258
+ fname = os.path.basename(test_label_files[i + j])
259
+ scores.append(score)
260
+ names.append(fname)
261
+
262
+ # Label masks
263
+ lbl_lab = label(lbl)
264
+ pred_lab = label(pred)
265
+
266
+ # Count objects
267
+ n_true = lbl_lab.max()
268
+ n_pred = pred_lab.max()
269
+ n_objects_true_ls.append(n_true)
270
+ n_objects_pred_ls.append(n_pred)
271
+
272
+ # Mean object size (area)
273
+ area_true = [p.area for p in regionprops(lbl_lab)]
274
+ area_pred = [p.area for p in regionprops(pred_lab)]
275
+
276
+ mean_area_true = np.mean(area_true) if area_true else 0
277
+ mean_area_pred = np.mean(area_pred) if area_pred else 0
278
+ mean_area_true_ls.append(mean_area_true)
279
+ mean_area_pred_ls.append(mean_area_pred)
280
+
281
+ # Compute object-level TP, FP, FN
282
+ ap, tp, fp, fn = average_precision([lbl], [pred], threshold=[0.5])
283
+ tp, fp, fn = int(tp[0, 0]), int(fp[0, 0]), int(fn[0, 0])
284
+ tp_ls.append(tp)
285
+ fp_ls.append(fp)
286
+ fn_ls.append(fn)
287
+
288
+ # Precision, Recall, F1, Accuracy
289
+ prec = tp / (tp + fp) if (tp + fp) > 0 else 0
290
+ rec = tp / (tp + fn) if (tp + fn) > 0 else 0
291
+ f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
292
+ acc = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
293
+
294
+ precision_ls.append(prec)
295
+ recall_ls.append(rec)
296
+ f1_ls.append(f1)
297
+ accuracy_ls.append(acc)
298
+
299
+ if settings['save']:
300
+ plot_cellpose_resilts(i, j, results_dir, img, lbl, pred, flow)
301
+
302
+ if settings['save']:
303
+ plot_cellpose_resilts(i,j,results_dir, img, lbl, pred, flow)
304
+
305
+ stop = time.time()
306
+ duration = stop-start
307
+ files_processed = (i+1) * batch_size
308
+ time_ls.append(duration)
309
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=batch_size, operation_type="test custom cellpose model")
310
+
311
+ df_results = pd.DataFrame({
312
+ 'label_image': names,
313
+ 'Jaccard': scores,
314
+ 'n_objects_true': n_objects_true_ls,
315
+ 'n_objects_pred': n_objects_pred_ls,
316
+ 'mean_area_true': mean_area_true_ls,
317
+ 'mean_area_pred': mean_area_pred_ls,
318
+ 'TP': tp_ls,
319
+ 'FP': fp_ls,
320
+ 'FN': fn_ls,
321
+ 'Precision': precision_ls,
322
+ 'Recall': recall_ls,
323
+ 'F1': f1_ls,
324
+ 'Accuracy': accuracy_ls
325
+ })
326
+
327
+ df_results['n_error'] = abs(df_results['n_objects_pred'] - df_results['n_objects_true'])
328
+
329
+ print(f"Average true objects/image: {df_results['n_objects_true'].mean():.2f}")
330
+ print(f"Average predicted objects/image: {df_results['n_objects_pred'].mean():.2f}")
331
+ print(f"Mean object area (true): {df_results['mean_area_true'].mean():.2f} px")
332
+ print(f"Mean object area (pred): {df_results['mean_area_pred'].mean():.2f} px")
333
+ print(f"Average Jaccard score: {df_results['Jaccard'].mean():.4f}")
334
+
335
+ print(f"Average Precision: {df_results['Precision'].mean():.3f}")
336
+ print(f"Average Recall: {df_results['Recall'].mean():.3f}")
337
+ print(f"Average F1-score: {df_results['F1'].mean():.3f}")
338
+ print(f"Average Accuracy: {df_results['Accuracy'].mean():.3f}")
339
+
340
+ display(df_results)
341
+
342
+ if settings['save']:
343
+ df_results.to_csv(os.path.join(results_dir, 'test_results.csv'), index=False)
344
+
345
+ def apply_cellpose_model(settings):
346
+
347
+ from .settings import get_default_apply_cellpose_model_settings
348
+ from .utils import save_settings, print_progress
349
+
350
+ def plot_cellpose_result(i, j, results_dir, img, pred, flow):
351
+
352
+ from .plot import generate_mask_random_cmap
353
+
354
+ fig, axs = plt.subplots(1, 4, figsize=(16, 4), gridspec_kw={'wspace': 0.1, 'hspace': 0.1})
355
+ cmap_pred = generate_mask_random_cmap(pred)
356
+
357
+ axs[0].imshow(img, cmap='gray')
358
+ axs[0].set_title('Image')
359
+ axs[0].axis('off')
360
+
361
+ axs[1].imshow(pred, cmap=cmap_pred, interpolation='nearest')
362
+ axs[1].set_title('Predicted Mask')
363
+ axs[1].axis('off')
364
+
365
+ axs[2].imshow(flow[2], cmap='gray')
366
+ axs[2].set_title('Cell Probability')
367
+ axs[2].axis('off')
368
+
369
+ axs[3].imshow(flow[0], cmap='gray')
370
+ axs[3].set_title('Flows')
371
+ axs[3].axis('off')
372
+
373
+ save_path = os.path.join(results_dir, f"cellpose_result_{i + j:03d}.png")
374
+ plt.savefig(save_path, dpi=200, bbox_inches='tight')
375
+ plt.show()
376
+ plt.close(fig)
377
+
378
+
379
+ settings = get_default_apply_cellpose_model_settings(settings)
380
+ save_settings(settings, name='apply_cellpose_model')
381
+
382
+ image_folder = os.path.join(settings['src'])
383
+ results_dir = os.path.join(settings['src'], 'results')
384
+ os.makedirs(results_dir, exist_ok=True)
385
+ print(f"Results will be saved in: {results_dir}")
386
+
387
+ image_files = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith('.tif')])
388
+ print(f"Found {len(image_files)} images")
389
+
390
+ dummy_labels = [image_files[0]] * len(image_files)
391
+ dataset = CellposeLazyDataset(image_files, dummy_labels, settings, randomize=False, augment=False)
392
+
393
+ model = cp_models.CellposeModel(gpu=True, pretrained_model=settings['model_path'])
394
+ batch_size = settings['batch_size']
395
+ measurements = []
396
+
397
+ files_to_process = len(image_files)
398
+ time_ls = []
399
+
400
+ for i in range(0, len(dataset), batch_size):
401
+ start = time.time()
402
+ batch = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
403
+ images, _ = zip(*batch)
404
+
405
+ X = list(images)
406
+
407
+ print(settings['CP_probability'])
408
+ masks_pred, flows, _ = model.eval(x=list(images),
409
+ channels=[0, 0],
410
+ normalize=False,
411
+ diameter=30,
412
+ flow_threshold=settings['FT'],
413
+ cellprob_threshold=settings['CP_probability'],
414
+ rescale=None,
415
+ resample=True,
416
+ interp=True,
417
+ anisotropy=None,
418
+ min_size=5,
419
+ augment=True,
420
+ tile=True,
421
+ tile_overlap=0.2,
422
+ bsize=224)
423
+
424
+ for j, (img, pred, flow) in enumerate(zip(images, masks_pred, flows)):
425
+ fname = os.path.basename(image_files[i + j])
426
+
427
+ if settings.get('circularize', False):
428
+ h, w = pred.shape
429
+ Y, X = np.ogrid[:h, :w]
430
+ center_x, center_y = w / 2, h / 2
431
+ radius = min(center_x, center_y)
432
+ circular_mask = (X - center_x)**2 + (Y - center_y)**2 <= radius**2
433
+ pred = pred * circular_mask
434
+
435
+ if settings['save']:
436
+ plot_cellpose_result(i, j, results_dir, img, pred, flow)
437
+
438
+ props = regionprops(sklabel(pred))
439
+ for k, prop in enumerate(props):
440
+ measurements.append({
441
+ 'image': fname,
442
+ 'object_id': k + 1,
443
+ 'area': prop.area
444
+ })
445
+
446
+ stop = time.time()
447
+ duration = stop-start
448
+ files_processed = (i+1) * batch_size
449
+ time_ls.append(duration)
450
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=batch_size, operation_type="apply custom cellpose model")
451
+
452
+
453
+ # Write after each batch
454
+ df_measurements = pd.DataFrame(measurements)
455
+ df_measurements.to_csv(os.path.join(results_dir, 'measurements.csv'), index=False)
456
+ print("Saved object counts and areas to measurements.csv")
457
+
458
+ df_summary = df_measurements.groupby('image').agg(
459
+ object_count=('object_id', 'count'),
460
+ average_area=('area', 'mean')
461
+ ).reset_index()
462
+ df_summary.to_csv(os.path.join(results_dir, 'summary.csv'), index=False)
463
+ print("Saved object count and average area to summary.csv")
464
+
465
+ def plot_cellpose_batch(images, labels):
466
+ from .plot import generate_mask_random_cmap
467
+
468
+ cmap_lbl = generate_mask_random_cmap(labels)
469
+ batch_size = len(images)
470
+ fig, axs = plt.subplots(2, batch_size, figsize=(4 * batch_size, 8))
471
+ for i in range(batch_size):
472
+ axs[0, i].imshow(images[i], cmap='gray')
473
+ axs[0, i].set_title(f'Image {i+1}')
474
+ axs[0, i].axis('off')
475
+ axs[1, i].imshow(labels[i], cmap=cmap_lbl, interpolation='nearest')
476
+ axs[1, i].set_title(f'Label {i+1}')
477
+ axs[1, i].axis('off')
478
+ plt.show()
479
+
480
+ def analyze_percent_positive(settings):
481
+ from .io import _read_and_merge_data
482
+ from .utils import save_settings
483
+ from .settings import default_settings_analyze_percent_positive
484
+
485
+ settings = default_settings_analyze_percent_positive(settings)
486
+
487
+ def translate_well_in_df(csv_loc):
488
+ # Load and extract metadata
489
+ df = pd.read_csv(csv_loc)
490
+ df[['plateID', 'well']] = df['Renamed TIFF'].str.replace('.tif', '', regex=False).str.split('_', expand=True)[[0, 1]]
491
+ df['plate_well'] = df['plateID'] + '_' + df['well']
492
+
493
+ # Retain one row per plate_well
494
+ df_2 = df.drop_duplicates(subset='plate_well').copy()
495
+
496
+ # Translate well to row and column
497
+ df_2['rowID'] = 'r' + df_2['well'].str[0].map(lambda x: str(string.ascii_uppercase.index(x) + 1))
498
+ df_2['column_name'] = 'c' + df_2['well'].str[1:].astype(int).astype(str)
499
+
500
+ # Optional: add prcf ID (plate_row_column_field)
501
+ df_2['fieldID'] = 'f1' # default or extract from filename if needed
502
+ df_2['prc'] = 'p' + df_2['plateID'].str.extract(r'(\d+)')[0] + '_' + df_2['rowID'] + '_' + df_2['column_name']
503
+
504
+ return df_2
505
+
506
+ def annotate_and_summarize(df, value_col, condition_col, well_col, threshold, annotation_col='annotation'):
507
+ """
508
+ Annotate and summarize a DataFrame based on a threshold.
509
+
510
+ Parameters:
511
+ - df: pandas.DataFrame
512
+ - value_col: str, column name to apply threshold on
513
+ - condition_col: str, column name for experimental condition
514
+ - well_col: str, column name for wells
515
+ - threshold: float, threshold value for annotation
516
+ - annotation_col: str, name of the new annotation column
517
+
518
+ Returns:
519
+ - df: annotated DataFrame
520
+ - summary_df: DataFrame with counts and fractions per condition and well
521
+ """
522
+ # Annotate
523
+ df[annotation_col] = np.where(df[value_col] > threshold, 'above', 'below')
524
+
525
+ # Count per condition and well
526
+ count_df = df.groupby([condition_col, well_col, annotation_col]).size().unstack(fill_value=0)
527
+
528
+ # Calculate total and fractions
529
+ count_df['total'] = count_df.sum(axis=1)
530
+ count_df['fraction_above'] = count_df.get('above', 0) / count_df['total']
531
+ count_df['fraction_below'] = count_df.get('below', 0) / count_df['total']
532
+
533
+ return df, count_df.reset_index()
534
+
535
+ save_settings(settings, name='analyze_percent_positive', show=False)
536
+
537
+ df, _ = _read_and_merge_data(locs=[settings['src']+'/measurements/measurements.db'],
538
+ tables=settings['tables'],
539
+ verbose=True,
540
+ nuclei_limit=None,
541
+ pathogen_limit=None)
542
+
543
+ df['condition'] = 'none'
544
+
545
+ if not settings['filter_1'] is None:
546
+ df = df[df[settings['filter_1'][0]]>settings['filter_1'][1]]
547
+
548
+ condition_col = 'condition'
549
+ well_col = 'prc'
550
+
551
+ df, count_df = annotate_and_summarize(df, settings['value_col'], condition_col, well_col, settings['threshold'], annotation_col='annotation')
552
+ count_df[['plateID', 'rowID', 'column_name']] = count_df['prc'].str.split('_', expand=True)
553
+
554
+ csv_loc = os.path.join(settings['src'], 'rename_log.csv')
555
+ csv_out_loc = os.path.join(settings['src'], 'result.csv')
556
+ translate_df = translate_well_in_df(csv_loc)
557
+
558
+ merged = pd.merge(count_df, translate_df, on=['rowID', 'column_name'], how='inner')
559
+
560
+ merged = merged[['plate_y', 'well', 'plate_well','fieldID','rowID','column_name','prc_x','Original File','Renamed TIFF','above','below','fraction_above','fraction_below']]
561
+ merged[[f'part{i}' for i in range(merged['Original File'].str.count('_').max() + 1)]] = merged['Original File'].str.split('_', expand=True)
562
+ merged.to_csv(csv_out_loc, index=False)
563
+ display(merged)
564
+ return merged
565
+
24
566
  def analyze_recruitment(settings):
25
567
  """
26
568
  Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
@@ -136,7 +678,7 @@ def analyze_recruitment(settings):
136
678
 
137
679
  def analyze_plaques(settings):
138
680
 
139
- from .cellpose import identify_masks_finetune
681
+ from .spacr_cellpose import identify_masks_finetune
140
682
  from .settings import get_analyze_plaque_settings
141
683
  from .utils import save_settings, download_models
142
684
  from spacr import __file__ as spacr_path
@@ -198,147 +740,6 @@ def analyze_plaques(settings):
198
740
 
199
741
  print(f"Analysis completed and saved to database '{db_name}'.")
200
742
 
201
- def train_cellpose(settings):
202
-
203
- from .io import _load_normalized_images_and_labels, _load_images_and_labels
204
- from .settings import get_train_cellpose_default_settings
205
- from .utils import save_settings
206
-
207
- settings = get_train_cellpose_default_settings(settings)
208
-
209
- img_src = settings['img_src']
210
- mask_src = os.path.join(img_src, 'masks')
211
- test_img_src = settings['test_img_src']
212
- test_mask_src = settings['test_mask_src']
213
-
214
- if settings['resize']:
215
- target_height = settings['width_height'][1]
216
- target_width = settings['width_height'][0]
217
-
218
- if settings['test']:
219
- test_img_src = os.path.join(os.path.dirname(settings['img_src']), 'test')
220
- test_mask_src = os.path.join(settings['test_img_src'], 'mask')
221
-
222
- test_images, test_masks, test_image_names, test_mask_names = None,None,None,None
223
- print(settings)
224
-
225
- if settings['from_scratch']:
226
- model_name=f"scratch_{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_width}_Y{target_height}.CP_model"
227
- else:
228
- if settings['resize']:
229
- model_name=f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_width}_Y{target_height}.CP_model"
230
- else:
231
- model_name=f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}.CP_model"
232
-
233
- model_save_path = os.path.join(settings['mask_src'], 'models', 'cellpose_model')
234
- print(model_save_path)
235
- os.makedirs(model_save_path, exist_ok=True)
236
-
237
- save_settings(settings, name=model_name)
238
-
239
- if settings['from_scratch']:
240
- model = cp_models.CellposeModel(gpu=True, model_type=settings['model_type'], diam_mean=settings['diameter'], pretrained_model=None)
241
- else:
242
- model = cp_models.CellposeModel(gpu=True, model_type=settings['model_type'])
243
-
244
- if settings['normalize']:
245
-
246
- image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
247
- label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
248
- images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files,
249
- label_files,
250
- settings['channels'],
251
- settings['percentiles'],
252
- settings['invert'],
253
- settings['verbose'],
254
- settings['remove_background'],
255
- settings['background'],
256
- settings['Signal_to_noise'],
257
- settings['target_height'],
258
- settings['target_width'])
259
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
260
-
261
- if settings['test']:
262
- test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
263
- test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
264
- test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files,
265
- test_label_files,
266
- settings['channels'],
267
- settings['percentiles'],
268
- settings['invert'],
269
- settings['verbose'],
270
- settings['remove_background'],
271
- settings['background'],
272
- settings['Signal_to_noise'],
273
- settings['target_height'],
274
- settings['target_width'])
275
- test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
276
-
277
- else:
278
- images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, settings['invert'])
279
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
280
-
281
- if settings['test']:
282
- test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(test_img_src,
283
- test_mask_src,
284
- settings['invert'])
285
-
286
- test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
287
-
288
- #if resize:
289
- # images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
290
-
291
- if settings['model_type'] == 'cyto':
292
- cp_channels = [0,1]
293
- if settings['model_type'] == 'cyto2':
294
- cp_channels = [0,2]
295
- if settings['model_type'] == 'nucleus':
296
- cp_channels = [0,0]
297
- if settings['grayscale']:
298
- cp_channels = [0,0]
299
- images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
300
-
301
- masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
302
-
303
- print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
304
- save_every = int(settings['n_epochs']/10)
305
- if save_every < 10:
306
- save_every = settings['n_epochs']
307
-
308
- train_cp.train_seg(model.net,
309
- train_data=images,
310
- train_labels=masks,
311
- train_files=image_names,
312
- train_labels_files=mask_names,
313
- train_probs=None,
314
- test_data=test_images,
315
- test_labels=test_masks,
316
- test_files=test_image_names,
317
- test_labels_files=test_mask_names,
318
- test_probs=None,
319
- load_files=True,
320
- batch_size=settings['batch_size'],
321
- learning_rate=settings['learning_rate'],
322
- n_epochs=settings['n_epochs'],
323
- weight_decay=settings['weight_decay'],
324
- momentum=0.9,
325
- SGD=False,
326
- channels=cp_channels,
327
- channel_axis=None,
328
- normalize=False,
329
- compute_flows=False,
330
- save_path=model_save_path,
331
- save_every=save_every,
332
- nimg_per_epoch=None,
333
- nimg_test_per_epoch=None,
334
- rescale=settings['rescale'],
335
- #scale_range=None,
336
- #bsize=224,
337
- min_train_masks=1,
338
- model_name=settings['model_name'])
339
-
340
- return print(f"Model saved at: {model_save_path}/{model_name}")
341
-
342
743
  def count_phenotypes(settings):
343
744
  from .io import _read_db
344
745
 
@@ -350,17 +751,17 @@ def count_phenotypes(settings):
350
751
  unique_values_count = df[settings['annotation_column']].nunique(dropna=True)
351
752
  print(f"Unique values in {settings['annotation_column']} (excluding NaN): {unique_values_count}")
352
753
 
353
- # Count unique values in 'value' column, grouped by 'plate', 'row_name', 'column'
354
- grouped_unique_count = df.groupby(['plate', 'row_name', 'column'])[settings['annotation_column']].nunique(dropna=True).reset_index(name='unique_count')
754
+ # Count unique values in 'value' column, grouped by 'plateID', 'rowID', 'columnID'
755
+ grouped_unique_count = df.groupby(['plateID', 'rowID', 'columnID'])[settings['annotation_column']].nunique(dropna=True).reset_index(name='unique_count')
355
756
  display(grouped_unique_count)
356
757
 
357
758
  save_path = os.path.join(settings['src'], 'phenotype_counts.csv')
358
759
 
359
760
  # Group by plate, row, and column, then count the occurrences of each unique value
360
- grouped_counts = df.groupby(['plate', 'row_name', 'column', 'value']).size().reset_index(name='count')
761
+ grouped_counts = df.groupby(['plateID', 'rowID', 'columnID', 'value']).size().reset_index(name='count')
361
762
 
362
763
  # Pivot the DataFrame so that unique values are columns and their counts are in the rows
363
- pivot_df = grouped_counts.pivot_table(index=['plate', 'row_name', 'column'], columns='value', values='count', fill_value=0)
764
+ pivot_df = grouped_counts.pivot_table(index=['plateID', 'rowID', 'columnID'], columns='value', values='count', fill_value=0)
364
765
 
365
766
  # Flatten the multi-level columns
366
767
  pivot_df.columns = [f"value_{int(col)}" for col in pivot_df.columns]
@@ -382,20 +783,20 @@ def count_phenotypes(settings):
382
783
  def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),'r2':(90,10),'r3':(80,20),'r4':(80,20),'r5':(70,30),'r6':(70,30),'r7':(60,40),'r8':(60,40),'r9':(50,50),'r10':(50,50),'r11':(40,60),'r12':(40,60),'r13':(30,70),'r14':(30,70),'r15':(20,80),'r16':(20,80)},
383
784
  pc_grna='TGGT1_220950_1', nc_grna='TGGT1_233460_4',
384
785
  y_columns=['class_1_fraction', 'TGGT1_220950_1_fraction', 'nc_fraction'],
385
- column='column', value='c3', plate=None, save_paths=None):
786
+ column='columnID', value='c3', plate=None, save_paths=None):
386
787
 
387
788
  def calculate_well_score_fractions(df, class_columns='cv_predictions'):
388
- if all(col in df.columns for col in ['plate', 'row_name', 'column']):
389
- df['prc'] = df['plate'] + '_' + df['row_name'] + '_' + df['column']
789
+ if all(col in df.columns for col in ['plateID', 'rowID', 'columnID']):
790
+ df['prc'] = df['plateID'] + '_' + df['rowID'] + '_' + df['columnID']
390
791
  else:
391
- raise ValueError("Cannot find 'plate', 'row_name', or 'column' in df.columns")
392
- prc_summary = df.groupby(['plate', 'row_name', 'column', 'prc']).size().reset_index(name='total_rows')
393
- well_counts = (df.groupby(['plate', 'row_name', 'column', 'prc', class_columns])
792
+ raise ValueError("Cannot find 'plateID', 'rowID', or 'columnID' in df.columns")
793
+ prc_summary = df.groupby(['plateID', 'rowID', 'columnID', 'prc']).size().reset_index(name='total_rows')
794
+ well_counts = (df.groupby(['plateID', 'rowID', 'columnID', 'prc', class_columns])
394
795
  .size()
395
796
  .unstack(fill_value=0)
396
797
  .reset_index()
397
798
  .rename(columns={0: 'class_0', 1: 'class_1'}))
398
- summary_df = pd.merge(prc_summary, well_counts, on=['plate', 'row_name', 'column', 'prc'], how='left')
799
+ summary_df = pd.merge(prc_summary, well_counts, on=['plateID', 'rowID', 'columnID', 'prc'], how='left')
399
800
  summary_df['class_0_fraction'] = summary_df['class_0'] / summary_df['total_rows']
400
801
  summary_df['class_1_fraction'] = summary_df['class_1'] / summary_df['total_rows']
401
802
  return summary_df
@@ -490,8 +891,8 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
490
891
  return result
491
892
 
492
893
  def calculate_well_read_fraction(df, count_column='count'):
493
- if all(col in df.columns for col in ['plate', 'row_name', 'column']):
494
- df['prc'] = df['plate'] + '_' + df['row_name'] + '_' + df['column']
894
+ if all(col in df.columns for col in ['plateID', 'rowID', 'columnID']):
895
+ df['prc'] = df['plateID'] + '_' + df['rowID'] + '_' + df['columnID']
495
896
  else:
496
897
  raise ValueError("Cannot find plate, row or column in df.columns")
497
898
  grouped_df = df.groupby('prc')[count_column].sum().reset_index()
@@ -507,21 +908,17 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
507
908
  for i, reads_csv_temp in enumerate(reads_csv):
508
909
  reads_df_temp = pd.read_csv(reads_csv_temp)
509
910
  scores_df_temp = pd.read_csv(scores_csv[i])
510
- reads_df_temp['plate'] = f"plate{i+1}"
511
- scores_df_temp['plate'] = f"plate{i+1}"
911
+ reads_df_temp['plateID'] = f"plate{i+1}"
912
+ scores_df_temp['plateID'] = f"plate{i+1}"
512
913
 
914
+ if 'column' in reads_df_temp.columns:
915
+ reads_df_temp = reads_df_temp.rename(columns={'column': 'columnID'})
513
916
  if 'column_name' in reads_df_temp.columns:
514
- reads_df_temp = reads_df_temp.rename(columns={'column_name': 'column'})
515
- if 'column_name' in reads_df_temp.columns:
516
- reads_df_temp = reads_df_temp.rename(columns={'column_name': 'column'})
517
- if 'column_name' in scores_df_temp.columns:
518
- scores_df_temp = scores_df_temp.rename(columns={'column_name': 'column'})
519
- if 'column_name' in scores_df_temp.columns:
520
- scores_df_temp = scores_df_temp.rename(columns={'column_name': 'column'})
521
- if 'row_name' in reads_df_temp.columns:
522
- reads_df_temp = reads_df_temp.rename(columns={'row_name': 'row_name'})
917
+ reads_df_temp = reads_df_temp.rename(columns={'column_name': 'columnID'})
918
+ if 'row' in reads_df_temp.columns:
919
+ reads_df_temp = reads_df_temp.rename(columns={'row_name': 'rowID'})
523
920
  if 'row_name' in scores_df_temp.columns:
524
- scores_df_temp = scores_df_temp.rename(columns={'row_name': 'row_name'})
921
+ scores_df_temp = scores_df_temp.rename(columns={'row_name': 'rowID'})
525
922
 
526
923
  reads_ls.append(reads_df_temp)
527
924
  scores_ls.append(scores_df_temp)
@@ -535,8 +932,8 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
535
932
  reads_df = pd.read_csv(reads_csv)
536
933
  scores_df = pd.read_csv(scores_csv)
537
934
  if plate != None:
538
- reads_df['plate'] = plate
539
- scores_df['plate'] = plate
935
+ reads_df['plateID'] = plate
936
+ scores_df['plateID'] = plate
540
937
 
541
938
  reads_df = calculate_well_read_fraction(reads_df)
542
939
  scores_df = calculate_well_score_fractions(scores_df)
@@ -548,7 +945,7 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
548
945
 
549
946
  df_emp = pd.DataFrame([(key, val[0], val[1], val[0] / (val[0] + val[1]), val[1] / (val[0] + val[1])) for key, val in empirical_dict.items()],columns=['key', 'value1', 'value2', 'pc_fraction', 'nc_fraction'])
550
947
 
551
- df = pd.merge(df, df_emp, left_on='row_name', right_on='key')
948
+ df = pd.merge(df, df_emp, left_on='rowID', right_on='key')
552
949
 
553
950
  if any in y_columns not in df.columns:
554
951
  print(f"columns in dataframe:")
@@ -620,7 +1017,7 @@ def interperate_vision_model(settings={}):
620
1017
  else:
621
1018
  return None
622
1019
 
623
- from spacr.plot import spacrGraph
1020
+ from .plot import spacrGraph
624
1021
 
625
1022
  df[name] = df['feature'].apply(lambda x: find_feature_class(x, feature_groups))
626
1023
 
@@ -698,11 +1095,17 @@ def interperate_vision_model(settings={}):
698
1095
  # Clean and align columns for merging
699
1096
  df['object_label'] = df['object_label'].str.replace('o', '')
700
1097
 
701
- if 'row_name' not in scores_df.columns:
702
- scores_df['row_name'] = scores_df['row']
1098
+ if 'rowID' not in scores_df.columns:
1099
+ if 'row' in scores_df.columns:
1100
+ scores_df['rowID'] = scores_df['row']
1101
+ if 'row_name' in scores_df.columns:
1102
+ scores_df['rowID'] = scores_df['row_name']
703
1103
 
704
- if 'column_name' not in scores_df.columns:
705
- scores_df['column_name'] = scores_df['col']
1104
+ if 'columnID' not in scores_df.columns:
1105
+ if 'column_name' in scores_df.columns:
1106
+ scores_df['columnID'] = scores_df['column_name']
1107
+ if 'column' in scores_df.columns:
1108
+ scores_df['columnID'] = scores_df['column']
706
1109
 
707
1110
  if 'object_label' not in scores_df.columns:
708
1111
  scores_df['object_label'] = scores_df['object']
@@ -714,14 +1117,14 @@ def interperate_vision_model(settings={}):
714
1117
  scores_df['object_label'] = scores_df['object'].astype(str)
715
1118
 
716
1119
  # Ensure all join columns have the same data type in both DataFrames
717
- df[['plate', 'row_name', 'column_name', 'field', 'object_label']] = df[['plate', 'row_name', 'column_name', 'field', 'object_label']].astype(str)
718
- scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label']] = scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label']].astype(str)
1120
+ df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label']] = df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label']].astype(str)
1121
+ scores_df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label']] = scores_df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label']].astype(str)
719
1122
 
720
1123
  # Select only the necessary columns from scores_df for merging
721
- scores_df = scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label', settings['score_column']]]
1124
+ scores_df = scores_df[['plateID', 'rowID', 'column_name', 'fieldID', 'object_label', settings['score_column']]]
722
1125
 
723
1126
  # Now merge DataFrames
724
- merged_df = pd.merge(df, scores_df, on=['plate', 'row_name', 'column_name', 'field', 'object_label'], how='inner')
1127
+ merged_df = pd.merge(df, scores_df, on=['plateID', 'rowID', 'column_name', 'fieldID', 'object_label'], how='inner')
725
1128
 
726
1129
  # Separate numerical features and the score column
727
1130
  X = merged_df.select_dtypes(include='number').drop(columns=[settings['score_column']])
@@ -997,8 +1400,8 @@ def analyze_endodyogeny(settings):
997
1400
  output['data'] = df
998
1401
 
999
1402
 
1000
- if settings['level'] == 'plate':
1001
- prc_column = 'plate'
1403
+ if settings['level'] == 'plateID':
1404
+ prc_column = 'plateID'
1002
1405
  else:
1003
1406
  prc_column = 'prc'
1004
1407
 
@@ -1144,28 +1547,28 @@ def generate_score_heatmap(settings):
1144
1547
  def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
1145
1548
 
1146
1549
  df = pd.read_csv(csv)
1147
- if 'col' in df.columns:
1148
- df = df[df['col']==column]
1550
+ if 'columnID' in df.columns:
1551
+ df = df[df['columnID']==column]
1149
1552
  elif 'column' in df.columns:
1150
- df['col'] = df['column']
1151
- df = df[df['col']==column]
1553
+ df['columnID'] = df['column']
1554
+ df = df[df['columnID']==column]
1152
1555
  if not plate is None:
1153
- df['plate'] = f"plate{plate}"
1154
- grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
1155
- grouped_df['prc'] = grouped_df['plate'].astype(str) + '_' + grouped_df['row'].astype(str) + '_' + grouped_df['col'].astype(str)
1556
+ df['plateID'] = f"plate{plate}"
1557
+ grouped_df = df.groupby(['plateID', 'rowID', 'columnID'])[data_column].mean().reset_index()
1558
+ grouped_df['prc'] = grouped_df['plateID'].astype(str) + '_' + grouped_df['rowID'].astype(str) + '_' + grouped_df['columnID'].astype(str)
1156
1559
  return grouped_df
1157
1560
 
1158
1561
  def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
1159
1562
  df = pd.read_csv(csv)
1160
1563
  df = df[df['column_name']==column]
1161
1564
  if plate not in df.columns:
1162
- df['plate'] = f"plate{plate}"
1565
+ df['plateID'] = f"plate{plate}"
1163
1566
  df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
1164
- grouped_df = df.groupby(['plate', 'row_name', 'column_name'])['count'].sum().reset_index()
1567
+ grouped_df = df.groupby(['plateID', 'rowID', 'columnID'])['count'].sum().reset_index()
1165
1568
  grouped_df = grouped_df.rename(columns={'count': 'total_count'})
1166
- merged_df = pd.merge(df, grouped_df, on=['plate', 'row_name', 'column_name'])
1569
+ merged_df = pd.merge(df, grouped_df, on=['plateID', 'rowID', 'column_name'])
1167
1570
  merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
1168
- merged_df['prc'] = merged_df['plate'].astype(str) + '_' + merged_df['row_name'].astype(str) + '_' + merged_df['column_name'].astype(str)
1571
+ merged_df['prc'] = merged_df['plateID'].astype(str) + '_' + merged_df['rowID'].astype(str) + '_' + merged_df['column_name'].astype(str)
1169
1572
  return merged_df
1170
1573
 
1171
1574
  def plot_multi_channel_heatmap(df, column='c3', cmap='coolwarm'):
@@ -1177,17 +1580,17 @@ def generate_score_heatmap(settings):
1177
1580
  - column: Column to filter by (default is 'c3').
1178
1581
  """
1179
1582
  # Extract row number and convert to integer for sorting
1180
- df['row_num'] = df['row'].str.extract(r'(\d+)').astype(int)
1583
+ df['row_num'] = df['rowID'].str.extract(r'(\d+)').astype(int)
1181
1584
 
1182
1585
  # Filter and sort by plate, row, and column
1183
- df = df[df['col'] == column]
1184
- df = df.sort_values(by=['plate', 'row_num', 'col'])
1586
+ df = df[df['columnID'] == column]
1587
+ df = df.sort_values(by=['plateID', 'row_num', 'columnID'])
1185
1588
 
1186
1589
  # Drop temporary 'row_num' column after sorting
1187
1590
  df = df.drop('row_num', axis=1)
1188
1591
 
1189
1592
  # Create a new column combining plate, row, and column for the index
1190
- df['plate_row_col'] = df['plate'] + '-' + df['row'] + '-' + df['col']
1593
+ df['plate_row_col'] = df['plateID'] + '-' + df['rowID'] + '-' + df['columnID']
1191
1594
 
1192
1595
  # Set 'plate_row_col' as the index
1193
1596
  df.set_index('plate_row_col', inplace=True)
@@ -1244,11 +1647,11 @@ def generate_score_heatmap(settings):
1244
1647
  # Loop through all collected CSV files and process them
1245
1648
  for csv_file in ls:
1246
1649
  df = pd.read_csv(csv_file) # Read CSV into DataFrame
1247
- df = df[df['col']==column]
1650
+ df = df[df['columnID']==column]
1248
1651
  if not plate is None:
1249
- df['plate'] = f"plate{plate}"
1250
- # Group the data by 'plate', 'row', and 'col'
1251
- grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
1652
+ df['plateID'] = f"plate{plate}"
1653
+ # Group the data by 'plateID', 'rowID', and 'columnID'
1654
+ grouped_df = df.groupby(['plateID', 'rowID', 'columnID'])[data_column].mean().reset_index()
1252
1655
  # Use the CSV filename to create a new column name
1253
1656
  folder_name = os.path.dirname(csv_file).replace(".csv", "")
1254
1657
  new_column_name = os.path.basename(f"{folder_name}_{data_column}")
@@ -1259,8 +1662,8 @@ def generate_score_heatmap(settings):
1259
1662
  if combined_df is None:
1260
1663
  combined_df = grouped_df
1261
1664
  else:
1262
- combined_df = pd.merge(combined_df, grouped_df, on=['plate', 'row', 'col'], how='outer')
1263
- combined_df['prc'] = combined_df['plate'].astype(str) + '_' + combined_df['row'].astype(str) + '_' + combined_df['col'].astype(str)
1665
+ combined_df = pd.merge(combined_df, grouped_df, on=['plateID', 'rowID', 'columnID'], how='outer')
1666
+ combined_df['prc'] = combined_df['plateID'].astype(str) + '_' + combined_df['rowID'].astype(str) + '_' + combined_df['columnID'].astype(str)
1264
1667
  return combined_df
1265
1668
 
1266
1669
  def calculate_mae(df):
@@ -1282,16 +1685,16 @@ def generate_score_heatmap(settings):
1282
1685
  mae_df = pd.DataFrame(mae_data)
1283
1686
  return mae_df
1284
1687
 
1285
- result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plate'], settings['column'], )
1286
- df = calculate_fraction_mixed_condition(settings['csv'], settings['plate'], settings['column'], settings['control_sgrnas'])
1688
+ result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plateID'], settings['columnID'], )
1689
+ df = calculate_fraction_mixed_condition(settings['csv'], settings['plateID'], settings['columnID'], settings['control_sgrnas'])
1287
1690
  df = df[df['grna_name']==settings['fraction_grna']]
1288
1691
  fraction_df = df[['fraction', 'prc']]
1289
1692
  merged_df = pd.merge(fraction_df, result_df, on=['prc'])
1290
- cv_df = group_cv_score(settings['cv_csv'], settings['plate'], settings['column'], settings['data_column_cv'])
1693
+ cv_df = group_cv_score(settings['cv_csv'], settings['plateID'], settings['columnID'], settings['data_column_cv'])
1291
1694
  cv_df = cv_df[[settings['data_column_cv'], 'prc']]
1292
1695
  merged_df = pd.merge(merged_df, cv_df, on=['prc'])
1293
1696
 
1294
- fig = plot_multi_channel_heatmap(merged_df, settings['column'], settings['cmap'])
1697
+ fig = plot_multi_channel_heatmap(merged_df, settings['columnID'], settings['cmap'])
1295
1698
  if 'row_number' in merged_df.columns:
1296
1699
  merged_df = merged_df.drop('row_num', axis=1)
1297
1700
  mae_df = calculate_mae(merged_df)
@@ -1299,9 +1702,9 @@ def generate_score_heatmap(settings):
1299
1702
  mae_df = mae_df.drop('row_num', axis=1)
1300
1703
 
1301
1704
  if not settings['dst'] is None:
1302
- mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plate']}.csv")
1303
- merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}_data.csv")
1304
- heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}.pdf")
1705
+ mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plateID']}.csv")
1706
+ merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}_data.csv")
1707
+ heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}.pdf")
1305
1708
  mae_df.to_csv(mae_dst, index=False)
1306
1709
  merged_df.to_csv(merged_dst, index=False)
1307
1710
  fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')