spacr 0.3.47__py3-none-any.whl → 0.3.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/chat_bot.py +31 -0
- spacr/gui_elements.py +33 -7
- spacr/gui_utils.py +11 -12
- spacr/measure.py +4 -1
- spacr/ml.py +453 -141
- spacr/plot.py +612 -52
- spacr/sequencing.py +5 -2
- spacr/settings.py +15 -31
- spacr/toxo.py +447 -159
- spacr/utils.py +35 -4
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/METADATA +3 -1
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/RECORD +16 -15
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/LICENSE +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/WHEEL +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/top_level.txt +0 -0
spacr/plot.py
CHANGED
@@ -16,6 +16,7 @@ from skimage import measure
|
|
16
16
|
from skimage.measure import find_contours, label, regionprops
|
17
17
|
from skimage.segmentation import mark_boundaries
|
18
18
|
from skimage.transform import resize as sk_resize
|
19
|
+
import scikit_posthocs as sp
|
19
20
|
|
20
21
|
import tifffile as tiff
|
21
22
|
|
@@ -32,7 +33,340 @@ import matplotlib.patches as patches
|
|
32
33
|
from collections import defaultdict
|
33
34
|
from matplotlib.gridspec import GridSpec
|
34
35
|
|
35
|
-
|
36
|
+
#filter_dict={'cell':[(0,100000), (0, 65000)],'nucleus':[(3000,100000), (1500, 65000)],'pathogen':[(500,100000), (0, 65000)]}
|
37
|
+
def plot_image_mask_overlay(
|
38
|
+
file,
|
39
|
+
channels,
|
40
|
+
cell_channel,
|
41
|
+
nucleus_channel,
|
42
|
+
pathogen_channel,
|
43
|
+
figuresize=10,
|
44
|
+
percentiles=(2, 98),
|
45
|
+
thickness=3,
|
46
|
+
save_pdf=True,
|
47
|
+
mode='outlines',
|
48
|
+
export_tiffs=False,
|
49
|
+
all_on_all=False,
|
50
|
+
all_outlines=False,
|
51
|
+
filter_dict=None
|
52
|
+
):
|
53
|
+
"""Plot image and mask overlays."""
|
54
|
+
|
55
|
+
def random_color_cmap(n_labels, seed=None):
|
56
|
+
"""Generates a random color map for a given number of labels."""
|
57
|
+
if seed is not None:
|
58
|
+
np.random.seed(seed)
|
59
|
+
rand_colors = np.random.rand(n_labels, 3)
|
60
|
+
rand_colors = np.vstack([[0, 0, 0], rand_colors]) # Ensure background is black
|
61
|
+
cmap = ListedColormap(rand_colors)
|
62
|
+
return cmap
|
63
|
+
|
64
|
+
def _plot_merged_plot(
|
65
|
+
image,
|
66
|
+
outlines,
|
67
|
+
outline_colors,
|
68
|
+
figuresize,
|
69
|
+
thickness,
|
70
|
+
percentiles,
|
71
|
+
mode='outlines',
|
72
|
+
all_on_all=False,
|
73
|
+
all_outlines=False,
|
74
|
+
channels=None,
|
75
|
+
cell_channel=None,
|
76
|
+
nucleus_channel=None,
|
77
|
+
pathogen_channel=None,
|
78
|
+
cell_outlines=None,
|
79
|
+
nucleus_outlines=None,
|
80
|
+
pathogen_outlines=None,
|
81
|
+
save_pdf=True
|
82
|
+
):
|
83
|
+
"""Plot the merged plot with overlay, image channels, and masks."""
|
84
|
+
|
85
|
+
def _generate_colored_mask(mask, cmap):
|
86
|
+
"""Generate a colored mask using the given colormap."""
|
87
|
+
mask_norm = mask / (mask.max() + 1e-5) # Normalize mask
|
88
|
+
colored_mask = cmap(mask_norm)
|
89
|
+
colored_mask[..., 3] = np.where(mask > 0, 1, 0) # Alpha channel
|
90
|
+
return colored_mask
|
91
|
+
|
92
|
+
def _overlay_mask(image, mask):
|
93
|
+
"""Overlay the colored mask onto the original image."""
|
94
|
+
combined = np.clip(image * (1 - mask[..., 3:]) + mask[..., :3] * mask[..., 3:], 0, 1)
|
95
|
+
return combined
|
96
|
+
|
97
|
+
def _normalize_image(image, percentiles):
|
98
|
+
"""Normalize the image based on given percentiles."""
|
99
|
+
v_min, v_max = np.percentile(image, percentiles)
|
100
|
+
image_normalized = np.clip((image - v_min) / (v_max - v_min + 1e-5), 0, 1)
|
101
|
+
return image_normalized
|
102
|
+
|
103
|
+
def _generate_contours(mask):
|
104
|
+
"""Generate contours from the mask using OpenCV."""
|
105
|
+
contours, _ = cv2.findContours(
|
106
|
+
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
107
|
+
)
|
108
|
+
return contours
|
109
|
+
|
110
|
+
def _apply_contours(image, mask, color, thickness):
|
111
|
+
"""Apply contours to the image."""
|
112
|
+
unique_labels = np.unique(mask)
|
113
|
+
for label in unique_labels:
|
114
|
+
if label == 0:
|
115
|
+
continue # Skip background
|
116
|
+
label_mask = (mask == label).astype(np.uint8)
|
117
|
+
contours = _generate_contours(label_mask)
|
118
|
+
cv2.drawContours(
|
119
|
+
image, contours, -1, mpl.colors.to_rgb(color), thickness
|
120
|
+
)
|
121
|
+
return image
|
122
|
+
|
123
|
+
num_channels = image.shape[-1]
|
124
|
+
fig, ax = plt.subplots(1, num_channels + 1, figsize=(4 * figuresize, figuresize))
|
125
|
+
|
126
|
+
# Identify channels without associated outlines
|
127
|
+
channels_with_outlines = []
|
128
|
+
if cell_channel is not None:
|
129
|
+
channels_with_outlines.append(cell_channel)
|
130
|
+
if nucleus_channel is not None:
|
131
|
+
channels_with_outlines.append(nucleus_channel)
|
132
|
+
if pathogen_channel is not None:
|
133
|
+
channels_with_outlines.append(pathogen_channel)
|
134
|
+
|
135
|
+
for v in range(num_channels):
|
136
|
+
channel_image = image[..., v]
|
137
|
+
channel_image_normalized = _normalize_image(channel_image, percentiles)
|
138
|
+
channel_image_rgb = np.dstack([channel_image_normalized] * 3)
|
139
|
+
|
140
|
+
current_channel = channels[v]
|
141
|
+
|
142
|
+
if all_on_all:
|
143
|
+
# Apply all outlines to all channels
|
144
|
+
for outline, color in zip(outlines, outline_colors):
|
145
|
+
if mode == 'outlines':
|
146
|
+
channel_image_rgb = _apply_contours(
|
147
|
+
channel_image_rgb, outline, color, thickness
|
148
|
+
)
|
149
|
+
else:
|
150
|
+
cmap = random_color_cmap(int(outline.max() + 1), random.randint(0, 100))
|
151
|
+
mask = _generate_colored_mask(outline, cmap)
|
152
|
+
channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
|
153
|
+
elif current_channel in channels_with_outlines:
|
154
|
+
# Apply only the relevant outline to each channel
|
155
|
+
outline = None
|
156
|
+
color = None
|
157
|
+
|
158
|
+
if current_channel == cell_channel and cell_outlines is not None:
|
159
|
+
outline = cell_outlines
|
160
|
+
elif current_channel == nucleus_channel and nucleus_outlines is not None:
|
161
|
+
outline = nucleus_outlines
|
162
|
+
elif current_channel == pathogen_channel and pathogen_outlines is not None:
|
163
|
+
outline = pathogen_outlines
|
164
|
+
|
165
|
+
if outline is not None:
|
166
|
+
if mode == 'outlines':
|
167
|
+
# Use magenta color when all_on_all=False
|
168
|
+
channel_image_rgb = _apply_contours(
|
169
|
+
channel_image_rgb, outline, '#FF00FF', thickness
|
170
|
+
)
|
171
|
+
else:
|
172
|
+
cmap = random_color_cmap(int(outline.max() + 1), random.randint(0, 100))
|
173
|
+
mask = _generate_colored_mask(outline, cmap)
|
174
|
+
channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
|
175
|
+
else:
|
176
|
+
# Channel without associated outlines
|
177
|
+
if all_outlines:
|
178
|
+
# Apply all outlines with specified colors
|
179
|
+
for outline, color in zip(outlines, ['blue', 'red', 'green']):
|
180
|
+
if mode == 'outlines':
|
181
|
+
channel_image_rgb = _apply_contours(
|
182
|
+
channel_image_rgb, outline, color, thickness
|
183
|
+
)
|
184
|
+
else:
|
185
|
+
cmap = random_color_cmap(int(outline.max() + 1), random.randint(0, 100))
|
186
|
+
mask = _generate_colored_mask(outline, cmap)
|
187
|
+
channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
|
188
|
+
|
189
|
+
ax[v].imshow(channel_image_rgb)
|
190
|
+
ax[v].set_title(f'Image - Channel {current_channel}')
|
191
|
+
|
192
|
+
# Create an image combining all objects filled with colors
|
193
|
+
combined_mask = np.zeros_like(outlines[0])
|
194
|
+
for outline in outlines:
|
195
|
+
combined_mask = np.maximum(combined_mask, outline)
|
196
|
+
|
197
|
+
cmap = random_color_cmap(int(combined_mask.max() + 1), random.randint(0, 100))
|
198
|
+
mask = _generate_colored_mask(combined_mask, cmap)
|
199
|
+
blank_image = np.zeros((*combined_mask.shape, 3))
|
200
|
+
filled_image = _overlay_mask(blank_image, mask)
|
201
|
+
|
202
|
+
ax[-1].imshow(filled_image)
|
203
|
+
ax[-1].set_title('Combined Objects Image')
|
204
|
+
|
205
|
+
plt.tight_layout()
|
206
|
+
|
207
|
+
# Save the figure as a PDF
|
208
|
+
if save_pdf:
|
209
|
+
pdf_dir = os.path.join(
|
210
|
+
os.path.dirname(os.path.dirname(file)), 'results', 'overlay'
|
211
|
+
)
|
212
|
+
os.makedirs(pdf_dir, exist_ok=True)
|
213
|
+
pdf_path = os.path.join(
|
214
|
+
pdf_dir, os.path.basename(file).replace('.npy', '.pdf')
|
215
|
+
)
|
216
|
+
fig.savefig(pdf_path, format='pdf')
|
217
|
+
|
218
|
+
plt.show()
|
219
|
+
return fig
|
220
|
+
|
221
|
+
def _save_channels_as_tiff(stack, save_dir, filename):
|
222
|
+
"""Save each channel in the stack as a grayscale TIFF."""
|
223
|
+
os.makedirs(save_dir, exist_ok=True)
|
224
|
+
for i in range(stack.shape[-1]):
|
225
|
+
channel = stack[..., i]
|
226
|
+
tiff_path = os.path.join(save_dir, f"{filename}_channel_{i}.tiff")
|
227
|
+
tiff.imwrite(tiff_path, channel.astype(np.uint16), photometric='minisblack')
|
228
|
+
print(f"Saved {tiff_path}")
|
229
|
+
|
230
|
+
def _filter_object(mask, intensity_image, min_max_area=(0, 10000000), min_max_intensity=(0, 65000), type_='object'):
|
231
|
+
"""
|
232
|
+
Filter objects in a mask based on their area (size) and mean intensity.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
mask (ndarray): The input mask.
|
236
|
+
intensity_image (ndarray): The corresponding intensity image.
|
237
|
+
min_max_area (tuple): A tuple (min_area, max_area) specifying the minimum and maximum area thresholds.
|
238
|
+
min_max_intensity (tuple): A tuple (min_intensity, max_intensity) specifying the minimum and maximum intensity thresholds.
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
ndarray: The filtered mask.
|
242
|
+
"""
|
243
|
+
original_dtype = mask.dtype
|
244
|
+
mask_int = mask.astype(np.int64)
|
245
|
+
intensity_image = intensity_image.astype(np.float64)
|
246
|
+
# Compute properties for each labeled object
|
247
|
+
unique_labels = np.unique(mask_int)
|
248
|
+
unique_labels = unique_labels[unique_labels != 0] # Exclude background
|
249
|
+
num_objects_before = len(unique_labels)
|
250
|
+
|
251
|
+
# Initialize lists to store area and intensity for each object
|
252
|
+
areas = []
|
253
|
+
mean_intensities = []
|
254
|
+
labels_to_keep = []
|
255
|
+
|
256
|
+
for label in unique_labels:
|
257
|
+
label_mask = (mask_int == label)
|
258
|
+
area = np.sum(label_mask)
|
259
|
+
mean_intensity = np.mean(intensity_image[label_mask])
|
260
|
+
|
261
|
+
areas.append(area)
|
262
|
+
mean_intensities.append(mean_intensity)
|
263
|
+
|
264
|
+
# Check if the object meets both area and intensity criteria
|
265
|
+
if (min_max_area[0] <= area <= min_max_area[1]) and (min_max_intensity[0] <= mean_intensity <= min_max_intensity[1]):
|
266
|
+
labels_to_keep.append(label)
|
267
|
+
|
268
|
+
# Convert lists to numpy arrays for easier computation
|
269
|
+
areas = np.array(areas)
|
270
|
+
mean_intensities = np.array(mean_intensities)
|
271
|
+
num_objects_after = len(labels_to_keep)
|
272
|
+
# Compute average area and intensity before and after filtering
|
273
|
+
avg_area_before = areas.mean() if num_objects_before > 0 else 0
|
274
|
+
avg_intensity_before = mean_intensities.mean() if num_objects_before > 0 else 0
|
275
|
+
areas_after = areas[np.isin(unique_labels, labels_to_keep)]
|
276
|
+
mean_intensities_after = mean_intensities[np.isin(unique_labels, labels_to_keep)]
|
277
|
+
avg_area_after = areas_after.mean() if num_objects_after > 0 else 0
|
278
|
+
avg_intensity_after = mean_intensities_after.mean() if num_objects_after > 0 else 0
|
279
|
+
print(f"Before filtering {type_}: {num_objects_before} objects")
|
280
|
+
print(f"Average area {type_}: {avg_area_before:.2f} pixels, Average intensity: {avg_intensity_before:.2f}")
|
281
|
+
print(f"After filtering {type_}: {num_objects_after} objects")
|
282
|
+
print(f"Average area {type_}: {avg_area_after:.2f} pixels, Average intensity: {avg_intensity_after:.2f}")
|
283
|
+
mask_filtered = np.zeros_like(mask_int)
|
284
|
+
for label in labels_to_keep:
|
285
|
+
mask_filtered[mask_int == label] = label
|
286
|
+
mask_filtered = mask_filtered.astype(original_dtype)
|
287
|
+
return mask_filtered
|
288
|
+
|
289
|
+
stack = np.load(file)
|
290
|
+
|
291
|
+
if export_tiffs:
|
292
|
+
save_dir = os.path.join(
|
293
|
+
os.path.dirname(os.path.dirname(file)),
|
294
|
+
'results',
|
295
|
+
os.path.splitext(os.path.basename(file))[0],
|
296
|
+
'tiff'
|
297
|
+
)
|
298
|
+
filename = os.path.splitext(os.path.basename(file))[0]
|
299
|
+
_save_channels_as_tiff(stack, save_dir, filename)
|
300
|
+
|
301
|
+
# Convert to float for normalization and ensure correct handling of arrays
|
302
|
+
if stack.dtype in (np.uint16, np.uint8):
|
303
|
+
stack = stack.astype(np.float32)
|
304
|
+
|
305
|
+
image = stack[..., channels]
|
306
|
+
outlines = []
|
307
|
+
outline_colors = []
|
308
|
+
|
309
|
+
# Define variables to hold individual outlines
|
310
|
+
cell_outlines = None
|
311
|
+
nucleus_outlines = None
|
312
|
+
pathogen_outlines = None
|
313
|
+
|
314
|
+
if pathogen_channel is not None:
|
315
|
+
pathogen_mask_dim = -1
|
316
|
+
pathogen_outlines = np.take(stack, pathogen_mask_dim, axis=2)
|
317
|
+
if not filter_dict is None:
|
318
|
+
pathogen_intensity = np.take(stack, pathogen_channel, axis=2)
|
319
|
+
pathogen_outlines = _filter_object(pathogen_outlines, pathogen_intensity, filter_dict['pathogen'][0], filter_dict['pathogen'][1], type_='pathogen')
|
320
|
+
|
321
|
+
outlines.append(pathogen_outlines)
|
322
|
+
outline_colors.append('green')
|
323
|
+
|
324
|
+
if nucleus_channel is not None:
|
325
|
+
nucleus_mask_dim = -2 if pathogen_channel is not None else -1
|
326
|
+
nucleus_outlines = np.take(stack, nucleus_mask_dim, axis=2)
|
327
|
+
if not filter_dict is None:
|
328
|
+
nucleus_intensity = np.take(stack, nucleus_channel, axis=2)
|
329
|
+
nucleus_outlines = _filter_object(nucleus_outlines, nucleus_intensity, filter_dict['nucleus'][0], filter_dict['nucleus'][1], type_='nucleus')
|
330
|
+
outlines.append(nucleus_outlines)
|
331
|
+
outline_colors.append('blue')
|
332
|
+
|
333
|
+
if cell_channel is not None:
|
334
|
+
if nucleus_channel is not None and pathogen_channel is not None:
|
335
|
+
cell_mask_dim = -3
|
336
|
+
elif nucleus_channel is not None or pathogen_channel is not None:
|
337
|
+
cell_mask_dim = -2
|
338
|
+
else:
|
339
|
+
cell_mask_dim = -1
|
340
|
+
cell_outlines = np.take(stack, cell_mask_dim, axis=2)
|
341
|
+
if not filter_dict is None:
|
342
|
+
cell_intensity = np.take(stack, cell_channel, axis=2)
|
343
|
+
cell_outlines = _filter_object(cell_outlines, cell_intensity, filter_dict['cell'][0], filter_dict['cell'][1], type_='cell')
|
344
|
+
outlines.append(cell_outlines)
|
345
|
+
outline_colors.append('red')
|
346
|
+
|
347
|
+
fig = _plot_merged_plot(
|
348
|
+
image=image,
|
349
|
+
outlines=outlines,
|
350
|
+
outline_colors=outline_colors,
|
351
|
+
figuresize=figuresize,
|
352
|
+
thickness=thickness,
|
353
|
+
percentiles=percentiles, # Pass percentiles to the plotting function
|
354
|
+
mode=mode,
|
355
|
+
all_on_all=all_on_all,
|
356
|
+
all_outlines=all_outlines,
|
357
|
+
channels=channels,
|
358
|
+
cell_channel=cell_channel,
|
359
|
+
nucleus_channel=nucleus_channel,
|
360
|
+
pathogen_channel=pathogen_channel,
|
361
|
+
cell_outlines=cell_outlines,
|
362
|
+
nucleus_outlines=nucleus_outlines,
|
363
|
+
pathogen_outlines=pathogen_outlines,
|
364
|
+
save_pdf=save_pdf
|
365
|
+
)
|
366
|
+
|
367
|
+
return fig
|
368
|
+
|
369
|
+
def plot_image_mask_overlay_v1(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2,98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False):
|
36
370
|
"""Plot image and mask overlays."""
|
37
371
|
|
38
372
|
def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness, percentiles, mode='outlines'):
|
@@ -1398,7 +1732,7 @@ def _plot_histograms_and_stats(df):
|
|
1398
1732
|
print('-'*40)
|
1399
1733
|
|
1400
1734
|
# Plot the histogram
|
1401
|
-
plt.figure(figsize=(10,
|
1735
|
+
plt.figure(figsize=(10,10))
|
1402
1736
|
plt.hist(subset['pred'], bins=30, edgecolor='black')
|
1403
1737
|
plt.axvline(mean_pred, color='red', linestyle='dashed', linewidth=1, label=f"Mean = {mean_pred:.2f}")
|
1404
1738
|
plt.title(f'Histogram for pred - Condition: {condition}')
|
@@ -1455,12 +1789,16 @@ def _reg_v_plot(df, grouping, variable, plate_number):
|
|
1455
1789
|
plt.show()
|
1456
1790
|
|
1457
1791
|
def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
|
1792
|
+
|
1793
|
+
if not isinstance(min_count, (int, float)):
|
1794
|
+
min_count = 0
|
1795
|
+
|
1458
1796
|
df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
|
1459
1797
|
df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
|
1460
1798
|
|
1461
1799
|
# Filtering the dataframe based on the plate_number
|
1462
1800
|
df = df[df['plate'] == plate_number].copy() # Create another copy after filtering
|
1463
|
-
|
1801
|
+
|
1464
1802
|
# Ensure proper ordering
|
1465
1803
|
row_order = [f'r{i}' for i in range(1, 17)]
|
1466
1804
|
col_order = [f'c{i}' for i in range(1, 28)] # Exclude c15 as per your earlier code
|
@@ -1496,7 +1834,6 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
|
|
1496
1834
|
min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
|
1497
1835
|
if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
|
1498
1836
|
min_max = [min_max[0], min_max[1]]
|
1499
|
-
|
1500
1837
|
return plate_map, min_max
|
1501
1838
|
|
1502
1839
|
def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None):
|
@@ -1516,10 +1853,14 @@ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True
|
|
1516
1853
|
plt.subplots_adjust(wspace=0.1, hspace=0.4)
|
1517
1854
|
|
1518
1855
|
if not dst is None:
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
1856
|
+
for i in range(0,1000):
|
1857
|
+
filename = os.path.join(dst, f'plate_heatmap_{i}.pdf')
|
1858
|
+
if os.path.exists(filename):
|
1859
|
+
continue
|
1860
|
+
else:
|
1861
|
+
fig.savefig(filename, format='pdf')
|
1862
|
+
print(f'Saved heatmap to {filename}')
|
1863
|
+
break
|
1523
1864
|
if verbose:
|
1524
1865
|
plt.show()
|
1525
1866
|
return fig
|
@@ -1886,22 +2227,77 @@ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
|
1886
2227
|
print(f'Saved Volcano plot: {filename}')
|
1887
2228
|
plt.show()
|
1888
2229
|
|
1889
|
-
def plot_histogram(df,
|
2230
|
+
def plot_histogram(df, column, dst=None):
|
1890
2231
|
# Plot histogram of the dependent variable
|
1891
|
-
|
1892
|
-
|
1893
|
-
|
1894
|
-
plt.
|
2232
|
+
bar_color = (0/255, 155/255, 155/255)
|
2233
|
+
plt.figure(figsize=(10, 10))
|
2234
|
+
sns.histplot(df[column], kde=False, color=bar_color, edgecolor=None, alpha=0.6)
|
2235
|
+
plt.title(f'Histogram of {column}')
|
2236
|
+
plt.xlabel(column)
|
1895
2237
|
plt.ylabel('Frequency')
|
1896
2238
|
|
1897
2239
|
if not dst is None:
|
1898
|
-
filename = os.path.join(dst, '
|
2240
|
+
filename = os.path.join(dst, f'{column}_histogram.pdf')
|
1899
2241
|
plt.savefig(filename, format='pdf')
|
1900
2242
|
print(f'Saved histogram to {filename}')
|
1901
2243
|
|
1902
2244
|
plt.show()
|
1903
2245
|
|
1904
|
-
def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
|
2246
|
+
def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count', remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4'], x_lim=[0.0,1], y_lim=[0,1], save=True):
|
2247
|
+
|
2248
|
+
def lorenz_curve(data):
|
2249
|
+
"""Calculate Lorenz curve."""
|
2250
|
+
sorted_data = np.sort(data)
|
2251
|
+
cumulative_data = np.cumsum(sorted_data)
|
2252
|
+
lorenz_curve = cumulative_data / cumulative_data[-1]
|
2253
|
+
lorenz_curve = np.insert(lorenz_curve, 0, 0)
|
2254
|
+
return lorenz_curve
|
2255
|
+
|
2256
|
+
combined_data = []
|
2257
|
+
|
2258
|
+
plt.figure(figsize=(10, 10))
|
2259
|
+
|
2260
|
+
for idx, csv_file in enumerate(csv_files):
|
2261
|
+
if idx == 1:
|
2262
|
+
save_fldr = os.path.dirname(csv_file)
|
2263
|
+
save_path = os.path.join(save_fldr, 'lorenz_curve.pdf')
|
2264
|
+
|
2265
|
+
df = pd.read_csv(csv_file)
|
2266
|
+
for remove in remove_keys:
|
2267
|
+
df = df[df[name_column] != remove]
|
2268
|
+
|
2269
|
+
values = df[value_column].values
|
2270
|
+
combined_data.extend(values)
|
2271
|
+
|
2272
|
+
lorenz = lorenz_curve(values)
|
2273
|
+
name = f"plate {idx+1}"
|
2274
|
+
plt.plot(np.linspace(0, 1, len(lorenz)), lorenz, label=name)
|
2275
|
+
|
2276
|
+
# Plot combined Lorenz curve
|
2277
|
+
combined_lorenz = lorenz_curve(np.array(combined_data))
|
2278
|
+
plt.plot(np.linspace(0, 1, len(combined_lorenz)), combined_lorenz, label="Combined", linestyle='--', color='black')
|
2279
|
+
|
2280
|
+
if x_lim != None:
|
2281
|
+
plt.xlim(x_lim)
|
2282
|
+
|
2283
|
+
if y_lim != None:
|
2284
|
+
plt.ylim(y_lim)
|
2285
|
+
|
2286
|
+
plt.title('Lorenz Curves')
|
2287
|
+
plt.xlabel('Cumulative Share of Individuals')
|
2288
|
+
plt.ylabel('Cumulative Share of Value')
|
2289
|
+
plt.legend()
|
2290
|
+
plt.grid(False)
|
2291
|
+
|
2292
|
+
if save:
|
2293
|
+
save_path = os.path.join(os.path.dirname(csv_files[0]), 'results')
|
2294
|
+
os.makedirs(save_path, exist_ok=True)
|
2295
|
+
save_file_path = os.path.join(save_path, 'lorenz_curve.pdf')
|
2296
|
+
plt.savefig(save_file_path, format='pdf', bbox_inches='tight')
|
2297
|
+
print(f"Saved Lorenz Curve: {save_file_path}")
|
2298
|
+
plt.show()
|
2299
|
+
|
2300
|
+
def plot_lorenz_curves_v1(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
|
1905
2301
|
|
1906
2302
|
def lorenz_curve(data):
|
1907
2303
|
"""Calculate Lorenz curve."""
|
@@ -2358,22 +2754,33 @@ class spacrGraph:
|
|
2358
2754
|
return filtered_df
|
2359
2755
|
|
2360
2756
|
def perform_normality_tests(self):
|
2361
|
-
"""Perform normality tests for each group and
|
2757
|
+
"""Perform normality tests for each group and data column."""
|
2362
2758
|
unique_groups = self.df[self.grouping_column].unique()
|
2363
2759
|
normality_results = []
|
2364
2760
|
|
2365
2761
|
for column in self.data_column:
|
2366
|
-
# Iterate over each group and its corresponding data
|
2367
2762
|
for group in unique_groups:
|
2368
|
-
data = self.df.loc[self.df[self.grouping_column] == group, column]
|
2763
|
+
data = self.df.loc[self.df[self.grouping_column] == group, column].dropna()
|
2369
2764
|
n_samples = len(data)
|
2370
2765
|
|
2766
|
+
if n_samples < 3:
|
2767
|
+
# Skip test if there aren't enough data points
|
2768
|
+
print(f"Skipping normality test for group '{group}' on column '{column}' - Not enough data.")
|
2769
|
+
normality_results.append({
|
2770
|
+
'Comparison': f'Normality test for {group} on {column}',
|
2771
|
+
'Test Statistic': None,
|
2772
|
+
'p-value': None,
|
2773
|
+
'Test Name': 'Skipped',
|
2774
|
+
'Column': column,
|
2775
|
+
'n': n_samples
|
2776
|
+
})
|
2777
|
+
continue
|
2778
|
+
|
2779
|
+
# Choose the appropriate normality test based on the sample size
|
2371
2780
|
if n_samples >= 8:
|
2372
|
-
# Use D'Agostino-Pearson test for larger samples
|
2373
2781
|
stat, p_value = normaltest(data)
|
2374
2782
|
test_name = "D'Agostino-Pearson test"
|
2375
2783
|
else:
|
2376
|
-
# Use Shapiro-Wilk test for smaller samples
|
2377
2784
|
stat, p_value = shapiro(data)
|
2378
2785
|
test_name = "Shapiro-Wilk test"
|
2379
2786
|
|
@@ -2384,11 +2791,11 @@ class spacrGraph:
|
|
2384
2791
|
'p-value': p_value,
|
2385
2792
|
'Test Name': test_name,
|
2386
2793
|
'Column': column,
|
2387
|
-
'n': n_samples
|
2794
|
+
'n': n_samples
|
2388
2795
|
})
|
2389
2796
|
|
2390
2797
|
# Check if all groups are normally distributed (p > 0.05)
|
2391
|
-
normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column]
|
2798
|
+
normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column and result['p-value'] is not None]
|
2392
2799
|
is_normal = all(p > 0.05 for p in normal_p_values)
|
2393
2800
|
|
2394
2801
|
return is_normal, normality_results
|
@@ -2438,9 +2845,13 @@ class spacrGraph:
|
|
2438
2845
|
len(self.df[self.df[self.grouping_column] == unique_groups[1]])})
|
2439
2846
|
|
2440
2847
|
return test_results
|
2441
|
-
|
2848
|
+
|
2442
2849
|
def perform_posthoc_tests(self, is_normal, unique_groups):
|
2443
2850
|
"""Perform post-hoc tests for multiple groups based on all_to_all flag."""
|
2851
|
+
|
2852
|
+
from .utils import choose_p_adjust_method
|
2853
|
+
|
2854
|
+
posthoc_results = []
|
2444
2855
|
if is_normal and len(unique_groups) > 2 and self.all_to_all:
|
2445
2856
|
tukey_result = pairwise_tukeyhsd(self.df[self.data_column], self.df[self.grouping_column], alpha=0.05)
|
2446
2857
|
posthoc_results = []
|
@@ -2456,22 +2867,40 @@ class spacrGraph:
|
|
2456
2867
|
'n_object': len(raw_data1) + len(raw_data2),
|
2457
2868
|
'n_well': len(self.df[self.df[self.grouping_column] == comparison[0]]) + len(self.df[self.df[self.grouping_column] == comparison[1]])})
|
2458
2869
|
return posthoc_results
|
2459
|
-
|
2460
|
-
elif len(unique_groups) > 2 and
|
2461
|
-
|
2462
|
-
|
2463
|
-
for
|
2464
|
-
|
2465
|
-
|
2466
|
-
|
2467
|
-
|
2468
|
-
|
2469
|
-
|
2470
|
-
|
2471
|
-
|
2472
|
-
|
2870
|
+
|
2871
|
+
elif len(unique_groups) > 2 and self.all_to_all:
|
2872
|
+
print('performing_dunns')
|
2873
|
+
|
2874
|
+
# Prepare data for Dunn's test in long format
|
2875
|
+
long_data = self.df[[self.data_column[0], self.grouping_column]].dropna()
|
2876
|
+
|
2877
|
+
p_adjust_method = choose_p_adjust_method(num_groups=len(long_data[self.grouping_column].unique()),num_data_points=len(long_data) // len(long_data[self.grouping_column].unique()))
|
2878
|
+
|
2879
|
+
# Perform Dunn's test with Bonferroni correction
|
2880
|
+
dunn_result = sp.posthoc_dunn(
|
2881
|
+
long_data,
|
2882
|
+
val_col=self.data_column[0],
|
2883
|
+
group_col=self.grouping_column,
|
2884
|
+
p_adjust=p_adjust_method
|
2885
|
+
)
|
2886
|
+
|
2887
|
+
for group_a, group_b in zip(*np.triu_indices_from(dunn_result, k=1)):
|
2888
|
+
raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == dunn_result.index[group_a]][self.data_column]
|
2889
|
+
raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == dunn_result.columns[group_b]][self.data_column]
|
2890
|
+
|
2891
|
+
posthoc_results.append({
|
2892
|
+
'Comparison': f"{dunn_result.index[group_a]} vs {dunn_result.columns[group_b]}",
|
2893
|
+
'Test Statistic': None, # Dunn's test does not return a specific test statistic
|
2894
|
+
'p-value': dunn_result.iloc[group_a, group_b], # Extract the p-value from the matrix
|
2895
|
+
'Test Name': "Dunn's Post-hoc",
|
2896
|
+
'p_adjust_method': p_adjust_method,
|
2897
|
+
'n_object': len(raw_data1) + len(raw_data2), # Total objects
|
2898
|
+
'n_well': len(self.df[self.df[self.grouping_column] == dunn_result.index[group_a]]) +
|
2899
|
+
len(self.df[self.grouping_column] == dunn_result.columns[group_b])})
|
2900
|
+
|
2473
2901
|
return posthoc_results
|
2474
|
-
|
2902
|
+
|
2903
|
+
return posthoc_results
|
2475
2904
|
|
2476
2905
|
def create_plot(self, ax=None):
|
2477
2906
|
"""Create and display the plot based on the chosen graph type."""
|
@@ -2507,7 +2936,43 @@ class spacrGraph:
|
|
2507
2936
|
transposed_table = list(map(list, zip(*table_data)))
|
2508
2937
|
return row_labels, transposed_table
|
2509
2938
|
|
2510
|
-
|
2939
|
+
|
2940
|
+
def _place_symbols(row_labels, transposed_table, x_positions, ax):
|
2941
|
+
"""
|
2942
|
+
Places symbols and row labels aligned under the bars or jitter points on the graph.
|
2943
|
+
|
2944
|
+
Parameters:
|
2945
|
+
- row_labels: List of row titles to be displayed along the y-axis.
|
2946
|
+
- transposed_table: Data to be placed under each bar/jitter as symbols.
|
2947
|
+
- x_positions: X-axis positions for each group to align the symbols.
|
2948
|
+
- ax: The matplotlib Axes object where the plot is drawn.
|
2949
|
+
"""
|
2950
|
+
# Get plot dimensions and adjust for different plot sizes
|
2951
|
+
y_axis_min = ax.get_ylim()[0] # Minimum y-axis value (usually 0)
|
2952
|
+
symbol_start_y = y_axis_min - 0.05 * (ax.get_ylim()[1] - y_axis_min) # Adjust a bit below the x-axis
|
2953
|
+
|
2954
|
+
# Calculate spacing for the table rows (adjust as needed)
|
2955
|
+
y_spacing = 0.04 # Adjust this for better spacing between rows
|
2956
|
+
|
2957
|
+
# Determine the leftmost x-position for row labels (align with the y-axis)
|
2958
|
+
label_x_pos = ax.get_xlim()[0] - 0.3 # Adjust offset from the y-axis
|
2959
|
+
|
2960
|
+
# Place row labels vertically aligned with symbols
|
2961
|
+
for row_idx, title in enumerate(row_labels):
|
2962
|
+
y_pos = symbol_start_y - (row_idx * y_spacing) # Calculate vertical position for each label
|
2963
|
+
ax.text(label_x_pos, y_pos, title, ha='right', va='center', fontsize=12, fontweight='regular')
|
2964
|
+
|
2965
|
+
# Place symbols under each bar or jitter point based on x-positions
|
2966
|
+
for idx, (x_pos, column_data) in enumerate(zip(x_positions, transposed_table)):
|
2967
|
+
for row_idx, text in enumerate(column_data):
|
2968
|
+
y_pos = symbol_start_y - (row_idx * y_spacing) # Adjust vertical spacing for symbols
|
2969
|
+
ax.text(x_pos, y_pos, text, ha='center', va='center', fontsize=12, fontweight='regular')
|
2970
|
+
|
2971
|
+
# Redraw to apply changes
|
2972
|
+
ax.figure.canvas.draw()
|
2973
|
+
|
2974
|
+
|
2975
|
+
def _place_symbols_v1(row_labels, transposed_table, x_positions, ax):
|
2511
2976
|
|
2512
2977
|
# Get the bottom of the y-axis (y=0) in data coordinates and convert to display coordinates
|
2513
2978
|
y_axis_min = ax.get_ylim()[0] # Minimum y-axis value (usually 0)
|
@@ -2642,6 +3107,10 @@ class spacrGraph:
|
|
2642
3107
|
else:
|
2643
3108
|
raise ValueError(f"Unknown graph type: {self.graph_type}")
|
2644
3109
|
|
3110
|
+
if len(self.data_column) == 1:
|
3111
|
+
num_groups = len(self.df[self.grouping_column].unique())
|
3112
|
+
self._standerdize_figure_format(ax=ax, num_groups=num_groups, graph_type=self.graph_type)
|
3113
|
+
|
2645
3114
|
# Set y-axis start
|
2646
3115
|
if isinstance(self.y_lim, list):
|
2647
3116
|
if len(self.y_lim) == 2:
|
@@ -2676,7 +3145,73 @@ class spacrGraph:
|
|
2676
3145
|
if self.save:
|
2677
3146
|
self._save_results()
|
2678
3147
|
|
2679
|
-
ax.margins(x=0.12)
|
3148
|
+
ax.margins(x=0.12)
|
3149
|
+
|
3150
|
+
def _standerdize_figure_format(self, ax, num_groups, graph_type):
|
3151
|
+
"""
|
3152
|
+
Adjusts the figure layout (size, bar width, jitter, and spacing) based on the number of groups.
|
3153
|
+
|
3154
|
+
Parameters:
|
3155
|
+
- ax: The matplotlib Axes object.
|
3156
|
+
- num_groups: Number of unique groups.
|
3157
|
+
- graph_type: The type of graph (e.g., 'bar', 'jitter', 'box', etc.).
|
3158
|
+
|
3159
|
+
Returns:
|
3160
|
+
- None. Modifies the figure and Axes in place.
|
3161
|
+
"""
|
3162
|
+
if graph_type in ['line', 'line_std']:
|
3163
|
+
print("Skipping layout adjustment for line graphs.")
|
3164
|
+
return # Skip layout adjustment for line graphs
|
3165
|
+
|
3166
|
+
correction_factor = 4
|
3167
|
+
|
3168
|
+
# Set figure size to ensure it remains square with a minimum size
|
3169
|
+
fig_size = max(6, num_groups * 2) / correction_factor
|
3170
|
+
ax.figure.set_size_inches(fig_size, fig_size)
|
3171
|
+
|
3172
|
+
# Configure layout based on the number of groups
|
3173
|
+
bar_width = min(0.8, 1.5 / num_groups) / correction_factor
|
3174
|
+
jitter_amount = min(0.1, 0.2 / num_groups) / correction_factor
|
3175
|
+
jitter_size = max(50 / num_groups, 200)
|
3176
|
+
|
3177
|
+
# Adjust axis limits to ensure bars are centered with respect to group labels
|
3178
|
+
ax.set_xlim(-0.5, num_groups - 0.5)
|
3179
|
+
|
3180
|
+
# Set ticks to match the group labels in your DataFrame
|
3181
|
+
group_labels = self.df[self.grouping_column].unique()
|
3182
|
+
ax.set_xticks(range(len(group_labels)))
|
3183
|
+
ax.set_xticklabels(group_labels, rotation=45, ha='right')
|
3184
|
+
|
3185
|
+
# Customize elements based on the graph type
|
3186
|
+
if graph_type == 'bar':
|
3187
|
+
# Adjust bars' width and position
|
3188
|
+
for bar in ax.patches:
|
3189
|
+
bar.set_width(bar_width)
|
3190
|
+
bar.set_x(bar.get_x() - bar_width / 2)
|
3191
|
+
|
3192
|
+
elif graph_type in ['jitter', 'jitter_bar', 'jitter_box']:
|
3193
|
+
# Adjust jitter points' position and size
|
3194
|
+
for coll in ax.collections:
|
3195
|
+
offsets = coll.get_offsets()
|
3196
|
+
offsets[:, 0] += jitter_amount # Shift jitter points slightly
|
3197
|
+
coll.set_offsets(offsets)
|
3198
|
+
coll.set_sizes([jitter_size] * len(offsets)) # Adjust point size dynamically
|
3199
|
+
|
3200
|
+
elif graph_type in ['box', 'violin']:
|
3201
|
+
# Adjust box width for consistent spacing
|
3202
|
+
for artist in ax.artists:
|
3203
|
+
artist.set_width(bar_width)
|
3204
|
+
|
3205
|
+
# Adjust legend and axis labels
|
3206
|
+
ax.tick_params(axis='x', labelsize=max(10, 15 - num_groups // 2))
|
3207
|
+
ax.tick_params(axis='y', labelsize=max(10, 15 - num_groups // 2))
|
3208
|
+
|
3209
|
+
if ax.get_legend():
|
3210
|
+
ax.get_legend().set_bbox_to_anchor((1.05, 1)) #loc='upper left',borderaxespad=0.
|
3211
|
+
ax.get_legend().prop.set_size(max(8, 12 - num_groups // 3))
|
3212
|
+
|
3213
|
+
# Redraw the figure to apply changes
|
3214
|
+
ax.figure.canvas.draw()
|
2680
3215
|
|
2681
3216
|
def _create_bar_plot(self, ax):
|
2682
3217
|
"""Helper method to create a bar plot with consistent bar thickness and centered error bars."""
|
@@ -2895,11 +3430,11 @@ class spacrGraph:
|
|
2895
3430
|
bar.set_x(bar.get_x() - target_width / 2)
|
2896
3431
|
|
2897
3432
|
# Adjust error bars alignment with bars
|
2898
|
-
bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
|
2899
|
-
for bar, (_, row) in zip(bars, summary_df.iterrows()):
|
2900
|
-
|
2901
|
-
|
2902
|
-
|
3433
|
+
#bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
|
3434
|
+
#for bar, (_, row) in zip(bars, summary_df.iterrows()):
|
3435
|
+
# x_bar = bar.get_x() + bar.get_width() / 2
|
3436
|
+
# err = row[self.error_bar_type]
|
3437
|
+
# ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
|
2903
3438
|
|
2904
3439
|
# Set legend and labels
|
2905
3440
|
ax.set_xlabel(self.grouping_column)
|
@@ -3092,9 +3627,13 @@ def plot_data_from_csv(settings):
|
|
3092
3627
|
dft = pd.read_csv(src)
|
3093
3628
|
if 'plate' not in dft.columns:
|
3094
3629
|
dft['plate'] = f"plate{i+1}"
|
3630
|
+
dft['common'] = 'spacr'
|
3095
3631
|
dfs.append(dft)
|
3096
3632
|
|
3097
3633
|
df = pd.concat(dfs, axis=0)
|
3634
|
+
|
3635
|
+
display(df)
|
3636
|
+
|
3098
3637
|
df = df.dropna(subset=settings['data_column'])
|
3099
3638
|
df = df.dropna(subset=settings['grouping_column'])
|
3100
3639
|
src = srcs[0]
|
@@ -3141,23 +3680,39 @@ def plot_region(settings):
|
|
3141
3680
|
print(f"Saved {path}")
|
3142
3681
|
|
3143
3682
|
from .io import _read_db
|
3683
|
+
from .utils import correct_paths
|
3144
3684
|
fov_path = os.path.join(settings['src'], 'merged', settings['name'])
|
3145
3685
|
name = os.path.splitext(settings['name'])[0]
|
3146
3686
|
|
3147
3687
|
db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
|
3148
3688
|
paths_df = _read_db(db_path, tables=['png_list'])[0]
|
3689
|
+
paths_df, _ = correct_paths(df=paths_df, base_path=settings['src'], folder='data')
|
3149
3690
|
paths_df = paths_df[paths_df['png_path'].str.contains(name, na=False)]
|
3150
3691
|
|
3151
3692
|
activation_mode = f"{settings['activation_mode']}_list"
|
3152
3693
|
activation_db_path = os.path.join(settings['src'], 'measurements', settings['activation_db'])
|
3153
3694
|
activation_paths_df = _read_db(activation_db_path, tables=[activation_mode])[0]
|
3695
|
+
activation_db = os.path.splitext(settings['activation_db'])[0]
|
3696
|
+
base_path=os.path.join(settings['src'], 'datasets',activation_db)
|
3697
|
+
activation_paths_df, _ = correct_paths(df=activation_paths_df, base_path=base_path, folder=settings['activation_mode'])
|
3154
3698
|
activation_paths_df = activation_paths_df[activation_paths_df['png_path'].str.contains(name, na=False)]
|
3155
3699
|
|
3156
3700
|
png_paths = _sort_paths_by_basename(paths_df['png_path'].tolist())
|
3157
3701
|
activation_paths = _sort_paths_by_basename(activation_paths_df['png_path'].tolist())
|
3158
3702
|
|
3159
|
-
|
3160
|
-
|
3703
|
+
|
3704
|
+
if activation_paths:
|
3705
|
+
fig_3 = plot_image_grid(image_paths=activation_paths, percentiles=settings['percentiles'])
|
3706
|
+
else:
|
3707
|
+
fig_3 = None
|
3708
|
+
print(f"Could not find any cropped PNGs")
|
3709
|
+
if png_paths:
|
3710
|
+
fig_2 = plot_image_grid(image_paths=png_paths, percentiles=settings['percentiles'])
|
3711
|
+
else:
|
3712
|
+
fig_2 = None
|
3713
|
+
print(f"Could not find any activation maps")
|
3714
|
+
|
3715
|
+
print('fov_path', fov_path)
|
3161
3716
|
fig_1 = plot_image_mask_overlay(file=fov_path,
|
3162
3717
|
channels=settings['channels'],
|
3163
3718
|
cell_channel=settings['cell_channel'],
|
@@ -3166,14 +3721,18 @@ def plot_region(settings):
|
|
3166
3721
|
figuresize=10,
|
3167
3722
|
percentiles=settings['percentiles'],
|
3168
3723
|
thickness=3,
|
3169
|
-
save_pdf=
|
3724
|
+
save_pdf=True,
|
3170
3725
|
mode=settings['mode'],
|
3171
3726
|
export_tiffs=settings['export_tiffs'])
|
3172
3727
|
|
3173
3728
|
dst = os.path.join(settings['src'], 'results', name)
|
3174
|
-
|
3175
|
-
|
3176
|
-
|
3729
|
+
|
3730
|
+
if not fig_1 == None:
|
3731
|
+
save_figure_as_pdf(fig_1, os.path.join(dst, f"{name}_mask_overlay.pdf"))
|
3732
|
+
if not fig_2 == None:
|
3733
|
+
save_figure_as_pdf(fig_2, os.path.join(dst, f"{name}_png_grid.pdf"))
|
3734
|
+
if not fig_3 == None:
|
3735
|
+
save_figure_as_pdf(fig_3, os.path.join(dst, f"{name}_activation_grid.pdf"))
|
3177
3736
|
|
3178
3737
|
return fig_1, fig_2, fig_3
|
3179
3738
|
|
@@ -3337,4 +3896,5 @@ def overlay_masks_on_images(img_folder, normalize=True, resize=True, save=False,
|
|
3337
3896
|
plt.imshow(blended)
|
3338
3897
|
plt.title(f"Overlay: {filename}")
|
3339
3898
|
plt.axis('off')
|
3340
|
-
plt.show()
|
3899
|
+
plt.show()
|
3900
|
+
|