spacr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py ADDED
@@ -0,0 +1,2271 @@
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose
2
+ import numpy as np
3
+ import pandas as pd
4
+ from PIL import Image
5
+ from collections import defaultdict, Counter
6
+ from pathlib import Path
7
+ from functools import partial
8
+ from matplotlib.animation import FuncAnimation
9
+ from IPython.display import display
10
+ from skimage.util import img_as_uint
11
+ from skimage.exposure import rescale_intensity
12
+ from skimage import filters
13
+ import skimage.measure as measure
14
+ from skimage import exposure
15
+ import imageio.v2 as imageio2
16
+ import matplotlib.pyplot as plt
17
+ from io import BytesIO
18
+ from IPython.display import display, clear_output
19
+ from multiprocessing import Pool, cpu_count
20
+ from torch.utils.data import Dataset
21
+ import seaborn as sns
22
+ import matplotlib.pyplot as plt
23
+ from torchvision.transforms import ToTensor
24
+
25
+ from .logger import log_function_call
26
+
27
+ @log_function_call
28
+ def _load_images_and_labels(image_files, label_files, circular=False, invert=False, image_extension="*.tif", label_extension="*.tif"):
29
+
30
+ from .utils import invert_image, apply_mask
31
+
32
+ images = []
33
+ labels = []
34
+
35
+ if not image_files is None:
36
+ image_names = sorted([os.path.basename(f) for f in image_files])
37
+ else:
38
+ image_names = []
39
+
40
+ if not label_files is None:
41
+ label_names = sorted([os.path.basename(f) for f in label_files])
42
+ else:
43
+ label_names = []
44
+
45
+ if not image_files is None and not label_files is None:
46
+ for img_file, lbl_file in zip(image_files, label_files):
47
+ image = cellpose.imread(img_file)
48
+ if invert:
49
+ image = invert_image(image)
50
+ if circular:
51
+ image = apply_mask(image, output_value=0)
52
+ label = cellpose.imread(lbl_file)
53
+ if image.max() > 1:
54
+ image = image / image.max()
55
+ images.append(image)
56
+ labels.append(label)
57
+ elif not image_files is None:
58
+ for img_file in image_files:
59
+ image = cellpose.imread(img_file)
60
+ if invert:
61
+ image = invert_image(image)
62
+ if circular:
63
+ image = apply_mask(image, output_value=0)
64
+ if image.max() > 1:
65
+ image = image / image.max()
66
+ images.append(image)
67
+ elif not image_files is None:
68
+ for lbl_file in label_files:
69
+ label = cellpose.imread(lbl_file)
70
+ if circular:
71
+ label = apply_mask(label, output_value=0)
72
+ labels.append(label)
73
+
74
+ if not image_files is None:
75
+ image_dir = os.path.dirname(image_files[0])
76
+ else:
77
+ image_dir = None
78
+
79
+ if not label_files is None:
80
+ label_dir = os.path.dirname(label_files[0])
81
+ else:
82
+ label_dir = None
83
+
84
+ # Log the number of loaded images and labels
85
+ print(f'Loaded {len(images)} images and {len(labels)} labels from {image_dir} and {label_dir}')
86
+ if len(labels) > 0 and len(images) > 0:
87
+ print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {labels[0].shape}, image type: labels[0].shape')
88
+ return images, labels, image_names, label_names
89
+
90
+ @log_function_call
91
+ def _load_normalized_images_and_labels(image_files, label_files, signal_thresholds=[1000], channels=None, percentiles=None, circular=False, invert=False, visualize=False):
92
+
93
+ from .plot import normalize_and_visualize
94
+ from .utils import invert_image, apply_mask
95
+
96
+ if isinstance(signal_thresholds, int):
97
+ signal_thresholds = [signal_thresholds] * (len(channels) if channels is not None else 1)
98
+ elif not isinstance(signal_thresholds, list):
99
+ signal_thresholds = [signal_thresholds]
100
+
101
+ images = []
102
+ labels = []
103
+
104
+ num_channels = 4
105
+ percentiles_1 = [[] for _ in range(num_channels)]
106
+ percentiles_99 = [[] for _ in range(num_channels)]
107
+
108
+ image_names = [os.path.basename(f) for f in image_files]
109
+
110
+ if label_files is not None:
111
+ label_names = [os.path.basename(f) for f in label_files]
112
+
113
+ # Load images and check percentiles
114
+ for i,img_file in enumerate(image_files):
115
+ image = cellpose.imread(img_file)
116
+ if invert:
117
+ image = invert_image(image)
118
+ if circular:
119
+ image = apply_mask(image, output_value=0)
120
+
121
+ # If specific channels are specified, select them
122
+ if channels is not None and image.ndim == 3:
123
+ image = image[..., channels]
124
+
125
+ if image.ndim < 3:
126
+ image = np.expand_dims(image, axis=-1)
127
+
128
+ images.append(image)
129
+ if percentiles is None:
130
+ for c in range(image.shape[-1]):
131
+ p1 = np.percentile(image[..., c], 1)
132
+ percentiles_1[c].append(p1)
133
+ for percentile in [99, 99.9, 99.99, 99.999]:
134
+ p = np.percentile(image[..., c], percentile)
135
+ if p > signal_thresholds[min(c, len(signal_thresholds)-1)]:
136
+ percentiles_99[c].append(p)
137
+ break
138
+
139
+ if not percentiles is None:
140
+ normalized_images = []
141
+ for image in images:
142
+ normalized_image = np.zeros_like(image, dtype=np.float32)
143
+ for c in range(image.shape[-1]):
144
+ high_p = np.percentile(image[..., c], percentiles[1])
145
+ low_p = np.percentile(image[..., c], percentiles[0])
146
+ normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(low_p, high_p), out_range=(0, 1))
147
+ normalized_images.append(normalized_image)
148
+ if visualize:
149
+ normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
150
+
151
+ if percentiles is None:
152
+ # Calculate average percentiles for normalization
153
+ avg_p1 = [np.mean(p) for p in percentiles_1]
154
+ avg_p99 = [np.mean(p) if len(p) > 0 else np.mean(percentiles_1[i]) for i, p in enumerate(percentiles_99)]
155
+
156
+ normalized_images = []
157
+ for image in images:
158
+ normalized_image = np.zeros_like(image, dtype=np.float32)
159
+ for c in range(image.shape[-1]):
160
+ normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
161
+ normalized_images.append(normalized_image)
162
+ if visualize:
163
+ normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
164
+
165
+ if not image_files is None:
166
+ image_dir = os.path.dirname(image_files[0])
167
+ else:
168
+ image_dir = None
169
+
170
+ if label_files is not None:
171
+ for lbl_file in label_files:
172
+ labels.append(cellpose.imread(lbl_file))
173
+ else:
174
+ label_names = []
175
+ label_dir = None
176
+
177
+ print(f'Loaded and normalized {len(normalized_images)} images and {len(labels)} labels from {image_dir} and {label_dir}')
178
+
179
+ return normalized_images, labels, image_names, label_names
180
+
181
+ class MyDataset(Dataset):
182
+ """
183
+ Custom dataset class for loading and processing image data.
184
+
185
+ Args:
186
+ data_dir (str): The directory path where the data is stored.
187
+ loader_classes (list): List of class names.
188
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default is None.
189
+ shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
190
+ load_to_memory (bool, optional): Whether to load images into memory. Default is False.
191
+
192
+ Attributes:
193
+ data_dir (str): The directory path where the data is stored.
194
+ classes (list): List of class names.
195
+ transform (callable): A function/transform that takes in an PIL image and returns a transformed version.
196
+ shuffle (bool): Whether to shuffle the dataset.
197
+ load_to_memory (bool): Whether to load images into memory.
198
+ filenames (list): List of file paths.
199
+ labels (list): List of labels corresponding to each file.
200
+ images (list): List of loaded images.
201
+ image_cache (Cache): Cache object for storing loaded images.
202
+
203
+ Methods:
204
+ load_image: Load an image from file.
205
+ __len__: Get the length of the dataset.
206
+ shuffle_dataset: Shuffle the dataset.
207
+ __getitem__: Get an item from the dataset.
208
+
209
+ """
210
+
211
+ def _init__(self, data_dir, loader_classes, transform=None, shuffle=True, load_to_memory=False):
212
+ from .utils import Cache
213
+ self.data_dir = data_dir
214
+ self.classes = loader_classes
215
+ self.transform = transform
216
+ self.shuffle = shuffle
217
+ self.load_to_memory = load_to_memory
218
+ self.filenames = []
219
+ self.labels = []
220
+ self.images = []
221
+ self.image_cache = Cache(50)
222
+ for class_name in self.classes:
223
+ class_path = os.path.join(data_dir, class_name)
224
+ class_files = [os.path.join(class_path, f) for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]
225
+ self.filenames.extend(class_files)
226
+ self.labels.extend([self.classes.index(class_name)] * len(class_files))
227
+ if self.shuffle:
228
+ self.shuffle_dataset()
229
+ if self.load_to_memory:
230
+ self.images = [self.load_image(f) for f in self.filenames]
231
+
232
+ def load_image(self, img_path):
233
+ img = self.image_cache.get(img_path)
234
+ if img is None:
235
+ img = Image.open(img_path).convert('RGB')
236
+ self.image_cache.put(img_path, img)
237
+ return img
238
+
239
+ def _len__(self):
240
+ return len(self.filenames)
241
+
242
+ def shuffle_dataset(self):
243
+ combined = list(zip(self.filenames, self.labels))
244
+ random.shuffle(combined)
245
+ self.filenames, self.labels = zip(*combined)
246
+
247
+ def _getitem__(self, index):
248
+ label = self.labels[index]
249
+ filename = self.filenames[index]
250
+ if self.load_to_memory:
251
+ img = self.images[index]
252
+ else:
253
+ img = self.load_image(filename)
254
+ if self.transform is not None:
255
+ img = self.transform(img)
256
+ else:
257
+ img = ToTensor()(img)
258
+ return img, label, filename
259
+
260
+ class CombineLoaders:
261
+ """
262
+ A class that combines multiple data loaders into a single iterator.
263
+
264
+ Args:
265
+ train_loaders (list): A list of data loaders.
266
+
267
+ Attributes:
268
+ train_loaders (list): A list of data loaders.
269
+ loader_iters (list): A list of iterator objects for each data loader.
270
+
271
+ Methods:
272
+ __iter__(): Returns the iterator object itself.
273
+ __next__(): Returns the next batch from one of the data loaders.
274
+
275
+ Raises:
276
+ StopIteration: If all data loaders have been exhausted.
277
+
278
+ """
279
+
280
+ def _init__(self, train_loaders):
281
+ self.train_loaders = train_loaders
282
+ self.loader_iters = [iter(loader) for loader in train_loaders]
283
+
284
+ def _iter__(self):
285
+ return self
286
+
287
+ def _next__(self):
288
+ while self.loader_iters:
289
+ random.shuffle(self.loader_iters) # Shuffle the loader_iters list
290
+ for i, loader_iter in enumerate(self.loader_iters):
291
+ try:
292
+ batch = next(loader_iter)
293
+ return i, batch
294
+ except StopIteration:
295
+ self.loader_iters.pop(i)
296
+ continue
297
+ else:
298
+ break
299
+ raise StopIteration
300
+
301
+ class CombinedDataset(Dataset):
302
+ """
303
+ A dataset that combines multiple datasets into one.
304
+
305
+ Args:
306
+ datasets (list): A list of datasets to be combined.
307
+ shuffle (bool, optional): Whether to shuffle the combined dataset. Defaults to True.
308
+ """
309
+
310
+ def _init__(self, datasets, shuffle=True):
311
+ self.datasets = datasets
312
+ self.lengths = [len(dataset) for dataset in datasets]
313
+ self.total_length = sum(self.lengths)
314
+ self.shuffle = shuffle
315
+ if shuffle:
316
+ self.indices = list(range(self.total_length))
317
+ random.shuffle(self.indices)
318
+ else:
319
+ self.indices = None
320
+ def _getitem__(self, index):
321
+ if self.shuffle:
322
+ index = self.indices[index]
323
+ for dataset, length in zip(self.datasets, self.lengths):
324
+ if index < length:
325
+ return dataset[index]
326
+ index -= length
327
+ def _len__(self):
328
+ return self.total_length
329
+
330
+ class NoClassDataset(Dataset):
331
+ """
332
+ A custom dataset class for handling images without class labels.
333
+
334
+ Args:
335
+ data_dir (str): The directory path where the images are stored.
336
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default is None.
337
+ shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
338
+ load_to_memory (bool, optional): Whether to load all images into memory. Default is False.
339
+
340
+ Attributes:
341
+ data_dir (str): The directory path where the images are stored.
342
+ transform (callable): A function/transform that takes in an PIL image and returns a transformed version.
343
+ shuffle (bool): Whether to shuffle the dataset.
344
+ load_to_memory (bool): Whether to load all images into memory.
345
+ filenames (list): List of file paths for the images.
346
+ images (list): List of loaded images (if load_to_memory is True).
347
+
348
+ Methods:
349
+ load_image: Loads an image from the given file path.
350
+ __len__: Returns the number of images in the dataset.
351
+ shuffle_dataset: Shuffles the dataset.
352
+ __getitem__: Retrieves an image and its corresponding file path from the dataset.
353
+
354
+ """
355
+
356
+ def _init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
357
+ self.data_dir = data_dir
358
+ self.transform = transform
359
+ self.shuffle = shuffle
360
+ self.load_to_memory = load_to_memory
361
+ self.filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]
362
+ if self.shuffle:
363
+ self.shuffle_dataset()
364
+ if self.load_to_memory:
365
+ self.images = [self.load_image(f) for f in self.filenames]
366
+ #@lru_cache(maxsize=None)
367
+ def load_image(self, img_path):
368
+ img = Image.open(img_path).convert('RGB')
369
+ return img
370
+ def _len__(self):
371
+ return len(self.filenames)
372
+ def shuffle_dataset(self):
373
+ if self.shuffle:
374
+ random.shuffle(self.filenames)
375
+ def _getitem__(self, index):
376
+ if self.load_to_memory:
377
+ img = self.images[index]
378
+ else:
379
+ img = self.load_image(self.filenames[index])
380
+ if self.transform is not None:
381
+ img = self.transform(img)
382
+ else:
383
+ img = ToTensor()(img)
384
+ # Return both the image and its filename
385
+ return img, self.filenames[index]
386
+
387
+ class MyDataset(Dataset):
388
+ """
389
+ A custom dataset class for loading and processing image data.
390
+
391
+ Args:
392
+ data_dir (str): The directory path where the image data is stored.
393
+ loader_classes (list): A list of class names for the dataset.
394
+ transform (callable, optional): A function/transform to apply to the image data. Default is None.
395
+ shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
396
+ pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
397
+ specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
398
+ specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
399
+ """
400
+
401
+ def _init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
402
+ self.data_dir = data_dir
403
+ self.classes = loader_classes
404
+ self.transform = transform
405
+ self.shuffle = shuffle
406
+ self.pin_memory = pin_memory
407
+ self.filenames = []
408
+ self.labels = []
409
+
410
+ if specific_files and specific_labels:
411
+ self.filenames = specific_files
412
+ self.labels = specific_labels
413
+ else:
414
+ for class_name in self.classes:
415
+ class_path = os.path.join(data_dir, class_name)
416
+ class_files = [os.path.join(class_path, f) for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]
417
+ self.filenames.extend(class_files)
418
+ self.labels.extend([self.classes.index(class_name)] * len(class_files))
419
+
420
+ if self.shuffle:
421
+ self.shuffle_dataset()
422
+
423
+ if self.pin_memory:
424
+ self.images = [self.load_image(f) for f in self.filenames]
425
+
426
+ def load_image(self, img_path):
427
+ img = Image.open(img_path).convert('RGB')
428
+ return img
429
+
430
+ def _len__(self):
431
+ return len(self.filenames)
432
+
433
+ def shuffle_dataset(self):
434
+ combined = list(zip(self.filenames, self.labels))
435
+ random.shuffle(combined)
436
+ self.filenames, self.labels = zip(*combined)
437
+
438
+ def get_plate(self, filepath):
439
+ filename = os.path.basename(filepath) # Get just the filename from the full path
440
+ return filename.split('_')[0]
441
+
442
+ def _getitem__(self, index):
443
+ label = self.labels[index]
444
+ filename = self.filenames[index]
445
+ img = self.load_image(filename)
446
+ if self.transform:
447
+ img = self.transform(img)
448
+ return img, label, filename
449
+
450
+ class NoClassDataset(Dataset):
451
+ """
452
+ A custom dataset class for handling images without class labels.
453
+
454
+ Args:
455
+ data_dir (str): The directory path where the images are stored.
456
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default is None.
457
+ shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
458
+ load_to_memory (bool, optional): Whether to load all images into memory. Default is False.
459
+
460
+ Attributes:
461
+ data_dir (str): The directory path where the images are stored.
462
+ transform (callable): A function/transform that takes in an PIL image and returns a transformed version.
463
+ shuffle (bool): Whether to shuffle the dataset.
464
+ load_to_memory (bool): Whether to load all images into memory.
465
+ filenames (list): List of file paths of the images.
466
+ images (list): List of loaded images (if load_to_memory is True).
467
+
468
+ Methods:
469
+ load_image: Load an image from the given file path.
470
+ __len__: Get the length of the dataset.
471
+ shuffle_dataset: Shuffle the dataset.
472
+ __getitem__: Get an item (image and its filename) from the dataset.
473
+
474
+ """
475
+
476
+ def _init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
477
+ self.data_dir = data_dir
478
+ self.transform = transform
479
+ self.shuffle = shuffle
480
+ self.load_to_memory = load_to_memory
481
+ self.filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]
482
+ if self.shuffle:
483
+ self.shuffle_dataset()
484
+ if self.load_to_memory:
485
+ self.images = [self.load_image(f) for f in self.filenames]
486
+ #@lru_cache(maxsize=None)
487
+ def load_image(self, img_path):
488
+ img = Image.open(img_path).convert('RGB')
489
+ return img
490
+ def _len__(self):
491
+ return len(self.filenames)
492
+ def shuffle_dataset(self):
493
+ if self.shuffle:
494
+ random.shuffle(self.filenames)
495
+ def _getitem__(self, index):
496
+ if self.load_to_memory:
497
+ img = self.images[index]
498
+ else:
499
+ img = self.load_image(self.filenames[index])
500
+ if self.transform is not None:
501
+ img = self.transform(img)
502
+ else:
503
+ img = ToTensor()(img)
504
+ # Return both the image and its filename
505
+ return img, self.filenames[index]
506
+
507
+ class TarImageDataset(Dataset):
508
+ def _init__(self, tar_path, transform=None):
509
+ self.tar_path = tar_path
510
+ self.transform = transform
511
+
512
+ # Open the tar file just to build the list of members
513
+ with tarfile.open(self.tar_path, 'r') as f:
514
+ self.members = [m for m in f.getmembers() if m.isfile()]
515
+
516
+ def _len__(self):
517
+ return len(self.members)
518
+
519
+ def _getitem__(self, idx):
520
+ with tarfile.open(self.tar_path, 'r') as f:
521
+ m = self.members[idx]
522
+ img_file = f.extractfile(m)
523
+ img = Image.open(BytesIO(img_file.read())).convert("RGB")
524
+
525
+ if self.transform:
526
+ img = self.transform(img)
527
+
528
+ return img, m.name
529
+
530
+ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=False, skip_mode='01', metadata_type='', img_format='.tif'):
531
+ """
532
+ Convert z-stack images to maximum intensity projection (MIP) images.
533
+
534
+ Args:
535
+ src (str): The source directory containing the z-stack images.
536
+ regex (str): The regular expression pattern used to match the filenames of the z-stack images.
537
+ batch_size (int, optional): The number of images to process in each batch. Defaults to 100.
538
+ pick_slice (bool, optional): Whether to pick a specific slice based on the provided skip mode. Defaults to False.
539
+ skip_mode (str, optional): The skip mode used to filter out specific slices. Defaults to '01'.
540
+ metadata_type (str, optional): The type of metadata associated with the images. Defaults to ''.
541
+
542
+ Returns:
543
+ None
544
+ """
545
+
546
+ from .utils import _extract_filename_metadata
547
+
548
+ regular_expression = re.compile(regex)
549
+ images_by_key = defaultdict(list)
550
+ stack_path = os.path.join(src, 'stack')
551
+ if not os.path.exists(stack_path) or (os.path.isdir(stack_path) and len(os.listdir(stack_path)) == 0):
552
+ all_filenames = [filename for filename in os.listdir(src) if filename.endswith(img_format)]
553
+ print(f'All_files:{len(all_filenames)} in {src}')
554
+ for i in range(0, len(all_filenames), batch_size):
555
+ batch_filenames = all_filenames[i:i+batch_size]
556
+ for filename in batch_filenames:
557
+ images_by_key = _extract_filename_metadata(batch_filenames, src, images_by_key, regular_expression, metadata_type, pick_slice, skip_mode)
558
+ if pick_slice:
559
+ for key in images_by_key:
560
+ plate, well, field, channel, mode = key
561
+ max_intensity_slice = max(images_by_key[key], key=lambda x: np.percentile(x, 90))
562
+ mip_image = Image.fromarray(max_intensity_slice)
563
+ output_dir = os.path.join(src, channel)
564
+ os.makedirs(output_dir, exist_ok=True)
565
+ output_filename = f'{plate}_{well}_{field}.tif'
566
+ output_path = os.path.join(output_dir, output_filename)
567
+
568
+ if os.path.exists(output_path):
569
+ print(f'WARNING: A file with the same name already exists at location {output_filename}')
570
+ else:
571
+ mip_image.save(output_path)
572
+ else:
573
+ for key, images in images_by_key.items():
574
+ mip = np.max(np.stack(images), axis=0)
575
+ mip_image = Image.fromarray(mip)
576
+ plate, well, field, channel = key[:4]
577
+ output_dir = os.path.join(src, channel)
578
+ os.makedirs(output_dir, exist_ok=True)
579
+ output_filename = f'{plate}_{well}_{field}.tif'
580
+ output_path = os.path.join(output_dir, output_filename)
581
+
582
+ if os.path.exists(output_path):
583
+ print(f'WARNING: A file with the same name already exists at location {output_filename}')
584
+ else:
585
+ mip_image.save(output_path)
586
+
587
+ images_by_key.clear()
588
+
589
+ # Move original images to a new directory
590
+ valid_exts = [img_format]
591
+ newpath = os.path.join(src, 'orig')
592
+ os.makedirs(newpath, exist_ok=True)
593
+ for filename in os.listdir(src):
594
+ if os.path.splitext(filename)[1] in valid_exts:
595
+ move = os.path.join(newpath, filename)
596
+ if os.path.exists(move):
597
+ print(f'WARNING: A file with the same name already exists at location {move}')
598
+ else:
599
+ shutil.move(os.path.join(src, filename), move)
600
+ return
601
+
602
+ def _merge_file(chan_dirs, stack_dir, file):
603
+ """
604
+ Merge multiple channels into a single stack and save it as a numpy array.
605
+
606
+ Args:
607
+ chan_dirs (list): List of directories containing channel images.
608
+ stack_dir (str): Directory to save the merged stack.
609
+ file (str): File name of the channel image.
610
+
611
+ Returns:
612
+ None
613
+ """
614
+ chan1 = cv2.imread(str(file), -1)
615
+ chan1 = np.expand_dims(chan1, axis=2)
616
+ new_file = stack_dir / (file.stem + '.npy')
617
+ if not new_file.exists():
618
+ stack_dir.mkdir(exist_ok=True)
619
+ channels = [chan1]
620
+ for chan_dir in chan_dirs[1:]:
621
+ img = cv2.imread(str(chan_dir / file.name), -1)
622
+ chan = np.expand_dims(img, axis=2)
623
+ channels.append(chan)
624
+ stack = np.concatenate(channels, axis=2)
625
+ np.save(new_file, stack)
626
+
627
+ def _is_dir_empty(dir_path):
628
+ """
629
+ Check if a directory is empty.
630
+
631
+ Args:
632
+ dir_path (str): The path to the directory.
633
+
634
+ Returns:
635
+ bool: True if the directory is empty, False otherwise.
636
+ """
637
+ return len(os.listdir(dir_path)) == 0
638
+
639
+ def _generate_time_lists(file_list):
640
+ """
641
+ Generate sorted lists of filenames grouped by plate, well, and field.
642
+
643
+ Args:
644
+ file_list (list): A list of filenames.
645
+
646
+ Returns:
647
+ list: A list of sorted file lists, where each file list contains filenames
648
+ belonging to the same plate, well, and field, sorted by timepoint.
649
+ """
650
+ file_dict = defaultdict(list)
651
+ for filename in file_list:
652
+ if filename.endswith('.npy'):
653
+ parts = filename.split('_')
654
+ if len(parts) >= 4:
655
+ plate, well, field = parts[:3]
656
+ try:
657
+ timepoint = int(parts[3].split('.')[0])
658
+ except ValueError:
659
+ continue # Skip file on conversion error
660
+ key = (plate, well, field)
661
+ file_dict[key].append((timepoint, filename))
662
+ else:
663
+ continue # Skip file if not correctly formatted
664
+
665
+ # Sort each list by timepoint, but keep them grouped
666
+ sorted_grouped_filenames = [sorted(files, key=lambda x: x[0]) for files in file_dict.values()]
667
+ # Extract just the filenames from each group
668
+ sorted_file_lists = [[filename for _, filename in group] for group in sorted_grouped_filenames]
669
+
670
+ return sorted_file_lists
671
+
672
+ def _move_to_chan_folder(src, regex, timelapse=False, metadata_type=''):
673
+
674
+ from .utils import _safe_int_convert, _convert_cq1_well_id
675
+
676
+ src_path = src
677
+ src = Path(src)
678
+ valid_exts = ['.tif', '.png']
679
+
680
+ if not (src / 'stack').exists():
681
+ for file in src.iterdir():
682
+ if file.is_file():
683
+ name, ext = file.stem, file.suffix
684
+ if ext in valid_exts:
685
+ metadata = re.match(regex, file.name)
686
+ try:
687
+ try:
688
+ plateID = metadata.group('plateID')
689
+ except:
690
+ plateID = src.name
691
+
692
+ wellID = metadata.group('wellID')
693
+ fieldID = metadata.group('fieldID')
694
+ chanID = metadata.group('chanID')
695
+ timeID = metadata.group('timeID')
696
+
697
+ if wellID[0].isdigit():
698
+ wellID = str(_safe_int_convert(wellID))
699
+ if fieldID[0].isdigit():
700
+ fieldID = str(_safe_int_convert(fieldID))
701
+ if chanID[0].isdigit():
702
+ chanID = str(_safe_int_convert(chanID))
703
+ if timeID[0].isdigit():
704
+ timeID = str(_safe_int_convert(timeID))
705
+
706
+ if metadata_type =='cq1':
707
+ orig_wellID = wellID
708
+ wellID = _convert_cq1_well_id(wellID)
709
+ print(f'Converted Well ID: {orig_wellID} to {wellID}')
710
+
711
+ newname = f"{plateID}_{wellID}_{fieldID}_{timeID if timelapse else ''}{ext}"
712
+ newpath = src / chanID
713
+ move = newpath / newname
714
+ if move.exists():
715
+ print(f'WARNING: A file with the same name already exists at location {move}')
716
+ else:
717
+ newpath.mkdir(exist_ok=True)
718
+ shutil.copy(file, move)
719
+ except:
720
+ print(f"Could not extract information from filename {name}{ext} with {regex}")
721
+
722
+ # Move original images to a new directory
723
+ valid_exts = ['.tif', '.png']
724
+ newpath = os.path.join(src_path, 'orig')
725
+ os.makedirs(newpath, exist_ok=True)
726
+ for filename in os.listdir(src_path):
727
+ if os.path.splitext(filename)[1] in valid_exts:
728
+ move = os.path.join(newpath, filename)
729
+ if os.path.exists(move):
730
+ print(f'WARNING: A file with the same name already exists at location {move}')
731
+ else:
732
+ shutil.move(os.path.join(src, filename), move)
733
+ return
734
+
735
+ def _merge_channels(src, plot=False):
736
+ from .plot import plot_arrays
737
+ """
738
+ Merge the channels in the given source directory and save the merged files in a 'stack' directory.
739
+
740
+ Args:
741
+ src (str): The path to the source directory containing the channel folders.
742
+ plot (bool, optional): Whether to plot the merged arrays. Defaults to False.
743
+
744
+ Returns:
745
+ None
746
+ """
747
+ src = Path(src)
748
+ stack_dir = src / 'stack'
749
+ chan_dirs = [d for d in src.iterdir() if d.is_dir() and d.name in ['01', '02', '03', '04', '00', '1', '2', '3', '4','0']]
750
+
751
+ chan_dirs.sort(key=lambda x: x.name)
752
+ print(f'List of folders in src: {[d.name for d in chan_dirs]}. Single channel folders.')
753
+ start_time = time.time()
754
+
755
+ # First directory and its files
756
+ dir_files = list(chan_dirs[0].iterdir())
757
+
758
+ # Create the 'stack' directory if it doesn't exist
759
+ stack_dir.mkdir(exist_ok=True)
760
+
761
+ if _is_dir_empty(stack_dir):
762
+ with Pool(cpu_count()) as pool:
763
+ merge_func = partial(_merge_file, chan_dirs, stack_dir)
764
+ pool.map(merge_func, dir_files)
765
+
766
+ avg_time = (time.time() - start_time) / len(dir_files)
767
+ print(f'Average Time: {avg_time:.3f} sec')
768
+
769
+ if plot:
770
+ plot_arrays(src+'/stack')
771
+
772
+ return
773
+
774
+ def _mip_all(src, include_first_chan=True):
775
+
776
+ """
777
+ Generate maximum intensity projections (MIPs) for each NumPy array file in the specified directory.
778
+
779
+ Args:
780
+ src (str): The directory path containing the NumPy array files.
781
+ include_first_chan (bool, optional): Whether to include the first channel of the array in the MIP computation.
782
+ Defaults to True.
783
+
784
+ Returns:
785
+ None
786
+ """
787
+
788
+ from .utils import normalize_to_dtype
789
+
790
+ #print('========== generating MIPs ==========')
791
+ # Iterate over each file in the specified directory (src).
792
+ for filename in os.listdir(src):
793
+ # Check if the current file is a NumPy array file (with .npy extension).
794
+ if filename.endswith('.npy'):
795
+ # Load the array from the file.
796
+ array = np.load(os.path.join(src, filename))
797
+ # Normalize the array using custom parameters (q1=2, q2=98).
798
+ array = normalize_to_dtype(array, q1=2, q2=98, percentiles=None)
799
+
800
+ if array.ndim != 3: # Check if the array is not 3-dimensional.
801
+ # Log a message indicating a zero array will be generated due to unexpected dimensions.
802
+ print(f"Generating zero array for {filename} due to unexpected dimensions: {array.shape}")
803
+ # Create a zero array with the same height and width as the original array, but with a single depth layer.
804
+ zeros_array = np.zeros((array.shape[0], array.shape[1], 1))
805
+ # Concatenate the original array with the zero array along the depth axis.
806
+ concatenated = np.concatenate([array, zeros_array], axis=2)
807
+ else:
808
+ if include_first_chan:
809
+ # Compute the MIP for the entire array along the third axis.
810
+ mip = np.max(array, axis=2)
811
+ else:
812
+ # Compute the MIP excluding the first layer of the array along the depth axis.
813
+ mip = np.max(array[:, :, 1:], axis=2)
814
+ # Reshape the MIP to make it 3-dimensional.
815
+ mip = mip[:, :, np.newaxis]
816
+ # Concatenate the MIP with the original array.
817
+ concatenated = np.concatenate([array, mip], axis=2)
818
+ # save
819
+ np.save(os.path.join(src, filename), concatenated)
820
+ return
821
+
822
+ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_size=100):
823
+ """
824
+ Concatenates channel data from multiple files and saves the concatenated data as numpy arrays.
825
+
826
+ Args:
827
+ src (str): The source directory containing the channel data files.
828
+ channels (list): The list of channel indices to be concatenated.
829
+ randomize (bool, optional): Whether to randomize the order of the files. Defaults to True.
830
+ timelapse (bool, optional): Whether the channel data is from a timelapse experiment. Defaults to False.
831
+ batch_size (int, optional): The number of files to be processed in each batch. Defaults to 100.
832
+
833
+ Returns:
834
+ str: The directory path where the concatenated channel data is saved.
835
+ """
836
+ channels = [item for item in channels if item is not None]
837
+ paths = []
838
+ index = 0
839
+ channel_stack_loc = os.path.join(os.path.dirname(src), 'channel_stack')
840
+ os.makedirs(channel_stack_loc, exist_ok=True)
841
+ if timelapse:
842
+ try:
843
+ time_stack_path_lists = _generate_time_lists(os.listdir(src))
844
+ for i, time_stack_list in enumerate(time_stack_path_lists):
845
+ stack_region = []
846
+ filenames_region = []
847
+ for idx, file in enumerate(time_stack_list):
848
+ path = os.path.join(src, file)
849
+ if idx == 0:
850
+ parts = file.split('_')
851
+ name = parts[0]+'_'+parts[1]+'_'+parts[2]
852
+ array = np.load(path)
853
+ array = np.take(array, channels, axis=2)
854
+ stack_region.append(array)
855
+ filenames_region.append(os.path.basename(path))
856
+ clear_output(wait=True)
857
+ print(f'\033[KRegion {i+1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
858
+ stack = np.stack(stack_region)
859
+ save_loc = os.path.join(channel_stack_loc, f'{name}.npz')
860
+ np.savez(save_loc, data=stack, filenames=filenames_region)
861
+ print(save_loc)
862
+ del stack
863
+ except Exception as e:
864
+ print(f"Error processing files, make sure filenames metadata is structured plate_well_field_time.npy")
865
+ print(f"Error: {e}")
866
+ else:
867
+ for file in os.listdir(src):
868
+ if file.endswith('.npy'):
869
+ path = os.path.join(src, file)
870
+ paths.append(path)
871
+ if randomize:
872
+ random.shuffle(paths)
873
+ nr_files = len(paths)
874
+ batch_index = 0 # Added this to name the output files
875
+ stack_ls = []
876
+ filenames_batch = [] # to hold filenames of the current batch
877
+ for i, path in enumerate(paths):
878
+ array = np.load(path)
879
+ array = np.take(array, channels, axis=2)
880
+ stack_ls.append(array)
881
+ filenames_batch.append(os.path.basename(path)) # store the filename
882
+ clear_output(wait=True)
883
+ print(f'\033[KConcatenated: {i+1}/{nr_files} files', end='\r', flush=True)
884
+
885
+ if (i+1) % batch_size == 0 or i+1 == nr_files:
886
+ unique_shapes = {arr.shape[:-1] for arr in stack_ls}
887
+ if len(unique_shapes) > 1:
888
+ max_dims = np.max(np.array(list(unique_shapes)), axis=0)
889
+ clear_output(wait=True)
890
+ print(f'\033[KWarning: arrays with multiple shapes found in batch {i+1}. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
891
+ padded_stack_ls = []
892
+ for arr in stack_ls:
893
+ pad_width = [(0, max_dim - dim) for max_dim, dim in zip(max_dims, arr.shape[:-1])]
894
+ pad_width.append((0, 0))
895
+ padded_arr = np.pad(arr, pad_width)
896
+ padded_stack_ls.append(padded_arr)
897
+ stack = np.stack(padded_stack_ls)
898
+ else:
899
+ stack = np.stack(stack_ls)
900
+ save_loc = os.path.join(channel_stack_loc, f'stack_{batch_index}.npz')
901
+ np.savez(save_loc, data=stack, filenames=filenames_batch)
902
+ batch_index += 1 # increment this after each batch is saved
903
+ del stack # delete to free memory
904
+ stack_ls = [] # empty the list for the next batch
905
+ filenames_batch = [] # empty the filenames list for the next batch
906
+ padded_stack_ls = []
907
+ #print(f'\nAll files concatenated and saved to:{channel_stack_loc}')
908
+ return channel_stack_loc
909
+
910
+ def _get_lists_for_normalization(settings):
911
+ """
912
+ Get lists for normalization based on the provided settings.
913
+
914
+ Args:
915
+ settings (dict): A dictionary containing the settings for normalization.
916
+
917
+ Returns:
918
+ tuple: A tuple containing three lists - backgrounds, signal_to_noise, and signal_thresholds.
919
+ """
920
+
921
+ # Initialize the lists
922
+ backgrounds = []
923
+ signal_to_noise = []
924
+ signal_thresholds = []
925
+
926
+ # Iterate through the channels and append the corresponding values if the channel is not None
927
+ for ch in settings['channels']:
928
+ if ch == settings['nucleus_channel']:
929
+ backgrounds.append(settings['nucleus_background'])
930
+ signal_to_noise.append(settings['nucleus_Signal_to_noise'])
931
+ signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
932
+ elif ch == settings['cell_channel']:
933
+ backgrounds.append(settings['cell_background'])
934
+ signal_to_noise.append(settings['cell_Signal_to_noise'])
935
+ signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
936
+ elif ch == settings['pathogen_channel']:
937
+ backgrounds.append(settings['pathogen_background'])
938
+ signal_to_noise.append(settings['pathogen_Signal_to_noise'])
939
+ signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
940
+ return backgrounds, signal_to_noise, signal_thresholds
941
+
942
+ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lower_quantile=0.01, save_dtype=np.float32, signal_to_noise=[5,5,5], signal_thresholds=[1000,1000,1000], correct_illumination=False):
943
+ """
944
+ Normalize the stack of images.
945
+
946
+ Args:
947
+ src (str): The source directory containing the stack of images.
948
+ backgrounds (list, optional): Background values for each channel. Defaults to [100,100,100].
949
+ remove_background (bool, optional): Whether to remove background values. Defaults to False.
950
+ lower_quantile (float, optional): Lower quantile value for normalization. Defaults to 0.01.
951
+ save_dtype (numpy.dtype, optional): Data type for saving the normalized stack. Defaults to np.float32.
952
+ signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5,5,5].
953
+ signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000,1000,1000].
954
+ correct_illumination (bool, optional): Whether to correct illumination. Defaults to False.
955
+
956
+ Returns:
957
+ None
958
+ """
959
+ paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
960
+ output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
961
+ os.makedirs(output_fldr, exist_ok=True)
962
+ time_ls = []
963
+ for file_index, path in enumerate(paths):
964
+ with np.load(path) as data:
965
+ stack = data['data']
966
+ filenames = data['filenames']
967
+ normalized_stack = np.zeros_like(stack, dtype=stack.dtype)
968
+ file = os.path.basename(path)
969
+ name, _ = os.path.splitext(file)
970
+
971
+ for chan_index, channel in enumerate(range(stack.shape[-1])):
972
+ single_channel = stack[:, :, :, channel]
973
+ background = backgrounds[chan_index]
974
+ signal_threshold = signal_thresholds[chan_index]
975
+ #print(f'signal_threshold:{signal_threshold} in {signal_thresholds} for {chan_index}')
976
+
977
+ signal_2_noise = signal_to_noise[chan_index]
978
+ if remove_background:
979
+ single_channel[single_channel < background] = 0
980
+ if correct_illumination:
981
+ bg = filters.gaussian(single_channel, sigma=50)
982
+ single_channel = single_channel - bg
983
+
984
+ #Calculate the global lower and upper quantiles for non-zero pixels
985
+ non_zero_single_channel = single_channel[single_channel != 0]
986
+ global_lower = np.quantile(non_zero_single_channel, lower_quantile)
987
+ for upper_p in np.linspace(0.98, 1.0, num=100).tolist():
988
+ global_upper = np.quantile(non_zero_single_channel, upper_p)
989
+ if global_upper >= signal_threshold:
990
+ break
991
+
992
+ #Normalize the pixels in each image to the global quantiles and then dtype.
993
+ arr_2d_normalized = np.zeros_like(single_channel, dtype=single_channel.dtype)
994
+ signal_to_noise_ratio_ls = []
995
+ for array_index in range(single_channel.shape[0]):
996
+ start = time.time()
997
+ arr_2d = single_channel[array_index, :, :]
998
+ non_zero_arr_2d = arr_2d[arr_2d != 0]
999
+ if non_zero_arr_2d.size > 0:
1000
+ lower, upper = np.quantile(non_zero_arr_2d, (lower_quantile, upper_p))
1001
+ signal_to_noise_ratio = upper/lower
1002
+ else:
1003
+ signal_to_noise_ratio = 0
1004
+ signal_to_noise_ratio_ls.append(signal_to_noise_ratio)
1005
+ average_stnr = np.mean(signal_to_noise_ratio_ls) if len(signal_to_noise_ratio_ls) > 0 else 0
1006
+
1007
+ if signal_to_noise_ratio > signal_2_noise:
1008
+ arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(global_lower, global_upper))
1009
+ arr_2d_normalized[array_index, :, :] = arr_2d_rescaled
1010
+ else:
1011
+ arr_2d_normalized[array_index, :, :] = arr_2d
1012
+ stop = time.time()
1013
+ duration = (stop - start)*single_channel.shape[0]
1014
+ time_ls.append(duration)
1015
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1016
+ clear_output(wait=True)
1017
+ print(f'\033[KProgress: files {file_index+1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1]-1}, arrays:{array_index+1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec', end='\r', flush=True)
1018
+ normalized_single_channel = exposure.rescale_intensity(arr_2d_normalized, out_range='dtype')
1019
+ normalized_stack[:, :, :, channel] = normalized_single_channel
1020
+ save_loc = output_fldr+'/'+name+'_norm_stack.npz'
1021
+ normalized_stack = normalized_stack.astype(save_dtype)
1022
+ np.savez(save_loc, data=normalized_stack, filenames=filenames)
1023
+ del normalized_stack, single_channel, normalized_single_channel, stack, filenames
1024
+ gc.collect()
1025
+ return print(f'Saved stacks:{output_fldr}')
1026
+
1027
+ def _normalize_timelapse(src, lower_quantile=0.01, save_dtype=np.float32):
1028
+ """
1029
+ Normalize the timelapse data by rescaling the intensity values based on percentiles.
1030
+
1031
+ Args:
1032
+ src (str): The source directory containing the timelapse data files.
1033
+ lower_quantile (float, optional): The lower quantile used to calculate the intensity range. Defaults to 0.01.
1034
+ save_dtype (numpy.dtype, optional): The data type to save the normalized stack. Defaults to np.float32.
1035
+ """
1036
+ paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
1037
+ output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
1038
+ os.makedirs(output_fldr, exist_ok=True)
1039
+
1040
+ for file_index, path in enumerate(paths):
1041
+ with np.load(path) as data:
1042
+ stack = data['data']
1043
+ filenames = data['filenames']
1044
+
1045
+ normalized_stack = np.zeros_like(stack, dtype=save_dtype)
1046
+ file = os.path.basename(path)
1047
+ name, _ = os.path.splitext(file)
1048
+
1049
+ for chan_index in range(stack.shape[-1]):
1050
+ single_channel = stack[:, :, :, chan_index]
1051
+
1052
+ for array_index in range(single_channel.shape[0]):
1053
+ arr_2d = single_channel[array_index]
1054
+ # Calculate the 1% and 98% percentiles for this specific image
1055
+ q_low = np.percentile(arr_2d[arr_2d != 0], 2)
1056
+ q_high = np.percentile(arr_2d[arr_2d != 0], 98)
1057
+
1058
+ # Rescale intensity based on the calculated percentiles to fill the dtype range
1059
+ arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(q_low, q_high), out_range='dtype')
1060
+ normalized_stack[array_index, :, :, chan_index] = arr_2d_rescaled
1061
+
1062
+ print(f'Progress: files {file_index+1}/{len(paths)}, channels:{chan_index+1}/{stack.shape[-1]}, arrays:{array_index+1}/{single_channel.shape[0]}', end='\r')
1063
+
1064
+ save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
1065
+ np.savez(save_loc, data=normalized_stack, filenames=filenames)
1066
+
1067
+ del normalized_stack, stack, filenames
1068
+ gc.collect()
1069
+
1070
+ print(f'\nSaved normalized stacks: {output_fldr}')
1071
+
1072
+
1073
+
1074
+ def _create_movies_from_npy_per_channel(src, fps=10):
1075
+ """
1076
+ Create movies from numpy files per channel.
1077
+
1078
+ Args:
1079
+ src (str): The source directory containing the numpy files.
1080
+ fps (int, optional): Frames per second for the output movies. Defaults to 10.
1081
+ """
1082
+
1083
+ from .timelapse import _npz_to_movie
1084
+
1085
+ master_path = os.path.dirname(src)
1086
+ save_path = os.path.join(master_path,'movies')
1087
+ os.makedirs(save_path, exist_ok=True)
1088
+ # Organize files by plate, well, field
1089
+ files = [f for f in os.listdir(src) if f.endswith('.npy')]
1090
+ organized_files = {}
1091
+ for f in files:
1092
+ match = re.match(r'(\w+)_(\w+)_(\w+)_(\d+)\.npy', f)
1093
+ if match:
1094
+ plate, well, field, time = match.groups()
1095
+ key = (plate, well, field)
1096
+ if key not in organized_files:
1097
+ organized_files[key] = []
1098
+ organized_files[key].append((int(time), os.path.join(src, f)))
1099
+ for key, file_list in organized_files.items():
1100
+ plate, well, field = key
1101
+ file_list.sort(key=lambda x: x[0])
1102
+ arrays = []
1103
+ filenames = []
1104
+ for f in file_list:
1105
+ array = np.load(f[1])
1106
+ #if array.dtype != np.uint8:
1107
+ # array = ((array - array.min()) / (array.max() - array.min()) * 255).astype(np.uint8)
1108
+ arrays.append(array)
1109
+ filenames.append(os.path.basename(f[1]))
1110
+ arrays = np.stack(arrays, axis=0)
1111
+ for channel in range(arrays.shape[-1]):
1112
+ # Extract the current channel for all time points
1113
+ channel_arrays = arrays[..., channel]
1114
+ # Flatten the channel data to compute global percentiles
1115
+ channel_data_flat = channel_arrays.reshape(-1)
1116
+ p1, p99 = np.percentile(channel_data_flat, [1, 99])
1117
+ # Normalize and rescale each array in the channel
1118
+ normalized_channel_arrays = [(np.clip((arr - p1) / (p99 - p1), 0, 1) * 255).astype(np.uint8) for arr in channel_arrays]
1119
+ # Convert the list of 2D arrays into a list of 3D arrays with a single channel
1120
+ normalized_channel_arrays_3d = [arr[..., np.newaxis] for arr in normalized_channel_arrays]
1121
+ # Save as movie for the current channel
1122
+ channel_save_path = os.path.join(save_path, f'{plate}_{well}_{field}_channel_{channel}.mp4')
1123
+ _npz_to_movie(normalized_channel_arrays_3d, filenames, channel_save_path, fps)
1124
+
1125
+ def preprocess_img_data(settings):
1126
+
1127
+ from .plot import plot_arrays, _plot_4D_arrays
1128
+
1129
+ """
1130
+ Preprocesses image data by converting z-stack images to maximum intensity projection (MIP) images.
1131
+
1132
+ Args:
1133
+ src (str): The source directory containing the z-stack images.
1134
+ metadata_type (str, optional): The type of metadata associated with the images. Defaults to 'cellvoyager'.
1135
+ custom_regex (str, optional): The custom regular expression pattern used to match the filenames of the z-stack images. Defaults to None.
1136
+ cmap (str, optional): The colormap used for plotting. Defaults to 'inferno'.
1137
+ figuresize (int, optional): The size of the figure for plotting. Defaults to 15.
1138
+ normalize (bool, optional): Whether to normalize the images. Defaults to False.
1139
+ nr (int, optional): The number of images to preprocess. Defaults to 1.
1140
+ plot (bool, optional): Whether to plot the images. Defaults to False.
1141
+ mask_channels (list, optional): The channels to use for masking. Defaults to [0, 1, 2].
1142
+ batch_size (list, optional): The number of images to process in each batch. Defaults to [100, 100, 100].
1143
+ timelapse (bool, optional): Whether the images are from a timelapse experiment. Defaults to False.
1144
+ remove_background (bool, optional): Whether to remove the background from the images. Defaults to False.
1145
+ backgrounds (int, optional): The number of background images to use for background removal. Defaults to 100.
1146
+ lower_quantile (float, optional): The lower quantile used for background removal. Defaults to 0.01.
1147
+ save_dtype (type, optional): The data type used for saving the preprocessed images. Defaults to np.float32.
1148
+ correct_illumination (bool, optional): Whether to correct the illumination of the images. Defaults to False.
1149
+ randomize (bool, optional): Whether to randomize the order of the images. Defaults to True.
1150
+ all_to_mip (bool, optional): Whether to convert all images to MIP. Defaults to False.
1151
+ pick_slice (bool, optional): Whether to pick a specific slice based on the provided skip mode. Defaults to False.
1152
+ skip_mode (str, optional): The skip mode used to filter out specific slices. Defaults to '01'.
1153
+ settings (dict, optional): Additional settings for preprocessing. Defaults to {}.
1154
+
1155
+ Returns:
1156
+ None
1157
+ """
1158
+ src = settings['src']
1159
+ valid_ext = ['tif', 'tiff', 'png', 'jpeg']
1160
+ files = os.listdir(src)
1161
+ extensions = [file.split('.')[-1] for file in files]
1162
+ extension_counts = Counter(extensions)
1163
+ most_common_extension = extension_counts.most_common(1)[0][0]
1164
+
1165
+ # Check if the most common extension is one of the specified image formats
1166
+ if most_common_extension in valid_ext:
1167
+ img_format = f'.{most_common_extension}'
1168
+ print(f'Found {extension_counts[most_common_extension]} {most_common_extension} files')
1169
+ else:
1170
+ print(f'Could not find any {valid_ext} files in {src} only found {extension_counts[0]}')
1171
+ return
1172
+
1173
+ cmap = 'inferno'
1174
+ figuresize = 20
1175
+ normalize = True
1176
+ save_dtype = 'uint16'
1177
+ correct_illumination = False
1178
+
1179
+ mask_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
1180
+ backgrounds = [settings['nucleus_background'], settings['pathogen_background'], settings['cell_background']]
1181
+
1182
+ metadata_type = settings['metadata_type']
1183
+ custom_regex = settings['custom_regex']
1184
+ nr = settings['examples_to_plot']
1185
+ plot = settings['plot']
1186
+ batch_size = settings['batch_size']
1187
+ timelapse = settings['timelapse']
1188
+ remove_background = settings['remove_background']
1189
+ lower_quantile = settings['lower_quantile']
1190
+ randomize = settings['randomize']
1191
+ all_to_mip = settings['all_to_mip']
1192
+ pick_slice = settings['pick_slice']
1193
+ skip_mode = settings['skip_mode']
1194
+
1195
+ if metadata_type == 'cellvoyager':
1196
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1197
+ elif metadata_type == 'cq1':
1198
+ regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1199
+ elif metadata_type == 'nikon':
1200
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1201
+ elif metadata_type == 'zeis':
1202
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1203
+ elif metadata_type == 'leica':
1204
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1205
+ elif metadata_type == 'custom':
1206
+ regex = f'({custom_regex}){img_format}'
1207
+
1208
+ print(f'regex mode:{metadata_type} regex:{regex}')
1209
+
1210
+ if not os.path.exists(src+'/stack'):
1211
+ if timelapse:
1212
+ _move_to_chan_folder(src, regex, timelapse, metadata_type)
1213
+ else:
1214
+ #_z_to_mip(src, regex, batch_size, pick_slice, skip_mode, metadata_type, img_format)
1215
+ _rename_and_organize_image_files(src, regex, batch_size, pick_slice, skip_mode, metadata_type, img_format)
1216
+
1217
+ #Make sure no batches will be of only one image
1218
+ all_imgs = len(src+'/stack')
1219
+ full_batches = all_imgs // batch_size
1220
+ last_batch_size = all_imgs % batch_size
1221
+
1222
+ # Check if the last batch is of size 1
1223
+ if last_batch_size == 1:
1224
+ # If there's only one batch and its size is 1, it's also an issue
1225
+ if full_batches == 0:
1226
+ raise ValueError("Only one batch of size 1 detected. Adjust the batch size.")
1227
+ # If the last batch is of size 1, merge it with the second last batch
1228
+ elif full_batches > 0:
1229
+ raise ValueError("Last batch of size 1 detected. Adjust the batch size.")
1230
+
1231
+ _merge_channels(src, plot=False)
1232
+ if timelapse:
1233
+ _create_movies_from_npy_per_channel(src+'/stack', fps=2)
1234
+
1235
+ if plot:
1236
+ print(f'plotting {nr} images from {src}/stack')
1237
+ plot_arrays(src+'/stack', figuresize, cmap, nr=nr, normalize=normalize)
1238
+ if all_to_mip:
1239
+ _mip_all(src+'/stack')
1240
+ if plot:
1241
+ print(f'plotting {nr} images from {src}/stack')
1242
+ plot_arrays(src+'/stack', figuresize, cmap, nr=nr, normalize=normalize)
1243
+ #nr_of_stacks = len(src+'/channel_stack')
1244
+
1245
+ _concatenate_channel(src+'/stack',
1246
+ channels=mask_channels,
1247
+ randomize=randomize,
1248
+ timelapse=timelapse,
1249
+ batch_size=batch_size)
1250
+
1251
+ if plot:
1252
+ print(f'plotting {nr} images from {src}/channel_stack')
1253
+ _plot_4D_arrays(src+'/channel_stack', figuresize, cmap, nr_npz=1, nr=nr)
1254
+ nr_of_chan_stacks = len(src+'/channel_stack')
1255
+
1256
+ backgrounds, signal_to_noise, signal_thresholds = _get_lists_for_normalization(settings=settings)
1257
+
1258
+ if not timelapse:
1259
+ _normalize_stack(src+'/channel_stack',
1260
+ backgrounds=backgrounds,
1261
+ lower_quantile=lower_quantile,
1262
+ save_dtype=save_dtype,
1263
+ signal_thresholds=signal_thresholds,
1264
+ correct_illumination=correct_illumination,
1265
+ signal_to_noise=signal_to_noise,
1266
+ remove_background=remove_background)
1267
+ else:
1268
+ _normalize_timelapse(src+'/channel_stack', lower_quantile=lower_quantile, save_dtype=np.float32)
1269
+
1270
+ if plot:
1271
+ _plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
1272
+
1273
+ return
1274
+
1275
+ def _check_masks(batch, batch_filenames, output_folder):
1276
+ """
1277
+ Check the masks in a batch and filter out the ones that already exist in the output folder.
1278
+
1279
+ Args:
1280
+ batch (list): List of masks.
1281
+ batch_filenames (list): List of filenames corresponding to the masks.
1282
+ output_folder (str): Path to the output folder.
1283
+
1284
+ Returns:
1285
+ tuple: A tuple containing the filtered batch (numpy array) and the filtered filenames (list).
1286
+ """
1287
+ # Create a mask for filenames that are already present in the output folder
1288
+ existing_files_mask = [not os.path.isfile(os.path.join(output_folder, filename)) for filename in batch_filenames]
1289
+
1290
+ # Use the mask to filter the batch and batch_filenames
1291
+ filtered_batch = [b for b, exists in zip(batch, existing_files_mask) if exists]
1292
+ filtered_filenames = [f for f, exists in zip(batch_filenames, existing_files_mask) if exists]
1293
+
1294
+ return np.array(filtered_batch), filtered_filenames
1295
+
1296
+
1297
+ def _get_avg_object_size(masks):
1298
+ """
1299
+ Calculate the average size of objects in a list of masks.
1300
+
1301
+ Parameters:
1302
+ masks (list): A list of masks representing objects.
1303
+
1304
+ Returns:
1305
+ float: The average size of objects in the masks. Returns 0 if no objects are found.
1306
+ """
1307
+ object_areas = []
1308
+ for mask in masks:
1309
+ # Check if the mask is a 2D or 3D array and is not empty
1310
+ if mask.ndim in [2, 3] and np.any(mask):
1311
+ properties = measure.regionprops(mask)
1312
+ object_areas += [prop.area for prop in properties]
1313
+ else:
1314
+ if not np.any(mask):
1315
+ print(f"Mask is empty. ")
1316
+ if not mask.ndim in [2, 3]:
1317
+ print(f"Mask is not in the correct format. dim: {mask.ndim}")
1318
+ continue
1319
+
1320
+ if object_areas:
1321
+ return sum(object_areas) / len(object_areas)
1322
+ else:
1323
+ return 0 # Return 0 if no objects are found
1324
+
1325
+ def _save_figure_v1(fig, src, text, dpi=300, ):
1326
+ """
1327
+ Save a figure to a specified location.
1328
+
1329
+ Parameters:
1330
+ fig (matplotlib.figure.Figure): The figure to be saved.
1331
+ src (str): The source file path.
1332
+ text (str): The text to be included in the figure name.
1333
+ dpi (int, optional): The resolution of the saved figure. Defaults to 300.
1334
+ """
1335
+ save_folder = os.path.dirname(src)
1336
+ obj_type = os.path.basename(src)
1337
+ name = os.path.basename(save_folder)
1338
+ save_folder = os.path.join(save_folder, 'figure')
1339
+ os.makedirs(save_folder, exist_ok=True)
1340
+ fig_name = f'{obj_type}_{name}_{text}.pdf'
1341
+ save_location = os.path.join(save_folder, fig_name)
1342
+ fig.savefig(save_location, bbox_inches='tight', dpi=dpi)
1343
+ print(f'Saved single cell figure: {save_location}')
1344
+ plt.close()
1345
+
1346
+ def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
1347
+ """
1348
+ Save a figure to a specified location.
1349
+
1350
+ Parameters:
1351
+ fig (matplotlib.figure.Figure): The figure to be saved.
1352
+ src (str): The source file path.
1353
+ text (str): The text to be included in the figure name.
1354
+ dpi (int, optional): The resolution of the saved figure. Defaults to 300.
1355
+ """
1356
+ save_folder = os.path.dirname(src)
1357
+ obj_type = os.path.basename(src)
1358
+ name = os.path.basename(save_folder)
1359
+ save_folder = os.path.join(save_folder, 'figure')
1360
+ os.makedirs(save_folder, exist_ok=True)
1361
+ fig_name = f'{obj_type}_{name}_{text}.pdf'
1362
+ save_location = os.path.join(save_folder, fig_name)
1363
+ fig.savefig(save_location, bbox_inches='tight', dpi=dpi)
1364
+ clear_output(wait=True)
1365
+ print(f'\033[KProgress: {i}/{all_folders}, Saved single cell figure: {os.path.basename(save_location)}', end='\r', flush=True)
1366
+ # Close and delete the figure to free up memory
1367
+ plt.close(fig)
1368
+ del fig
1369
+ gc.collect()
1370
+
1371
+ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'parasite', 'png_list']):
1372
+ """
1373
+ Reads and joins tables from a SQLite database.
1374
+
1375
+ Args:
1376
+ db_path (str): The path to the SQLite database file.
1377
+ table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', 'parasite', 'png_list'].
1378
+
1379
+ Returns:
1380
+ pandas.DataFrame: The joined DataFrame containing the data from the specified tables, or None if an error occurs.
1381
+ """
1382
+ conn = sqlite3.connect(db_path)
1383
+ dataframes = {}
1384
+ for table_name in table_names:
1385
+ try:
1386
+ dataframes[table_name] = pd.read_sql(f"SELECT * FROM {table_name}", conn)
1387
+ except (sqlite3.OperationalError, pd.io.sql.DatabaseError) as e:
1388
+ print(f"Table {table_name} not found in the database.")
1389
+ print(e)
1390
+ conn.close()
1391
+ if 'png_list' in dataframes:
1392
+ png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row', 'col']].copy()
1393
+ png_list_df['cell_id'] = png_list_df['cell_id'].str[1:].astype(int)
1394
+ png_list_df.rename(columns={'cell_id': 'object_label'}, inplace=True)
1395
+ if 'cell' in dataframes:
1396
+ join_cols = ['object_label', 'plate', 'row', 'col']
1397
+ dataframes['cell'] = pd.merge(dataframes['cell'], png_list_df, on=join_cols, how='left')
1398
+ else:
1399
+ print("Cell table not found. Cannot join with png_list.")
1400
+ return None
1401
+ for entity in ['nucleus', 'pathogen', 'parasite']:
1402
+ if entity in dataframes:
1403
+ numeric_cols = dataframes[entity].select_dtypes(include=[np.number]).columns.tolist()
1404
+ non_numeric_cols = dataframes[entity].select_dtypes(exclude=[np.number]).columns.tolist()
1405
+ agg_dict = {col: 'mean' for col in numeric_cols}
1406
+ agg_dict.update({col: 'first' for col in non_numeric_cols if col not in ['cell_id', 'prcf']})
1407
+ grouping_cols = ['cell_id', 'prcf']
1408
+ agg_df = dataframes[entity].groupby(grouping_cols).agg(agg_dict)
1409
+ agg_df['count_' + entity] = dataframes[entity].groupby(grouping_cols).size()
1410
+ dataframes[entity] = agg_df
1411
+ joined_df = None
1412
+ if 'cell' in dataframes:
1413
+ joined_df = dataframes['cell']
1414
+ if 'cytoplasm' in dataframes:
1415
+ joined_df = pd.merge(joined_df, dataframes['cytoplasm'], on=['object_label', 'prcf'], how='left', suffixes=('', '_cytoplasm'))
1416
+ for entity in ['nucleus', 'pathogen']:
1417
+ if entity in dataframes:
1418
+ joined_df = pd.merge(joined_df, dataframes[entity], left_on=['object_label', 'prcf'], right_index=True, how='left', suffixes=('', f'_{entity}'))
1419
+ else:
1420
+ print("Cell table not found. Cannot proceed with joining.")
1421
+ return None
1422
+ return joined_df
1423
+
1424
+ def _save_settings_to_db(settings):
1425
+ """
1426
+ Save the settings dictionary to a SQLite database.
1427
+
1428
+ Args:
1429
+ settings (dict): A dictionary containing the settings.
1430
+
1431
+ Returns:
1432
+ None
1433
+ """
1434
+ # Convert the settings dictionary into a DataFrame
1435
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
1436
+ # Convert all values in the 'setting_value' column to strings
1437
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
1438
+ display(settings_df)
1439
+ # Determine the directory path
1440
+ src = os.path.dirname(settings['input_folder'])
1441
+ directory = f'{src}/measurements'
1442
+ # Create the directory if it doesn't exist
1443
+ os.makedirs(directory, exist_ok=True)
1444
+ # Database connection and saving the settings DataFrame
1445
+ conn = sqlite3.connect(f'{directory}/measurements.db', timeout=5)
1446
+ settings_df.to_sql('settings', conn, if_exists='replace', index=False) # Replace the table if it already exists
1447
+ conn.close()
1448
+
1449
+ def _save_mask_timelapse_as_gif(masks, tracks_df, path, cmap, norm, filenames):
1450
+ """
1451
+ Save a timelapse animation of masks as a GIF.
1452
+
1453
+ Parameters:
1454
+ - masks (list): List of mask frames.
1455
+ - tracks_df (pandas.DataFrame): DataFrame containing track information.
1456
+ - path (str): Path to save the GIF file.
1457
+ - cmap (str or matplotlib.colors.Colormap): Colormap for displaying the masks.
1458
+ - norm (matplotlib.colors.Normalize): Normalization for the colormap.
1459
+ - filenames (list): List of filenames corresponding to each mask frame.
1460
+
1461
+ Returns:
1462
+ None
1463
+ """
1464
+ # Set the face color for the figure to black
1465
+ fig, ax = plt.subplots(figsize=(50, 50), facecolor='black')
1466
+ ax.set_facecolor('black') # Set the axes background color to black
1467
+ ax.axis('off') # Turn off the axis
1468
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) # Adjust the subplot edges
1469
+
1470
+ filename_text_obj = None # Initialize a variable to keep track of the text object
1471
+
1472
+ def _update(frame):
1473
+ """
1474
+ Update the frame of the animation.
1475
+
1476
+ Parameters:
1477
+ - frame (int): The frame number to update.
1478
+
1479
+ Returns:
1480
+ None
1481
+ """
1482
+ nonlocal filename_text_obj # Reference the nonlocal variable to update it
1483
+ if filename_text_obj is not None:
1484
+ filename_text_obj.remove() # Remove the previous text object if it exists
1485
+
1486
+ ax.clear() # Clear the axis to draw the new frame
1487
+ ax.axis('off') # Ensure axis is still off after clearing
1488
+ current_mask = masks[frame]
1489
+ ax.imshow(current_mask, cmap=cmap, norm=norm)
1490
+ ax.set_title(f'Frame: {frame}', fontsize=24, color='white')
1491
+
1492
+ # Add the filename as text on the figure
1493
+ filename_text = filenames[frame] # Get the filename corresponding to the current frame
1494
+ filename_text_obj = fig.text(0.5, 0.01, filename_text, ha='center', va='center', fontsize=20, color='white') # Adjust text position, size, and color as needed
1495
+
1496
+ # Annotate each object with its label number from the mask
1497
+ for label_value in np.unique(current_mask):
1498
+ if label_value == 0: continue # Skip background
1499
+ y, x = np.mean(np.where(current_mask == label_value), axis=1)
1500
+ ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
1501
+
1502
+ # Overlay tracks
1503
+ for track in tracks_df['track_id'].unique():
1504
+ _track = tracks_df[tracks_df['track_id'] == track]
1505
+ ax.plot(_track['x'], _track['y'], '-w', linewidth=1)
1506
+
1507
+ anim = FuncAnimation(fig, _update, frames=len(masks), blit=False)
1508
+ anim.save(path, writer='pillow', fps=2, dpi=80) # Adjust DPI for size/quality
1509
+ plt.close(fig)
1510
+ print(f'Saved timelapse to {path}')
1511
+
1512
+ def _save_object_counts_to_database(arrays, object_type, file_names, db_path, added_string):
1513
+ """
1514
+ Save the counts of unique objects in masks to a SQLite database.
1515
+
1516
+ Args:
1517
+ arrays (List[np.ndarray]): List of masks.
1518
+ object_type (str): Type of object.
1519
+ file_names (List[str]): List of file names corresponding to the masks.
1520
+ db_path (str): Path to the SQLite database.
1521
+ added_string (str): Additional string to append to the count type.
1522
+
1523
+ Returns:
1524
+ None
1525
+ """
1526
+ def _count_objects(mask):
1527
+ """Count unique objects in a mask, assuming 0 is the background."""
1528
+ unique, counts = np.unique(mask, return_counts=True)
1529
+ # Assuming 0 is the background label, remove it from the count
1530
+ if unique[0] == 0:
1531
+ return len(unique) - 1
1532
+ return len(unique)
1533
+
1534
+ records = []
1535
+ for mask, file_name in zip(arrays, file_names):
1536
+ object_count = _count_objects(mask)
1537
+ count_type = f"{object_type}{added_string}"
1538
+
1539
+ # Append a tuple of (file_name, count_type, object_count) to the records list
1540
+ records.append((file_name, count_type, object_count))
1541
+
1542
+ # Connect to the database
1543
+ conn = sqlite3.connect(db_path)
1544
+ cursor = conn.cursor()
1545
+
1546
+ # Create the table if it doesn't exist
1547
+ cursor.execute('''
1548
+ CREATE TABLE IF NOT EXISTS object_counts (
1549
+ file_name TEXT,
1550
+ count_type TEXT,
1551
+ object_count INTEGER,
1552
+ PRIMARY KEY (file_name, count_type)
1553
+ )
1554
+ ''')
1555
+
1556
+ # Batch insert or update the object counts
1557
+ cursor.executemany('''
1558
+ INSERT INTO object_counts (file_name, count_type, object_count)
1559
+ VALUES (?, ?, ?)
1560
+ ON CONFLICT(file_name, count_type) DO UPDATE SET
1561
+ object_count = excluded.object_count
1562
+ ''', records)
1563
+
1564
+ # Commit changes and close the database connection
1565
+ conn.commit()
1566
+ conn.close()
1567
+
1568
+ def _create_database(db_path):
1569
+ """
1570
+ Creates a SQLite database at the specified path.
1571
+
1572
+ Args:
1573
+ db_path (str): The path where the database should be created.
1574
+
1575
+ Returns:
1576
+ None
1577
+ """
1578
+ conn = None
1579
+ try:
1580
+ conn = sqlite3.connect(db_path)
1581
+ except Exception as e:
1582
+ print(e)
1583
+ finally:
1584
+ if conn:
1585
+ conn.close()
1586
+
1587
+ def _load_and_concatenate_arrays(src, channels, cell_chann_dim, nucleus_chann_dim, pathogen_chann_dim):
1588
+ """
1589
+ Load and concatenate arrays from multiple folders.
1590
+
1591
+ Args:
1592
+ src (str): The source directory containing the arrays.
1593
+ channels (list): List of channel indices to select from the arrays.
1594
+ cell_chann_dim (int): Dimension of the cell channel.
1595
+ nucleus_chann_dim (int): Dimension of the nucleus channel.
1596
+ pathogen_chann_dim (int): Dimension of the pathogen channel.
1597
+
1598
+ Returns:
1599
+ None
1600
+ """
1601
+ folder_paths = [os.path.join(src+'/stack')]
1602
+
1603
+ if cell_chann_dim is not None or os.path.exists(os.path.join(src, 'norm_channel_stack', 'cell_mask_stack')):
1604
+ folder_paths = folder_paths + [os.path.join(src, 'norm_channel_stack','cell_mask_stack')]
1605
+ if nucleus_chann_dim is not None or os.path.exists(os.path.join(src, 'norm_channel_stack', 'nucleus_mask_stack')):
1606
+ folder_paths = folder_paths + [os.path.join(src, 'norm_channel_stack','nucleus_mask_stack')]
1607
+ if pathogen_chann_dim is not None or os.path.exists(os.path.join(src, 'norm_channel_stack', 'pathogen_mask_stack')):
1608
+ folder_paths = folder_paths + [os.path.join(src, 'norm_channel_stack','pathogen_mask_stack')]
1609
+
1610
+ output_folder = src+'/merged'
1611
+ reference_folder = folder_paths[0]
1612
+ os.makedirs(output_folder, exist_ok=True)
1613
+
1614
+ count=0
1615
+ all_imgs = len(os.listdir(reference_folder))
1616
+
1617
+ # Iterate through each file in the reference folder
1618
+ for filename in os.listdir(reference_folder):
1619
+
1620
+ stack_ls = []
1621
+ array_path = []
1622
+
1623
+ if filename.endswith('.npy'):
1624
+ count+=1
1625
+ # Initialize the concatenated array with the array from the reference folder
1626
+ concatenated_array = np.load(os.path.join(reference_folder, filename))
1627
+ if channels is not None:
1628
+ concatenated_array = np.take(concatenated_array, channels, axis=2)
1629
+ stack_ls.append(concatenated_array)
1630
+ # For each of the other folders, load the array and concatenate it
1631
+ for folder in folder_paths[1:]:
1632
+ array_path = os.path.join(folder, filename)
1633
+ if os.path.isfile(array_path):
1634
+ array = np.load(array_path)
1635
+ if array.ndim == 2:
1636
+ array = np.expand_dims(array, axis=-1) # add an extra dimension if the array is 2D
1637
+ stack_ls.append(array)
1638
+
1639
+ stack_ls = [np.expand_dims(arr, axis=-1) if arr.ndim == 2 else arr for arr in stack_ls]
1640
+ unique_shapes = {arr.shape[:-1] for arr in stack_ls}
1641
+ if len(unique_shapes) > 1:
1642
+ #max_dims = np.max(np.array(list(unique_shapes)), axis=0)
1643
+ # Determine the maximum length of tuples in unique_shapes
1644
+ max_tuple_length = max(len(shape) for shape in unique_shapes)
1645
+ # Pad shorter tuples with zeros to make them all the same length
1646
+ padded_shapes = [shape + (0,) * (max_tuple_length - len(shape)) for shape in unique_shapes]
1647
+ # Now create a NumPy array and find the maximum dimensions
1648
+ max_dims = np.max(np.array(padded_shapes), axis=0)
1649
+ clear_output(wait=True)
1650
+ print(f'\033[KWarning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
1651
+ padded_stack_ls = []
1652
+ for arr in stack_ls:
1653
+ pad_width = [(0, max_dim - dim) for max_dim, dim in zip(max_dims, arr.shape[:-1])]
1654
+ pad_width.append((0, 0))
1655
+ padded_arr = np.pad(arr, pad_width)
1656
+ padded_stack_ls.append(padded_arr)
1657
+ # Concatenate the padded arrays along the channel dimension (last dimension)
1658
+ stack = np.concatenate(padded_stack_ls, axis=-1)
1659
+
1660
+ else:
1661
+ stack = np.concatenate(stack_ls, axis=-1)
1662
+
1663
+ if stack.shape[-1] > concatenated_array.shape[-1]:
1664
+ output_path = os.path.join(output_folder, filename)
1665
+ np.save(output_path, stack)
1666
+
1667
+ clear_output(wait=True)
1668
+ #print(f'\033[KFiles merged: {count}/{all_imgs}', end='\r', flush=True)
1669
+ return
1670
+
1671
+ def _read_db(db_loc, tables):
1672
+ """
1673
+ Read data from a SQLite database.
1674
+
1675
+ Parameters:
1676
+ - db_loc (str): The location of the SQLite database file.
1677
+ - tables (list): A list of table names to read from.
1678
+
1679
+ Returns:
1680
+ - dfs (list): A list of pandas DataFrames, each containing the data from a table.
1681
+ """
1682
+ conn = sqlite3.connect(db_loc)
1683
+ dfs = []
1684
+ for table in tables:
1685
+ query = f'SELECT * FROM {table}'
1686
+ df = pd.read_sql_query(query, conn)
1687
+ dfs.append(df)
1688
+ conn.close()
1689
+ return dfs
1690
+
1691
+ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=False, include_multiinfected=False, include_noninfected=False):
1692
+ """
1693
+ Read and merge data from SQLite databases and perform data preprocessing.
1694
+
1695
+ Parameters:
1696
+ - locs (list): A list of file paths to the SQLite database files.
1697
+ - tables (list): A list of table names to read from the databases.
1698
+ - verbose (bool): Whether to print verbose output. Default is False.
1699
+ - include_multinucleated (bool): Whether to include multinucleated cells. Default is False.
1700
+ - include_multiinfected (bool): Whether to include cells with multiple infections. Default is False.
1701
+ - include_noninfected (bool): Whether to include non-infected cells. Default is False.
1702
+
1703
+ Returns:
1704
+ - merged_df (pandas.DataFrame): The merged and preprocessed dataframe.
1705
+ - obj_df_ls (list): A list of pandas DataFrames, each containing the data for a specific object type.
1706
+ """
1707
+
1708
+ from .utils import _split_data
1709
+
1710
+ #Extract plate DataFrames
1711
+ all_dfs = []
1712
+ for loc in locs:
1713
+ db_dfs = _read_db(loc, tables)
1714
+ all_dfs.append(db_dfs)
1715
+
1716
+ #Extract Tables from DataFrames and concatinate rows
1717
+ for i, dfs in enumerate(all_dfs):
1718
+ if 'cell' in tables:
1719
+ cell = dfs[0]
1720
+ print(f'plate: {i+1} cells:{len(cell)}')
1721
+
1722
+ if 'nucleus' in tables:
1723
+ nucleus = dfs[1]
1724
+ print(f'plate: {i+1} nucleus:{len(nucleus)} ')
1725
+
1726
+ if 'pathogen' in tables:
1727
+ pathogen = dfs[2]
1728
+
1729
+ print(f'plate: {i+1} pathogens:{len(pathogen)}')
1730
+ if 'cytoplasm' in tables:
1731
+ if not 'pathogen' in tables:
1732
+ cytoplasm = dfs[2]
1733
+ else:
1734
+ cytoplasm = dfs[3]
1735
+ print(f'plate: {i+1} cytoplasms: {len(cytoplasm)}')
1736
+
1737
+ if i > 0:
1738
+ if 'cell' in tables:
1739
+ cells = pd.concat([cells, cell], axis = 0)
1740
+ if 'nucleus' in tables:
1741
+ nucleus = pd.concat([nucleus, nucleus], axis = 0)
1742
+ if 'pathogen' in tables:
1743
+ pathogens = pd.concat([pathogens, pathogen], axis = 0)
1744
+ if 'cytoplasm' in tables:
1745
+ cytoplasms = pd.concat([cytoplasms, cytoplasm], axis = 0)
1746
+ else:
1747
+ if 'cell' in tables:
1748
+ cells = cell.copy()
1749
+ if 'nucleus' in tables:
1750
+ nucleus = nucleus.copy()
1751
+ if 'pathogen' in tables:
1752
+ pathogens = pathogen.copy()
1753
+ if 'cytoplasm' in tables:
1754
+ cytoplasms = cytoplasm.copy()
1755
+
1756
+ #Add an o in front of all object and cell lables to convert them to strings
1757
+ if 'cell' in tables:
1758
+ cells = cells.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
1759
+ cells = cells.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
1760
+ cells_g_df, metadata = _split_data(cells, 'prcfo', 'object_label')
1761
+ print(f'cells: {len(cells)}')
1762
+ print(f'cells grouped: {len(cells_g_df)}')
1763
+ if 'cytoplasm' in tables:
1764
+ cytoplasms = cytoplasms.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
1765
+ cytoplasms = cytoplasms.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
1766
+ cytoplasms_g_df, _ = _split_data(cytoplasms, 'prcfo', 'object_label')
1767
+ merged_df = cells_g_df.merge(cytoplasms_g_df, left_index=True, right_index=True)
1768
+ print(f'cytoplasms: {len(cytoplasms)}')
1769
+ print(f'cytoplasms grouped: {len(cytoplasms_g_df)}')
1770
+ if 'nucleus' in tables:
1771
+ nucleus = nucleus.dropna(subset=['cell_id'])
1772
+ nucleus = nucleus.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
1773
+ nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
1774
+ nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
1775
+ nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
1776
+ if include_multinucleated == False:
1777
+ #nucleus = nucleus[~nucleus['prcfo'].duplicated()]
1778
+ nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
1779
+ nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
1780
+ print(f'nucleus: {len(nucleus)}')
1781
+ print(f'nucleus grouped: {len(nucleus_g_df)}')
1782
+ if 'cytoplasm' in tables:
1783
+ merged_df = merged_df.merge(nucleus_g_df, left_index=True, right_index=True)
1784
+ else:
1785
+ merged_df = cells_g_df.merge(nucleus_g_df, left_index=True, right_index=True)
1786
+ if 'pathogen' in tables:
1787
+ pathogens = pathogens.dropna(subset=['cell_id'])
1788
+ pathogens = pathogens.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
1789
+ pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
1790
+ pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
1791
+ pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
1792
+ if include_noninfected == False:
1793
+ pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
1794
+ if include_multiinfected == False:
1795
+ pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
1796
+ pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
1797
+ print(f'pathogens: {len(pathogens)}')
1798
+ print(f'pathogens grouped: {len(pathogens_g_df)}')
1799
+ merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
1800
+
1801
+ #Add prc column (plate row column)
1802
+ metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row'] + '_' +x['col'])
1803
+
1804
+ #Count cells per well
1805
+ cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
1806
+
1807
+ cells_well.reset_index(inplace=True)
1808
+ cells_well.rename(columns={'object_label': 'cells_per_well'}, inplace=True)
1809
+ metadata = pd.merge(metadata, cells_well, on='prc', how='inner', suffixes=('', '_drop_col'))
1810
+ object_label_cols = [col for col in metadata.columns if '_drop_col' in col]
1811
+ metadata.drop(columns=object_label_cols, inplace=True)
1812
+
1813
+ #Add prcfo column (plate row column field object)
1814
+ metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row'] + '_' +x['col']+ '_' +x['field']+ '_' +x['object_label'])
1815
+ metadata.set_index('prcfo', inplace=True)
1816
+
1817
+ merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
1818
+
1819
+ merged_df = merged_df.dropna(axis=1)
1820
+
1821
+ print(f'Generated dataframe with: {len(merged_df.columns)} columns and {len(merged_df)} rows')
1822
+
1823
+ obj_df_ls = []
1824
+ if 'cell' in tables:
1825
+ obj_df_ls.append(cells)
1826
+ if 'cytoplasm' in tables:
1827
+ obj_df_ls.append(cytoplasms)
1828
+ if 'nucleus' in tables:
1829
+ obj_df_ls.append(nucleus)
1830
+ if 'pathogen' in tables:
1831
+ obj_df_ls.append(pathogens)
1832
+
1833
+ return merged_df, obj_df_ls
1834
+
1835
+ def _results_to_csv(src, df, df_well):
1836
+ """
1837
+ Save the given dataframes as CSV files in the specified directory.
1838
+
1839
+ Args:
1840
+ src (str): The directory path where the CSV files will be saved.
1841
+ df (pandas.DataFrame): The dataframe containing cell data.
1842
+ df_well (pandas.DataFrame): The dataframe containing well data.
1843
+
1844
+ Returns:
1845
+ tuple: A tuple containing the cell dataframe and well dataframe.
1846
+ """
1847
+ cells = df
1848
+ wells = df_well
1849
+ results_loc = src+'/results'
1850
+ wells_loc = results_loc+'/wells.csv'
1851
+ cells_loc = results_loc+'/cells.csv'
1852
+ os.makedirs(results_loc, exist_ok=True)
1853
+ wells.to_csv(wells_loc, index=True, header=True)
1854
+ cells.to_csv(cells_loc, index=True, header=True)
1855
+ return cells, wells
1856
+
1857
+ ###################################################
1858
+ # Classify
1859
+ ###################################################
1860
+
1861
+ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94]):
1862
+ """
1863
+ Save the model based on certain conditions during training.
1864
+
1865
+ Args:
1866
+ model (torch.nn.Module): The trained model to be saved.
1867
+ model_type (str): The type of the model.
1868
+ results_df (pandas.DataFrame): The dataframe containing the training results.
1869
+ dst (str): The destination directory to save the model.
1870
+ epoch (int): The current epoch number.
1871
+ epochs (int): The total number of epochs.
1872
+ intermedeate_save (list, optional): List of accuracy thresholds to trigger intermediate model saves.
1873
+ Defaults to [0.99, 0.98, 0.95, 0.94].
1874
+ """
1875
+
1876
+ if epoch % 100 == 0:
1877
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
1878
+
1879
+ if epoch == epochs:
1880
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
1881
+
1882
+ if results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[0] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[0]:
1883
+ percentile = str(intermedeate_save[0]*100)
1884
+ print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
1885
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
1886
+
1887
+ elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[1] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[1]:
1888
+ percentile = str(intermedeate_save[1]*100)
1889
+ print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
1890
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
1891
+
1892
+ elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[2] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[2]:
1893
+ percentile = str(intermedeate_save[2]*100)
1894
+ print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
1895
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
1896
+
1897
+ elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[3] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[3]:
1898
+ percentile = str(intermedeate_save[3]*100)
1899
+ print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
1900
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
1901
+
1902
+ def _save_progress(dst, results_df, train_metrics_df):
1903
+ """
1904
+ Save the progress of the classification model.
1905
+
1906
+ Parameters:
1907
+ dst (str): The destination directory to save the progress.
1908
+ results_df (pandas.DataFrame): The DataFrame containing accuracy, loss, and PRAUC.
1909
+ train_metrics_df (pandas.DataFrame): The DataFrame containing training metrics.
1910
+
1911
+ Returns:
1912
+ None
1913
+ """
1914
+ # Save accuracy, loss, PRAUC
1915
+ os.makedirs(dst, exist_ok=True)
1916
+ results_path = os.path.join(dst, 'acc_loss_prauc.csv')
1917
+ if not os.path.exists(results_path):
1918
+ results_df.to_csv(results_path, index=True, header=True, mode='w')
1919
+ else:
1920
+ results_df.to_csv(results_path, index=True, header=False, mode='a')
1921
+ training_metrics_path = os.path.join(dst, 'training_metrics.csv')
1922
+ if not os.path.exists(training_metrics_path):
1923
+ train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
1924
+ else:
1925
+ train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
1926
+ return
1927
+
1928
+ def _save_settings(settings, src):
1929
+ """
1930
+ Save the settings dictionary to a CSV file.
1931
+
1932
+ Parameters:
1933
+ - settings (dict): A dictionary containing the settings.
1934
+ - src (str): The source directory where the settings file will be saved.
1935
+
1936
+ Returns:
1937
+ None
1938
+ """
1939
+ dst = os.path.join(src,'model')
1940
+ settings_loc = os.path.join(dst,'settings.csv')
1941
+ os.makedirs(dst, exist_ok=True)
1942
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
1943
+ display(settings_df)
1944
+ settings_df.to_csv(settings_loc, index=False)
1945
+ return
1946
+
1947
+
1948
+ def _copy_missclassified(df):
1949
+ misclassified = df[df['true_label'] != df['predicted_label']]
1950
+ for _, row in misclassified.iterrows():
1951
+ original_path = row['filename']
1952
+ filename = os.path.basename(original_path)
1953
+ dest_folder = os.path.dirname(os.path.dirname(original_path))
1954
+ if "pc" in original_path:
1955
+ new_path = os.path.join(dest_folder, "missclassified/pc", filename)
1956
+ else:
1957
+ new_path = os.path.join(dest_folder, "missclassified/nc", filename)
1958
+ os.makedirs(os.path.dirname(new_path), exist_ok=True)
1959
+ shutil.copy(original_path, new_path)
1960
+ print(f"Copied {len(misclassified)} misclassified images.")
1961
+ return
1962
+
1963
+ def _read_db(db_loc, tables):
1964
+ conn = sqlite3.connect(db_loc) # Create a connection to the database
1965
+ dfs = []
1966
+ for table in tables:
1967
+ query = f'SELECT * FROM {table}' # Write a SQL query to get the data from the database
1968
+ df = pd.read_sql_query(query, conn) # Use the read_sql_query function to get the data and save it as a DataFrame
1969
+ dfs.append(df)
1970
+ conn.close() # Close the connection
1971
+ return dfs
1972
+
1973
+ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=False, include_multiinfected=False, include_noninfected=False):
1974
+
1975
+ from .utils import _split_data
1976
+
1977
+ #Extract plate DataFrames
1978
+ all_dfs = []
1979
+ for loc in locs:
1980
+ db_dfs = _read_db(loc, tables)
1981
+ all_dfs.append(db_dfs)
1982
+
1983
+ #Extract Tables from DataFrames and concatinate rows
1984
+ for i, dfs in enumerate(all_dfs):
1985
+ if 'cell' in tables:
1986
+ cell = dfs[0]
1987
+ if verbose:
1988
+ print(f'plate: {i+1} cells:{len(cell)}')
1989
+ # see pathogens logic, copy logic to other tables #here
1990
+ if 'nucleus' in tables:
1991
+ nucleus = dfs[1]
1992
+ if verbose:
1993
+ print(f'plate: {i+1} nucleus:{len(nucleus)} ')
1994
+
1995
+ if 'pathogen' in tables:
1996
+ if len(tables) == 1:
1997
+ pathogen = dfs[0]
1998
+ print(len(pathogen))
1999
+ else:
2000
+ pathogen = dfs[2]
2001
+ if verbose:
2002
+ print(f'plate: {i+1} pathogens:{len(pathogen)}')
2003
+
2004
+ if 'cytoplasm' in tables:
2005
+ if not 'pathogen' in tables:
2006
+ cytoplasm = dfs[2]
2007
+ else:
2008
+ cytoplasm = dfs[3]
2009
+ if verbose:
2010
+ print(f'plate: {i+1} cytoplasms: {len(cytoplasm)}')
2011
+
2012
+ if i > 0:
2013
+ if 'cell' in tables:
2014
+ cells = pd.concat([cells, cell], axis = 0)
2015
+ if 'nucleus' in tables:
2016
+ nucleus = pd.concat([nucleus, nucleus], axis = 0)
2017
+ if 'pathogen' in tables:
2018
+ pathogens = pd.concat([pathogens, pathogen], axis = 0)
2019
+ if 'cytoplasm' in tables:
2020
+ cytoplasms = pd.concat([cytoplasms, cytoplasm], axis = 0)
2021
+ else:
2022
+ if 'cell' in tables:
2023
+ cells = cell.copy()
2024
+ if 'nucleus' in tables:
2025
+ nucleus = nucleus.copy()
2026
+ if 'pathogen' in tables:
2027
+ pathogens = pathogen.copy()
2028
+ if 'cytoplasm' in tables:
2029
+ cytoplasms = cytoplasm.copy()
2030
+
2031
+ #Add an o in front of all object and cell lables to convert them to strings
2032
+ if 'cell' in tables:
2033
+ cells = cells.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2034
+ cells = cells.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
2035
+ cells_g_df, metadata = _split_data(cells, 'prcfo', 'object_label')
2036
+ merged_df = cells_g_df.copy()
2037
+ if verbose:
2038
+ print(f'cells: {len(cells)}')
2039
+ print(f'cells grouped: {len(cells_g_df)}')
2040
+
2041
+ if 'cytoplasm' in tables:
2042
+ cytoplasms = cytoplasms.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2043
+ cytoplasms = cytoplasms.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
2044
+ cytoplasms_g_df, _ = _split_data(cytoplasms, 'prcfo', 'object_label')
2045
+ merged_df = cells_g_df.merge(cytoplasms_g_df, left_index=True, right_index=True)
2046
+ if verbose:
2047
+ print(f'cytoplasms: {len(cytoplasms)}')
2048
+ print(f'cytoplasms grouped: {len(cytoplasms_g_df)}')
2049
+
2050
+ if 'nucleus' in tables:
2051
+ if not 'cell' in tables:
2052
+ cells_g_df = pd.DataFrame()
2053
+ nucleus = nucleus.dropna(subset=['cell_id'])
2054
+ nucleus = nucleus.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2055
+ nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2056
+ nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2057
+ nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
2058
+ if include_multinucleated == False:
2059
+ nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
2060
+ nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
2061
+ if verbose:
2062
+ print(f'nucleus: {len(nucleus)}')
2063
+ print(f'nucleus grouped: {len(nucleus_g_df)}')
2064
+ if 'cytoplasm' in tables:
2065
+ merged_df = merged_df.merge(nucleus_g_df, left_index=True, right_index=True)
2066
+ else:
2067
+ merged_df = cells_g_df.merge(nucleus_g_df, left_index=True, right_index=True)
2068
+
2069
+ if 'pathogen' in tables:
2070
+ if not 'cell' in tables:
2071
+ cells_g_df = pd.DataFrame()
2072
+ merged_df = []
2073
+ try:
2074
+ pathogens = pathogens.dropna(subset=['cell_id'])
2075
+
2076
+ except:
2077
+ pathogens['cell_id'] = pathogens['object_label']
2078
+ pathogens = pathogens.dropna(subset=['cell_id'])
2079
+
2080
+ pathogens = pathogens.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2081
+ pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2082
+ pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2083
+ pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
2084
+ if include_noninfected == False:
2085
+ pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
2086
+ if isinstance(include_multiinfected, bool):
2087
+ if include_multiinfected == False:
2088
+ pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
2089
+ if isinstance(include_multiinfected, float):
2090
+ pathogens = pathogens[pathogens['pathogen_prcfo_count']<=include_multiinfected]
2091
+ if not 'cell' in tables:
2092
+ pathogens_g_df, metadata = _split_data(pathogens, 'prcfo', 'cell_id')
2093
+ else:
2094
+ pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
2095
+ if verbose:
2096
+ print(f'pathogens: {len(pathogens)}')
2097
+ print(f'pathogens grouped: {len(pathogens_g_df)}')
2098
+ if len(merged_df) == 0:
2099
+ merged_df = pathogens_g_df
2100
+ else:
2101
+ merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
2102
+
2103
+ #Add prc column (plate row column)
2104
+ metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row'] + '_' +x['col'])
2105
+
2106
+ #Count cells per well
2107
+ cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
2108
+
2109
+ cells_well.reset_index(inplace=True)
2110
+ cells_well.rename(columns={'object_label': 'cells_per_well'}, inplace=True)
2111
+ metadata = pd.merge(metadata, cells_well, on='prc', how='inner', suffixes=('', '_drop_col'))
2112
+ object_label_cols = [col for col in metadata.columns if '_drop_col' in col]
2113
+ metadata.drop(columns=object_label_cols, inplace=True)
2114
+
2115
+ #Add prcfo column (plate row column field object)
2116
+ metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row'] + '_' +x['col']+ '_' +x['field']+ '_' +x['object_label'])
2117
+ metadata.set_index('prcfo', inplace=True)
2118
+
2119
+ merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
2120
+
2121
+ merged_df = merged_df.dropna(axis=1)
2122
+ if verbose:
2123
+ print(f'Generated dataframe with: {len(merged_df.columns)} columns and {len(merged_df)} rows')
2124
+
2125
+ obj_df_ls = []
2126
+ if 'cell' in tables:
2127
+ obj_df_ls.append(cells)
2128
+ if 'cytoplasm' in tables:
2129
+ obj_df_ls.append(cytoplasms)
2130
+ if 'nucleus' in tables:
2131
+ obj_df_ls.append(nucleus)
2132
+ if 'pathogen' in tables:
2133
+ obj_df_ls.append(pathogens)
2134
+
2135
+ return merged_df, obj_df_ls
2136
+
2137
+ def _read_mask(mask_path):
2138
+ mask = imageio2.imread(mask_path)
2139
+ if mask.dtype != np.uint16:
2140
+ mask = img_as_uint(mask)
2141
+ return mask
2142
+
2143
+
2144
+
2145
+
2146
+
2147
+
2148
+
2149
+
2150
+
2151
+
2152
+
2153
+
2154
+
2155
+
2156
+
2157
+
2158
+
2159
+
2160
+
2161
+
2162
+
2163
+
2164
+
2165
+
2166
+
2167
+
2168
+
2169
+
2170
+
2171
+
2172
+
2173
+
2174
+
2175
+
2176
+
2177
+
2178
+
2179
+
2180
+
2181
+
2182
+
2183
+
2184
+
2185
+
2186
+
2187
+
2188
+
2189
+
2190
+
2191
+
2192
+
2193
+
2194
+
2195
+
2196
+
2197
+
2198
+
2199
+
2200
+
2201
+
2202
+
2203
+
2204
+
2205
+
2206
+
2207
+
2208
+
2209
+
2210
+
2211
+
2212
+
2213
+
2214
+
2215
+
2216
+
2217
+
2218
+
2219
+
2220
+
2221
+
2222
+
2223
+
2224
+
2225
+
2226
+
2227
+
2228
+
2229
+
2230
+
2231
+
2232
+
2233
+
2234
+
2235
+
2236
+
2237
+
2238
+
2239
+
2240
+
2241
+
2242
+
2243
+
2244
+
2245
+
2246
+
2247
+
2248
+
2249
+
2250
+
2251
+
2252
+
2253
+
2254
+
2255
+
2256
+
2257
+
2258
+
2259
+
2260
+
2261
+
2262
+
2263
+
2264
+
2265
+
2266
+
2267
+
2268
+
2269
+
2270
+
2271
+