spacr 0.3.37__py3-none-any.whl → 0.3.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +1 -1
- spacr/io.py +20 -13
- spacr/ml.py +33 -24
- spacr/plot.py +427 -37
- spacr/toxo.py +202 -16
- spacr/utils.py +14 -12
- {spacr-0.3.37.dist-info → spacr-0.3.41.dist-info}/METADATA +1 -1
- {spacr-0.3.37.dist-info → spacr-0.3.41.dist-info}/RECORD +12 -12
- {spacr-0.3.37.dist-info → spacr-0.3.41.dist-info}/LICENSE +0 -0
- {spacr-0.3.37.dist-info → spacr-0.3.41.dist-info}/WHEEL +0 -0
- {spacr-0.3.37.dist-info → spacr-0.3.41.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.37.dist-info → spacr-0.3.41.dist-info}/top_level.txt +0 -0
spacr/plot.py
CHANGED
@@ -13,10 +13,11 @@ from IPython.display import display
|
|
13
13
|
from skimage.segmentation import find_boundaries
|
14
14
|
from skimage import measure
|
15
15
|
from skimage.measure import find_contours, label, regionprops
|
16
|
+
import tifffile as tiff
|
16
17
|
|
17
18
|
from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
|
18
19
|
from statsmodels.stats.multicomp import pairwise_tukeyhsd
|
19
|
-
from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
|
20
|
+
from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal, normaltest, shapiro
|
20
21
|
import itertools
|
21
22
|
import pingouin as pg
|
22
23
|
|
@@ -25,13 +26,26 @@ from IPython.display import Image as ipyimage
|
|
25
26
|
|
26
27
|
import matplotlib.patches as patches
|
27
28
|
from collections import defaultdict
|
29
|
+
from matplotlib.gridspec import GridSpec
|
28
30
|
|
29
|
-
def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10,
|
31
|
+
def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2,98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False):
|
30
32
|
"""Plot image and mask overlays."""
|
31
33
|
|
32
|
-
def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness):
|
34
|
+
def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness, percentiles, mode='outlines'):
|
33
35
|
"""Plot the merged plot with overlay, image channels, and masks."""
|
34
36
|
|
37
|
+
def _generate_colored_mask(mask, alpha):
|
38
|
+
""" Generate a colored mask with transparency using the given colormap. """
|
39
|
+
cmap = generate_mask_random_cmap(mask)
|
40
|
+
rgba_mask = cmap(mask / mask.max()) # Normalize mask and map to colormap (RGBA)
|
41
|
+
rgba_mask[..., 3] = np.where(mask > 0, alpha, 0) # Apply transparency only where mask is present
|
42
|
+
return rgba_mask
|
43
|
+
|
44
|
+
def _overlay_mask(image, mask):
|
45
|
+
"""Overlay the colored mask onto the original image."""
|
46
|
+
combined = np.clip(image + mask[..., :3] * mask[..., 3:4], 0, 1) # Ensure pixel values stay in [0, 1]
|
47
|
+
return combined
|
48
|
+
|
35
49
|
def _normalize_image(image, percentiles=(2, 98)):
|
36
50
|
"""Normalize the image to the given percentiles."""
|
37
51
|
v_min, v_max = np.percentile(image, percentiles)
|
@@ -61,11 +75,15 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
61
75
|
# Plot each channel with its corresponding outlines
|
62
76
|
for v in range(num_channels):
|
63
77
|
channel_image = image[..., v]
|
64
|
-
channel_image_normalized = _normalize_image(channel_image)
|
78
|
+
channel_image_normalized = _normalize_image(channel_image, percentiles)
|
65
79
|
channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
|
66
80
|
|
67
81
|
for outline, color in zip(outlines, outline_colors):
|
68
|
-
|
82
|
+
if mode == 'outlines':
|
83
|
+
channel_image_rgb = _apply_contours(channel_image_rgb, outline, color, thickness)
|
84
|
+
else:
|
85
|
+
mask = _generate_colored_mask(outline, alpha=0.5)
|
86
|
+
channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
|
69
87
|
|
70
88
|
ax[v].imshow(channel_image_rgb)
|
71
89
|
ax[v].set_title(f'Image - Channel {v}')
|
@@ -75,11 +93,15 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
75
93
|
rgb_channels = min(3, num_channels)
|
76
94
|
for i in range(rgb_channels):
|
77
95
|
channel_image = image[..., i]
|
78
|
-
channel_image_normalized = _normalize_image(channel_image)
|
96
|
+
channel_image_normalized = _normalize_image(channel_image, percentiles)
|
79
97
|
rgb_image[..., i] = channel_image_normalized
|
80
98
|
|
81
99
|
for outline, color in zip(outlines, outline_colors):
|
82
|
-
|
100
|
+
if mode == 'outlines':
|
101
|
+
rgb_image = _apply_contours(rgb_image, outline, color, thickness)
|
102
|
+
else:
|
103
|
+
mask = _generate_colored_mask(outline, alpha=0.5)
|
104
|
+
rgb_image = _overlay_mask(rgb_image, mask)
|
83
105
|
|
84
106
|
ax[-1].imshow(rgb_image)
|
85
107
|
ax[-1].set_title('Combined RGB Image')
|
@@ -96,8 +118,22 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
96
118
|
plt.show()
|
97
119
|
return fig
|
98
120
|
|
121
|
+
def _save_channels_as_tiff(stack, save_dir, filename):
|
122
|
+
"""Save each channel in the stack as a grayscale TIFF."""
|
123
|
+
os.makedirs(save_dir, exist_ok=True)
|
124
|
+
for i in range(stack.shape[-1]):
|
125
|
+
channel = stack[..., i]
|
126
|
+
tiff_path = os.path.join(save_dir, f"{filename}_channel_{i}.tiff")
|
127
|
+
tiff.imwrite(tiff_path, channel, photometric='minisblack')
|
128
|
+
print(f"Saved {tiff_path}")
|
129
|
+
|
99
130
|
stack = np.load(file)
|
100
131
|
|
132
|
+
if export_tiffs:
|
133
|
+
save_dir = os.path.join(os.path.dirname(os.path.dirname(file)), 'results', os.path.splitext(os.path.basename(file))[0], 'tiff')
|
134
|
+
filename = os.path.splitext(os.path.basename(file))[0]
|
135
|
+
_save_channels_as_tiff(stack, save_dir, filename)
|
136
|
+
|
101
137
|
# Convert to float for normalization and ensure correct handling of both 8-bit and 16-bit arrays
|
102
138
|
if stack.dtype == np.uint16:
|
103
139
|
stack = stack.astype(np.float32)
|
@@ -128,7 +164,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
128
164
|
outlines.append(np.take(stack, cell_mask_dim, axis=2))
|
129
165
|
outline_colors.append('red')
|
130
166
|
|
131
|
-
fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
|
167
|
+
fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness, percentiles=percentiles, mode=mode)
|
132
168
|
|
133
169
|
return fig
|
134
170
|
|
@@ -1691,17 +1727,25 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
|
|
1691
1727
|
overlay=True,
|
1692
1728
|
max_nr=10,
|
1693
1729
|
randomize=True)
|
1694
|
-
|
1730
|
+
|
1695
1731
|
def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
1732
|
+
palette = {
|
1733
|
+
'pc': 'red',
|
1734
|
+
'nc': 'green',
|
1735
|
+
'control': 'blue',
|
1736
|
+
'other': 'gray'
|
1737
|
+
}
|
1738
|
+
|
1696
1739
|
# Create the volcano plot
|
1697
1740
|
plt.figure(figsize=(10, 6))
|
1698
1741
|
sns.scatterplot(
|
1699
1742
|
data=coef_df,
|
1700
1743
|
x='coefficient',
|
1701
1744
|
y='-log10(p_value)',
|
1702
|
-
hue='
|
1703
|
-
palette=
|
1745
|
+
hue='condition',
|
1746
|
+
palette=palette
|
1704
1747
|
)
|
1748
|
+
|
1705
1749
|
plt.title('Volcano Plot of Coefficients')
|
1706
1750
|
plt.xlabel('Coefficient')
|
1707
1751
|
plt.ylabel('-log10(p-value)')
|
@@ -2098,7 +2142,7 @@ class spacrGraph:
|
|
2098
2142
|
def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
|
2099
2143
|
order=None, colors=None, output_dir='./output', save=False, y_lim=None,
|
2100
2144
|
error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
|
2101
|
-
paired=False, all_to_all=True, compare_group=None):
|
2145
|
+
paired=False, all_to_all=True, compare_group=None, graph_name=None):
|
2102
2146
|
|
2103
2147
|
"""
|
2104
2148
|
Class for creating grouped plots with optional statistical tests and data preprocessing.
|
@@ -2121,11 +2165,14 @@ class spacrGraph:
|
|
2121
2165
|
self.all_to_all = all_to_all
|
2122
2166
|
self.compare_group = compare_group
|
2123
2167
|
self.y_lim = y_lim
|
2168
|
+
self.graph_name = graph_name
|
2169
|
+
|
2170
|
+
|
2124
2171
|
self.results_df = pd.DataFrame()
|
2125
2172
|
self.sns_palette = None
|
2126
2173
|
self.fig = None
|
2127
2174
|
|
2128
|
-
self.results_name = str(self.data_column[0])+'_'+str(self.grouping_column)+'_'+str(self.graph_type)
|
2175
|
+
self.results_name = str(self.graph_name)+'_'+str(self.data_column[0])+'_'+str(self.grouping_column)+'_'+str(self.graph_type)
|
2129
2176
|
|
2130
2177
|
self._set_theme()
|
2131
2178
|
self.raw_df = self.df.copy()
|
@@ -2134,10 +2181,10 @@ class spacrGraph:
|
|
2134
2181
|
def _set_theme(self):
|
2135
2182
|
"""Set the Seaborn theme and reorder colors if necessary."""
|
2136
2183
|
integer_list = list(range(1, 81))
|
2137
|
-
color_order = [
|
2184
|
+
color_order = [7,9,4,0,3,6,2] + integer_list
|
2138
2185
|
self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
|
2139
2186
|
|
2140
|
-
def _set_reordered_theme(self, theme='
|
2187
|
+
def _set_reordered_theme(self, theme='deep', order=None, n_colors=100, show_theme=False):
|
2141
2188
|
"""Set and reorder the Seaborn color palette."""
|
2142
2189
|
palette = sns.color_palette(theme, n_colors)
|
2143
2190
|
if order:
|
@@ -2152,10 +2199,12 @@ class spacrGraph:
|
|
2152
2199
|
def preprocess_data(self):
|
2153
2200
|
"""Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
|
2154
2201
|
# Remove NaNs in both the grouping column and each data column
|
2155
|
-
df = self.df.dropna(subset=[self.grouping_column] + self.data_column)
|
2202
|
+
df = self.df.dropna(subset=[self.grouping_column] + self.data_column)
|
2156
2203
|
# Group by 'prc' column if representation is 'well'
|
2157
2204
|
if self.representation == 'well':
|
2158
2205
|
df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
|
2206
|
+
if self.representation == 'plate':
|
2207
|
+
df = df.groupby(['plate', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
|
2159
2208
|
if self.order:
|
2160
2209
|
df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
|
2161
2210
|
else:
|
@@ -2180,20 +2229,36 @@ class spacrGraph:
|
|
2180
2229
|
"""Perform normality tests for each group and each data column."""
|
2181
2230
|
unique_groups = self.df[self.grouping_column].unique()
|
2182
2231
|
normality_results = []
|
2232
|
+
|
2183
2233
|
for column in self.data_column:
|
2184
|
-
|
2185
|
-
|
2186
|
-
|
2187
|
-
|
2188
|
-
|
2234
|
+
# Iterate over each group and its corresponding data
|
2235
|
+
for group in unique_groups:
|
2236
|
+
data = self.df.loc[self.df[self.grouping_column] == group, column]
|
2237
|
+
n_samples = len(data)
|
2238
|
+
|
2239
|
+
if n_samples >= 8:
|
2240
|
+
# Use D'Agostino-Pearson test for larger samples
|
2241
|
+
stat, p_value = normaltest(data)
|
2242
|
+
test_name = "D'Agostino-Pearson test"
|
2243
|
+
else:
|
2244
|
+
# Use Shapiro-Wilk test for smaller samples
|
2245
|
+
stat, p_value = shapiro(data)
|
2246
|
+
test_name = "Shapiro-Wilk test"
|
2247
|
+
|
2248
|
+
# Store the result for this group and column
|
2189
2249
|
normality_results.append({
|
2190
2250
|
'Comparison': f'Normality test for {group} on {column}',
|
2191
2251
|
'Test Statistic': stat,
|
2192
2252
|
'p-value': p_value,
|
2193
|
-
'Test Name':
|
2253
|
+
'Test Name': test_name,
|
2194
2254
|
'Column': column,
|
2195
|
-
'n':
|
2255
|
+
'n': n_samples # Sample size
|
2196
2256
|
})
|
2257
|
+
|
2258
|
+
# Check if all groups are normally distributed (p > 0.05)
|
2259
|
+
normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column]
|
2260
|
+
is_normal = all(p > 0.05 for p in normal_p_values)
|
2261
|
+
|
2197
2262
|
return is_normal, normality_results
|
2198
2263
|
|
2199
2264
|
def perform_levene_test(self, unique_groups):
|
@@ -2337,17 +2402,21 @@ class spacrGraph:
|
|
2337
2402
|
ax.text(x_pos, y_pos, text, ha='center', va='center', fontsize=12)
|
2338
2403
|
|
2339
2404
|
def _get_positions(self, ax):
|
2340
|
-
if self.graph_type
|
2405
|
+
if self.graph_type in ['bar','jitter_bar']:
|
2341
2406
|
x_positions = [np.mean(bar.get_paths()[0].vertices[:, 0]) for bar in ax.collections if hasattr(bar, 'get_paths')]
|
2342
2407
|
|
2343
2408
|
elif self.graph_type == 'violin':
|
2344
2409
|
x_positions = [np.mean(violin.get_paths()[0].vertices[:, 0]) for violin in ax.collections if hasattr(violin, 'get_paths')]
|
2345
2410
|
|
2346
|
-
elif self.graph_type
|
2411
|
+
elif self.graph_type in ['box', 'jitter_box']:
|
2347
2412
|
x_positions = list(set(line.get_xdata().mean() for line in ax.lines if line.get_linestyle() == '-'))
|
2348
2413
|
|
2349
2414
|
elif self.graph_type == 'jitter':
|
2350
2415
|
x_positions = [np.mean(collection.get_offsets()[:, 0]) for collection in ax.collections if collection.get_offsets().size > 0]
|
2416
|
+
|
2417
|
+
elif self.graph_type in ['line', 'line_std']:
|
2418
|
+
x_positions = []
|
2419
|
+
|
2351
2420
|
return x_positions
|
2352
2421
|
|
2353
2422
|
def _draw_comparison_lines(ax, x_positions):
|
@@ -2365,7 +2434,7 @@ class spacrGraph:
|
|
2365
2434
|
|
2366
2435
|
# Determine significance marker
|
2367
2436
|
if p_value <= 0.001:
|
2368
|
-
|
2437
|
+
signiresults_namecance = '***'
|
2369
2438
|
elif p_value <= 0.01:
|
2370
2439
|
significance = '**'
|
2371
2440
|
elif p_value <= 0.05:
|
@@ -2406,6 +2475,9 @@ class spacrGraph:
|
|
2406
2475
|
self.fig_width = (num_groups * self.bar_width) + (spacing_between_groups * num_groups)
|
2407
2476
|
self.fig_height = self.fig_width/2
|
2408
2477
|
|
2478
|
+
if self.graph_type in ['line','line_std']:
|
2479
|
+
self.fig_height, self.fig_width = 10, 10
|
2480
|
+
|
2409
2481
|
if ax is None:
|
2410
2482
|
self.fig, ax = plt.subplots(figsize=(self.fig_height, self.fig_width))
|
2411
2483
|
else:
|
@@ -2427,6 +2499,14 @@ class spacrGraph:
|
|
2427
2499
|
self._create_box_plot(ax)
|
2428
2500
|
elif self.graph_type == 'violin':
|
2429
2501
|
self._create_violin_plot(ax)
|
2502
|
+
elif self.graph_type == 'jitter_box':
|
2503
|
+
self._create_jitter_box_plot(ax)
|
2504
|
+
elif self.graph_type == 'jitter_bar':
|
2505
|
+
self._create_jitter_bar_plot(ax)
|
2506
|
+
elif self.graph_type == 'line':
|
2507
|
+
self._create_line_graph(ax)
|
2508
|
+
elif self.graph_type == 'line_std':
|
2509
|
+
self._create_line_with_std_area(ax)
|
2430
2510
|
else:
|
2431
2511
|
raise ValueError(f"Unknown graph type: {self.graph_type}")
|
2432
2512
|
|
@@ -2439,14 +2519,17 @@ class spacrGraph:
|
|
2439
2519
|
|
2440
2520
|
sns.despine(ax=ax, top=True, right=True)
|
2441
2521
|
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Data Column') # Move the legend outside the plot
|
2442
|
-
|
2522
|
+
|
2523
|
+
if not self.graph_type in ['line','line_std']:
|
2524
|
+
ax.set_xlabel('')
|
2525
|
+
|
2443
2526
|
x_positions = _get_positions(self, ax)
|
2444
2527
|
|
2445
|
-
if len(self.data_column) == 1:
|
2528
|
+
if len(self.data_column) == 1 and not self.graph_type in ['line','line_std']:
|
2446
2529
|
ax.legend().remove()
|
2447
2530
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
|
2448
2531
|
|
2449
|
-
elif len(self.data_column) > 1:
|
2532
|
+
elif len(self.data_column) > 1 and not self.graph_type in ['line','line_std']:
|
2450
2533
|
ax.set_xticks([])
|
2451
2534
|
ax.tick_params(bottom=False)
|
2452
2535
|
ax.set_xticklabels([])
|
@@ -2522,7 +2605,54 @@ class spacrGraph:
|
|
2522
2605
|
handles, labels = ax.get_legend_handles_labels()
|
2523
2606
|
unique_labels = dict(zip(labels, handles))
|
2524
2607
|
ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
|
2525
|
-
|
2608
|
+
|
2609
|
+
def _create_line_graph(self, ax):
|
2610
|
+
"""Helper method to create a line graph with one line per group based on epochs and accuracy."""
|
2611
|
+
#display(self.df)
|
2612
|
+
# Ensure epoch is used on the x-axis and accuracy on the y-axis
|
2613
|
+
x_axis_column = self.data_column[0]
|
2614
|
+
y_axis_column = self.data_column[1]
|
2615
|
+
|
2616
|
+
# Set hue to the grouping column to get one line per group
|
2617
|
+
hue = self.grouping_column
|
2618
|
+
|
2619
|
+
# Check if the required columns exist in the DataFrame
|
2620
|
+
required_columns = [x_axis_column, y_axis_column, self.grouping_column]
|
2621
|
+
for col in required_columns:
|
2622
|
+
if col not in self.df.columns:
|
2623
|
+
raise ValueError(f"Column '{col}' not found in DataFrame.")
|
2624
|
+
|
2625
|
+
# Create the line graph with one line per group
|
2626
|
+
sns.lineplot(data=self.df,x=x_axis_column,y=y_axis_column,hue=hue,palette=self.sns_palette,ax=ax,marker='o',linewidth=1,markersize=6)
|
2627
|
+
|
2628
|
+
# Adjust axis labels
|
2629
|
+
ax.set_xlabel(f"{x_axis_column}")
|
2630
|
+
ax.set_ylabel(f"{y_axis_column}")
|
2631
|
+
|
2632
|
+
def _create_line_with_std_area(self, ax):
|
2633
|
+
"""Helper method to create a line graph with shaded area representing standard deviation."""
|
2634
|
+
|
2635
|
+
x_axis_column = self.data_column[0]
|
2636
|
+
y_axis_column = self.data_column[1]
|
2637
|
+
y_axis_column_mean = f"mean_{y_axis_column}"
|
2638
|
+
y_axis_column_std = f"std_{y_axis_column_mean}"
|
2639
|
+
|
2640
|
+
# Pivot the DataFrame to get mean and std for each epoch across plates
|
2641
|
+
summary_df = self.df.pivot_table(index=x_axis_column,values=y_axis_column,aggfunc=['mean', 'std']).reset_index()
|
2642
|
+
|
2643
|
+
# Flatten MultiIndex columns (result of pivoting)
|
2644
|
+
summary_df.columns = [x_axis_column, y_axis_column_mean, y_axis_column_std]
|
2645
|
+
|
2646
|
+
# Plot the mean accuracy as a line
|
2647
|
+
sns.lineplot(data=summary_df,x=x_axis_column,y=y_axis_column_mean,ax=ax,marker='o',linewidth=1,markersize=0,color='blue',label=y_axis_column_mean)
|
2648
|
+
|
2649
|
+
# Fill the area representing the standard deviation
|
2650
|
+
ax.fill_between(summary_df[x_axis_column],summary_df[y_axis_column_mean] - summary_df[y_axis_column_std],summary_df[y_axis_column_mean] + summary_df[y_axis_column_std],color='blue', alpha=0.1 )
|
2651
|
+
|
2652
|
+
# Adjust axis labels
|
2653
|
+
ax.set_xlabel(f"{x_axis_column}")
|
2654
|
+
ax.set_ylabel(f"{y_axis_column}")
|
2655
|
+
|
2526
2656
|
def _create_box_plot(self, ax):
|
2527
2657
|
"""Helper method to create a box plot with consistent spacing."""
|
2528
2658
|
# Combine grouping column and data column if needed
|
@@ -2572,6 +2702,68 @@ class spacrGraph:
|
|
2572
2702
|
unique_labels = dict(zip(labels, handles))
|
2573
2703
|
ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
|
2574
2704
|
|
2705
|
+
def _create_jitter_bar_plot(self, ax):
|
2706
|
+
"""Helper method to create a bar plot with consistent bar thickness and centered error bars."""
|
2707
|
+
# Flatten DataFrame: Combine grouping column and data column into one group if needed
|
2708
|
+
if len(self.data_column) > 1:
|
2709
|
+
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
2710
|
+
x_axis_column = 'Combined Group'
|
2711
|
+
hue = None
|
2712
|
+
ax.set_ylabel('Value')
|
2713
|
+
else:
|
2714
|
+
x_axis_column = self.grouping_column
|
2715
|
+
ax.set_ylabel(self.data_column[0])
|
2716
|
+
hue = None
|
2717
|
+
|
2718
|
+
summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
|
2719
|
+
error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
|
2720
|
+
sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
|
2721
|
+
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16)
|
2722
|
+
|
2723
|
+
# Adjust the bar width manually
|
2724
|
+
if len(self.data_column) > 1:
|
2725
|
+
bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
|
2726
|
+
target_width = self.bar_width * 2
|
2727
|
+
for bar in bars:
|
2728
|
+
bar.set_width(target_width) # Set new width
|
2729
|
+
# Center the bar on its x-coordinate
|
2730
|
+
bar.set_x(bar.get_x() - target_width / 2)
|
2731
|
+
|
2732
|
+
# Adjust error bars alignment with bars
|
2733
|
+
bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
|
2734
|
+
for bar, (_, row) in zip(bars, summary_df.iterrows()):
|
2735
|
+
x_bar = bar.get_x() + bar.get_width() / 2
|
2736
|
+
err = row[self.error_bar_type]
|
2737
|
+
ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
|
2738
|
+
|
2739
|
+
# Set legend and labels
|
2740
|
+
ax.set_xlabel(self.grouping_column)
|
2741
|
+
|
2742
|
+
def _create_jitter_box_plot(self, ax):
|
2743
|
+
"""Helper method to create a box plot with consistent spacing."""
|
2744
|
+
# Combine grouping column and data column if needed
|
2745
|
+
if len(self.data_column) > 1:
|
2746
|
+
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
2747
|
+
x_axis_column = 'Combined Group'
|
2748
|
+
hue = None
|
2749
|
+
ax.set_ylabel('Value')
|
2750
|
+
else:
|
2751
|
+
x_axis_column = self.grouping_column
|
2752
|
+
ax.set_ylabel(self.data_column[0])
|
2753
|
+
hue = None
|
2754
|
+
|
2755
|
+
# Create the box plot
|
2756
|
+
sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
|
2757
|
+
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12)
|
2758
|
+
|
2759
|
+
# Adjust legend and labels
|
2760
|
+
ax.set_xlabel(self.grouping_column)
|
2761
|
+
|
2762
|
+
# Manage the legend
|
2763
|
+
handles, labels = ax.get_legend_handles_labels()
|
2764
|
+
unique_labels = dict(zip(labels, handles))
|
2765
|
+
ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
|
2766
|
+
|
2575
2767
|
def _save_results(self):
|
2576
2768
|
"""Helper method to save the plot and results."""
|
2577
2769
|
os.makedirs(self.output_dir, exist_ok=True)
|
@@ -2592,14 +2784,14 @@ class spacrGraph:
|
|
2592
2784
|
|
2593
2785
|
def plot_data_from_db(settings):
|
2594
2786
|
from .io import _read_db, _read_and_merge_data
|
2595
|
-
from .utils import annotate_conditions
|
2787
|
+
from .utils import annotate_conditions, save_settings
|
2596
2788
|
"""
|
2597
2789
|
Extracts the specified table from the SQLite database and plots a specified column.
|
2598
2790
|
|
2599
2791
|
Args:
|
2600
2792
|
db_path (str): The path to the SQLite database.
|
2601
2793
|
table_names (str): The name of the table to extract.
|
2602
|
-
|
2794
|
+
data_column (str): The column to plot from the table.
|
2603
2795
|
|
2604
2796
|
Returns:
|
2605
2797
|
df (pd.DataFrame): The extracted table as a DataFrame.
|
@@ -2614,6 +2806,8 @@ def plot_data_from_db(settings):
|
|
2614
2806
|
else:
|
2615
2807
|
raise ValueError("src must be a string or a list of strings.")
|
2616
2808
|
|
2809
|
+
save_settings(settings, name=f"{settings['graph_name']}_plot_settings_db", show=True)
|
2810
|
+
|
2617
2811
|
dfs = []
|
2618
2812
|
for i, src in enumerate(srcs):
|
2619
2813
|
|
@@ -2641,6 +2835,7 @@ def plot_data_from_db(settings):
|
|
2641
2835
|
df = pd.concat(dfs, axis=0)
|
2642
2836
|
df['prc'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str)
|
2643
2837
|
df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
|
2838
|
+
df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
|
2644
2839
|
|
2645
2840
|
if settings['cell_plate_metadata'] != None:
|
2646
2841
|
df = df.dropna(subset='host_cell')
|
@@ -2651,20 +2846,23 @@ def plot_data_from_db(settings):
|
|
2651
2846
|
if settings['treatment_plate_metadata'] != None:
|
2652
2847
|
df = df.dropna(subset='treatment')
|
2653
2848
|
|
2654
|
-
df = df.dropna(subset=settings['
|
2849
|
+
df = df.dropna(subset=settings['data_column'])
|
2655
2850
|
df = df.dropna(subset=settings['grouping_column'])
|
2656
|
-
#display(df)
|
2657
2851
|
|
2658
2852
|
#df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
|
2853
|
+
src = srcs[0]
|
2854
|
+
dst = os.path.join(src, 'results', settings['graph_name'])
|
2855
|
+
os.makedirs(dst, exist_ok=True)
|
2659
2856
|
|
2660
2857
|
spacr_graph = spacrGraph(
|
2661
2858
|
df=df, # Your DataFrame
|
2662
2859
|
grouping_column=settings['grouping_column'], # Column for grouping the data (x-axis)
|
2663
|
-
data_column=settings['
|
2860
|
+
data_column=settings['data_column'], # Column for the data (y-axis)
|
2664
2861
|
graph_type=settings['graph_type'], # Type of plot ('bar', 'box', 'violin', 'jitter')
|
2862
|
+
graph_name=settings['graph_name'], # Name of the plot
|
2665
2863
|
summary_func='mean', # Function to summarize data (e.g., 'mean', 'median')
|
2666
2864
|
colors=None, # Custom colors for the plot (optional)
|
2667
|
-
output_dir=
|
2865
|
+
output_dir=dst, # Directory to save the plot and results
|
2668
2866
|
save=settings['save'], # Whether to save the plot and results
|
2669
2867
|
y_lim=settings['y_lim'], # Starting point for y-axis (optional)
|
2670
2868
|
error_bar_type='std', # Type of error bar ('std' or 'sem')
|
@@ -2681,5 +2879,197 @@ def plot_data_from_db(settings):
|
|
2681
2879
|
|
2682
2880
|
# Optional: Get the results DataFrame containing statistical test results
|
2683
2881
|
results_df = spacr_graph.get_results()
|
2882
|
+
return fig, results_df
|
2883
|
+
|
2884
|
+
def plot_data_from_csv(settings):
|
2885
|
+
from .io import _read_db, _read_and_merge_data
|
2886
|
+
from .utils import annotate_conditions, save_settings
|
2887
|
+
"""
|
2888
|
+
Extracts the specified table from the SQLite database and plots a specified column.
|
2889
|
+
|
2890
|
+
Args:
|
2891
|
+
db_path (str): The path to the SQLite database.
|
2892
|
+
table_names (str): The name of the table to extract.
|
2893
|
+
data_column (str): The column to plot from the table.
|
2894
|
+
|
2895
|
+
Returns:
|
2896
|
+
df (pd.DataFrame): The extracted table as a DataFrame.
|
2897
|
+
"""
|
2898
|
+
|
2899
|
+
if isinstance(settings['src'], str):
|
2900
|
+
srcs = [settings['src']]
|
2901
|
+
elif isinstance(settings['src'], list):
|
2902
|
+
srcs = settings['src']
|
2903
|
+
else:
|
2904
|
+
raise ValueError("src must be a string or a list of strings.")
|
2684
2905
|
|
2685
|
-
|
2906
|
+
#save_settings(settings, name=f"{settings['graph_name']}_plot_settings_csv", show=True)
|
2907
|
+
|
2908
|
+
dfs = []
|
2909
|
+
for i, src in enumerate(srcs):
|
2910
|
+
|
2911
|
+
dft = pd.read_csv(src)
|
2912
|
+
if 'plate' not in dft.columns:
|
2913
|
+
dft['plate'] = f"plate{i+1}"
|
2914
|
+
dfs.append(dft)
|
2915
|
+
|
2916
|
+
df = pd.concat(dfs, axis=0)
|
2917
|
+
#display(df)
|
2918
|
+
|
2919
|
+
df = df.dropna(subset=settings['data_column'])
|
2920
|
+
df = df.dropna(subset=settings['grouping_column'])
|
2921
|
+
src = srcs[0]
|
2922
|
+
dst = os.path.join(os.path.dirname(src), 'results', settings['graph_name'])
|
2923
|
+
os.makedirs(dst, exist_ok=True)
|
2924
|
+
|
2925
|
+
spacr_graph = spacrGraph(
|
2926
|
+
df=df, # Your DataFrame
|
2927
|
+
grouping_column=settings['grouping_column'], # Column for grouping the data (x-axis)
|
2928
|
+
data_column=settings['data_column'], # Column for the data (y-axis)
|
2929
|
+
graph_type=settings['graph_type'], # Type of plot ('bar', 'box', 'violin', 'jitter')
|
2930
|
+
graph_name=settings['graph_name'], # Name of the plot
|
2931
|
+
summary_func='mean', # Function to summarize data (e.g., 'mean', 'median')
|
2932
|
+
colors=None, # Custom colors for the plot (optional)
|
2933
|
+
output_dir=dst, # Directory to save the plot and results
|
2934
|
+
save=settings['save'], # Whether to save the plot and results
|
2935
|
+
y_lim=settings['y_lim'], # Starting point for y-axis (optional)
|
2936
|
+
error_bar_type='std', # Type of error bar ('std' or 'sem')
|
2937
|
+
representation=settings['representation'],
|
2938
|
+
theme=settings['theme'], # Seaborn color palette theme (e.g., 'pastel', 'muted')
|
2939
|
+
)
|
2940
|
+
|
2941
|
+
# Create the plot
|
2942
|
+
spacr_graph.create_plot()
|
2943
|
+
|
2944
|
+
# Get the figure object if needed
|
2945
|
+
fig = spacr_graph.get_figure()
|
2946
|
+
plt.show()
|
2947
|
+
|
2948
|
+
# Optional: Get the results DataFrame containing statistical test results
|
2949
|
+
results_df = spacr_graph.get_results()
|
2950
|
+
return fig, results_df
|
2951
|
+
|
2952
|
+
def plot_region(settings):
|
2953
|
+
|
2954
|
+
def _sort_paths_by_basename(paths):
|
2955
|
+
return sorted(paths, key=lambda path: os.path.basename(path))
|
2956
|
+
|
2957
|
+
def save_figure_as_pdf(fig, path):
|
2958
|
+
os.makedirs(os.path.dirname(path), exist_ok=True) # Create directory if it doesn't exist
|
2959
|
+
fig.savefig(path, format='pdf', dpi=600, bbox_inches='tight')
|
2960
|
+
print(f"Saved {path}")
|
2961
|
+
|
2962
|
+
from .io import _read_db
|
2963
|
+
fov_path = os.path.join(settings['src'], 'merged', settings['name'])
|
2964
|
+
name = os.path.splitext(settings['name'])[0]
|
2965
|
+
|
2966
|
+
db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
|
2967
|
+
paths_df = _read_db(db_path, tables=['png_list'])[0]
|
2968
|
+
paths_df = paths_df[paths_df['png_path'].str.contains(name, na=False)]
|
2969
|
+
|
2970
|
+
activation_mode = f"{settings['activation_mode']}_list"
|
2971
|
+
activation_db_path = os.path.join(settings['src'], 'measurements', settings['activation_db'])
|
2972
|
+
activation_paths_df = _read_db(activation_db_path, tables=[activation_mode])[0]
|
2973
|
+
activation_paths_df = activation_paths_df[activation_paths_df['png_path'].str.contains(name, na=False)]
|
2974
|
+
|
2975
|
+
png_paths = _sort_paths_by_basename(paths_df['png_path'].tolist())
|
2976
|
+
activation_paths = _sort_paths_by_basename(activation_paths_df['png_path'].tolist())
|
2977
|
+
|
2978
|
+
fig_3 = plot_image_grid(image_paths=activation_paths, percentiles=settings['percentiles'])
|
2979
|
+
fig_2 = plot_image_grid(image_paths=png_paths, percentiles=settings['percentiles'])
|
2980
|
+
fig_1 = plot_image_mask_overlay(file=fov_path,
|
2981
|
+
channels=settings['channels'],
|
2982
|
+
cell_channel=settings['cell_channel'],
|
2983
|
+
nucleus_channel=settings['nucleus_channel'],
|
2984
|
+
pathogen_channel=settings['pathogen_channel'],
|
2985
|
+
figuresize=10,
|
2986
|
+
percentiles=settings['percentiles'],
|
2987
|
+
thickness=3,
|
2988
|
+
save_pdf=False,
|
2989
|
+
mode=settings['mode'],
|
2990
|
+
export_tiffs=settings['export_tiffs'])
|
2991
|
+
|
2992
|
+
dst = os.path.join(settings['src'], 'results', name)
|
2993
|
+
save_figure_as_pdf(fig_1, os.path.join(dst, f"{name}_mask_overlay.pdf"))
|
2994
|
+
save_figure_as_pdf(fig_2, os.path.join(dst, f"{name}_png_grid.pdf"))
|
2995
|
+
save_figure_as_pdf(fig_3, os.path.join(dst, f"{name}_activation_grid.pdf"))
|
2996
|
+
|
2997
|
+
return fig_1, fig_2, fig_3
|
2998
|
+
|
2999
|
+
def plot_image_grid(image_paths, percentiles):
|
3000
|
+
"""
|
3001
|
+
Plots a square grid of images from a list of image paths.
|
3002
|
+
Unused subplots are filled with black, and padding is minimized.
|
3003
|
+
|
3004
|
+
Parameters:
|
3005
|
+
- image_paths: List of paths to images to be displayed.
|
3006
|
+
|
3007
|
+
Returns:
|
3008
|
+
- fig: The generated matplotlib figure.
|
3009
|
+
"""
|
3010
|
+
|
3011
|
+
from PIL import Image
|
3012
|
+
import matplotlib.pyplot as plt
|
3013
|
+
import math
|
3014
|
+
|
3015
|
+
def _normalize_image(image, percentiles=(2, 98)):
|
3016
|
+
""" Normalize the image to the given percentiles for each channel independently, preserving the input type (either PIL.Image or numpy.ndarray)."""
|
3017
|
+
|
3018
|
+
# Check if the input is a PIL image and convert it to a NumPy array
|
3019
|
+
is_pil_image = isinstance(image, Image.Image)
|
3020
|
+
if is_pil_image:
|
3021
|
+
image = np.array(image)
|
3022
|
+
|
3023
|
+
# If the image is single-channel, normalize directly
|
3024
|
+
if image.ndim == 2:
|
3025
|
+
v_min, v_max = np.percentile(image, percentiles)
|
3026
|
+
normalized_image = np.clip((image - v_min) / (v_max - v_min), 0, 1)
|
3027
|
+
else:
|
3028
|
+
# If multi-channel, normalize each channel independently
|
3029
|
+
normalized_image = np.zeros_like(image, dtype=np.float32)
|
3030
|
+
for c in range(image.shape[-1]):
|
3031
|
+
v_min, v_max = np.percentile(image[..., c], percentiles)
|
3032
|
+
normalized_image[..., c] = np.clip((image[..., c] - v_min) / (v_max - v_min), 0, 1)
|
3033
|
+
|
3034
|
+
# If the input was a PIL image, convert the result back to PIL format
|
3035
|
+
if is_pil_image:
|
3036
|
+
# Ensure the image is converted back to 8-bit range (0-255) for PIL
|
3037
|
+
normalized_image = (normalized_image * 255).astype(np.uint8)
|
3038
|
+
return Image.fromarray(normalized_image)
|
3039
|
+
|
3040
|
+
return normalized_image
|
3041
|
+
|
3042
|
+
N = len(image_paths)
|
3043
|
+
# Calculate the smallest square grid size to fit all images
|
3044
|
+
grid_size = math.ceil(math.sqrt(N))
|
3045
|
+
|
3046
|
+
# Create the square grid of subplots with a black background
|
3047
|
+
fig, axs = plt.subplots(
|
3048
|
+
grid_size, grid_size,
|
3049
|
+
figsize=(grid_size * 2, grid_size * 2),
|
3050
|
+
facecolor='black' # Set figure background to black
|
3051
|
+
)
|
3052
|
+
|
3053
|
+
# Flatten axs in case of a 2D array
|
3054
|
+
axs = axs.flatten()
|
3055
|
+
|
3056
|
+
for i, img_path in enumerate(image_paths):
|
3057
|
+
ax = axs[i]
|
3058
|
+
|
3059
|
+
# Load the image
|
3060
|
+
img = Image.open(img_path)
|
3061
|
+
img = _normalize_image(img, percentiles)
|
3062
|
+
|
3063
|
+
# Display the image
|
3064
|
+
ax.imshow(img)
|
3065
|
+
ax.axis('off') # Hide axes
|
3066
|
+
|
3067
|
+
# Fill any unused subplots with black
|
3068
|
+
for j in range(i + 1, len(axs)):
|
3069
|
+
axs[j].imshow([[0, 0, 0]], cmap='gray') # Black square
|
3070
|
+
axs[j].axis('off') # Hide axes
|
3071
|
+
|
3072
|
+
# Adjust layout to minimize white space
|
3073
|
+
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, top=1, bottom=0)
|
3074
|
+
|
3075
|
+
return fig
|