spacr 0.3.38__py3-none-any.whl → 0.3.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +1 -1
- spacr/io.py +20 -13
- spacr/measure.py +4 -4
- spacr/ml.py +53 -44
- spacr/plot.py +421 -37
- spacr/settings.py +18 -13
- spacr/toxo.py +223 -16
- spacr/utils.py +7 -5
- {spacr-0.3.38.dist-info → spacr-0.3.42.dist-info}/METADATA +1 -1
- {spacr-0.3.38.dist-info → spacr-0.3.42.dist-info}/RECORD +14 -14
- {spacr-0.3.38.dist-info → spacr-0.3.42.dist-info}/LICENSE +0 -0
- {spacr-0.3.38.dist-info → spacr-0.3.42.dist-info}/WHEEL +0 -0
- {spacr-0.3.38.dist-info → spacr-0.3.42.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.38.dist-info → spacr-0.3.42.dist-info}/top_level.txt +0 -0
spacr/plot.py
CHANGED
@@ -13,10 +13,11 @@ from IPython.display import display
|
|
13
13
|
from skimage.segmentation import find_boundaries
|
14
14
|
from skimage import measure
|
15
15
|
from skimage.measure import find_contours, label, regionprops
|
16
|
+
import tifffile as tiff
|
16
17
|
|
17
18
|
from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
|
18
19
|
from statsmodels.stats.multicomp import pairwise_tukeyhsd
|
19
|
-
from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
|
20
|
+
from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal, normaltest, shapiro
|
20
21
|
import itertools
|
21
22
|
import pingouin as pg
|
22
23
|
|
@@ -25,13 +26,26 @@ from IPython.display import Image as ipyimage
|
|
25
26
|
|
26
27
|
import matplotlib.patches as patches
|
27
28
|
from collections import defaultdict
|
29
|
+
from matplotlib.gridspec import GridSpec
|
28
30
|
|
29
|
-
def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10,
|
31
|
+
def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2,98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False):
|
30
32
|
"""Plot image and mask overlays."""
|
31
33
|
|
32
|
-
def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness):
|
34
|
+
def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness, percentiles, mode='outlines'):
|
33
35
|
"""Plot the merged plot with overlay, image channels, and masks."""
|
34
36
|
|
37
|
+
def _generate_colored_mask(mask, alpha):
|
38
|
+
""" Generate a colored mask with transparency using the given colormap. """
|
39
|
+
cmap = generate_mask_random_cmap(mask)
|
40
|
+
rgba_mask = cmap(mask / mask.max()) # Normalize mask and map to colormap (RGBA)
|
41
|
+
rgba_mask[..., 3] = np.where(mask > 0, alpha, 0) # Apply transparency only where mask is present
|
42
|
+
return rgba_mask
|
43
|
+
|
44
|
+
def _overlay_mask(image, mask):
|
45
|
+
"""Overlay the colored mask onto the original image."""
|
46
|
+
combined = np.clip(image + mask[..., :3] * mask[..., 3:4], 0, 1) # Ensure pixel values stay in [0, 1]
|
47
|
+
return combined
|
48
|
+
|
35
49
|
def _normalize_image(image, percentiles=(2, 98)):
|
36
50
|
"""Normalize the image to the given percentiles."""
|
37
51
|
v_min, v_max = np.percentile(image, percentiles)
|
@@ -61,11 +75,15 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
61
75
|
# Plot each channel with its corresponding outlines
|
62
76
|
for v in range(num_channels):
|
63
77
|
channel_image = image[..., v]
|
64
|
-
channel_image_normalized = _normalize_image(channel_image)
|
78
|
+
channel_image_normalized = _normalize_image(channel_image, percentiles)
|
65
79
|
channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
|
66
80
|
|
67
81
|
for outline, color in zip(outlines, outline_colors):
|
68
|
-
|
82
|
+
if mode == 'outlines':
|
83
|
+
channel_image_rgb = _apply_contours(channel_image_rgb, outline, color, thickness)
|
84
|
+
else:
|
85
|
+
mask = _generate_colored_mask(outline, alpha=0.5)
|
86
|
+
channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
|
69
87
|
|
70
88
|
ax[v].imshow(channel_image_rgb)
|
71
89
|
ax[v].set_title(f'Image - Channel {v}')
|
@@ -75,11 +93,15 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
75
93
|
rgb_channels = min(3, num_channels)
|
76
94
|
for i in range(rgb_channels):
|
77
95
|
channel_image = image[..., i]
|
78
|
-
channel_image_normalized = _normalize_image(channel_image)
|
96
|
+
channel_image_normalized = _normalize_image(channel_image, percentiles)
|
79
97
|
rgb_image[..., i] = channel_image_normalized
|
80
98
|
|
81
99
|
for outline, color in zip(outlines, outline_colors):
|
82
|
-
|
100
|
+
if mode == 'outlines':
|
101
|
+
rgb_image = _apply_contours(rgb_image, outline, color, thickness)
|
102
|
+
else:
|
103
|
+
mask = _generate_colored_mask(outline, alpha=0.5)
|
104
|
+
rgb_image = _overlay_mask(rgb_image, mask)
|
83
105
|
|
84
106
|
ax[-1].imshow(rgb_image)
|
85
107
|
ax[-1].set_title('Combined RGB Image')
|
@@ -96,8 +118,22 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
96
118
|
plt.show()
|
97
119
|
return fig
|
98
120
|
|
121
|
+
def _save_channels_as_tiff(stack, save_dir, filename):
|
122
|
+
"""Save each channel in the stack as a grayscale TIFF."""
|
123
|
+
os.makedirs(save_dir, exist_ok=True)
|
124
|
+
for i in range(stack.shape[-1]):
|
125
|
+
channel = stack[..., i]
|
126
|
+
tiff_path = os.path.join(save_dir, f"{filename}_channel_{i}.tiff")
|
127
|
+
tiff.imwrite(tiff_path, channel, photometric='minisblack')
|
128
|
+
print(f"Saved {tiff_path}")
|
129
|
+
|
99
130
|
stack = np.load(file)
|
100
131
|
|
132
|
+
if export_tiffs:
|
133
|
+
save_dir = os.path.join(os.path.dirname(os.path.dirname(file)), 'results', os.path.splitext(os.path.basename(file))[0], 'tiff')
|
134
|
+
filename = os.path.splitext(os.path.basename(file))[0]
|
135
|
+
_save_channels_as_tiff(stack, save_dir, filename)
|
136
|
+
|
101
137
|
# Convert to float for normalization and ensure correct handling of both 8-bit and 16-bit arrays
|
102
138
|
if stack.dtype == np.uint16:
|
103
139
|
stack = stack.astype(np.float32)
|
@@ -128,7 +164,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
128
164
|
outlines.append(np.take(stack, cell_mask_dim, axis=2))
|
129
165
|
outline_colors.append('red')
|
130
166
|
|
131
|
-
fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
|
167
|
+
fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness, percentiles=percentiles, mode=mode)
|
132
168
|
|
133
169
|
return fig
|
134
170
|
|
@@ -1691,17 +1727,25 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
|
|
1691
1727
|
overlay=True,
|
1692
1728
|
max_nr=10,
|
1693
1729
|
randomize=True)
|
1694
|
-
|
1730
|
+
|
1695
1731
|
def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
1732
|
+
palette = {
|
1733
|
+
'pc': 'red',
|
1734
|
+
'nc': 'green',
|
1735
|
+
'control': 'blue',
|
1736
|
+
'other': 'gray'
|
1737
|
+
}
|
1738
|
+
|
1696
1739
|
# Create the volcano plot
|
1697
1740
|
plt.figure(figsize=(10, 6))
|
1698
1741
|
sns.scatterplot(
|
1699
1742
|
data=coef_df,
|
1700
1743
|
x='coefficient',
|
1701
1744
|
y='-log10(p_value)',
|
1702
|
-
hue='
|
1703
|
-
palette=
|
1745
|
+
hue='condition',
|
1746
|
+
palette=palette
|
1704
1747
|
)
|
1748
|
+
|
1705
1749
|
plt.title('Volcano Plot of Coefficients')
|
1706
1750
|
plt.xlabel('Coefficient')
|
1707
1751
|
plt.ylabel('-log10(p-value)')
|
@@ -2098,7 +2142,7 @@ class spacrGraph:
|
|
2098
2142
|
def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
|
2099
2143
|
order=None, colors=None, output_dir='./output', save=False, y_lim=None,
|
2100
2144
|
error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
|
2101
|
-
paired=False, all_to_all=True, compare_group=None):
|
2145
|
+
paired=False, all_to_all=True, compare_group=None, graph_name=None):
|
2102
2146
|
|
2103
2147
|
"""
|
2104
2148
|
Class for creating grouped plots with optional statistical tests and data preprocessing.
|
@@ -2121,11 +2165,14 @@ class spacrGraph:
|
|
2121
2165
|
self.all_to_all = all_to_all
|
2122
2166
|
self.compare_group = compare_group
|
2123
2167
|
self.y_lim = y_lim
|
2168
|
+
self.graph_name = graph_name
|
2169
|
+
|
2170
|
+
|
2124
2171
|
self.results_df = pd.DataFrame()
|
2125
2172
|
self.sns_palette = None
|
2126
2173
|
self.fig = None
|
2127
2174
|
|
2128
|
-
self.results_name = str(self.data_column[0])+'_'+str(self.grouping_column)+'_'+str(self.graph_type)
|
2175
|
+
self.results_name = str(self.graph_name)+'_'+str(self.data_column[0])+'_'+str(self.grouping_column)+'_'+str(self.graph_type)
|
2129
2176
|
|
2130
2177
|
self._set_theme()
|
2131
2178
|
self.raw_df = self.df.copy()
|
@@ -2134,10 +2181,10 @@ class spacrGraph:
|
|
2134
2181
|
def _set_theme(self):
|
2135
2182
|
"""Set the Seaborn theme and reorder colors if necessary."""
|
2136
2183
|
integer_list = list(range(1, 81))
|
2137
|
-
color_order = [
|
2184
|
+
color_order = [7,9,4,0,3,6,2] + integer_list
|
2138
2185
|
self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
|
2139
2186
|
|
2140
|
-
def _set_reordered_theme(self, theme='
|
2187
|
+
def _set_reordered_theme(self, theme='deep', order=None, n_colors=100, show_theme=False):
|
2141
2188
|
"""Set and reorder the Seaborn color palette."""
|
2142
2189
|
palette = sns.color_palette(theme, n_colors)
|
2143
2190
|
if order:
|
@@ -2182,20 +2229,36 @@ class spacrGraph:
|
|
2182
2229
|
"""Perform normality tests for each group and each data column."""
|
2183
2230
|
unique_groups = self.df[self.grouping_column].unique()
|
2184
2231
|
normality_results = []
|
2232
|
+
|
2185
2233
|
for column in self.data_column:
|
2186
|
-
|
2187
|
-
|
2188
|
-
|
2189
|
-
|
2190
|
-
|
2234
|
+
# Iterate over each group and its corresponding data
|
2235
|
+
for group in unique_groups:
|
2236
|
+
data = self.df.loc[self.df[self.grouping_column] == group, column]
|
2237
|
+
n_samples = len(data)
|
2238
|
+
|
2239
|
+
if n_samples >= 8:
|
2240
|
+
# Use D'Agostino-Pearson test for larger samples
|
2241
|
+
stat, p_value = normaltest(data)
|
2242
|
+
test_name = "D'Agostino-Pearson test"
|
2243
|
+
else:
|
2244
|
+
# Use Shapiro-Wilk test for smaller samples
|
2245
|
+
stat, p_value = shapiro(data)
|
2246
|
+
test_name = "Shapiro-Wilk test"
|
2247
|
+
|
2248
|
+
# Store the result for this group and column
|
2191
2249
|
normality_results.append({
|
2192
2250
|
'Comparison': f'Normality test for {group} on {column}',
|
2193
2251
|
'Test Statistic': stat,
|
2194
2252
|
'p-value': p_value,
|
2195
|
-
'Test Name':
|
2253
|
+
'Test Name': test_name,
|
2196
2254
|
'Column': column,
|
2197
|
-
'n':
|
2255
|
+
'n': n_samples # Sample size
|
2198
2256
|
})
|
2257
|
+
|
2258
|
+
# Check if all groups are normally distributed (p > 0.05)
|
2259
|
+
normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column]
|
2260
|
+
is_normal = all(p > 0.05 for p in normal_p_values)
|
2261
|
+
|
2199
2262
|
return is_normal, normality_results
|
2200
2263
|
|
2201
2264
|
def perform_levene_test(self, unique_groups):
|
@@ -2339,17 +2402,21 @@ class spacrGraph:
|
|
2339
2402
|
ax.text(x_pos, y_pos, text, ha='center', va='center', fontsize=12)
|
2340
2403
|
|
2341
2404
|
def _get_positions(self, ax):
|
2342
|
-
if self.graph_type
|
2405
|
+
if self.graph_type in ['bar','jitter_bar']:
|
2343
2406
|
x_positions = [np.mean(bar.get_paths()[0].vertices[:, 0]) for bar in ax.collections if hasattr(bar, 'get_paths')]
|
2344
2407
|
|
2345
2408
|
elif self.graph_type == 'violin':
|
2346
2409
|
x_positions = [np.mean(violin.get_paths()[0].vertices[:, 0]) for violin in ax.collections if hasattr(violin, 'get_paths')]
|
2347
2410
|
|
2348
|
-
elif self.graph_type
|
2411
|
+
elif self.graph_type in ['box', 'jitter_box']:
|
2349
2412
|
x_positions = list(set(line.get_xdata().mean() for line in ax.lines if line.get_linestyle() == '-'))
|
2350
2413
|
|
2351
2414
|
elif self.graph_type == 'jitter':
|
2352
2415
|
x_positions = [np.mean(collection.get_offsets()[:, 0]) for collection in ax.collections if collection.get_offsets().size > 0]
|
2416
|
+
|
2417
|
+
elif self.graph_type in ['line', 'line_std']:
|
2418
|
+
x_positions = []
|
2419
|
+
|
2353
2420
|
return x_positions
|
2354
2421
|
|
2355
2422
|
def _draw_comparison_lines(ax, x_positions):
|
@@ -2367,7 +2434,7 @@ class spacrGraph:
|
|
2367
2434
|
|
2368
2435
|
# Determine significance marker
|
2369
2436
|
if p_value <= 0.001:
|
2370
|
-
|
2437
|
+
signiresults_namecance = '***'
|
2371
2438
|
elif p_value <= 0.01:
|
2372
2439
|
significance = '**'
|
2373
2440
|
elif p_value <= 0.05:
|
@@ -2408,6 +2475,9 @@ class spacrGraph:
|
|
2408
2475
|
self.fig_width = (num_groups * self.bar_width) + (spacing_between_groups * num_groups)
|
2409
2476
|
self.fig_height = self.fig_width/2
|
2410
2477
|
|
2478
|
+
if self.graph_type in ['line','line_std']:
|
2479
|
+
self.fig_height, self.fig_width = 10, 10
|
2480
|
+
|
2411
2481
|
if ax is None:
|
2412
2482
|
self.fig, ax = plt.subplots(figsize=(self.fig_height, self.fig_width))
|
2413
2483
|
else:
|
@@ -2429,6 +2499,14 @@ class spacrGraph:
|
|
2429
2499
|
self._create_box_plot(ax)
|
2430
2500
|
elif self.graph_type == 'violin':
|
2431
2501
|
self._create_violin_plot(ax)
|
2502
|
+
elif self.graph_type == 'jitter_box':
|
2503
|
+
self._create_jitter_box_plot(ax)
|
2504
|
+
elif self.graph_type == 'jitter_bar':
|
2505
|
+
self._create_jitter_bar_plot(ax)
|
2506
|
+
elif self.graph_type == 'line':
|
2507
|
+
self._create_line_graph(ax)
|
2508
|
+
elif self.graph_type == 'line_std':
|
2509
|
+
self._create_line_with_std_area(ax)
|
2432
2510
|
else:
|
2433
2511
|
raise ValueError(f"Unknown graph type: {self.graph_type}")
|
2434
2512
|
|
@@ -2441,14 +2519,17 @@ class spacrGraph:
|
|
2441
2519
|
|
2442
2520
|
sns.despine(ax=ax, top=True, right=True)
|
2443
2521
|
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Data Column') # Move the legend outside the plot
|
2444
|
-
|
2522
|
+
|
2523
|
+
if not self.graph_type in ['line','line_std']:
|
2524
|
+
ax.set_xlabel('')
|
2525
|
+
|
2445
2526
|
x_positions = _get_positions(self, ax)
|
2446
2527
|
|
2447
|
-
if len(self.data_column) == 1:
|
2528
|
+
if len(self.data_column) == 1 and not self.graph_type in ['line','line_std']:
|
2448
2529
|
ax.legend().remove()
|
2449
2530
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
|
2450
2531
|
|
2451
|
-
elif len(self.data_column) > 1:
|
2532
|
+
elif len(self.data_column) > 1 and not self.graph_type in ['line','line_std']:
|
2452
2533
|
ax.set_xticks([])
|
2453
2534
|
ax.tick_params(bottom=False)
|
2454
2535
|
ax.set_xticklabels([])
|
@@ -2524,7 +2605,54 @@ class spacrGraph:
|
|
2524
2605
|
handles, labels = ax.get_legend_handles_labels()
|
2525
2606
|
unique_labels = dict(zip(labels, handles))
|
2526
2607
|
ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
|
2527
|
-
|
2608
|
+
|
2609
|
+
def _create_line_graph(self, ax):
|
2610
|
+
"""Helper method to create a line graph with one line per group based on epochs and accuracy."""
|
2611
|
+
#display(self.df)
|
2612
|
+
# Ensure epoch is used on the x-axis and accuracy on the y-axis
|
2613
|
+
x_axis_column = self.data_column[0]
|
2614
|
+
y_axis_column = self.data_column[1]
|
2615
|
+
|
2616
|
+
# Set hue to the grouping column to get one line per group
|
2617
|
+
hue = self.grouping_column
|
2618
|
+
|
2619
|
+
# Check if the required columns exist in the DataFrame
|
2620
|
+
required_columns = [x_axis_column, y_axis_column, self.grouping_column]
|
2621
|
+
for col in required_columns:
|
2622
|
+
if col not in self.df.columns:
|
2623
|
+
raise ValueError(f"Column '{col}' not found in DataFrame.")
|
2624
|
+
|
2625
|
+
# Create the line graph with one line per group
|
2626
|
+
sns.lineplot(data=self.df,x=x_axis_column,y=y_axis_column,hue=hue,palette=self.sns_palette,ax=ax,marker='o',linewidth=1,markersize=6)
|
2627
|
+
|
2628
|
+
# Adjust axis labels
|
2629
|
+
ax.set_xlabel(f"{x_axis_column}")
|
2630
|
+
ax.set_ylabel(f"{y_axis_column}")
|
2631
|
+
|
2632
|
+
def _create_line_with_std_area(self, ax):
|
2633
|
+
"""Helper method to create a line graph with shaded area representing standard deviation."""
|
2634
|
+
|
2635
|
+
x_axis_column = self.data_column[0]
|
2636
|
+
y_axis_column = self.data_column[1]
|
2637
|
+
y_axis_column_mean = f"mean_{y_axis_column}"
|
2638
|
+
y_axis_column_std = f"std_{y_axis_column_mean}"
|
2639
|
+
|
2640
|
+
# Pivot the DataFrame to get mean and std for each epoch across plates
|
2641
|
+
summary_df = self.df.pivot_table(index=x_axis_column,values=y_axis_column,aggfunc=['mean', 'std']).reset_index()
|
2642
|
+
|
2643
|
+
# Flatten MultiIndex columns (result of pivoting)
|
2644
|
+
summary_df.columns = [x_axis_column, y_axis_column_mean, y_axis_column_std]
|
2645
|
+
|
2646
|
+
# Plot the mean accuracy as a line
|
2647
|
+
sns.lineplot(data=summary_df,x=x_axis_column,y=y_axis_column_mean,ax=ax,marker='o',linewidth=1,markersize=0,color='blue',label=y_axis_column_mean)
|
2648
|
+
|
2649
|
+
# Fill the area representing the standard deviation
|
2650
|
+
ax.fill_between(summary_df[x_axis_column],summary_df[y_axis_column_mean] - summary_df[y_axis_column_std],summary_df[y_axis_column_mean] + summary_df[y_axis_column_std],color='blue', alpha=0.1 )
|
2651
|
+
|
2652
|
+
# Adjust axis labels
|
2653
|
+
ax.set_xlabel(f"{x_axis_column}")
|
2654
|
+
ax.set_ylabel(f"{y_axis_column}")
|
2655
|
+
|
2528
2656
|
def _create_box_plot(self, ax):
|
2529
2657
|
"""Helper method to create a box plot with consistent spacing."""
|
2530
2658
|
# Combine grouping column and data column if needed
|
@@ -2574,6 +2702,68 @@ class spacrGraph:
|
|
2574
2702
|
unique_labels = dict(zip(labels, handles))
|
2575
2703
|
ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
|
2576
2704
|
|
2705
|
+
def _create_jitter_bar_plot(self, ax):
|
2706
|
+
"""Helper method to create a bar plot with consistent bar thickness and centered error bars."""
|
2707
|
+
# Flatten DataFrame: Combine grouping column and data column into one group if needed
|
2708
|
+
if len(self.data_column) > 1:
|
2709
|
+
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
2710
|
+
x_axis_column = 'Combined Group'
|
2711
|
+
hue = None
|
2712
|
+
ax.set_ylabel('Value')
|
2713
|
+
else:
|
2714
|
+
x_axis_column = self.grouping_column
|
2715
|
+
ax.set_ylabel(self.data_column[0])
|
2716
|
+
hue = None
|
2717
|
+
|
2718
|
+
summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
|
2719
|
+
error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
|
2720
|
+
sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
|
2721
|
+
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16)
|
2722
|
+
|
2723
|
+
# Adjust the bar width manually
|
2724
|
+
if len(self.data_column) > 1:
|
2725
|
+
bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
|
2726
|
+
target_width = self.bar_width * 2
|
2727
|
+
for bar in bars:
|
2728
|
+
bar.set_width(target_width) # Set new width
|
2729
|
+
# Center the bar on its x-coordinate
|
2730
|
+
bar.set_x(bar.get_x() - target_width / 2)
|
2731
|
+
|
2732
|
+
# Adjust error bars alignment with bars
|
2733
|
+
bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
|
2734
|
+
for bar, (_, row) in zip(bars, summary_df.iterrows()):
|
2735
|
+
x_bar = bar.get_x() + bar.get_width() / 2
|
2736
|
+
err = row[self.error_bar_type]
|
2737
|
+
ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
|
2738
|
+
|
2739
|
+
# Set legend and labels
|
2740
|
+
ax.set_xlabel(self.grouping_column)
|
2741
|
+
|
2742
|
+
def _create_jitter_box_plot(self, ax):
|
2743
|
+
"""Helper method to create a box plot with consistent spacing."""
|
2744
|
+
# Combine grouping column and data column if needed
|
2745
|
+
if len(self.data_column) > 1:
|
2746
|
+
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
2747
|
+
x_axis_column = 'Combined Group'
|
2748
|
+
hue = None
|
2749
|
+
ax.set_ylabel('Value')
|
2750
|
+
else:
|
2751
|
+
x_axis_column = self.grouping_column
|
2752
|
+
ax.set_ylabel(self.data_column[0])
|
2753
|
+
hue = None
|
2754
|
+
|
2755
|
+
# Create the box plot
|
2756
|
+
sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
|
2757
|
+
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12)
|
2758
|
+
|
2759
|
+
# Adjust legend and labels
|
2760
|
+
ax.set_xlabel(self.grouping_column)
|
2761
|
+
|
2762
|
+
# Manage the legend
|
2763
|
+
handles, labels = ax.get_legend_handles_labels()
|
2764
|
+
unique_labels = dict(zip(labels, handles))
|
2765
|
+
ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
|
2766
|
+
|
2577
2767
|
def _save_results(self):
|
2578
2768
|
"""Helper method to save the plot and results."""
|
2579
2769
|
os.makedirs(self.output_dir, exist_ok=True)
|
@@ -2594,14 +2784,14 @@ class spacrGraph:
|
|
2594
2784
|
|
2595
2785
|
def plot_data_from_db(settings):
|
2596
2786
|
from .io import _read_db, _read_and_merge_data
|
2597
|
-
from .utils import annotate_conditions
|
2787
|
+
from .utils import annotate_conditions, save_settings
|
2598
2788
|
"""
|
2599
2789
|
Extracts the specified table from the SQLite database and plots a specified column.
|
2600
2790
|
|
2601
2791
|
Args:
|
2602
2792
|
db_path (str): The path to the SQLite database.
|
2603
2793
|
table_names (str): The name of the table to extract.
|
2604
|
-
|
2794
|
+
data_column (str): The column to plot from the table.
|
2605
2795
|
|
2606
2796
|
Returns:
|
2607
2797
|
df (pd.DataFrame): The extracted table as a DataFrame.
|
@@ -2616,6 +2806,8 @@ def plot_data_from_db(settings):
|
|
2616
2806
|
else:
|
2617
2807
|
raise ValueError("src must be a string or a list of strings.")
|
2618
2808
|
|
2809
|
+
save_settings(settings, name=f"{settings['graph_name']}_plot_settings_db", show=True)
|
2810
|
+
|
2619
2811
|
dfs = []
|
2620
2812
|
for i, src in enumerate(srcs):
|
2621
2813
|
|
@@ -2643,6 +2835,7 @@ def plot_data_from_db(settings):
|
|
2643
2835
|
df = pd.concat(dfs, axis=0)
|
2644
2836
|
df['prc'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str)
|
2645
2837
|
df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
|
2838
|
+
df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
|
2646
2839
|
|
2647
2840
|
if settings['cell_plate_metadata'] != None:
|
2648
2841
|
df = df.dropna(subset='host_cell')
|
@@ -2653,24 +2846,91 @@ def plot_data_from_db(settings):
|
|
2653
2846
|
if settings['treatment_plate_metadata'] != None:
|
2654
2847
|
df = df.dropna(subset='treatment')
|
2655
2848
|
|
2656
|
-
df = df.dropna(subset=settings['
|
2849
|
+
df = df.dropna(subset=settings['data_column'])
|
2657
2850
|
df = df.dropna(subset=settings['grouping_column'])
|
2658
2851
|
|
2852
|
+
#df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
|
2853
|
+
src = srcs[0]
|
2854
|
+
dst = os.path.join(src, 'results', settings['graph_name'])
|
2855
|
+
os.makedirs(dst, exist_ok=True)
|
2856
|
+
|
2857
|
+
spacr_graph = spacrGraph(
|
2858
|
+
df=df, # Your DataFrame
|
2859
|
+
grouping_column=settings['grouping_column'], # Column for grouping the data (x-axis)
|
2860
|
+
data_column=settings['data_column'], # Column for the data (y-axis)
|
2861
|
+
graph_type=settings['graph_type'], # Type of plot ('bar', 'box', 'violin', 'jitter')
|
2862
|
+
graph_name=settings['graph_name'], # Name of the plot
|
2863
|
+
summary_func='mean', # Function to summarize data (e.g., 'mean', 'median')
|
2864
|
+
colors=None, # Custom colors for the plot (optional)
|
2865
|
+
output_dir=dst, # Directory to save the plot and results
|
2866
|
+
save=settings['save'], # Whether to save the plot and results
|
2867
|
+
y_lim=settings['y_lim'], # Starting point for y-axis (optional)
|
2868
|
+
error_bar_type='std', # Type of error bar ('std' or 'sem')
|
2869
|
+
representation=settings['representation'],
|
2870
|
+
theme=settings['theme'], # Seaborn color palette theme (e.g., 'pastel', 'muted')
|
2871
|
+
)
|
2872
|
+
|
2873
|
+
# Create the plot
|
2874
|
+
spacr_graph.create_plot()
|
2875
|
+
|
2876
|
+
# Get the figure object if needed
|
2877
|
+
fig = spacr_graph.get_figure()
|
2878
|
+
plt.show()
|
2879
|
+
|
2880
|
+
# Optional: Get the results DataFrame containing statistical test results
|
2881
|
+
results_df = spacr_graph.get_results()
|
2882
|
+
return fig, results_df
|
2883
|
+
|
2884
|
+
def plot_data_from_csv(settings):
|
2885
|
+
from .io import _read_db, _read_and_merge_data
|
2886
|
+
from .utils import annotate_conditions, save_settings
|
2887
|
+
"""
|
2888
|
+
Extracts the specified table from the SQLite database and plots a specified column.
|
2889
|
+
|
2890
|
+
Args:
|
2891
|
+
db_path (str): The path to the SQLite database.
|
2892
|
+
table_names (str): The name of the table to extract.
|
2893
|
+
data_column (str): The column to plot from the table.
|
2894
|
+
|
2895
|
+
Returns:
|
2896
|
+
df (pd.DataFrame): The extracted table as a DataFrame.
|
2897
|
+
"""
|
2659
2898
|
|
2899
|
+
if isinstance(settings['src'], str):
|
2900
|
+
srcs = [settings['src']]
|
2901
|
+
elif isinstance(settings['src'], list):
|
2902
|
+
srcs = settings['src']
|
2903
|
+
else:
|
2904
|
+
raise ValueError("src must be a string or a list of strings.")
|
2905
|
+
|
2906
|
+
#save_settings(settings, name=f"{settings['graph_name']}_plot_settings_csv", show=True)
|
2660
2907
|
|
2908
|
+
dfs = []
|
2909
|
+
for i, src in enumerate(srcs):
|
2661
2910
|
|
2911
|
+
dft = pd.read_csv(src)
|
2912
|
+
if 'plate' not in dft.columns:
|
2913
|
+
dft['plate'] = f"plate{i+1}"
|
2914
|
+
dfs.append(dft)
|
2915
|
+
|
2916
|
+
df = pd.concat(dfs, axis=0)
|
2662
2917
|
#display(df)
|
2663
2918
|
|
2664
|
-
|
2919
|
+
df = df.dropna(subset=settings['data_column'])
|
2920
|
+
df = df.dropna(subset=settings['grouping_column'])
|
2921
|
+
src = srcs[0]
|
2922
|
+
dst = os.path.join(os.path.dirname(src), 'results', settings['graph_name'])
|
2923
|
+
os.makedirs(dst, exist_ok=True)
|
2665
2924
|
|
2666
2925
|
spacr_graph = spacrGraph(
|
2667
2926
|
df=df, # Your DataFrame
|
2668
2927
|
grouping_column=settings['grouping_column'], # Column for grouping the data (x-axis)
|
2669
|
-
data_column=settings['
|
2928
|
+
data_column=settings['data_column'], # Column for the data (y-axis)
|
2670
2929
|
graph_type=settings['graph_type'], # Type of plot ('bar', 'box', 'violin', 'jitter')
|
2930
|
+
graph_name=settings['graph_name'], # Name of the plot
|
2671
2931
|
summary_func='mean', # Function to summarize data (e.g., 'mean', 'median')
|
2672
2932
|
colors=None, # Custom colors for the plot (optional)
|
2673
|
-
output_dir=
|
2933
|
+
output_dir=dst, # Directory to save the plot and results
|
2674
2934
|
save=settings['save'], # Whether to save the plot and results
|
2675
2935
|
y_lim=settings['y_lim'], # Starting point for y-axis (optional)
|
2676
2936
|
error_bar_type='std', # Type of error bar ('std' or 'sem')
|
@@ -2687,5 +2947,129 @@ def plot_data_from_db(settings):
|
|
2687
2947
|
|
2688
2948
|
# Optional: Get the results DataFrame containing statistical test results
|
2689
2949
|
results_df = spacr_graph.get_results()
|
2690
|
-
|
2691
|
-
|
2950
|
+
return fig, results_df
|
2951
|
+
|
2952
|
+
def plot_region(settings):
|
2953
|
+
|
2954
|
+
def _sort_paths_by_basename(paths):
|
2955
|
+
return sorted(paths, key=lambda path: os.path.basename(path))
|
2956
|
+
|
2957
|
+
def save_figure_as_pdf(fig, path):
|
2958
|
+
os.makedirs(os.path.dirname(path), exist_ok=True) # Create directory if it doesn't exist
|
2959
|
+
fig.savefig(path, format='pdf', dpi=600, bbox_inches='tight')
|
2960
|
+
print(f"Saved {path}")
|
2961
|
+
|
2962
|
+
from .io import _read_db
|
2963
|
+
fov_path = os.path.join(settings['src'], 'merged', settings['name'])
|
2964
|
+
name = os.path.splitext(settings['name'])[0]
|
2965
|
+
|
2966
|
+
db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
|
2967
|
+
paths_df = _read_db(db_path, tables=['png_list'])[0]
|
2968
|
+
paths_df = paths_df[paths_df['png_path'].str.contains(name, na=False)]
|
2969
|
+
|
2970
|
+
activation_mode = f"{settings['activation_mode']}_list"
|
2971
|
+
activation_db_path = os.path.join(settings['src'], 'measurements', settings['activation_db'])
|
2972
|
+
activation_paths_df = _read_db(activation_db_path, tables=[activation_mode])[0]
|
2973
|
+
activation_paths_df = activation_paths_df[activation_paths_df['png_path'].str.contains(name, na=False)]
|
2974
|
+
|
2975
|
+
png_paths = _sort_paths_by_basename(paths_df['png_path'].tolist())
|
2976
|
+
activation_paths = _sort_paths_by_basename(activation_paths_df['png_path'].tolist())
|
2977
|
+
|
2978
|
+
fig_3 = plot_image_grid(image_paths=activation_paths, percentiles=settings['percentiles'])
|
2979
|
+
fig_2 = plot_image_grid(image_paths=png_paths, percentiles=settings['percentiles'])
|
2980
|
+
fig_1 = plot_image_mask_overlay(file=fov_path,
|
2981
|
+
channels=settings['channels'],
|
2982
|
+
cell_channel=settings['cell_channel'],
|
2983
|
+
nucleus_channel=settings['nucleus_channel'],
|
2984
|
+
pathogen_channel=settings['pathogen_channel'],
|
2985
|
+
figuresize=10,
|
2986
|
+
percentiles=settings['percentiles'],
|
2987
|
+
thickness=3,
|
2988
|
+
save_pdf=False,
|
2989
|
+
mode=settings['mode'],
|
2990
|
+
export_tiffs=settings['export_tiffs'])
|
2991
|
+
|
2992
|
+
dst = os.path.join(settings['src'], 'results', name)
|
2993
|
+
save_figure_as_pdf(fig_1, os.path.join(dst, f"{name}_mask_overlay.pdf"))
|
2994
|
+
save_figure_as_pdf(fig_2, os.path.join(dst, f"{name}_png_grid.pdf"))
|
2995
|
+
save_figure_as_pdf(fig_3, os.path.join(dst, f"{name}_activation_grid.pdf"))
|
2996
|
+
|
2997
|
+
return fig_1, fig_2, fig_3
|
2998
|
+
|
2999
|
+
def plot_image_grid(image_paths, percentiles):
|
3000
|
+
"""
|
3001
|
+
Plots a square grid of images from a list of image paths.
|
3002
|
+
Unused subplots are filled with black, and padding is minimized.
|
3003
|
+
|
3004
|
+
Parameters:
|
3005
|
+
- image_paths: List of paths to images to be displayed.
|
3006
|
+
|
3007
|
+
Returns:
|
3008
|
+
- fig: The generated matplotlib figure.
|
3009
|
+
"""
|
3010
|
+
|
3011
|
+
from PIL import Image
|
3012
|
+
import matplotlib.pyplot as plt
|
3013
|
+
import math
|
3014
|
+
|
3015
|
+
def _normalize_image(image, percentiles=(2, 98)):
|
3016
|
+
""" Normalize the image to the given percentiles for each channel independently, preserving the input type (either PIL.Image or numpy.ndarray)."""
|
3017
|
+
|
3018
|
+
# Check if the input is a PIL image and convert it to a NumPy array
|
3019
|
+
is_pil_image = isinstance(image, Image.Image)
|
3020
|
+
if is_pil_image:
|
3021
|
+
image = np.array(image)
|
3022
|
+
|
3023
|
+
# If the image is single-channel, normalize directly
|
3024
|
+
if image.ndim == 2:
|
3025
|
+
v_min, v_max = np.percentile(image, percentiles)
|
3026
|
+
normalized_image = np.clip((image - v_min) / (v_max - v_min), 0, 1)
|
3027
|
+
else:
|
3028
|
+
# If multi-channel, normalize each channel independently
|
3029
|
+
normalized_image = np.zeros_like(image, dtype=np.float32)
|
3030
|
+
for c in range(image.shape[-1]):
|
3031
|
+
v_min, v_max = np.percentile(image[..., c], percentiles)
|
3032
|
+
normalized_image[..., c] = np.clip((image[..., c] - v_min) / (v_max - v_min), 0, 1)
|
3033
|
+
|
3034
|
+
# If the input was a PIL image, convert the result back to PIL format
|
3035
|
+
if is_pil_image:
|
3036
|
+
# Ensure the image is converted back to 8-bit range (0-255) for PIL
|
3037
|
+
normalized_image = (normalized_image * 255).astype(np.uint8)
|
3038
|
+
return Image.fromarray(normalized_image)
|
3039
|
+
|
3040
|
+
return normalized_image
|
3041
|
+
|
3042
|
+
N = len(image_paths)
|
3043
|
+
# Calculate the smallest square grid size to fit all images
|
3044
|
+
grid_size = math.ceil(math.sqrt(N))
|
3045
|
+
|
3046
|
+
# Create the square grid of subplots with a black background
|
3047
|
+
fig, axs = plt.subplots(
|
3048
|
+
grid_size, grid_size,
|
3049
|
+
figsize=(grid_size * 2, grid_size * 2),
|
3050
|
+
facecolor='black' # Set figure background to black
|
3051
|
+
)
|
3052
|
+
|
3053
|
+
# Flatten axs in case of a 2D array
|
3054
|
+
axs = axs.flatten()
|
3055
|
+
|
3056
|
+
for i, img_path in enumerate(image_paths):
|
3057
|
+
ax = axs[i]
|
3058
|
+
|
3059
|
+
# Load the image
|
3060
|
+
img = Image.open(img_path)
|
3061
|
+
img = _normalize_image(img, percentiles)
|
3062
|
+
|
3063
|
+
# Display the image
|
3064
|
+
ax.imshow(img)
|
3065
|
+
ax.axis('off') # Hide axes
|
3066
|
+
|
3067
|
+
# Fill any unused subplots with black
|
3068
|
+
for j in range(i + 1, len(axs)):
|
3069
|
+
axs[j].imshow([[0, 0, 0]], cmap='gray') # Black square
|
3070
|
+
axs[j].axis('off') # Hide axes
|
3071
|
+
|
3072
|
+
# Adjust layout to minimize white space
|
3073
|
+
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, top=1, bottom=0)
|
3074
|
+
|
3075
|
+
return fig
|