spacr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +37 -0
- spacr/__main__.py +15 -0
- spacr/annotate_app.py +495 -0
- spacr/cli.py +203 -0
- spacr/core.py +2250 -0
- spacr/gui_mask_app.py +247 -0
- spacr/gui_measure_app.py +214 -0
- spacr/gui_utils.py +488 -0
- spacr/io.py +2271 -0
- spacr/logger.py +20 -0
- spacr/mask_app.py +818 -0
- spacr/measure.py +1014 -0
- spacr/old_code.py +104 -0
- spacr/plot.py +1273 -0
- spacr/sim.py +1187 -0
- spacr/timelapse.py +576 -0
- spacr/train.py +494 -0
- spacr/umap.py +689 -0
- spacr/utils.py +2726 -0
- spacr/version.py +19 -0
- spacr-0.0.1.dist-info/LICENSE +21 -0
- spacr-0.0.1.dist-info/METADATA +64 -0
- spacr-0.0.1.dist-info/RECORD +26 -0
- spacr-0.0.1.dist-info/WHEEL +5 -0
- spacr-0.0.1.dist-info/entry_points.txt +5 -0
- spacr-0.0.1.dist-info/top_level.txt +1 -0
spacr/plot.py
ADDED
@@ -0,0 +1,1273 @@
|
|
1
|
+
import os,re, random, cv2, glob, time, math
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import pandas as pd
|
5
|
+
import matplotlib.pyplot as plt
|
6
|
+
import matplotlib as mpl
|
7
|
+
import scipy.ndimage as ndi
|
8
|
+
import seaborn as sns
|
9
|
+
import scipy.stats as stats
|
10
|
+
import statsmodels.api as sm
|
11
|
+
|
12
|
+
from IPython.display import display
|
13
|
+
from skimage.segmentation import find_boundaries
|
14
|
+
from skimage.measure import find_contours
|
15
|
+
from skimage.morphology import square, dilation
|
16
|
+
from skimage import measure
|
17
|
+
|
18
|
+
from ipywidgets import IntSlider, interact
|
19
|
+
from IPython.display import Image as ipyimage
|
20
|
+
|
21
|
+
from .logger import log_function_call
|
22
|
+
|
23
|
+
|
24
|
+
#from .io import _save_figure
|
25
|
+
#from .timelapse import _save_mask_timelapse_as_gif
|
26
|
+
#from .utils import normalize_to_dtype, _remove_outside_objects, _remove_multiobject_cells, _find_similar_sized_images, _remove_noninfected
|
27
|
+
|
28
|
+
def plot_masks(batch, masks, flows, cmap='inferno', figuresize=20, nr=1, file_type='.npz', print_object_number=True):
|
29
|
+
"""
|
30
|
+
Plot the masks and flows for a given batch of images.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
batch (numpy.ndarray): The batch of images.
|
34
|
+
masks (list or numpy.ndarray): The masks corresponding to the images.
|
35
|
+
flows (list or numpy.ndarray): The flows corresponding to the images.
|
36
|
+
cmap (str, optional): The colormap to use for displaying the images. Defaults to 'inferno'.
|
37
|
+
figuresize (int, optional): The size of the figure. Defaults to 20.
|
38
|
+
nr (int, optional): The maximum number of images to plot. Defaults to 1.
|
39
|
+
file_type (str, optional): The file type of the flows. Defaults to '.npz'.
|
40
|
+
print_object_number (bool, optional): Whether to print the object number on the mask. Defaults to True.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
None
|
44
|
+
"""
|
45
|
+
if len(batch.shape) == 3:
|
46
|
+
batch = np.expand_dims(batch, axis=0)
|
47
|
+
if not isinstance(masks, list):
|
48
|
+
masks = [masks]
|
49
|
+
if not isinstance(flows, list):
|
50
|
+
flows = [flows]
|
51
|
+
else:
|
52
|
+
flows = flows[0]
|
53
|
+
if file_type == 'png':
|
54
|
+
flows = [f[0] for f in flows] # assuming this is what you want to do when file_type is 'png'
|
55
|
+
font = figuresize/2
|
56
|
+
index = 0
|
57
|
+
for image, mask, flow in zip(batch, masks, flows):
|
58
|
+
unique_labels = np.unique(mask)
|
59
|
+
|
60
|
+
num_objects = len(unique_labels[unique_labels != 0])
|
61
|
+
random_colors = np.random.rand(num_objects+1, 4)
|
62
|
+
random_colors[:, 3] = 1
|
63
|
+
random_colors[0, :] = [0, 0, 0, 1]
|
64
|
+
random_cmap = mpl.colors.ListedColormap(random_colors)
|
65
|
+
|
66
|
+
if index < nr:
|
67
|
+
index += 1
|
68
|
+
chans = image.shape[-1]
|
69
|
+
fig, ax = plt.subplots(1, image.shape[-1] + 2, figsize=(4 * figuresize, figuresize))
|
70
|
+
for v in range(0, image.shape[-1]):
|
71
|
+
ax[v].imshow(image[..., v], cmap=cmap) #_imshow
|
72
|
+
ax[v].set_title('Image - Channel'+str(v))
|
73
|
+
ax[chans].imshow(mask, cmap=random_cmap) #_imshow
|
74
|
+
ax[chans].set_title('Mask')
|
75
|
+
if print_object_number:
|
76
|
+
unique_objects = np.unique(mask)[1:]
|
77
|
+
for obj in unique_objects:
|
78
|
+
cy, cx = ndi.center_of_mass(mask == obj)
|
79
|
+
ax[chans].text(cx, cy, str(obj), color='white', fontsize=font, ha='center', va='center')
|
80
|
+
ax[chans+1].imshow(flow, cmap='viridis') #_imshow
|
81
|
+
ax[chans+1].set_title('Flow')
|
82
|
+
plt.show()
|
83
|
+
return
|
84
|
+
|
85
|
+
def _plot_4D_arrays(src, figuresize=10, cmap='inferno', nr_npz=1, nr=1):
|
86
|
+
"""
|
87
|
+
Plot 4D arrays from .npz files.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
src (str): The directory path where the .npz files are located.
|
91
|
+
figuresize (int, optional): The size of the figure. Defaults to 10.
|
92
|
+
cmap (str, optional): The colormap to use for image visualization. Defaults to 'inferno'.
|
93
|
+
nr_npz (int, optional): The number of .npz files to plot. Defaults to 1.
|
94
|
+
nr (int, optional): The number of images to plot from each .npz file. Defaults to 1.
|
95
|
+
"""
|
96
|
+
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
97
|
+
paths = random.sample(paths, min(nr_npz, len(paths)))
|
98
|
+
|
99
|
+
for path in paths:
|
100
|
+
with np.load(path) as data:
|
101
|
+
stack = data['data']
|
102
|
+
num_images = stack.shape[0]
|
103
|
+
num_channels = stack.shape[3]
|
104
|
+
|
105
|
+
for i in range(min(nr, num_images)):
|
106
|
+
img = stack[i]
|
107
|
+
|
108
|
+
# Create subplots
|
109
|
+
if num_channels == 1:
|
110
|
+
fig, axs = plt.subplots(1, 1, figsize=(figuresize, figuresize))
|
111
|
+
axs = [axs] # Make axs a list to use axs[c] later
|
112
|
+
else:
|
113
|
+
fig, axs = plt.subplots(1, num_channels, figsize=(num_channels * figuresize, figuresize))
|
114
|
+
|
115
|
+
for c in range(num_channels):
|
116
|
+
axs[c].imshow(img[:, :, c], cmap=cmap) #_imshow
|
117
|
+
axs[c].set_title(f'Channel {c}', size=24)
|
118
|
+
axs[c].axis('off')
|
119
|
+
|
120
|
+
fig.tight_layout()
|
121
|
+
plt.show()
|
122
|
+
return
|
123
|
+
|
124
|
+
def generate_mask_random_cmap(mask):
|
125
|
+
"""
|
126
|
+
Generate a random colormap based on the unique labels in the given mask.
|
127
|
+
|
128
|
+
Parameters:
|
129
|
+
mask (numpy.ndarray): The input mask array.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
matplotlib.colors.ListedColormap: The random colormap.
|
133
|
+
"""
|
134
|
+
unique_labels = np.unique(mask)
|
135
|
+
num_objects = len(unique_labels[unique_labels != 0])
|
136
|
+
random_colors = np.random.rand(num_objects+1, 4)
|
137
|
+
random_colors[:, 3] = 1
|
138
|
+
random_colors[0, :] = [0, 0, 0, 1]
|
139
|
+
random_cmap = mpl.colors.ListedColormap(random_colors)
|
140
|
+
return random_cmap
|
141
|
+
|
142
|
+
def random_cmap(num_objects=100):
|
143
|
+
"""
|
144
|
+
Generate a random colormap.
|
145
|
+
|
146
|
+
Parameters:
|
147
|
+
num_objects (int): The number of objects to generate colors for. Default is 100.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
random_cmap (matplotlib.colors.ListedColormap): A random colormap.
|
151
|
+
"""
|
152
|
+
random_colors = np.random.rand(num_objects+1, 4)
|
153
|
+
random_colors[:, 3] = 1
|
154
|
+
random_colors[0, :] = [0, 0, 0, 1]
|
155
|
+
random_cmap = mpl.colors.ListedColormap(random_colors)
|
156
|
+
return random_cmap
|
157
|
+
|
158
|
+
def _generate_mask_random_cmap(mask):
|
159
|
+
"""
|
160
|
+
Generate a random colormap based on the unique labels in the given mask.
|
161
|
+
|
162
|
+
Parameters:
|
163
|
+
mask (ndarray): The mask array containing unique labels.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
ListedColormap: A random colormap generated based on the unique labels in the mask.
|
167
|
+
"""
|
168
|
+
unique_labels = np.unique(mask)
|
169
|
+
num_objects = len(unique_labels[unique_labels != 0])
|
170
|
+
random_colors = np.random.rand(num_objects+1, 4)
|
171
|
+
random_colors[:, 3] = 1
|
172
|
+
random_colors[0, :] = [0, 0, 0, 1]
|
173
|
+
random_cmap = mpl.colors.ListedColormap(random_colors)
|
174
|
+
return random_cmap
|
175
|
+
|
176
|
+
def _get_colours_merged(outline_color):
|
177
|
+
"""
|
178
|
+
Get the merged outline colors based on the specified outline color format.
|
179
|
+
|
180
|
+
Parameters:
|
181
|
+
outline_color (str): The outline color format. Can be one of 'rgb', 'bgr', 'gbr', or 'rbg'.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
list: A list of merged outline colors based on the specified format.
|
185
|
+
"""
|
186
|
+
if outline_color == 'rgb':
|
187
|
+
outline_colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rgb
|
188
|
+
elif outline_color == 'bgr':
|
189
|
+
outline_colors = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] # bgr
|
190
|
+
elif outline_color == 'gbr':
|
191
|
+
outline_colors = [[0, 1, 0], [0, 0, 1], [1, 0, 0]] # gbr
|
192
|
+
elif outline_color == 'rbg':
|
193
|
+
outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
|
194
|
+
else:
|
195
|
+
outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
|
196
|
+
return outline_colors
|
197
|
+
|
198
|
+
def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
|
199
|
+
"""
|
200
|
+
Filters objects in a plot based on various criteria.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
stack (numpy.ndarray): The input stack of masks.
|
204
|
+
cell_mask_dim (int): The dimension index of the cell mask.
|
205
|
+
nucleus_mask_dim (int): The dimension index of the nucleus mask.
|
206
|
+
pathogen_mask_dim (int): The dimension index of the pathogen mask.
|
207
|
+
mask_dims (list): A list of dimension indices for additional masks.
|
208
|
+
filter_min_max (list): A list of minimum and maximum area values for each mask.
|
209
|
+
include_multinucleated (bool): Whether to include multinucleated cells.
|
210
|
+
include_multiinfected (bool): Whether to include multiinfected cells.
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
numpy.ndarray: The filtered stack of masks.
|
214
|
+
"""
|
215
|
+
from .utils import _remove_outside_objects, _remove_multiobject_cells
|
216
|
+
|
217
|
+
stack = _remove_outside_objects(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim)
|
218
|
+
|
219
|
+
for i, mask_dim in enumerate(mask_dims):
|
220
|
+
if not filter_min_max is None:
|
221
|
+
min_max = filter_min_max[i]
|
222
|
+
else:
|
223
|
+
min_max = [0, 100000]
|
224
|
+
|
225
|
+
mask = np.take(stack, mask_dim, axis=2)
|
226
|
+
props = measure.regionprops_table(mask, properties=['label', 'area'])
|
227
|
+
avg_size_before = np.mean(props['area'])
|
228
|
+
total_count_before = len(props['label'])
|
229
|
+
|
230
|
+
if not filter_min_max is None:
|
231
|
+
valid_labels = props['label'][np.logical_and(props['area'] > min_max[0], props['area'] < min_max[1])]
|
232
|
+
stack[:, :, mask_dim] = np.isin(mask, valid_labels) * mask
|
233
|
+
|
234
|
+
props_after = measure.regionprops_table(stack[:, :, mask_dim], properties=['label', 'area'])
|
235
|
+
avg_size_after = np.mean(props_after['area'])
|
236
|
+
total_count_after = len(props_after['label'])
|
237
|
+
|
238
|
+
if mask_dim == cell_mask_dim:
|
239
|
+
if include_multinucleated is False and nucleus_mask_dim is not None:
|
240
|
+
stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=pathogen_mask_dim)
|
241
|
+
if include_multiinfected is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
|
242
|
+
stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=nucleus_mask_dim)
|
243
|
+
cell_area_before = avg_size_before
|
244
|
+
cell_count_before = total_count_before
|
245
|
+
cell_area_after = avg_size_after
|
246
|
+
cell_count_after = total_count_after
|
247
|
+
if mask_dim == nucleus_mask_dim:
|
248
|
+
nucleus_area_before = avg_size_before
|
249
|
+
nucleus_count_before = total_count_before
|
250
|
+
nucleus_area_after = avg_size_after
|
251
|
+
nucleus_count_after = total_count_after
|
252
|
+
if mask_dim == pathogen_mask_dim:
|
253
|
+
pathogen_area_before = avg_size_before
|
254
|
+
pathogen_count_before = total_count_before
|
255
|
+
pathogen_area_after = avg_size_after
|
256
|
+
pathogen_count_after = total_count_after
|
257
|
+
|
258
|
+
if cell_mask_dim is not None:
|
259
|
+
print(f'removed {cell_count_before-cell_count_after} cells, cell size from {cell_area_before} to {cell_area_after}')
|
260
|
+
if nucleus_mask_dim is not None:
|
261
|
+
print(f'removed {nucleus_count_before-nucleus_count_after} nucleus, nucleus size from {nucleus_area_before} to {nucleus_area_after}')
|
262
|
+
if pathogen_mask_dim is not None:
|
263
|
+
print(f'removed {pathogen_count_before-pathogen_count_after} pathogens, pathogen size from {pathogen_area_before} to {pathogen_area_after}')
|
264
|
+
|
265
|
+
return stack
|
266
|
+
|
267
|
+
def _normalize_and_outline(image, remove_background, backgrounds, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
|
268
|
+
"""
|
269
|
+
Normalize and outline an image.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
image (ndarray): The input image.
|
273
|
+
remove_background (bool): Flag indicating whether to remove the background.
|
274
|
+
backgrounds (list): List of background values for each channel.
|
275
|
+
normalize (bool): Flag indicating whether to normalize the image.
|
276
|
+
normalization_percentiles (list): List of percentiles for normalization.
|
277
|
+
overlay (bool): Flag indicating whether to overlay outlines onto the image.
|
278
|
+
overlay_chans (list): List of channel indices to overlay.
|
279
|
+
mask_dims (list): List of dimensions to use for masking.
|
280
|
+
outline_colors (list): List of colors for the outlines.
|
281
|
+
outline_thickness (int): Thickness of the outlines.
|
282
|
+
|
283
|
+
Returns:
|
284
|
+
tuple: A tuple containing the overlayed image, the original image, and a list of outlines.
|
285
|
+
"""
|
286
|
+
from .utils import normalize_to_dtype
|
287
|
+
|
288
|
+
outlines = []
|
289
|
+
if remove_background:
|
290
|
+
for chan_index, channel in enumerate(range(image.shape[-1])):
|
291
|
+
single_channel = image[:, :, channel] # Extract the specific channel
|
292
|
+
background = backgrounds[chan_index]
|
293
|
+
single_channel[single_channel < background] = 0
|
294
|
+
image[:, :, channel] = single_channel
|
295
|
+
if normalize:
|
296
|
+
image = normalize_to_dtype(array=image, q1=normalization_percentiles[0], q2=normalization_percentiles[1])
|
297
|
+
rgb_image = np.take(image, overlay_chans, axis=-1)
|
298
|
+
rgb_image = rgb_image.astype(float)
|
299
|
+
rgb_image -= rgb_image.min()
|
300
|
+
rgb_image /= rgb_image.max()
|
301
|
+
if overlay:
|
302
|
+
overlayed_image = rgb_image.copy()
|
303
|
+
for i, mask_dim in enumerate(mask_dims):
|
304
|
+
mask = np.take(image, mask_dim, axis=2)
|
305
|
+
outline = np.zeros_like(mask)
|
306
|
+
# Find the contours of the objects in the mask
|
307
|
+
for j in np.unique(mask)[1:]:
|
308
|
+
contours = find_contours(mask == j, 0.5)
|
309
|
+
for contour in contours:
|
310
|
+
contour = contour.astype(int)
|
311
|
+
outline[contour[:, 0], contour[:, 1]] = j
|
312
|
+
# Make the outline thicker
|
313
|
+
outline = dilation(outline, square(outline_thickness))
|
314
|
+
outlines.append(outline)
|
315
|
+
# Overlay the outlines onto the RGB image
|
316
|
+
for j in np.unique(outline)[1:]:
|
317
|
+
overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
|
318
|
+
return overlayed_image, image, outlines
|
319
|
+
else:
|
320
|
+
return [], image, []
|
321
|
+
|
322
|
+
def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_image, outlines, cmap, outline_colors, print_object_number):
|
323
|
+
"""
|
324
|
+
Plot the merged plot with overlay, image channels, and masks.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
overlay (bool): Flag indicating whether to overlay the image with outlines.
|
328
|
+
image (ndarray): Input image array.
|
329
|
+
stack (ndarray): Stack of masks.
|
330
|
+
mask_dims (list): List of mask dimensions.
|
331
|
+
figuresize (float): Size of the figure.
|
332
|
+
overlayed_image (ndarray): Overlayed image array.
|
333
|
+
outlines (list): List of outlines.
|
334
|
+
cmap (str): Colormap for the masks.
|
335
|
+
outline_colors (list): List of outline colors.
|
336
|
+
print_object_number (bool): Flag indicating whether to print object numbers on the masks.
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
fig (Figure): The generated matplotlib figure.
|
340
|
+
"""
|
341
|
+
if overlay:
|
342
|
+
fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims) + 1, figsize=(4 * figuresize, figuresize))
|
343
|
+
ax[0].imshow(overlayed_image) #_imshow
|
344
|
+
ax[0].set_title('Overlayed Image')
|
345
|
+
ax_index = 1
|
346
|
+
else:
|
347
|
+
fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims), figsize=(4 * figuresize, figuresize))
|
348
|
+
ax_index = 0
|
349
|
+
|
350
|
+
# Normalize and plot each channel with outlines
|
351
|
+
for v in range(0, image.shape[-1]):
|
352
|
+
channel_image = image[..., v]
|
353
|
+
channel_image_normalized = channel_image.astype(float)
|
354
|
+
channel_image_normalized -= channel_image_normalized.min()
|
355
|
+
channel_image_normalized /= channel_image_normalized.max()
|
356
|
+
channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
|
357
|
+
|
358
|
+
# Apply the outlines onto the RGB image
|
359
|
+
for outline, color in zip(outlines, outline_colors):
|
360
|
+
for j in np.unique(outline)[1:]:
|
361
|
+
channel_image_rgb[outline == j] = mpl.colors.to_rgb(color)
|
362
|
+
|
363
|
+
ax[v + ax_index].imshow(channel_image_rgb)
|
364
|
+
ax[v + ax_index].set_title('Image - Channel'+str(v))
|
365
|
+
|
366
|
+
for i, mask_dim in enumerate(mask_dims):
|
367
|
+
mask = np.take(stack, mask_dim, axis=2)
|
368
|
+
random_cmap = _generate_mask_random_cmap(mask)
|
369
|
+
ax[i + image.shape[-1] + ax_index].imshow(mask, cmap=random_cmap)
|
370
|
+
ax[i + image.shape[-1] + ax_index].set_title('Mask '+ str(i))
|
371
|
+
if print_object_number:
|
372
|
+
unique_objects = np.unique(mask)[1:]
|
373
|
+
for obj in unique_objects:
|
374
|
+
cy, cx = ndi.center_of_mass(mask == obj)
|
375
|
+
ax[i + image.shape[-1] + ax_index].text(cx, cy, str(obj), color='white', fontsize=8, ha='center', va='center')
|
376
|
+
|
377
|
+
plt.tight_layout()
|
378
|
+
plt.show()
|
379
|
+
return fig
|
380
|
+
|
381
|
+
def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
|
382
|
+
"""
|
383
|
+
Plot randomly selected arrays from a given directory.
|
384
|
+
|
385
|
+
Parameters:
|
386
|
+
- src (str): The directory path containing the arrays.
|
387
|
+
- figuresize (int): The size of the figure (default: 50).
|
388
|
+
- cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
|
389
|
+
- nr (int): The number of arrays to plot (default: 1).
|
390
|
+
- normalize (bool): Whether to normalize the arrays (default: True).
|
391
|
+
- q1 (int): The lower percentile for normalization (default: 1).
|
392
|
+
- q2 (int): The upper percentile for normalization (default: 99).
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
None
|
396
|
+
"""
|
397
|
+
from .utils import normalize_to_dtype
|
398
|
+
|
399
|
+
mask_cmap = random_cmap()
|
400
|
+
paths = []
|
401
|
+
for file in os.listdir(src):
|
402
|
+
if file.endswith('.npy'):
|
403
|
+
path = os.path.join(src, file)
|
404
|
+
paths.append(path)
|
405
|
+
paths = random.sample(paths, nr)
|
406
|
+
for path in paths:
|
407
|
+
print(f'Image path:{path}')
|
408
|
+
img = np.load(path)
|
409
|
+
if normalize:
|
410
|
+
img = normalize_to_dtype(array=img, q1=q1, q2=q2)
|
411
|
+
dim = img.shape
|
412
|
+
if len(img.shape)>2:
|
413
|
+
array_nr = img.shape[2]
|
414
|
+
fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
|
415
|
+
for channel in range(array_nr):
|
416
|
+
i = np.take(img, [channel], axis=2)
|
417
|
+
axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
|
418
|
+
axs[channel].set_title('Channel '+str(channel),size=24)
|
419
|
+
axs[channel].axis('off')
|
420
|
+
else:
|
421
|
+
fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
|
422
|
+
ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
|
423
|
+
ax.set_title('Channel 0',size=24)
|
424
|
+
ax.axis('off')
|
425
|
+
fig.tight_layout()
|
426
|
+
plt.show()
|
427
|
+
return
|
428
|
+
|
429
|
+
def plot_merged(src, settings):
|
430
|
+
"""
|
431
|
+
Plot the merged images after applying various filters and modifications.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
src (ndarray): The source images.
|
435
|
+
settings (dict): The settings for the plot.
|
436
|
+
|
437
|
+
Returns:
|
438
|
+
None
|
439
|
+
"""
|
440
|
+
from .utils import _remove_noninfected
|
441
|
+
|
442
|
+
font = settings['figuresize']/2
|
443
|
+
outline_colors = _get_colours_merged(settings['outline_color'])
|
444
|
+
index = 0
|
445
|
+
|
446
|
+
mask_dims = [settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim']]
|
447
|
+
mask_dims = [element for element in mask_dims if element is not None]
|
448
|
+
|
449
|
+
if settings['verbose']:
|
450
|
+
display(settings)
|
451
|
+
|
452
|
+
if settings['pathogen_mask_dim'] is None:
|
453
|
+
settings['include_multiinfected'] = True
|
454
|
+
|
455
|
+
for file in os.listdir(src):
|
456
|
+
path = os.path.join(src, file)
|
457
|
+
stack = np.load(path)
|
458
|
+
print(f'Loaded: {path}')
|
459
|
+
if not settings['include_noninfected']:
|
460
|
+
if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
|
461
|
+
stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
|
462
|
+
|
463
|
+
if settings['include_multiinfected'] is not True or settings['include_multinucleated'] is not True or settings['filter_min_max'] is not None:
|
464
|
+
stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['include_multinucleated'], settings['include_multiinfected'])
|
465
|
+
|
466
|
+
#image = np.take(stack, settings['channel_dims'], axis=2)
|
467
|
+
print('stack.shape', stack.shape)
|
468
|
+
overlayed_image, image, outlines = _normalize_and_outline(stack, settings['remove_background'], settings['backgrounds'], settings['normalize'], settings['normalization_percentiles'], settings['overlay'], settings['overlay_chans'], mask_dims, outline_colors, settings['outline_thickness'])
|
469
|
+
|
470
|
+
if index < settings['nr']:
|
471
|
+
index += 1
|
472
|
+
fig = _plot_merged_plot(settings['overlay'], image, stack, mask_dims, settings['figuresize'], overlayed_image, outlines, settings['cmap'], outline_colors, settings['print_object_number'])
|
473
|
+
else:
|
474
|
+
return fig
|
475
|
+
|
476
|
+
def _plot_images_on_grid(image_files, channel_indices, um_per_pixel, scale_bar_length_um=5, fontsize=8, show_filename=True, channel_names=None, plot=False):
|
477
|
+
"""
|
478
|
+
Plots a grid of images with optional scale bar and channel names.
|
479
|
+
|
480
|
+
Args:
|
481
|
+
image_files (list): List of image file paths.
|
482
|
+
channel_indices (list): List of channel indices to select from the images.
|
483
|
+
um_per_pixel (float): Micrometers per pixel.
|
484
|
+
scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 5.
|
485
|
+
fontsize (int, optional): Font size for the image titles. Defaults to 8.
|
486
|
+
show_filename (bool, optional): Whether to show the image file names as titles. Defaults to True.
|
487
|
+
channel_names (list, optional): List of channel names. Defaults to None.
|
488
|
+
plot (bool, optional): Whether to display the plot. Defaults to False.
|
489
|
+
|
490
|
+
Returns:
|
491
|
+
matplotlib.figure.Figure: The generated figure object.
|
492
|
+
"""
|
493
|
+
print(f'scale bar represents {scale_bar_length_um} um')
|
494
|
+
nr_of_images = len(image_files)
|
495
|
+
cols = int(np.ceil(np.sqrt(nr_of_images)))
|
496
|
+
rows = np.ceil(nr_of_images / cols)
|
497
|
+
fig, axes = plt.subplots(int(rows), int(cols), figsize=(20, 20), facecolor='black')
|
498
|
+
fig.patch.set_facecolor('black')
|
499
|
+
axes = axes.flatten()
|
500
|
+
# Calculate the scale bar length in pixels
|
501
|
+
scale_bar_length_px = int(scale_bar_length_um / um_per_pixel) # Convert to pixels
|
502
|
+
|
503
|
+
channel_colors = ['red','green','blue']
|
504
|
+
for i, image_file in enumerate(image_files):
|
505
|
+
img_array = cv2.imread(image_file, cv2.IMREAD_UNCHANGED)
|
506
|
+
|
507
|
+
if img_array.ndim == 3 and img_array.shape[2] >= 3:
|
508
|
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
|
509
|
+
# Handle different channel selections
|
510
|
+
if channel_indices is not None:
|
511
|
+
if len(channel_indices) == 1: # Single channel (grayscale)
|
512
|
+
img_array = img_array[:, :, channel_indices[0]]
|
513
|
+
cmap = 'gray'
|
514
|
+
elif len(channel_indices) == 2: # Dual channels
|
515
|
+
img_array = np.mean(img_array[:, :, channel_indices], axis=2)
|
516
|
+
cmap = 'gray'
|
517
|
+
else: # RGB or more channels
|
518
|
+
img_array = img_array[:, :, channel_indices]
|
519
|
+
cmap = None
|
520
|
+
else:
|
521
|
+
cmap = None if img_array.ndim == 3 else 'gray'
|
522
|
+
# Normalize based on dtype
|
523
|
+
if img_array.dtype == np.uint16:
|
524
|
+
img_array = img_array.astype(np.float32) / 65535.0
|
525
|
+
elif img_array.dtype == np.uint8:
|
526
|
+
img_array = img_array.astype(np.float32) / 255.0
|
527
|
+
ax = axes[i]
|
528
|
+
ax.imshow(img_array, cmap=cmap)
|
529
|
+
ax.axis('off')
|
530
|
+
if show_filename:
|
531
|
+
ax.set_title(os.path.basename(image_file), color='white', fontsize=fontsize, pad=20)
|
532
|
+
# Add scale bar
|
533
|
+
ax.plot([10, 10 + scale_bar_length_px], [img_array.shape[0] - 10] * 2, lw=2, color='white')
|
534
|
+
# Add channel names at the top if specified
|
535
|
+
initial_offset = 0.02 # Starting offset from the left side of the figure
|
536
|
+
increment = 0.05 # Fixed increment for each subsequent channel name, adjust based on figure width
|
537
|
+
if channel_names:
|
538
|
+
current_offset = initial_offset
|
539
|
+
for i, channel_name in enumerate(channel_names):
|
540
|
+
color = channel_colors[i] if i < len(channel_colors) else 'white'
|
541
|
+
fig.text(current_offset, 0.99, channel_name, color=color, fontsize=fontsize,
|
542
|
+
verticalalignment='top', horizontalalignment='left',
|
543
|
+
bbox=dict(facecolor='black', edgecolor='none', pad=3))
|
544
|
+
current_offset += increment
|
545
|
+
|
546
|
+
for j in range(i + 1, len(axes)):
|
547
|
+
axes[j].axis('off')
|
548
|
+
|
549
|
+
plt.tight_layout(pad=3)
|
550
|
+
if plot:
|
551
|
+
plt.show()
|
552
|
+
return fig
|
553
|
+
|
554
|
+
def _save_scimg_plot(src, nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1, scale_bar_length_um=10, standardize=True, fontsize=8, show_filename=True, channel_names=None, dpi=300, plot=False, i=1, all_folders=1):
|
555
|
+
|
556
|
+
"""
|
557
|
+
Save and visualize single-cell images.
|
558
|
+
|
559
|
+
Args:
|
560
|
+
src (str): The source directory path.
|
561
|
+
nr_imgs (int, optional): The number of images to visualize. Defaults to 16.
|
562
|
+
channel_indices (list, optional): List of channel indices to visualize. Defaults to [0,1,2].
|
563
|
+
um_per_pixel (float, optional): Micrometers per pixel. Defaults to 0.1.
|
564
|
+
scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 10.
|
565
|
+
standardize (bool, optional): Whether to standardize the image sizes. Defaults to True.
|
566
|
+
fontsize (int, optional): Font size for the filename. Defaults to 8.
|
567
|
+
show_filename (bool, optional): Whether to show the filename on the image. Defaults to True.
|
568
|
+
channel_names (list, optional): List of channel names. Defaults to None.
|
569
|
+
dpi (int, optional): Dots per inch for the saved image. Defaults to 300.
|
570
|
+
plot (bool, optional): Whether to plot the images. Defaults to False.
|
571
|
+
|
572
|
+
Returns:
|
573
|
+
None
|
574
|
+
"""
|
575
|
+
from .io import _save_figure
|
576
|
+
|
577
|
+
def _visualize_scimgs(src, channel_indices=None, um_per_pixel=0.1, scale_bar_length_um=10, show_filename=True, standardize=True, nr_imgs=None, fontsize=8, channel_names=None, plot=False):
|
578
|
+
"""
|
579
|
+
Visualize single-cell images.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
src (str): The source directory path.
|
583
|
+
channel_indices (list, optional): List of channel indices to visualize. Defaults to None.
|
584
|
+
um_per_pixel (float, optional): Micrometers per pixel. Defaults to 0.1.
|
585
|
+
scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 10.
|
586
|
+
show_filename (bool, optional): Whether to show the filename on the image. Defaults to True.
|
587
|
+
standardize (bool, optional): Whether to standardize the image sizes. Defaults to True.
|
588
|
+
nr_imgs (int, optional): The number of images to visualize. Defaults to None.
|
589
|
+
fontsize (int, optional): Font size for the filename. Defaults to 8.
|
590
|
+
channel_names (list, optional): List of channel names. Defaults to None.
|
591
|
+
plot (bool, optional): Whether to plot the images. Defaults to False.
|
592
|
+
|
593
|
+
Returns:
|
594
|
+
matplotlib.figure.Figure: The figure object containing the plotted images.
|
595
|
+
"""
|
596
|
+
from .utils import _find_similar_sized_images
|
597
|
+
def _generate_filelist(src):
|
598
|
+
"""
|
599
|
+
Generate a list of image files in the specified directory.
|
600
|
+
|
601
|
+
Args:
|
602
|
+
src (str): The source directory path.
|
603
|
+
|
604
|
+
Returns:
|
605
|
+
list: A list of image file paths.
|
606
|
+
|
607
|
+
"""
|
608
|
+
files = glob.glob(os.path.join(src, '*'))
|
609
|
+
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff', '.gif'))]
|
610
|
+
return image_files
|
611
|
+
|
612
|
+
def _random_sample(file_list, nr_imgs=None):
|
613
|
+
"""
|
614
|
+
Randomly selects a subset of files from the given file list.
|
615
|
+
|
616
|
+
Args:
|
617
|
+
file_list (list): A list of file names.
|
618
|
+
nr_imgs (int, optional): The number of files to select. If None, all files are selected. Defaults to None.
|
619
|
+
|
620
|
+
Returns:
|
621
|
+
list: A list of randomly selected file names.
|
622
|
+
"""
|
623
|
+
if nr_imgs is not None and nr_imgs < len(file_list):
|
624
|
+
random.seed(42)
|
625
|
+
file_list = random.sample(file_list, nr_imgs)
|
626
|
+
return file_list
|
627
|
+
|
628
|
+
image_files = _generate_filelist(src)
|
629
|
+
|
630
|
+
if standardize:
|
631
|
+
image_files = _find_similar_sized_images(image_files)
|
632
|
+
|
633
|
+
if nr_imgs is not None:
|
634
|
+
image_files = _random_sample(image_files, nr_imgs)
|
635
|
+
|
636
|
+
fig = _plot_images_on_grid(image_files, channel_indices, um_per_pixel, scale_bar_length_um, fontsize, show_filename, channel_names, plot)
|
637
|
+
|
638
|
+
return fig
|
639
|
+
|
640
|
+
fig = _visualize_scimgs(src, channel_indices, um_per_pixel, scale_bar_length_um, show_filename, standardize, nr_imgs, fontsize, channel_names, plot)
|
641
|
+
_save_figure(fig, src, text='all_channels')
|
642
|
+
|
643
|
+
for channel in channel_indices:
|
644
|
+
channel_indices=[channel]
|
645
|
+
fig = _visualize_scimgs(src, channel_indices, um_per_pixel, scale_bar_length_um, show_filename, standardize, nr_imgs, fontsize, channel_names=None, plot=plot)
|
646
|
+
_save_figure(fig, src, text=f'channel_{channel}')
|
647
|
+
|
648
|
+
return
|
649
|
+
|
650
|
+
def _plot_cropped_arrays(stack, figuresize=20,cmap='inferno'):
|
651
|
+
"""
|
652
|
+
Plot cropped arrays.
|
653
|
+
|
654
|
+
Args:
|
655
|
+
stack (ndarray): The array to be plotted.
|
656
|
+
figuresize (int, optional): The size of the figure. Defaults to 20.
|
657
|
+
cmap (str, optional): The colormap to be used. Defaults to 'inferno'.
|
658
|
+
|
659
|
+
Returns:
|
660
|
+
None
|
661
|
+
"""
|
662
|
+
start = time.time()
|
663
|
+
dim = stack.shape
|
664
|
+
channel=min(dim)
|
665
|
+
if len(stack.shape) == 2:
|
666
|
+
f, a = plt.subplots(1, 1,figsize=(figuresize,figuresize))
|
667
|
+
a.imshow(stack, cmap=plt.get_cmap(cmap))
|
668
|
+
a.set_title('Channel one',size=18)
|
669
|
+
a.axis('off')
|
670
|
+
f.tight_layout()
|
671
|
+
plt.show()
|
672
|
+
if len(stack.shape) > 2:
|
673
|
+
anr = stack.shape[2]
|
674
|
+
f, a = plt.subplots(1, anr,figsize=(figuresize,figuresize))
|
675
|
+
for channel in range(anr):
|
676
|
+
a[channel].imshow(stack[:,:,channel], cmap=plt.get_cmap(cmap))
|
677
|
+
a[channel].set_title('Channel '+str(channel),size=18)
|
678
|
+
a[channel].axis('off')
|
679
|
+
f.tight_layout()
|
680
|
+
plt.show()
|
681
|
+
stop = time.time()
|
682
|
+
duration = stop - start
|
683
|
+
print('plot_cropped_arrays', duration)
|
684
|
+
return
|
685
|
+
|
686
|
+
def _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df, save, src, name, plot, filenames, object_type, mode='btrack', interactive=False):
|
687
|
+
"""
|
688
|
+
Visualizes and saves a timelapse stack with tracks.
|
689
|
+
|
690
|
+
Args:
|
691
|
+
masks (list): List of binary masks representing each frame of the timelapse stack.
|
692
|
+
tracks_df (pandas.DataFrame): DataFrame containing track information.
|
693
|
+
save (bool): Flag indicating whether to save the timelapse stack.
|
694
|
+
src (str): Source file path.
|
695
|
+
name (str): Name of the timelapse stack.
|
696
|
+
plot (bool): Flag indicating whether to plot the timelapse stack.
|
697
|
+
filenames (list): List of filenames corresponding to each frame of the timelapse stack.
|
698
|
+
object_type (str): Type of object being tracked.
|
699
|
+
mode (str, optional): Tracking mode. Defaults to 'btrack'.
|
700
|
+
interactive (bool, optional): Flag indicating whether to display the timelapse stack interactively. Defaults to False.
|
701
|
+
"""
|
702
|
+
|
703
|
+
from .timelapse import _save_mask_timelapse_as_gif
|
704
|
+
|
705
|
+
highest_label = max(np.max(mask) for mask in masks)
|
706
|
+
# Generate random colors for each label, including the background
|
707
|
+
random_colors = np.random.rand(highest_label + 1, 4)
|
708
|
+
random_colors[:, 3] = 1 # Full opacity
|
709
|
+
random_colors[0] = [0, 0, 0, 1] # Background color
|
710
|
+
cmap = plt.cm.colors.ListedColormap(random_colors)
|
711
|
+
# Ensure the normalization range covers all labels
|
712
|
+
norm = plt.cm.colors.Normalize(vmin=0, vmax=highest_label)
|
713
|
+
|
714
|
+
# Function to plot a frame and overlay tracks
|
715
|
+
def _view_frame_with_tracks(frame=0):
|
716
|
+
"""
|
717
|
+
Display the frame with tracks overlaid.
|
718
|
+
|
719
|
+
Parameters:
|
720
|
+
frame (int): The frame number to display.
|
721
|
+
|
722
|
+
Returns:
|
723
|
+
None
|
724
|
+
"""
|
725
|
+
fig, ax = plt.subplots(figsize=(50, 50))
|
726
|
+
current_mask = masks[frame]
|
727
|
+
ax.imshow(current_mask, cmap=cmap, norm=norm) # Apply both colormap and normalization
|
728
|
+
ax.set_title(f'Frame: {frame}')
|
729
|
+
|
730
|
+
# Directly annotate each object with its label number from the mask
|
731
|
+
for label_value in np.unique(current_mask):
|
732
|
+
if label_value == 0: continue # Skip background
|
733
|
+
y, x = np.mean(np.where(current_mask == label_value), axis=1)
|
734
|
+
ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
|
735
|
+
|
736
|
+
# Overlay tracks
|
737
|
+
for track in tracks_df['track_id'].unique():
|
738
|
+
_track = tracks_df[tracks_df['track_id'] == track]
|
739
|
+
ax.plot(_track['x'], _track['y'], '-k', linewidth=1)
|
740
|
+
|
741
|
+
ax.axis('off')
|
742
|
+
plt.show()
|
743
|
+
|
744
|
+
if plot:
|
745
|
+
if interactive:
|
746
|
+
interact(_view_frame_with_tracks, frame=IntSlider(min=0, max=len(masks)-1, step=1, value=0))
|
747
|
+
|
748
|
+
if save:
|
749
|
+
# Save as gif
|
750
|
+
gif_path = os.path.join(os.path.dirname(src), 'movies', 'gif')
|
751
|
+
os.makedirs(gif_path, exist_ok=True)
|
752
|
+
save_path_gif = os.path.join(gif_path, f'timelapse_masks_{object_type}_{name}.gif')
|
753
|
+
_save_mask_timelapse_as_gif(masks, tracks_df, save_path_gif, cmap, norm, filenames)
|
754
|
+
if plot:
|
755
|
+
if not interactive:
|
756
|
+
_display_gif(save_path_gif)
|
757
|
+
|
758
|
+
def _display_gif(path):
|
759
|
+
"""
|
760
|
+
Display a GIF image from the given path.
|
761
|
+
|
762
|
+
Parameters:
|
763
|
+
path (str): The path to the GIF image file.
|
764
|
+
|
765
|
+
Returns:
|
766
|
+
None
|
767
|
+
"""
|
768
|
+
with open(path, 'rb') as file:
|
769
|
+
display(ipyimage(file.read()))
|
770
|
+
|
771
|
+
def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=50):
|
772
|
+
"""
|
773
|
+
Plot recruitment data for different conditions and pathogens.
|
774
|
+
|
775
|
+
Args:
|
776
|
+
df (DataFrame): The input DataFrame containing the recruitment data.
|
777
|
+
df_type (str): The type of DataFrame (e.g., 'train', 'test').
|
778
|
+
channel_of_interest (str): The channel of interest for plotting.
|
779
|
+
target (str): The target variable for plotting.
|
780
|
+
columns (list, optional): Additional columns to plot. Defaults to an empty list.
|
781
|
+
figuresize (int, optional): The size of the figure. Defaults to 50.
|
782
|
+
|
783
|
+
Returns:
|
784
|
+
None
|
785
|
+
"""
|
786
|
+
|
787
|
+
color_list = [(55/255, 155/255, 155/255),
|
788
|
+
(155/255, 55/255, 155/255),
|
789
|
+
(55/255, 155/255, 255/255),
|
790
|
+
(255/255, 55/255, 155/255)]
|
791
|
+
|
792
|
+
sns.set_palette(sns.color_palette(color_list))
|
793
|
+
font = figuresize/2
|
794
|
+
width=figuresize
|
795
|
+
height=figuresize/4
|
796
|
+
|
797
|
+
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(width, height))
|
798
|
+
sns.barplot(ax=axes[0], data=df, x='condition', y=f'cell_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
|
799
|
+
axes[0].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
800
|
+
axes[0].set_ylabel(f'cell_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
801
|
+
|
802
|
+
sns.barplot(ax=axes[1], data=df, x='condition', y=f'nucleus_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
|
803
|
+
axes[1].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
804
|
+
axes[1].set_ylabel(f'nucleus_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
805
|
+
|
806
|
+
sns.barplot(ax=axes[2], data=df, x='condition', y=f'cytoplasm_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
|
807
|
+
axes[2].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
808
|
+
axes[2].set_ylabel(f'cytoplasm_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
809
|
+
|
810
|
+
sns.barplot(ax=axes[3], data=df, x='condition', y=f'pathogen_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
|
811
|
+
axes[3].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
812
|
+
axes[3].set_ylabel(f'pathogen_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
813
|
+
|
814
|
+
axes[0].legend_.remove()
|
815
|
+
axes[1].legend_.remove()
|
816
|
+
axes[2].legend_.remove()
|
817
|
+
axes[3].legend_.remove()
|
818
|
+
|
819
|
+
handles, labels = axes[3].get_legend_handles_labels()
|
820
|
+
axes[3].legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc='center left')
|
821
|
+
for i in [0,1,2,3]:
|
822
|
+
axes[i].tick_params(axis='both', which='major', labelsize=font)
|
823
|
+
axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)
|
824
|
+
|
825
|
+
plt.tight_layout()
|
826
|
+
plt.show()
|
827
|
+
|
828
|
+
columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
|
829
|
+
columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
|
830
|
+
|
831
|
+
width = figuresize*2
|
832
|
+
columns_per_row = math.ceil(len(columns) / 2)
|
833
|
+
height = (figuresize*2)/columns_per_row
|
834
|
+
|
835
|
+
fig, axes = plt.subplots(nrows=2, ncols=columns_per_row, figsize=(width, height * 2))
|
836
|
+
axes = axes.flatten()
|
837
|
+
|
838
|
+
print(f'{columns}')
|
839
|
+
|
840
|
+
for i, col in enumerate(columns):
|
841
|
+
|
842
|
+
ax = axes[i]
|
843
|
+
sns.barplot(ax=ax, data=df, x='condition', y=f'{col}', hue='pathogen', capsize=.1, ci='sd', dodge=False)
|
844
|
+
ax.set_xlabel(f'pathogen {df_type}', fontsize=font)
|
845
|
+
ax.set_ylabel(f'{col}', fontsize=int(font*2))
|
846
|
+
ax.legend_.remove()
|
847
|
+
ax.tick_params(axis='both', which='major', labelsize=font)
|
848
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
|
849
|
+
if i <= 5:
|
850
|
+
ax.set_ylim(1, None)
|
851
|
+
|
852
|
+
for i in range(len(columns), len(axes)):
|
853
|
+
axes[i].axis('off')
|
854
|
+
|
855
|
+
plt.tight_layout()
|
856
|
+
plt.show()
|
857
|
+
|
858
|
+
def _plot_controls(df, mask_chans, channel_of_interest, figuresize=5):
|
859
|
+
"""
|
860
|
+
Plot controls for different channels and conditions.
|
861
|
+
|
862
|
+
Args:
|
863
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
864
|
+
mask_chans (list): The list of channels to include in the plot.
|
865
|
+
channel_of_interest (int): The channel of interest.
|
866
|
+
figuresize (int, optional): The size of the figure. Defaults to 5.
|
867
|
+
|
868
|
+
Returns:
|
869
|
+
None
|
870
|
+
"""
|
871
|
+
mask_chans.append(channel_of_interest)
|
872
|
+
if len(mask_chans) == 4:
|
873
|
+
mask_chans = [0,1,2,3]
|
874
|
+
if len(mask_chans) == 3:
|
875
|
+
mask_chans = [0,1,2]
|
876
|
+
if len(mask_chans) == 2:
|
877
|
+
mask_chans = [0,1]
|
878
|
+
if len(mask_chans) == 1:
|
879
|
+
mask_chans = [0]
|
880
|
+
controls_cols = []
|
881
|
+
for chan in mask_chans:
|
882
|
+
|
883
|
+
controls_cols_c = []
|
884
|
+
controls_cols_c.append(f'cell_channel_{chan}_mean_intensity')
|
885
|
+
controls_cols_c.append(f'nucleus_channel_{chan}_mean_intensity')
|
886
|
+
controls_cols_c.append(f'pathogen_channel_{chan}_mean_intensity')
|
887
|
+
controls_cols_c.append(f'cytoplasm_channel_{chan}_mean_intensity')
|
888
|
+
controls_cols.append(controls_cols_c)
|
889
|
+
|
890
|
+
unique_conditions = df['condition'].unique().tolist()
|
891
|
+
|
892
|
+
if len(unique_conditions) ==1:
|
893
|
+
unique_conditions=unique_conditions+unique_conditions
|
894
|
+
|
895
|
+
fig, axes = plt.subplots(len(unique_conditions), len(mask_chans)+1, figsize=(figuresize*len(mask_chans), figuresize*len(unique_conditions)))
|
896
|
+
|
897
|
+
# Define RGB color tuples (scaled to 0-1 range)
|
898
|
+
color_list = [(55/255, 155/255, 155/255),
|
899
|
+
(155/255, 55/255, 155/255),
|
900
|
+
(55/255, 155/255, 255/255),
|
901
|
+
(255/255, 55/255, 155/255)]
|
902
|
+
|
903
|
+
for idx_condition, condition in enumerate(unique_conditions):
|
904
|
+
df_temp = df[df['condition'] == condition]
|
905
|
+
for idx_channel, control_cols_c in enumerate(controls_cols):
|
906
|
+
data = []
|
907
|
+
std_dev = []
|
908
|
+
for control_col in control_cols_c:
|
909
|
+
if control_col in df_temp.columns:
|
910
|
+
mean_intensity = df_temp[control_col].mean()
|
911
|
+
mean_intensity = 0 if np.isnan(mean_intensity) else mean_intensity
|
912
|
+
data.append(mean_intensity)
|
913
|
+
std_dev.append(df_temp[control_col].std())
|
914
|
+
|
915
|
+
current_axis = axes[idx_condition][idx_channel]
|
916
|
+
current_axis.bar(["cell", "nucleus", "pathogen", "cytoplasm"], data, yerr=std_dev,
|
917
|
+
capsize=4, color=color_list)
|
918
|
+
current_axis.set_xlabel('Component')
|
919
|
+
current_axis.set_ylabel('Mean Intensity')
|
920
|
+
current_axis.set_title(f'Condition: {condition} - Channel {idx_channel}')
|
921
|
+
plt.tight_layout()
|
922
|
+
plt.show()
|
923
|
+
|
924
|
+
###################################################
|
925
|
+
# Classify
|
926
|
+
###################################################
|
927
|
+
|
928
|
+
def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
929
|
+
"""
|
930
|
+
Display multiple images in a grid with corresponding labels.
|
931
|
+
|
932
|
+
Args:
|
933
|
+
img (list): List of images to display.
|
934
|
+
labels (list): List of labels corresponding to each image.
|
935
|
+
nrow (int, optional): Number of images per row in the grid. Defaults to 20.
|
936
|
+
color (str, optional): Color of the label text. Defaults to 'white'.
|
937
|
+
fontsize (int, optional): Font size of the label text. Defaults to 12.
|
938
|
+
"""
|
939
|
+
n_images = len(labels)
|
940
|
+
n_col = nrow
|
941
|
+
n_row = int(np.ceil(n_images / n_col))
|
942
|
+
img_height = img[0].shape[1]
|
943
|
+
img_width = img[0].shape[2]
|
944
|
+
canvas = np.zeros((img_height * n_row, img_width * n_col, 3))
|
945
|
+
for i in range(n_row):
|
946
|
+
for j in range(n_col):
|
947
|
+
idx = i * n_col + j
|
948
|
+
if idx < n_images:
|
949
|
+
canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
|
950
|
+
plt.figure(figsize=(50, 50))
|
951
|
+
plt._imshow(canvas)
|
952
|
+
plt.axis("off")
|
953
|
+
for i, label in enumerate(labels):
|
954
|
+
row = i // n_col
|
955
|
+
col = i % n_col
|
956
|
+
x = col * img_width + 2
|
957
|
+
y = row * img_height + 15
|
958
|
+
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
959
|
+
plt.show()
|
960
|
+
|
961
|
+
def _plot_histograms_and_stats(df):
|
962
|
+
conditions = df['condition'].unique()
|
963
|
+
|
964
|
+
for condition in conditions:
|
965
|
+
subset = df[df['condition'] == condition]
|
966
|
+
|
967
|
+
# Calculate the statistics
|
968
|
+
mean_pred = subset['pred'].mean()
|
969
|
+
over_0_5 = sum(subset['pred'] > 0.5)
|
970
|
+
under_0_5 = sum(subset['pred'] <= 0.5)
|
971
|
+
|
972
|
+
# Print the statistics
|
973
|
+
print(f"Condition: {condition}")
|
974
|
+
print(f"Number of rows: {len(subset)}")
|
975
|
+
print(f"Mean of pred: {mean_pred}")
|
976
|
+
print(f"Count of pred values over 0.5: {over_0_5}")
|
977
|
+
print(f"Count of pred values under 0.5: {under_0_5}")
|
978
|
+
print(f"Percent positive: {(over_0_5/(over_0_5+under_0_5))*100}")
|
979
|
+
print(f"Percent negative: {(under_0_5/(over_0_5+under_0_5))*100}")
|
980
|
+
print('-'*40)
|
981
|
+
|
982
|
+
# Plot the histogram
|
983
|
+
plt.figure(figsize=(10,6))
|
984
|
+
plt.hist(subset['pred'], bins=30, edgecolor='black')
|
985
|
+
plt.axvline(mean_pred, color='red', linestyle='dashed', linewidth=1, label=f"Mean = {mean_pred:.2f}")
|
986
|
+
plt.title(f'Histogram for pred - Condition: {condition}')
|
987
|
+
plt.xlabel('Pred Value')
|
988
|
+
plt.ylabel('Count')
|
989
|
+
plt.legend()
|
990
|
+
plt.show()
|
991
|
+
|
992
|
+
def _show_residules(model):
|
993
|
+
|
994
|
+
# Get the residuals
|
995
|
+
residuals = model.resid
|
996
|
+
|
997
|
+
# Histogram of residuals
|
998
|
+
plt.hist(residuals, bins=30)
|
999
|
+
plt.title('Histogram of Residuals')
|
1000
|
+
plt.xlabel('Residual Value')
|
1001
|
+
plt.ylabel('Frequency')
|
1002
|
+
plt.show()
|
1003
|
+
|
1004
|
+
# QQ plot
|
1005
|
+
sm.qqplot(residuals, fit=True, line='45')
|
1006
|
+
plt.title('QQ Plot')
|
1007
|
+
plt.show()
|
1008
|
+
|
1009
|
+
# Residuals vs. Fitted values
|
1010
|
+
plt.scatter(model.fittedvalues, residuals)
|
1011
|
+
plt.xlabel('Fitted values')
|
1012
|
+
plt.ylabel('Residuals')
|
1013
|
+
plt.title('Residuals vs. Fitted Values')
|
1014
|
+
plt.axhline(y=0, color='red')
|
1015
|
+
plt.show()
|
1016
|
+
|
1017
|
+
# Shapiro-Wilk test for normality
|
1018
|
+
W, p_value = stats.shapiro(residuals)
|
1019
|
+
print(f'Shapiro-Wilk Test W-statistic: {W}, p-value: {p_value}')
|
1020
|
+
|
1021
|
+
def _reg_v_plot(df, grouping, variable, plate_number):
|
1022
|
+
df['-log10(p)'] = -np.log10(df['p'])
|
1023
|
+
|
1024
|
+
# Create the volcano plot
|
1025
|
+
plt.figure(figsize=(40, 30))
|
1026
|
+
sc = plt.scatter(df['effect'], df['-log10(p)'], c=np.sign(df['effect']), cmap='coolwarm')
|
1027
|
+
plt.title('Volcano Plot', fontsize=12)
|
1028
|
+
plt.xlabel('Coefficient', fontsize=12)
|
1029
|
+
plt.ylabel('-log10(P-value)', fontsize=12)
|
1030
|
+
|
1031
|
+
# Add text for specified points
|
1032
|
+
for idx, row in df.iterrows():
|
1033
|
+
if row['p'] < 0.05:# and abs(row['effect']) > 0.1:
|
1034
|
+
plt.text(row['effect'], -np.log10(row['p']), idx, fontsize=12, ha='center', va='bottom', color='black')
|
1035
|
+
|
1036
|
+
plt.axhline(y=-np.log10(0.05), color='gray', linestyle='--') # line for p=0.05
|
1037
|
+
plt.show()
|
1038
|
+
|
1039
|
+
def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
|
1040
|
+
df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
|
1041
|
+
df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
|
1042
|
+
|
1043
|
+
# Filtering the dataframe based on the plate_number
|
1044
|
+
df = df[df['plate'] == plate_number].copy() # Create another copy after filtering
|
1045
|
+
|
1046
|
+
# Ensure proper ordering
|
1047
|
+
row_order = [f'r{i}' for i in range(1, 17)]
|
1048
|
+
col_order = [f'c{i}' for i in range(1, 28)] # Exclude c15 as per your earlier code
|
1049
|
+
|
1050
|
+
df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
|
1051
|
+
df['col'] = pd.Categorical(df['col'], categories=col_order, ordered=True)
|
1052
|
+
|
1053
|
+
# Explicitly set observed=True to avoid FutureWarning
|
1054
|
+
grouped = df.groupby(['row', 'col'], observed=True)
|
1055
|
+
|
1056
|
+
if grouping == 'mean':
|
1057
|
+
plate = grouped[variable].mean().reset_index()
|
1058
|
+
elif grouping == 'sum':
|
1059
|
+
plate = grouped[variable].sum().reset_index()
|
1060
|
+
elif grouping == 'count':
|
1061
|
+
plate = grouped[variable].count().reset_index()
|
1062
|
+
else:
|
1063
|
+
raise ValueError(f"Unsupported grouping: {grouping}")
|
1064
|
+
|
1065
|
+
plate_map = pd.pivot_table(plate, values=variable, index='row', columns='col').fillna(0)
|
1066
|
+
|
1067
|
+
if min_max == 'all':
|
1068
|
+
min_max = [plate_map.min().min(), plate_map.max().max()]
|
1069
|
+
elif min_max == 'allq':
|
1070
|
+
min_max = np.quantile(plate_map.values, [0.2, 0.98])
|
1071
|
+
elif min_max == 'plate':
|
1072
|
+
min_max = [plate_map.min().min(), plate_map.max().max()]
|
1073
|
+
|
1074
|
+
return plate_map, min_max
|
1075
|
+
|
1076
|
+
def _plot_plates(df, variable, grouping, min_max, cmap):
|
1077
|
+
plates = df['prc'].str.split('_', expand=True)[0].unique()
|
1078
|
+
n_rows, n_cols = (len(plates) + 3) // 4, 4
|
1079
|
+
fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
|
1080
|
+
ax = ax.flatten()
|
1081
|
+
|
1082
|
+
for index, plate in enumerate(plates):
|
1083
|
+
plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max)
|
1084
|
+
sns.heatmap(plate_map, cmap=cmap, vmin=0, vmax=2, ax=ax[index])
|
1085
|
+
ax[index].set_title(plate)
|
1086
|
+
|
1087
|
+
for i in range(len(plates), n_rows * n_cols):
|
1088
|
+
fig.delaxes(ax[i])
|
1089
|
+
|
1090
|
+
plt.subplots_adjust(wspace=0.1, hspace=0.4)
|
1091
|
+
plt.show()
|
1092
|
+
return
|
1093
|
+
|
1094
|
+
#from finetune cellpose
|
1095
|
+
#def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
|
1096
|
+
# paths = []
|
1097
|
+
# for file in os.listdir(src):
|
1098
|
+
# if file.endswith('.tif') or file.endswith('.tiff'):
|
1099
|
+
# path = os.path.join(src, file)
|
1100
|
+
# paths.append(path)
|
1101
|
+
# paths = random.sample(paths, nr)
|
1102
|
+
# for path in paths:
|
1103
|
+
# print(f'Image path:{path}')
|
1104
|
+
# img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
1105
|
+
# if normalize:
|
1106
|
+
# img = normalize_to_dtype(array=img, q1=q1, q2=q2)
|
1107
|
+
# dim = img.shape
|
1108
|
+
# if len(img.shape) > 2:
|
1109
|
+
# array_nr = img.shape[2]
|
1110
|
+
# fig, axs = plt.subplots(1, array_nr, figsize=(figuresize, figuresize))
|
1111
|
+
# for channel in range(array_nr):
|
1112
|
+
# i = np.take(img, [channel], axis=2)
|
1113
|
+
# axs[channel].imshow(i, cmap=plt.get_cmap(cmap))
|
1114
|
+
# axs[channel].set_title('Channel '+str(channel), size=24)
|
1115
|
+
# axs[channel].axis('off')
|
1116
|
+
# else:
|
1117
|
+
# fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
|
1118
|
+
# ax.imshow(img, cmap=plt.get_cmap(cmap))
|
1119
|
+
# ax.set_title('Channel 0', size=24)
|
1120
|
+
# ax.axis('off')
|
1121
|
+
# fig.tight_layout()
|
1122
|
+
# plt.show()
|
1123
|
+
# return
|
1124
|
+
|
1125
|
+
def print_mask_and_flows(stack, mask, flows, overlay=False):
|
1126
|
+
fig, axs = plt.subplots(1, 3, figsize=(30, 10)) # Adjust subplot layout
|
1127
|
+
|
1128
|
+
if stack.shape[-1] == 1:
|
1129
|
+
stack = np.squeeze(stack)
|
1130
|
+
|
1131
|
+
# Display original image or its first channel
|
1132
|
+
if stack.ndim == 2:
|
1133
|
+
axs[0].imshow(stack, cmap='gray')
|
1134
|
+
elif stack.ndim == 3:
|
1135
|
+
axs[0].imshow(stack)
|
1136
|
+
else:
|
1137
|
+
raise ValueError("Unexpected stack dimensionality.")
|
1138
|
+
|
1139
|
+
axs[0].set_title('Original Image')
|
1140
|
+
axs[0].axis('off')
|
1141
|
+
|
1142
|
+
|
1143
|
+
# Overlay mask on original image if overlay is True
|
1144
|
+
if overlay:
|
1145
|
+
mask_cmap = generate_mask_random_cmap(mask) # Generate random colormap for mask
|
1146
|
+
mask_overlay = np.ma.masked_where(mask == 0, mask) # Mask background
|
1147
|
+
outlines = find_boundaries(mask, mode='thick') # Find mask outlines
|
1148
|
+
|
1149
|
+
if stack.ndim == 2 or stack.ndim == 3:
|
1150
|
+
axs[1].imshow(stack, cmap='gray' if stack.ndim == 2 else None)
|
1151
|
+
axs[1].imshow(mask_overlay, cmap=mask_cmap, alpha=0.5) # Overlay mask
|
1152
|
+
axs[1].contour(outlines, colors='r', linewidths=2) # Add red outlines with thickness 2
|
1153
|
+
else:
|
1154
|
+
axs[1].imshow(mask, cmap='gray')
|
1155
|
+
|
1156
|
+
axs[1].set_title('Mask with Overlay' if overlay else 'Mask')
|
1157
|
+
axs[1].axis('off')
|
1158
|
+
|
1159
|
+
# Display flow image or its first channel
|
1160
|
+
if flows and isinstance(flows, list) and flows[0].ndim in [2, 3]:
|
1161
|
+
flow_image = flows[0]
|
1162
|
+
if flow_image.ndim == 3:
|
1163
|
+
flow_image = flow_image[:, :, 0] # Use first channel for 3D
|
1164
|
+
axs[2].imshow(flow_image, cmap='jet')
|
1165
|
+
else:
|
1166
|
+
raise ValueError("Unexpected flow dimensionality or structure.")
|
1167
|
+
|
1168
|
+
axs[2].set_title('Flows')
|
1169
|
+
axs[2].axis('off')
|
1170
|
+
|
1171
|
+
fig.tight_layout()
|
1172
|
+
plt.show()
|
1173
|
+
|
1174
|
+
def plot_resize(images, resized_images, labels, resized_labels):
|
1175
|
+
# Display an example image and label before and after resizing
|
1176
|
+
fig, ax = plt.subplots(2, 2, figsize=(20, 20))
|
1177
|
+
|
1178
|
+
# Check if the image is grayscale; if so, add a colormap and keep dimensions correct
|
1179
|
+
if images[0].ndim == 2: # Grayscale image
|
1180
|
+
ax[0, 0].imshow(images[0], cmap='gray')
|
1181
|
+
else: # RGB or RGBA image
|
1182
|
+
ax[0, 0].imshow(images[0])
|
1183
|
+
ax[0, 0].set_title('Original Image')
|
1184
|
+
|
1185
|
+
if resized_images[0].ndim == 2: # Grayscale image
|
1186
|
+
ax[0, 1].imshow(resized_images[0], cmap='gray')
|
1187
|
+
else: # RGB or RGBA image
|
1188
|
+
ax[0, 1].imshow(resized_images[0])
|
1189
|
+
ax[0, 1].set_title('Resized Image')
|
1190
|
+
|
1191
|
+
# Assuming labels are always grayscale (most common scenario)
|
1192
|
+
ax[1, 0].imshow(labels[0], cmap='gray')
|
1193
|
+
ax[1, 0].set_title('Original Label')
|
1194
|
+
ax[1, 1].imshow(resized_labels[0], cmap='gray')
|
1195
|
+
ax[1, 1].set_title('Resized Label')
|
1196
|
+
plt.show()
|
1197
|
+
|
1198
|
+
def normalize_and_visualize(image, normalized_image, title=""):
|
1199
|
+
"""Utility function for visualization"""
|
1200
|
+
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
|
1201
|
+
if image.ndim == 3: # Multi-channel image
|
1202
|
+
ax[0].imshow(np.mean(image, axis=-1), cmap='gray') # Display the average over channels for visualization
|
1203
|
+
else: # Grayscale image
|
1204
|
+
ax[0].imshow(image, cmap='gray')
|
1205
|
+
ax[0].set_title("Original " + title)
|
1206
|
+
ax[0].axis('off')
|
1207
|
+
|
1208
|
+
if normalized_image.ndim == 3:
|
1209
|
+
ax[1].imshow(np.mean(normalized_image, axis=-1), cmap='gray') # Similarly, display the average over channels
|
1210
|
+
else:
|
1211
|
+
ax[1].imshow(normalized_image, cmap='gray')
|
1212
|
+
ax[1].set_title("Normalized " + title)
|
1213
|
+
ax[1].axis('off')
|
1214
|
+
|
1215
|
+
plt.show()
|
1216
|
+
|
1217
|
+
def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
|
1218
|
+
fig, axs = plt.subplots(1, 3, figsize=(30, 10))
|
1219
|
+
for ax, mask, title in zip(axs, [mask1, mask2, mask3], ['Mask 1', 'Mask 2', 'Mask 3']):
|
1220
|
+
cmap = generate_mask_random_cmap(mask)
|
1221
|
+
# If the mask is binary, we can skip normalization
|
1222
|
+
if np.isin(mask, [0, 1]).all():
|
1223
|
+
ax.imshow(mask, cmap=cmap)
|
1224
|
+
else:
|
1225
|
+
# Normalize the image for displaying purposes
|
1226
|
+
norm = plt.Normalize(vmin=0, vmax=mask.max())
|
1227
|
+
ax.imshow(mask, cmap=cmap, norm=norm)
|
1228
|
+
ax.set_title(title)
|
1229
|
+
ax.axis('off')
|
1230
|
+
plt.suptitle(title)
|
1231
|
+
plt.show()
|
1232
|
+
|
1233
|
+
def plot_comparison_results(comparison_results):
|
1234
|
+
df = pd.DataFrame(comparison_results)
|
1235
|
+
df_melted = pd.melt(df, id_vars=['filename'], var_name='metric', value_name='value')
|
1236
|
+
df_jaccard = df_melted[df_melted['metric'].str.contains('jaccard')]
|
1237
|
+
df_dice = df_melted[df_melted['metric'].str.contains('dice')]
|
1238
|
+
df_boundary_f1 = df_melted[df_melted['metric'].str.contains('boundary_f1')]
|
1239
|
+
df_ap = df_melted[df_melted['metric'].str.contains('average_precision')]
|
1240
|
+
fig, axs = plt.subplots(1, 4, figsize=(40, 10))
|
1241
|
+
|
1242
|
+
# Jaccard Index Plot
|
1243
|
+
sns.boxplot(data=df_jaccard, x='metric', y='value', ax=axs[0], color='lightgrey')
|
1244
|
+
sns.stripplot(data=df_jaccard, x='metric', y='value', ax=axs[0], jitter=True, alpha=0.6)
|
1245
|
+
axs[0].set_title('Jaccard Index by Comparison')
|
1246
|
+
axs[0].set_xticklabels(axs[0].get_xticklabels(), rotation=45, horizontalalignment='right')
|
1247
|
+
axs[0].set_xlabel('Comparison')
|
1248
|
+
axs[0].set_ylabel('Jaccard Index')
|
1249
|
+
# Dice Coefficient Plot
|
1250
|
+
sns.boxplot(data=df_dice, x='metric', y='value', ax=axs[1], color='lightgrey')
|
1251
|
+
sns.stripplot(data=df_dice, x='metric', y='value', ax=axs[1], jitter=True, alpha=0.6)
|
1252
|
+
axs[1].set_title('Dice Coefficient by Comparison')
|
1253
|
+
axs[1].set_xticklabels(axs[1].get_xticklabels(), rotation=45, horizontalalignment='right')
|
1254
|
+
axs[1].set_xlabel('Comparison')
|
1255
|
+
axs[1].set_ylabel('Dice Coefficient')
|
1256
|
+
# Border F1 scores
|
1257
|
+
sns.boxplot(data=df_boundary_f1, x='metric', y='value', ax=axs[2], color='lightgrey')
|
1258
|
+
sns.stripplot(data=df_boundary_f1, x='metric', y='value', ax=axs[2], jitter=True, alpha=0.6)
|
1259
|
+
axs[2].set_title('Boundary F1 Score by Comparison')
|
1260
|
+
axs[2].set_xticklabels(axs[2].get_xticklabels(), rotation=45, horizontalalignment='right')
|
1261
|
+
axs[2].set_xlabel('Comparison')
|
1262
|
+
axs[2].set_ylabel('Boundary F1 Score')
|
1263
|
+
# AP scores plot
|
1264
|
+
sns.boxplot(data=df_ap, x='metric', y='value', ax=axs[3], color='lightgrey')
|
1265
|
+
sns.stripplot(data=df_ap, x='metric', y='value', ax=axs[3], jitter=True, alpha=0.6)
|
1266
|
+
axs[3].set_title('Average Precision by Comparison')
|
1267
|
+
axs[3].set_xticklabels(axs[3].get_xticklabels(), rotation=45, horizontalalignment='right')
|
1268
|
+
axs[3].set_xlabel('Comparison')
|
1269
|
+
axs[3].set_ylabel('Average Precision')
|
1270
|
+
|
1271
|
+
plt.tight_layout()
|
1272
|
+
plt.show()
|
1273
|
+
return fig
|