spacr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py ADDED
@@ -0,0 +1,1273 @@
1
+ import os,re, random, cv2, glob, time, math
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib as mpl
7
+ import scipy.ndimage as ndi
8
+ import seaborn as sns
9
+ import scipy.stats as stats
10
+ import statsmodels.api as sm
11
+
12
+ from IPython.display import display
13
+ from skimage.segmentation import find_boundaries
14
+ from skimage.measure import find_contours
15
+ from skimage.morphology import square, dilation
16
+ from skimage import measure
17
+
18
+ from ipywidgets import IntSlider, interact
19
+ from IPython.display import Image as ipyimage
20
+
21
+ from .logger import log_function_call
22
+
23
+
24
+ #from .io import _save_figure
25
+ #from .timelapse import _save_mask_timelapse_as_gif
26
+ #from .utils import normalize_to_dtype, _remove_outside_objects, _remove_multiobject_cells, _find_similar_sized_images, _remove_noninfected
27
+
28
+ def plot_masks(batch, masks, flows, cmap='inferno', figuresize=20, nr=1, file_type='.npz', print_object_number=True):
29
+ """
30
+ Plot the masks and flows for a given batch of images.
31
+
32
+ Args:
33
+ batch (numpy.ndarray): The batch of images.
34
+ masks (list or numpy.ndarray): The masks corresponding to the images.
35
+ flows (list or numpy.ndarray): The flows corresponding to the images.
36
+ cmap (str, optional): The colormap to use for displaying the images. Defaults to 'inferno'.
37
+ figuresize (int, optional): The size of the figure. Defaults to 20.
38
+ nr (int, optional): The maximum number of images to plot. Defaults to 1.
39
+ file_type (str, optional): The file type of the flows. Defaults to '.npz'.
40
+ print_object_number (bool, optional): Whether to print the object number on the mask. Defaults to True.
41
+
42
+ Returns:
43
+ None
44
+ """
45
+ if len(batch.shape) == 3:
46
+ batch = np.expand_dims(batch, axis=0)
47
+ if not isinstance(masks, list):
48
+ masks = [masks]
49
+ if not isinstance(flows, list):
50
+ flows = [flows]
51
+ else:
52
+ flows = flows[0]
53
+ if file_type == 'png':
54
+ flows = [f[0] for f in flows] # assuming this is what you want to do when file_type is 'png'
55
+ font = figuresize/2
56
+ index = 0
57
+ for image, mask, flow in zip(batch, masks, flows):
58
+ unique_labels = np.unique(mask)
59
+
60
+ num_objects = len(unique_labels[unique_labels != 0])
61
+ random_colors = np.random.rand(num_objects+1, 4)
62
+ random_colors[:, 3] = 1
63
+ random_colors[0, :] = [0, 0, 0, 1]
64
+ random_cmap = mpl.colors.ListedColormap(random_colors)
65
+
66
+ if index < nr:
67
+ index += 1
68
+ chans = image.shape[-1]
69
+ fig, ax = plt.subplots(1, image.shape[-1] + 2, figsize=(4 * figuresize, figuresize))
70
+ for v in range(0, image.shape[-1]):
71
+ ax[v].imshow(image[..., v], cmap=cmap) #_imshow
72
+ ax[v].set_title('Image - Channel'+str(v))
73
+ ax[chans].imshow(mask, cmap=random_cmap) #_imshow
74
+ ax[chans].set_title('Mask')
75
+ if print_object_number:
76
+ unique_objects = np.unique(mask)[1:]
77
+ for obj in unique_objects:
78
+ cy, cx = ndi.center_of_mass(mask == obj)
79
+ ax[chans].text(cx, cy, str(obj), color='white', fontsize=font, ha='center', va='center')
80
+ ax[chans+1].imshow(flow, cmap='viridis') #_imshow
81
+ ax[chans+1].set_title('Flow')
82
+ plt.show()
83
+ return
84
+
85
+ def _plot_4D_arrays(src, figuresize=10, cmap='inferno', nr_npz=1, nr=1):
86
+ """
87
+ Plot 4D arrays from .npz files.
88
+
89
+ Args:
90
+ src (str): The directory path where the .npz files are located.
91
+ figuresize (int, optional): The size of the figure. Defaults to 10.
92
+ cmap (str, optional): The colormap to use for image visualization. Defaults to 'inferno'.
93
+ nr_npz (int, optional): The number of .npz files to plot. Defaults to 1.
94
+ nr (int, optional): The number of images to plot from each .npz file. Defaults to 1.
95
+ """
96
+ paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
97
+ paths = random.sample(paths, min(nr_npz, len(paths)))
98
+
99
+ for path in paths:
100
+ with np.load(path) as data:
101
+ stack = data['data']
102
+ num_images = stack.shape[0]
103
+ num_channels = stack.shape[3]
104
+
105
+ for i in range(min(nr, num_images)):
106
+ img = stack[i]
107
+
108
+ # Create subplots
109
+ if num_channels == 1:
110
+ fig, axs = plt.subplots(1, 1, figsize=(figuresize, figuresize))
111
+ axs = [axs] # Make axs a list to use axs[c] later
112
+ else:
113
+ fig, axs = plt.subplots(1, num_channels, figsize=(num_channels * figuresize, figuresize))
114
+
115
+ for c in range(num_channels):
116
+ axs[c].imshow(img[:, :, c], cmap=cmap) #_imshow
117
+ axs[c].set_title(f'Channel {c}', size=24)
118
+ axs[c].axis('off')
119
+
120
+ fig.tight_layout()
121
+ plt.show()
122
+ return
123
+
124
+ def generate_mask_random_cmap(mask):
125
+ """
126
+ Generate a random colormap based on the unique labels in the given mask.
127
+
128
+ Parameters:
129
+ mask (numpy.ndarray): The input mask array.
130
+
131
+ Returns:
132
+ matplotlib.colors.ListedColormap: The random colormap.
133
+ """
134
+ unique_labels = np.unique(mask)
135
+ num_objects = len(unique_labels[unique_labels != 0])
136
+ random_colors = np.random.rand(num_objects+1, 4)
137
+ random_colors[:, 3] = 1
138
+ random_colors[0, :] = [0, 0, 0, 1]
139
+ random_cmap = mpl.colors.ListedColormap(random_colors)
140
+ return random_cmap
141
+
142
+ def random_cmap(num_objects=100):
143
+ """
144
+ Generate a random colormap.
145
+
146
+ Parameters:
147
+ num_objects (int): The number of objects to generate colors for. Default is 100.
148
+
149
+ Returns:
150
+ random_cmap (matplotlib.colors.ListedColormap): A random colormap.
151
+ """
152
+ random_colors = np.random.rand(num_objects+1, 4)
153
+ random_colors[:, 3] = 1
154
+ random_colors[0, :] = [0, 0, 0, 1]
155
+ random_cmap = mpl.colors.ListedColormap(random_colors)
156
+ return random_cmap
157
+
158
+ def _generate_mask_random_cmap(mask):
159
+ """
160
+ Generate a random colormap based on the unique labels in the given mask.
161
+
162
+ Parameters:
163
+ mask (ndarray): The mask array containing unique labels.
164
+
165
+ Returns:
166
+ ListedColormap: A random colormap generated based on the unique labels in the mask.
167
+ """
168
+ unique_labels = np.unique(mask)
169
+ num_objects = len(unique_labels[unique_labels != 0])
170
+ random_colors = np.random.rand(num_objects+1, 4)
171
+ random_colors[:, 3] = 1
172
+ random_colors[0, :] = [0, 0, 0, 1]
173
+ random_cmap = mpl.colors.ListedColormap(random_colors)
174
+ return random_cmap
175
+
176
+ def _get_colours_merged(outline_color):
177
+ """
178
+ Get the merged outline colors based on the specified outline color format.
179
+
180
+ Parameters:
181
+ outline_color (str): The outline color format. Can be one of 'rgb', 'bgr', 'gbr', or 'rbg'.
182
+
183
+ Returns:
184
+ list: A list of merged outline colors based on the specified format.
185
+ """
186
+ if outline_color == 'rgb':
187
+ outline_colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rgb
188
+ elif outline_color == 'bgr':
189
+ outline_colors = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] # bgr
190
+ elif outline_color == 'gbr':
191
+ outline_colors = [[0, 1, 0], [0, 0, 1], [1, 0, 0]] # gbr
192
+ elif outline_color == 'rbg':
193
+ outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
194
+ else:
195
+ outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
196
+ return outline_colors
197
+
198
+ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
199
+ """
200
+ Filters objects in a plot based on various criteria.
201
+
202
+ Args:
203
+ stack (numpy.ndarray): The input stack of masks.
204
+ cell_mask_dim (int): The dimension index of the cell mask.
205
+ nucleus_mask_dim (int): The dimension index of the nucleus mask.
206
+ pathogen_mask_dim (int): The dimension index of the pathogen mask.
207
+ mask_dims (list): A list of dimension indices for additional masks.
208
+ filter_min_max (list): A list of minimum and maximum area values for each mask.
209
+ include_multinucleated (bool): Whether to include multinucleated cells.
210
+ include_multiinfected (bool): Whether to include multiinfected cells.
211
+
212
+ Returns:
213
+ numpy.ndarray: The filtered stack of masks.
214
+ """
215
+ from .utils import _remove_outside_objects, _remove_multiobject_cells
216
+
217
+ stack = _remove_outside_objects(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim)
218
+
219
+ for i, mask_dim in enumerate(mask_dims):
220
+ if not filter_min_max is None:
221
+ min_max = filter_min_max[i]
222
+ else:
223
+ min_max = [0, 100000]
224
+
225
+ mask = np.take(stack, mask_dim, axis=2)
226
+ props = measure.regionprops_table(mask, properties=['label', 'area'])
227
+ avg_size_before = np.mean(props['area'])
228
+ total_count_before = len(props['label'])
229
+
230
+ if not filter_min_max is None:
231
+ valid_labels = props['label'][np.logical_and(props['area'] > min_max[0], props['area'] < min_max[1])]
232
+ stack[:, :, mask_dim] = np.isin(mask, valid_labels) * mask
233
+
234
+ props_after = measure.regionprops_table(stack[:, :, mask_dim], properties=['label', 'area'])
235
+ avg_size_after = np.mean(props_after['area'])
236
+ total_count_after = len(props_after['label'])
237
+
238
+ if mask_dim == cell_mask_dim:
239
+ if include_multinucleated is False and nucleus_mask_dim is not None:
240
+ stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=pathogen_mask_dim)
241
+ if include_multiinfected is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
242
+ stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=nucleus_mask_dim)
243
+ cell_area_before = avg_size_before
244
+ cell_count_before = total_count_before
245
+ cell_area_after = avg_size_after
246
+ cell_count_after = total_count_after
247
+ if mask_dim == nucleus_mask_dim:
248
+ nucleus_area_before = avg_size_before
249
+ nucleus_count_before = total_count_before
250
+ nucleus_area_after = avg_size_after
251
+ nucleus_count_after = total_count_after
252
+ if mask_dim == pathogen_mask_dim:
253
+ pathogen_area_before = avg_size_before
254
+ pathogen_count_before = total_count_before
255
+ pathogen_area_after = avg_size_after
256
+ pathogen_count_after = total_count_after
257
+
258
+ if cell_mask_dim is not None:
259
+ print(f'removed {cell_count_before-cell_count_after} cells, cell size from {cell_area_before} to {cell_area_after}')
260
+ if nucleus_mask_dim is not None:
261
+ print(f'removed {nucleus_count_before-nucleus_count_after} nucleus, nucleus size from {nucleus_area_before} to {nucleus_area_after}')
262
+ if pathogen_mask_dim is not None:
263
+ print(f'removed {pathogen_count_before-pathogen_count_after} pathogens, pathogen size from {pathogen_area_before} to {pathogen_area_after}')
264
+
265
+ return stack
266
+
267
+ def _normalize_and_outline(image, remove_background, backgrounds, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
268
+ """
269
+ Normalize and outline an image.
270
+
271
+ Args:
272
+ image (ndarray): The input image.
273
+ remove_background (bool): Flag indicating whether to remove the background.
274
+ backgrounds (list): List of background values for each channel.
275
+ normalize (bool): Flag indicating whether to normalize the image.
276
+ normalization_percentiles (list): List of percentiles for normalization.
277
+ overlay (bool): Flag indicating whether to overlay outlines onto the image.
278
+ overlay_chans (list): List of channel indices to overlay.
279
+ mask_dims (list): List of dimensions to use for masking.
280
+ outline_colors (list): List of colors for the outlines.
281
+ outline_thickness (int): Thickness of the outlines.
282
+
283
+ Returns:
284
+ tuple: A tuple containing the overlayed image, the original image, and a list of outlines.
285
+ """
286
+ from .utils import normalize_to_dtype
287
+
288
+ outlines = []
289
+ if remove_background:
290
+ for chan_index, channel in enumerate(range(image.shape[-1])):
291
+ single_channel = image[:, :, channel] # Extract the specific channel
292
+ background = backgrounds[chan_index]
293
+ single_channel[single_channel < background] = 0
294
+ image[:, :, channel] = single_channel
295
+ if normalize:
296
+ image = normalize_to_dtype(array=image, q1=normalization_percentiles[0], q2=normalization_percentiles[1])
297
+ rgb_image = np.take(image, overlay_chans, axis=-1)
298
+ rgb_image = rgb_image.astype(float)
299
+ rgb_image -= rgb_image.min()
300
+ rgb_image /= rgb_image.max()
301
+ if overlay:
302
+ overlayed_image = rgb_image.copy()
303
+ for i, mask_dim in enumerate(mask_dims):
304
+ mask = np.take(image, mask_dim, axis=2)
305
+ outline = np.zeros_like(mask)
306
+ # Find the contours of the objects in the mask
307
+ for j in np.unique(mask)[1:]:
308
+ contours = find_contours(mask == j, 0.5)
309
+ for contour in contours:
310
+ contour = contour.astype(int)
311
+ outline[contour[:, 0], contour[:, 1]] = j
312
+ # Make the outline thicker
313
+ outline = dilation(outline, square(outline_thickness))
314
+ outlines.append(outline)
315
+ # Overlay the outlines onto the RGB image
316
+ for j in np.unique(outline)[1:]:
317
+ overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
318
+ return overlayed_image, image, outlines
319
+ else:
320
+ return [], image, []
321
+
322
+ def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_image, outlines, cmap, outline_colors, print_object_number):
323
+ """
324
+ Plot the merged plot with overlay, image channels, and masks.
325
+
326
+ Args:
327
+ overlay (bool): Flag indicating whether to overlay the image with outlines.
328
+ image (ndarray): Input image array.
329
+ stack (ndarray): Stack of masks.
330
+ mask_dims (list): List of mask dimensions.
331
+ figuresize (float): Size of the figure.
332
+ overlayed_image (ndarray): Overlayed image array.
333
+ outlines (list): List of outlines.
334
+ cmap (str): Colormap for the masks.
335
+ outline_colors (list): List of outline colors.
336
+ print_object_number (bool): Flag indicating whether to print object numbers on the masks.
337
+
338
+ Returns:
339
+ fig (Figure): The generated matplotlib figure.
340
+ """
341
+ if overlay:
342
+ fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims) + 1, figsize=(4 * figuresize, figuresize))
343
+ ax[0].imshow(overlayed_image) #_imshow
344
+ ax[0].set_title('Overlayed Image')
345
+ ax_index = 1
346
+ else:
347
+ fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims), figsize=(4 * figuresize, figuresize))
348
+ ax_index = 0
349
+
350
+ # Normalize and plot each channel with outlines
351
+ for v in range(0, image.shape[-1]):
352
+ channel_image = image[..., v]
353
+ channel_image_normalized = channel_image.astype(float)
354
+ channel_image_normalized -= channel_image_normalized.min()
355
+ channel_image_normalized /= channel_image_normalized.max()
356
+ channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
357
+
358
+ # Apply the outlines onto the RGB image
359
+ for outline, color in zip(outlines, outline_colors):
360
+ for j in np.unique(outline)[1:]:
361
+ channel_image_rgb[outline == j] = mpl.colors.to_rgb(color)
362
+
363
+ ax[v + ax_index].imshow(channel_image_rgb)
364
+ ax[v + ax_index].set_title('Image - Channel'+str(v))
365
+
366
+ for i, mask_dim in enumerate(mask_dims):
367
+ mask = np.take(stack, mask_dim, axis=2)
368
+ random_cmap = _generate_mask_random_cmap(mask)
369
+ ax[i + image.shape[-1] + ax_index].imshow(mask, cmap=random_cmap)
370
+ ax[i + image.shape[-1] + ax_index].set_title('Mask '+ str(i))
371
+ if print_object_number:
372
+ unique_objects = np.unique(mask)[1:]
373
+ for obj in unique_objects:
374
+ cy, cx = ndi.center_of_mass(mask == obj)
375
+ ax[i + image.shape[-1] + ax_index].text(cx, cy, str(obj), color='white', fontsize=8, ha='center', va='center')
376
+
377
+ plt.tight_layout()
378
+ plt.show()
379
+ return fig
380
+
381
+ def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
382
+ """
383
+ Plot randomly selected arrays from a given directory.
384
+
385
+ Parameters:
386
+ - src (str): The directory path containing the arrays.
387
+ - figuresize (int): The size of the figure (default: 50).
388
+ - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
389
+ - nr (int): The number of arrays to plot (default: 1).
390
+ - normalize (bool): Whether to normalize the arrays (default: True).
391
+ - q1 (int): The lower percentile for normalization (default: 1).
392
+ - q2 (int): The upper percentile for normalization (default: 99).
393
+
394
+ Returns:
395
+ None
396
+ """
397
+ from .utils import normalize_to_dtype
398
+
399
+ mask_cmap = random_cmap()
400
+ paths = []
401
+ for file in os.listdir(src):
402
+ if file.endswith('.npy'):
403
+ path = os.path.join(src, file)
404
+ paths.append(path)
405
+ paths = random.sample(paths, nr)
406
+ for path in paths:
407
+ print(f'Image path:{path}')
408
+ img = np.load(path)
409
+ if normalize:
410
+ img = normalize_to_dtype(array=img, q1=q1, q2=q2)
411
+ dim = img.shape
412
+ if len(img.shape)>2:
413
+ array_nr = img.shape[2]
414
+ fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
415
+ for channel in range(array_nr):
416
+ i = np.take(img, [channel], axis=2)
417
+ axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
418
+ axs[channel].set_title('Channel '+str(channel),size=24)
419
+ axs[channel].axis('off')
420
+ else:
421
+ fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
422
+ ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
423
+ ax.set_title('Channel 0',size=24)
424
+ ax.axis('off')
425
+ fig.tight_layout()
426
+ plt.show()
427
+ return
428
+
429
+ def plot_merged(src, settings):
430
+ """
431
+ Plot the merged images after applying various filters and modifications.
432
+
433
+ Args:
434
+ src (ndarray): The source images.
435
+ settings (dict): The settings for the plot.
436
+
437
+ Returns:
438
+ None
439
+ """
440
+ from .utils import _remove_noninfected
441
+
442
+ font = settings['figuresize']/2
443
+ outline_colors = _get_colours_merged(settings['outline_color'])
444
+ index = 0
445
+
446
+ mask_dims = [settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim']]
447
+ mask_dims = [element for element in mask_dims if element is not None]
448
+
449
+ if settings['verbose']:
450
+ display(settings)
451
+
452
+ if settings['pathogen_mask_dim'] is None:
453
+ settings['include_multiinfected'] = True
454
+
455
+ for file in os.listdir(src):
456
+ path = os.path.join(src, file)
457
+ stack = np.load(path)
458
+ print(f'Loaded: {path}')
459
+ if not settings['include_noninfected']:
460
+ if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
461
+ stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
462
+
463
+ if settings['include_multiinfected'] is not True or settings['include_multinucleated'] is not True or settings['filter_min_max'] is not None:
464
+ stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['include_multinucleated'], settings['include_multiinfected'])
465
+
466
+ #image = np.take(stack, settings['channel_dims'], axis=2)
467
+ print('stack.shape', stack.shape)
468
+ overlayed_image, image, outlines = _normalize_and_outline(stack, settings['remove_background'], settings['backgrounds'], settings['normalize'], settings['normalization_percentiles'], settings['overlay'], settings['overlay_chans'], mask_dims, outline_colors, settings['outline_thickness'])
469
+
470
+ if index < settings['nr']:
471
+ index += 1
472
+ fig = _plot_merged_plot(settings['overlay'], image, stack, mask_dims, settings['figuresize'], overlayed_image, outlines, settings['cmap'], outline_colors, settings['print_object_number'])
473
+ else:
474
+ return fig
475
+
476
+ def _plot_images_on_grid(image_files, channel_indices, um_per_pixel, scale_bar_length_um=5, fontsize=8, show_filename=True, channel_names=None, plot=False):
477
+ """
478
+ Plots a grid of images with optional scale bar and channel names.
479
+
480
+ Args:
481
+ image_files (list): List of image file paths.
482
+ channel_indices (list): List of channel indices to select from the images.
483
+ um_per_pixel (float): Micrometers per pixel.
484
+ scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 5.
485
+ fontsize (int, optional): Font size for the image titles. Defaults to 8.
486
+ show_filename (bool, optional): Whether to show the image file names as titles. Defaults to True.
487
+ channel_names (list, optional): List of channel names. Defaults to None.
488
+ plot (bool, optional): Whether to display the plot. Defaults to False.
489
+
490
+ Returns:
491
+ matplotlib.figure.Figure: The generated figure object.
492
+ """
493
+ print(f'scale bar represents {scale_bar_length_um} um')
494
+ nr_of_images = len(image_files)
495
+ cols = int(np.ceil(np.sqrt(nr_of_images)))
496
+ rows = np.ceil(nr_of_images / cols)
497
+ fig, axes = plt.subplots(int(rows), int(cols), figsize=(20, 20), facecolor='black')
498
+ fig.patch.set_facecolor('black')
499
+ axes = axes.flatten()
500
+ # Calculate the scale bar length in pixels
501
+ scale_bar_length_px = int(scale_bar_length_um / um_per_pixel) # Convert to pixels
502
+
503
+ channel_colors = ['red','green','blue']
504
+ for i, image_file in enumerate(image_files):
505
+ img_array = cv2.imread(image_file, cv2.IMREAD_UNCHANGED)
506
+
507
+ if img_array.ndim == 3 and img_array.shape[2] >= 3:
508
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
509
+ # Handle different channel selections
510
+ if channel_indices is not None:
511
+ if len(channel_indices) == 1: # Single channel (grayscale)
512
+ img_array = img_array[:, :, channel_indices[0]]
513
+ cmap = 'gray'
514
+ elif len(channel_indices) == 2: # Dual channels
515
+ img_array = np.mean(img_array[:, :, channel_indices], axis=2)
516
+ cmap = 'gray'
517
+ else: # RGB or more channels
518
+ img_array = img_array[:, :, channel_indices]
519
+ cmap = None
520
+ else:
521
+ cmap = None if img_array.ndim == 3 else 'gray'
522
+ # Normalize based on dtype
523
+ if img_array.dtype == np.uint16:
524
+ img_array = img_array.astype(np.float32) / 65535.0
525
+ elif img_array.dtype == np.uint8:
526
+ img_array = img_array.astype(np.float32) / 255.0
527
+ ax = axes[i]
528
+ ax.imshow(img_array, cmap=cmap)
529
+ ax.axis('off')
530
+ if show_filename:
531
+ ax.set_title(os.path.basename(image_file), color='white', fontsize=fontsize, pad=20)
532
+ # Add scale bar
533
+ ax.plot([10, 10 + scale_bar_length_px], [img_array.shape[0] - 10] * 2, lw=2, color='white')
534
+ # Add channel names at the top if specified
535
+ initial_offset = 0.02 # Starting offset from the left side of the figure
536
+ increment = 0.05 # Fixed increment for each subsequent channel name, adjust based on figure width
537
+ if channel_names:
538
+ current_offset = initial_offset
539
+ for i, channel_name in enumerate(channel_names):
540
+ color = channel_colors[i] if i < len(channel_colors) else 'white'
541
+ fig.text(current_offset, 0.99, channel_name, color=color, fontsize=fontsize,
542
+ verticalalignment='top', horizontalalignment='left',
543
+ bbox=dict(facecolor='black', edgecolor='none', pad=3))
544
+ current_offset += increment
545
+
546
+ for j in range(i + 1, len(axes)):
547
+ axes[j].axis('off')
548
+
549
+ plt.tight_layout(pad=3)
550
+ if plot:
551
+ plt.show()
552
+ return fig
553
+
554
+ def _save_scimg_plot(src, nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1, scale_bar_length_um=10, standardize=True, fontsize=8, show_filename=True, channel_names=None, dpi=300, plot=False, i=1, all_folders=1):
555
+
556
+ """
557
+ Save and visualize single-cell images.
558
+
559
+ Args:
560
+ src (str): The source directory path.
561
+ nr_imgs (int, optional): The number of images to visualize. Defaults to 16.
562
+ channel_indices (list, optional): List of channel indices to visualize. Defaults to [0,1,2].
563
+ um_per_pixel (float, optional): Micrometers per pixel. Defaults to 0.1.
564
+ scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 10.
565
+ standardize (bool, optional): Whether to standardize the image sizes. Defaults to True.
566
+ fontsize (int, optional): Font size for the filename. Defaults to 8.
567
+ show_filename (bool, optional): Whether to show the filename on the image. Defaults to True.
568
+ channel_names (list, optional): List of channel names. Defaults to None.
569
+ dpi (int, optional): Dots per inch for the saved image. Defaults to 300.
570
+ plot (bool, optional): Whether to plot the images. Defaults to False.
571
+
572
+ Returns:
573
+ None
574
+ """
575
+ from .io import _save_figure
576
+
577
+ def _visualize_scimgs(src, channel_indices=None, um_per_pixel=0.1, scale_bar_length_um=10, show_filename=True, standardize=True, nr_imgs=None, fontsize=8, channel_names=None, plot=False):
578
+ """
579
+ Visualize single-cell images.
580
+
581
+ Args:
582
+ src (str): The source directory path.
583
+ channel_indices (list, optional): List of channel indices to visualize. Defaults to None.
584
+ um_per_pixel (float, optional): Micrometers per pixel. Defaults to 0.1.
585
+ scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 10.
586
+ show_filename (bool, optional): Whether to show the filename on the image. Defaults to True.
587
+ standardize (bool, optional): Whether to standardize the image sizes. Defaults to True.
588
+ nr_imgs (int, optional): The number of images to visualize. Defaults to None.
589
+ fontsize (int, optional): Font size for the filename. Defaults to 8.
590
+ channel_names (list, optional): List of channel names. Defaults to None.
591
+ plot (bool, optional): Whether to plot the images. Defaults to False.
592
+
593
+ Returns:
594
+ matplotlib.figure.Figure: The figure object containing the plotted images.
595
+ """
596
+ from .utils import _find_similar_sized_images
597
+ def _generate_filelist(src):
598
+ """
599
+ Generate a list of image files in the specified directory.
600
+
601
+ Args:
602
+ src (str): The source directory path.
603
+
604
+ Returns:
605
+ list: A list of image file paths.
606
+
607
+ """
608
+ files = glob.glob(os.path.join(src, '*'))
609
+ image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff', '.gif'))]
610
+ return image_files
611
+
612
+ def _random_sample(file_list, nr_imgs=None):
613
+ """
614
+ Randomly selects a subset of files from the given file list.
615
+
616
+ Args:
617
+ file_list (list): A list of file names.
618
+ nr_imgs (int, optional): The number of files to select. If None, all files are selected. Defaults to None.
619
+
620
+ Returns:
621
+ list: A list of randomly selected file names.
622
+ """
623
+ if nr_imgs is not None and nr_imgs < len(file_list):
624
+ random.seed(42)
625
+ file_list = random.sample(file_list, nr_imgs)
626
+ return file_list
627
+
628
+ image_files = _generate_filelist(src)
629
+
630
+ if standardize:
631
+ image_files = _find_similar_sized_images(image_files)
632
+
633
+ if nr_imgs is not None:
634
+ image_files = _random_sample(image_files, nr_imgs)
635
+
636
+ fig = _plot_images_on_grid(image_files, channel_indices, um_per_pixel, scale_bar_length_um, fontsize, show_filename, channel_names, plot)
637
+
638
+ return fig
639
+
640
+ fig = _visualize_scimgs(src, channel_indices, um_per_pixel, scale_bar_length_um, show_filename, standardize, nr_imgs, fontsize, channel_names, plot)
641
+ _save_figure(fig, src, text='all_channels')
642
+
643
+ for channel in channel_indices:
644
+ channel_indices=[channel]
645
+ fig = _visualize_scimgs(src, channel_indices, um_per_pixel, scale_bar_length_um, show_filename, standardize, nr_imgs, fontsize, channel_names=None, plot=plot)
646
+ _save_figure(fig, src, text=f'channel_{channel}')
647
+
648
+ return
649
+
650
+ def _plot_cropped_arrays(stack, figuresize=20,cmap='inferno'):
651
+ """
652
+ Plot cropped arrays.
653
+
654
+ Args:
655
+ stack (ndarray): The array to be plotted.
656
+ figuresize (int, optional): The size of the figure. Defaults to 20.
657
+ cmap (str, optional): The colormap to be used. Defaults to 'inferno'.
658
+
659
+ Returns:
660
+ None
661
+ """
662
+ start = time.time()
663
+ dim = stack.shape
664
+ channel=min(dim)
665
+ if len(stack.shape) == 2:
666
+ f, a = plt.subplots(1, 1,figsize=(figuresize,figuresize))
667
+ a.imshow(stack, cmap=plt.get_cmap(cmap))
668
+ a.set_title('Channel one',size=18)
669
+ a.axis('off')
670
+ f.tight_layout()
671
+ plt.show()
672
+ if len(stack.shape) > 2:
673
+ anr = stack.shape[2]
674
+ f, a = plt.subplots(1, anr,figsize=(figuresize,figuresize))
675
+ for channel in range(anr):
676
+ a[channel].imshow(stack[:,:,channel], cmap=plt.get_cmap(cmap))
677
+ a[channel].set_title('Channel '+str(channel),size=18)
678
+ a[channel].axis('off')
679
+ f.tight_layout()
680
+ plt.show()
681
+ stop = time.time()
682
+ duration = stop - start
683
+ print('plot_cropped_arrays', duration)
684
+ return
685
+
686
+ def _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df, save, src, name, plot, filenames, object_type, mode='btrack', interactive=False):
687
+ """
688
+ Visualizes and saves a timelapse stack with tracks.
689
+
690
+ Args:
691
+ masks (list): List of binary masks representing each frame of the timelapse stack.
692
+ tracks_df (pandas.DataFrame): DataFrame containing track information.
693
+ save (bool): Flag indicating whether to save the timelapse stack.
694
+ src (str): Source file path.
695
+ name (str): Name of the timelapse stack.
696
+ plot (bool): Flag indicating whether to plot the timelapse stack.
697
+ filenames (list): List of filenames corresponding to each frame of the timelapse stack.
698
+ object_type (str): Type of object being tracked.
699
+ mode (str, optional): Tracking mode. Defaults to 'btrack'.
700
+ interactive (bool, optional): Flag indicating whether to display the timelapse stack interactively. Defaults to False.
701
+ """
702
+
703
+ from .timelapse import _save_mask_timelapse_as_gif
704
+
705
+ highest_label = max(np.max(mask) for mask in masks)
706
+ # Generate random colors for each label, including the background
707
+ random_colors = np.random.rand(highest_label + 1, 4)
708
+ random_colors[:, 3] = 1 # Full opacity
709
+ random_colors[0] = [0, 0, 0, 1] # Background color
710
+ cmap = plt.cm.colors.ListedColormap(random_colors)
711
+ # Ensure the normalization range covers all labels
712
+ norm = plt.cm.colors.Normalize(vmin=0, vmax=highest_label)
713
+
714
+ # Function to plot a frame and overlay tracks
715
+ def _view_frame_with_tracks(frame=0):
716
+ """
717
+ Display the frame with tracks overlaid.
718
+
719
+ Parameters:
720
+ frame (int): The frame number to display.
721
+
722
+ Returns:
723
+ None
724
+ """
725
+ fig, ax = plt.subplots(figsize=(50, 50))
726
+ current_mask = masks[frame]
727
+ ax.imshow(current_mask, cmap=cmap, norm=norm) # Apply both colormap and normalization
728
+ ax.set_title(f'Frame: {frame}')
729
+
730
+ # Directly annotate each object with its label number from the mask
731
+ for label_value in np.unique(current_mask):
732
+ if label_value == 0: continue # Skip background
733
+ y, x = np.mean(np.where(current_mask == label_value), axis=1)
734
+ ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
735
+
736
+ # Overlay tracks
737
+ for track in tracks_df['track_id'].unique():
738
+ _track = tracks_df[tracks_df['track_id'] == track]
739
+ ax.plot(_track['x'], _track['y'], '-k', linewidth=1)
740
+
741
+ ax.axis('off')
742
+ plt.show()
743
+
744
+ if plot:
745
+ if interactive:
746
+ interact(_view_frame_with_tracks, frame=IntSlider(min=0, max=len(masks)-1, step=1, value=0))
747
+
748
+ if save:
749
+ # Save as gif
750
+ gif_path = os.path.join(os.path.dirname(src), 'movies', 'gif')
751
+ os.makedirs(gif_path, exist_ok=True)
752
+ save_path_gif = os.path.join(gif_path, f'timelapse_masks_{object_type}_{name}.gif')
753
+ _save_mask_timelapse_as_gif(masks, tracks_df, save_path_gif, cmap, norm, filenames)
754
+ if plot:
755
+ if not interactive:
756
+ _display_gif(save_path_gif)
757
+
758
+ def _display_gif(path):
759
+ """
760
+ Display a GIF image from the given path.
761
+
762
+ Parameters:
763
+ path (str): The path to the GIF image file.
764
+
765
+ Returns:
766
+ None
767
+ """
768
+ with open(path, 'rb') as file:
769
+ display(ipyimage(file.read()))
770
+
771
+ def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=50):
772
+ """
773
+ Plot recruitment data for different conditions and pathogens.
774
+
775
+ Args:
776
+ df (DataFrame): The input DataFrame containing the recruitment data.
777
+ df_type (str): The type of DataFrame (e.g., 'train', 'test').
778
+ channel_of_interest (str): The channel of interest for plotting.
779
+ target (str): The target variable for plotting.
780
+ columns (list, optional): Additional columns to plot. Defaults to an empty list.
781
+ figuresize (int, optional): The size of the figure. Defaults to 50.
782
+
783
+ Returns:
784
+ None
785
+ """
786
+
787
+ color_list = [(55/255, 155/255, 155/255),
788
+ (155/255, 55/255, 155/255),
789
+ (55/255, 155/255, 255/255),
790
+ (255/255, 55/255, 155/255)]
791
+
792
+ sns.set_palette(sns.color_palette(color_list))
793
+ font = figuresize/2
794
+ width=figuresize
795
+ height=figuresize/4
796
+
797
+ fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(width, height))
798
+ sns.barplot(ax=axes[0], data=df, x='condition', y=f'cell_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
799
+ axes[0].set_xlabel(f'pathogen {df_type}', fontsize=font)
800
+ axes[0].set_ylabel(f'cell_channel_{channel_of_interest}_mean_intensity', fontsize=font)
801
+
802
+ sns.barplot(ax=axes[1], data=df, x='condition', y=f'nucleus_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
803
+ axes[1].set_xlabel(f'pathogen {df_type}', fontsize=font)
804
+ axes[1].set_ylabel(f'nucleus_channel_{channel_of_interest}_mean_intensity', fontsize=font)
805
+
806
+ sns.barplot(ax=axes[2], data=df, x='condition', y=f'cytoplasm_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
807
+ axes[2].set_xlabel(f'pathogen {df_type}', fontsize=font)
808
+ axes[2].set_ylabel(f'cytoplasm_channel_{channel_of_interest}_mean_intensity', fontsize=font)
809
+
810
+ sns.barplot(ax=axes[3], data=df, x='condition', y=f'pathogen_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
811
+ axes[3].set_xlabel(f'pathogen {df_type}', fontsize=font)
812
+ axes[3].set_ylabel(f'pathogen_channel_{channel_of_interest}_mean_intensity', fontsize=font)
813
+
814
+ axes[0].legend_.remove()
815
+ axes[1].legend_.remove()
816
+ axes[2].legend_.remove()
817
+ axes[3].legend_.remove()
818
+
819
+ handles, labels = axes[3].get_legend_handles_labels()
820
+ axes[3].legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc='center left')
821
+ for i in [0,1,2,3]:
822
+ axes[i].tick_params(axis='both', which='major', labelsize=font)
823
+ axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)
824
+
825
+ plt.tight_layout()
826
+ plt.show()
827
+
828
+ columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
829
+ columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
830
+
831
+ width = figuresize*2
832
+ columns_per_row = math.ceil(len(columns) / 2)
833
+ height = (figuresize*2)/columns_per_row
834
+
835
+ fig, axes = plt.subplots(nrows=2, ncols=columns_per_row, figsize=(width, height * 2))
836
+ axes = axes.flatten()
837
+
838
+ print(f'{columns}')
839
+
840
+ for i, col in enumerate(columns):
841
+
842
+ ax = axes[i]
843
+ sns.barplot(ax=ax, data=df, x='condition', y=f'{col}', hue='pathogen', capsize=.1, ci='sd', dodge=False)
844
+ ax.set_xlabel(f'pathogen {df_type}', fontsize=font)
845
+ ax.set_ylabel(f'{col}', fontsize=int(font*2))
846
+ ax.legend_.remove()
847
+ ax.tick_params(axis='both', which='major', labelsize=font)
848
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
849
+ if i <= 5:
850
+ ax.set_ylim(1, None)
851
+
852
+ for i in range(len(columns), len(axes)):
853
+ axes[i].axis('off')
854
+
855
+ plt.tight_layout()
856
+ plt.show()
857
+
858
+ def _plot_controls(df, mask_chans, channel_of_interest, figuresize=5):
859
+ """
860
+ Plot controls for different channels and conditions.
861
+
862
+ Args:
863
+ df (pandas.DataFrame): The DataFrame containing the data.
864
+ mask_chans (list): The list of channels to include in the plot.
865
+ channel_of_interest (int): The channel of interest.
866
+ figuresize (int, optional): The size of the figure. Defaults to 5.
867
+
868
+ Returns:
869
+ None
870
+ """
871
+ mask_chans.append(channel_of_interest)
872
+ if len(mask_chans) == 4:
873
+ mask_chans = [0,1,2,3]
874
+ if len(mask_chans) == 3:
875
+ mask_chans = [0,1,2]
876
+ if len(mask_chans) == 2:
877
+ mask_chans = [0,1]
878
+ if len(mask_chans) == 1:
879
+ mask_chans = [0]
880
+ controls_cols = []
881
+ for chan in mask_chans:
882
+
883
+ controls_cols_c = []
884
+ controls_cols_c.append(f'cell_channel_{chan}_mean_intensity')
885
+ controls_cols_c.append(f'nucleus_channel_{chan}_mean_intensity')
886
+ controls_cols_c.append(f'pathogen_channel_{chan}_mean_intensity')
887
+ controls_cols_c.append(f'cytoplasm_channel_{chan}_mean_intensity')
888
+ controls_cols.append(controls_cols_c)
889
+
890
+ unique_conditions = df['condition'].unique().tolist()
891
+
892
+ if len(unique_conditions) ==1:
893
+ unique_conditions=unique_conditions+unique_conditions
894
+
895
+ fig, axes = plt.subplots(len(unique_conditions), len(mask_chans)+1, figsize=(figuresize*len(mask_chans), figuresize*len(unique_conditions)))
896
+
897
+ # Define RGB color tuples (scaled to 0-1 range)
898
+ color_list = [(55/255, 155/255, 155/255),
899
+ (155/255, 55/255, 155/255),
900
+ (55/255, 155/255, 255/255),
901
+ (255/255, 55/255, 155/255)]
902
+
903
+ for idx_condition, condition in enumerate(unique_conditions):
904
+ df_temp = df[df['condition'] == condition]
905
+ for idx_channel, control_cols_c in enumerate(controls_cols):
906
+ data = []
907
+ std_dev = []
908
+ for control_col in control_cols_c:
909
+ if control_col in df_temp.columns:
910
+ mean_intensity = df_temp[control_col].mean()
911
+ mean_intensity = 0 if np.isnan(mean_intensity) else mean_intensity
912
+ data.append(mean_intensity)
913
+ std_dev.append(df_temp[control_col].std())
914
+
915
+ current_axis = axes[idx_condition][idx_channel]
916
+ current_axis.bar(["cell", "nucleus", "pathogen", "cytoplasm"], data, yerr=std_dev,
917
+ capsize=4, color=color_list)
918
+ current_axis.set_xlabel('Component')
919
+ current_axis.set_ylabel('Mean Intensity')
920
+ current_axis.set_title(f'Condition: {condition} - Channel {idx_channel}')
921
+ plt.tight_layout()
922
+ plt.show()
923
+
924
+ ###################################################
925
+ # Classify
926
+ ###################################################
927
+
928
+ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
929
+ """
930
+ Display multiple images in a grid with corresponding labels.
931
+
932
+ Args:
933
+ img (list): List of images to display.
934
+ labels (list): List of labels corresponding to each image.
935
+ nrow (int, optional): Number of images per row in the grid. Defaults to 20.
936
+ color (str, optional): Color of the label text. Defaults to 'white'.
937
+ fontsize (int, optional): Font size of the label text. Defaults to 12.
938
+ """
939
+ n_images = len(labels)
940
+ n_col = nrow
941
+ n_row = int(np.ceil(n_images / n_col))
942
+ img_height = img[0].shape[1]
943
+ img_width = img[0].shape[2]
944
+ canvas = np.zeros((img_height * n_row, img_width * n_col, 3))
945
+ for i in range(n_row):
946
+ for j in range(n_col):
947
+ idx = i * n_col + j
948
+ if idx < n_images:
949
+ canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
950
+ plt.figure(figsize=(50, 50))
951
+ plt._imshow(canvas)
952
+ plt.axis("off")
953
+ for i, label in enumerate(labels):
954
+ row = i // n_col
955
+ col = i % n_col
956
+ x = col * img_width + 2
957
+ y = row * img_height + 15
958
+ plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
959
+ plt.show()
960
+
961
+ def _plot_histograms_and_stats(df):
962
+ conditions = df['condition'].unique()
963
+
964
+ for condition in conditions:
965
+ subset = df[df['condition'] == condition]
966
+
967
+ # Calculate the statistics
968
+ mean_pred = subset['pred'].mean()
969
+ over_0_5 = sum(subset['pred'] > 0.5)
970
+ under_0_5 = sum(subset['pred'] <= 0.5)
971
+
972
+ # Print the statistics
973
+ print(f"Condition: {condition}")
974
+ print(f"Number of rows: {len(subset)}")
975
+ print(f"Mean of pred: {mean_pred}")
976
+ print(f"Count of pred values over 0.5: {over_0_5}")
977
+ print(f"Count of pred values under 0.5: {under_0_5}")
978
+ print(f"Percent positive: {(over_0_5/(over_0_5+under_0_5))*100}")
979
+ print(f"Percent negative: {(under_0_5/(over_0_5+under_0_5))*100}")
980
+ print('-'*40)
981
+
982
+ # Plot the histogram
983
+ plt.figure(figsize=(10,6))
984
+ plt.hist(subset['pred'], bins=30, edgecolor='black')
985
+ plt.axvline(mean_pred, color='red', linestyle='dashed', linewidth=1, label=f"Mean = {mean_pred:.2f}")
986
+ plt.title(f'Histogram for pred - Condition: {condition}')
987
+ plt.xlabel('Pred Value')
988
+ plt.ylabel('Count')
989
+ plt.legend()
990
+ plt.show()
991
+
992
+ def _show_residules(model):
993
+
994
+ # Get the residuals
995
+ residuals = model.resid
996
+
997
+ # Histogram of residuals
998
+ plt.hist(residuals, bins=30)
999
+ plt.title('Histogram of Residuals')
1000
+ plt.xlabel('Residual Value')
1001
+ plt.ylabel('Frequency')
1002
+ plt.show()
1003
+
1004
+ # QQ plot
1005
+ sm.qqplot(residuals, fit=True, line='45')
1006
+ plt.title('QQ Plot')
1007
+ plt.show()
1008
+
1009
+ # Residuals vs. Fitted values
1010
+ plt.scatter(model.fittedvalues, residuals)
1011
+ plt.xlabel('Fitted values')
1012
+ plt.ylabel('Residuals')
1013
+ plt.title('Residuals vs. Fitted Values')
1014
+ plt.axhline(y=0, color='red')
1015
+ plt.show()
1016
+
1017
+ # Shapiro-Wilk test for normality
1018
+ W, p_value = stats.shapiro(residuals)
1019
+ print(f'Shapiro-Wilk Test W-statistic: {W}, p-value: {p_value}')
1020
+
1021
+ def _reg_v_plot(df, grouping, variable, plate_number):
1022
+ df['-log10(p)'] = -np.log10(df['p'])
1023
+
1024
+ # Create the volcano plot
1025
+ plt.figure(figsize=(40, 30))
1026
+ sc = plt.scatter(df['effect'], df['-log10(p)'], c=np.sign(df['effect']), cmap='coolwarm')
1027
+ plt.title('Volcano Plot', fontsize=12)
1028
+ plt.xlabel('Coefficient', fontsize=12)
1029
+ plt.ylabel('-log10(P-value)', fontsize=12)
1030
+
1031
+ # Add text for specified points
1032
+ for idx, row in df.iterrows():
1033
+ if row['p'] < 0.05:# and abs(row['effect']) > 0.1:
1034
+ plt.text(row['effect'], -np.log10(row['p']), idx, fontsize=12, ha='center', va='bottom', color='black')
1035
+
1036
+ plt.axhline(y=-np.log10(0.05), color='gray', linestyle='--') # line for p=0.05
1037
+ plt.show()
1038
+
1039
+ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1040
+ df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
1041
+ df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
1042
+
1043
+ # Filtering the dataframe based on the plate_number
1044
+ df = df[df['plate'] == plate_number].copy() # Create another copy after filtering
1045
+
1046
+ # Ensure proper ordering
1047
+ row_order = [f'r{i}' for i in range(1, 17)]
1048
+ col_order = [f'c{i}' for i in range(1, 28)] # Exclude c15 as per your earlier code
1049
+
1050
+ df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
1051
+ df['col'] = pd.Categorical(df['col'], categories=col_order, ordered=True)
1052
+
1053
+ # Explicitly set observed=True to avoid FutureWarning
1054
+ grouped = df.groupby(['row', 'col'], observed=True)
1055
+
1056
+ if grouping == 'mean':
1057
+ plate = grouped[variable].mean().reset_index()
1058
+ elif grouping == 'sum':
1059
+ plate = grouped[variable].sum().reset_index()
1060
+ elif grouping == 'count':
1061
+ plate = grouped[variable].count().reset_index()
1062
+ else:
1063
+ raise ValueError(f"Unsupported grouping: {grouping}")
1064
+
1065
+ plate_map = pd.pivot_table(plate, values=variable, index='row', columns='col').fillna(0)
1066
+
1067
+ if min_max == 'all':
1068
+ min_max = [plate_map.min().min(), plate_map.max().max()]
1069
+ elif min_max == 'allq':
1070
+ min_max = np.quantile(plate_map.values, [0.2, 0.98])
1071
+ elif min_max == 'plate':
1072
+ min_max = [plate_map.min().min(), plate_map.max().max()]
1073
+
1074
+ return plate_map, min_max
1075
+
1076
+ def _plot_plates(df, variable, grouping, min_max, cmap):
1077
+ plates = df['prc'].str.split('_', expand=True)[0].unique()
1078
+ n_rows, n_cols = (len(plates) + 3) // 4, 4
1079
+ fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
1080
+ ax = ax.flatten()
1081
+
1082
+ for index, plate in enumerate(plates):
1083
+ plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max)
1084
+ sns.heatmap(plate_map, cmap=cmap, vmin=0, vmax=2, ax=ax[index])
1085
+ ax[index].set_title(plate)
1086
+
1087
+ for i in range(len(plates), n_rows * n_cols):
1088
+ fig.delaxes(ax[i])
1089
+
1090
+ plt.subplots_adjust(wspace=0.1, hspace=0.4)
1091
+ plt.show()
1092
+ return
1093
+
1094
+ #from finetune cellpose
1095
+ #def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
1096
+ # paths = []
1097
+ # for file in os.listdir(src):
1098
+ # if file.endswith('.tif') or file.endswith('.tiff'):
1099
+ # path = os.path.join(src, file)
1100
+ # paths.append(path)
1101
+ # paths = random.sample(paths, nr)
1102
+ # for path in paths:
1103
+ # print(f'Image path:{path}')
1104
+ # img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
1105
+ # if normalize:
1106
+ # img = normalize_to_dtype(array=img, q1=q1, q2=q2)
1107
+ # dim = img.shape
1108
+ # if len(img.shape) > 2:
1109
+ # array_nr = img.shape[2]
1110
+ # fig, axs = plt.subplots(1, array_nr, figsize=(figuresize, figuresize))
1111
+ # for channel in range(array_nr):
1112
+ # i = np.take(img, [channel], axis=2)
1113
+ # axs[channel].imshow(i, cmap=plt.get_cmap(cmap))
1114
+ # axs[channel].set_title('Channel '+str(channel), size=24)
1115
+ # axs[channel].axis('off')
1116
+ # else:
1117
+ # fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
1118
+ # ax.imshow(img, cmap=plt.get_cmap(cmap))
1119
+ # ax.set_title('Channel 0', size=24)
1120
+ # ax.axis('off')
1121
+ # fig.tight_layout()
1122
+ # plt.show()
1123
+ # return
1124
+
1125
+ def print_mask_and_flows(stack, mask, flows, overlay=False):
1126
+ fig, axs = plt.subplots(1, 3, figsize=(30, 10)) # Adjust subplot layout
1127
+
1128
+ if stack.shape[-1] == 1:
1129
+ stack = np.squeeze(stack)
1130
+
1131
+ # Display original image or its first channel
1132
+ if stack.ndim == 2:
1133
+ axs[0].imshow(stack, cmap='gray')
1134
+ elif stack.ndim == 3:
1135
+ axs[0].imshow(stack)
1136
+ else:
1137
+ raise ValueError("Unexpected stack dimensionality.")
1138
+
1139
+ axs[0].set_title('Original Image')
1140
+ axs[0].axis('off')
1141
+
1142
+
1143
+ # Overlay mask on original image if overlay is True
1144
+ if overlay:
1145
+ mask_cmap = generate_mask_random_cmap(mask) # Generate random colormap for mask
1146
+ mask_overlay = np.ma.masked_where(mask == 0, mask) # Mask background
1147
+ outlines = find_boundaries(mask, mode='thick') # Find mask outlines
1148
+
1149
+ if stack.ndim == 2 or stack.ndim == 3:
1150
+ axs[1].imshow(stack, cmap='gray' if stack.ndim == 2 else None)
1151
+ axs[1].imshow(mask_overlay, cmap=mask_cmap, alpha=0.5) # Overlay mask
1152
+ axs[1].contour(outlines, colors='r', linewidths=2) # Add red outlines with thickness 2
1153
+ else:
1154
+ axs[1].imshow(mask, cmap='gray')
1155
+
1156
+ axs[1].set_title('Mask with Overlay' if overlay else 'Mask')
1157
+ axs[1].axis('off')
1158
+
1159
+ # Display flow image or its first channel
1160
+ if flows and isinstance(flows, list) and flows[0].ndim in [2, 3]:
1161
+ flow_image = flows[0]
1162
+ if flow_image.ndim == 3:
1163
+ flow_image = flow_image[:, :, 0] # Use first channel for 3D
1164
+ axs[2].imshow(flow_image, cmap='jet')
1165
+ else:
1166
+ raise ValueError("Unexpected flow dimensionality or structure.")
1167
+
1168
+ axs[2].set_title('Flows')
1169
+ axs[2].axis('off')
1170
+
1171
+ fig.tight_layout()
1172
+ plt.show()
1173
+
1174
+ def plot_resize(images, resized_images, labels, resized_labels):
1175
+ # Display an example image and label before and after resizing
1176
+ fig, ax = plt.subplots(2, 2, figsize=(20, 20))
1177
+
1178
+ # Check if the image is grayscale; if so, add a colormap and keep dimensions correct
1179
+ if images[0].ndim == 2: # Grayscale image
1180
+ ax[0, 0].imshow(images[0], cmap='gray')
1181
+ else: # RGB or RGBA image
1182
+ ax[0, 0].imshow(images[0])
1183
+ ax[0, 0].set_title('Original Image')
1184
+
1185
+ if resized_images[0].ndim == 2: # Grayscale image
1186
+ ax[0, 1].imshow(resized_images[0], cmap='gray')
1187
+ else: # RGB or RGBA image
1188
+ ax[0, 1].imshow(resized_images[0])
1189
+ ax[0, 1].set_title('Resized Image')
1190
+
1191
+ # Assuming labels are always grayscale (most common scenario)
1192
+ ax[1, 0].imshow(labels[0], cmap='gray')
1193
+ ax[1, 0].set_title('Original Label')
1194
+ ax[1, 1].imshow(resized_labels[0], cmap='gray')
1195
+ ax[1, 1].set_title('Resized Label')
1196
+ plt.show()
1197
+
1198
+ def normalize_and_visualize(image, normalized_image, title=""):
1199
+ """Utility function for visualization"""
1200
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6))
1201
+ if image.ndim == 3: # Multi-channel image
1202
+ ax[0].imshow(np.mean(image, axis=-1), cmap='gray') # Display the average over channels for visualization
1203
+ else: # Grayscale image
1204
+ ax[0].imshow(image, cmap='gray')
1205
+ ax[0].set_title("Original " + title)
1206
+ ax[0].axis('off')
1207
+
1208
+ if normalized_image.ndim == 3:
1209
+ ax[1].imshow(np.mean(normalized_image, axis=-1), cmap='gray') # Similarly, display the average over channels
1210
+ else:
1211
+ ax[1].imshow(normalized_image, cmap='gray')
1212
+ ax[1].set_title("Normalized " + title)
1213
+ ax[1].axis('off')
1214
+
1215
+ plt.show()
1216
+
1217
+ def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
1218
+ fig, axs = plt.subplots(1, 3, figsize=(30, 10))
1219
+ for ax, mask, title in zip(axs, [mask1, mask2, mask3], ['Mask 1', 'Mask 2', 'Mask 3']):
1220
+ cmap = generate_mask_random_cmap(mask)
1221
+ # If the mask is binary, we can skip normalization
1222
+ if np.isin(mask, [0, 1]).all():
1223
+ ax.imshow(mask, cmap=cmap)
1224
+ else:
1225
+ # Normalize the image for displaying purposes
1226
+ norm = plt.Normalize(vmin=0, vmax=mask.max())
1227
+ ax.imshow(mask, cmap=cmap, norm=norm)
1228
+ ax.set_title(title)
1229
+ ax.axis('off')
1230
+ plt.suptitle(title)
1231
+ plt.show()
1232
+
1233
+ def plot_comparison_results(comparison_results):
1234
+ df = pd.DataFrame(comparison_results)
1235
+ df_melted = pd.melt(df, id_vars=['filename'], var_name='metric', value_name='value')
1236
+ df_jaccard = df_melted[df_melted['metric'].str.contains('jaccard')]
1237
+ df_dice = df_melted[df_melted['metric'].str.contains('dice')]
1238
+ df_boundary_f1 = df_melted[df_melted['metric'].str.contains('boundary_f1')]
1239
+ df_ap = df_melted[df_melted['metric'].str.contains('average_precision')]
1240
+ fig, axs = plt.subplots(1, 4, figsize=(40, 10))
1241
+
1242
+ # Jaccard Index Plot
1243
+ sns.boxplot(data=df_jaccard, x='metric', y='value', ax=axs[0], color='lightgrey')
1244
+ sns.stripplot(data=df_jaccard, x='metric', y='value', ax=axs[0], jitter=True, alpha=0.6)
1245
+ axs[0].set_title('Jaccard Index by Comparison')
1246
+ axs[0].set_xticklabels(axs[0].get_xticklabels(), rotation=45, horizontalalignment='right')
1247
+ axs[0].set_xlabel('Comparison')
1248
+ axs[0].set_ylabel('Jaccard Index')
1249
+ # Dice Coefficient Plot
1250
+ sns.boxplot(data=df_dice, x='metric', y='value', ax=axs[1], color='lightgrey')
1251
+ sns.stripplot(data=df_dice, x='metric', y='value', ax=axs[1], jitter=True, alpha=0.6)
1252
+ axs[1].set_title('Dice Coefficient by Comparison')
1253
+ axs[1].set_xticklabels(axs[1].get_xticklabels(), rotation=45, horizontalalignment='right')
1254
+ axs[1].set_xlabel('Comparison')
1255
+ axs[1].set_ylabel('Dice Coefficient')
1256
+ # Border F1 scores
1257
+ sns.boxplot(data=df_boundary_f1, x='metric', y='value', ax=axs[2], color='lightgrey')
1258
+ sns.stripplot(data=df_boundary_f1, x='metric', y='value', ax=axs[2], jitter=True, alpha=0.6)
1259
+ axs[2].set_title('Boundary F1 Score by Comparison')
1260
+ axs[2].set_xticklabels(axs[2].get_xticklabels(), rotation=45, horizontalalignment='right')
1261
+ axs[2].set_xlabel('Comparison')
1262
+ axs[2].set_ylabel('Boundary F1 Score')
1263
+ # AP scores plot
1264
+ sns.boxplot(data=df_ap, x='metric', y='value', ax=axs[3], color='lightgrey')
1265
+ sns.stripplot(data=df_ap, x='metric', y='value', ax=axs[3], jitter=True, alpha=0.6)
1266
+ axs[3].set_title('Average Precision by Comparison')
1267
+ axs[3].set_xticklabels(axs[3].get_xticklabels(), rotation=45, horizontalalignment='right')
1268
+ axs[3].set_xlabel('Comparison')
1269
+ axs[3].set_ylabel('Average Precision')
1270
+
1271
+ plt.tight_layout()
1272
+ plt.show()
1273
+ return fig