spacr 0.3.2__py3-none-any.whl → 0.3.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +105 -1
- spacr/deep_spacr.py +171 -25
- spacr/io.py +80 -121
- spacr/ml.py +153 -66
- spacr/plot.py +429 -7
- spacr/settings.py +6 -5
- spacr/submodules.py +7 -6
- spacr/toxo.py +9 -4
- spacr/utils.py +152 -13
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/METADATA +28 -25
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/RECORD +15 -15
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/plot.py
CHANGED
@@ -16,8 +16,9 @@ from skimage.measure import find_contours, label, regionprops
|
|
16
16
|
|
17
17
|
from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
|
18
18
|
from statsmodels.stats.multicomp import pairwise_tukeyhsd
|
19
|
+
from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
|
19
20
|
import itertools
|
20
|
-
|
21
|
+
import pingouin as pg
|
21
22
|
|
22
23
|
from ipywidgets import IntSlider, interact
|
23
24
|
from IPython.display import Image as ipyimage
|
@@ -1002,8 +1003,107 @@ def _display_gif(path):
|
|
1002
1003
|
"""
|
1003
1004
|
with open(path, 'rb') as file:
|
1004
1005
|
display(ipyimage(file.read()))
|
1006
|
+
|
1007
|
+
def _plot_recruitment_v2(df, df_type, channel_of_interest, columns=[], figuresize=10):
|
1008
|
+
"""
|
1009
|
+
Plot recruitment data for different conditions and pathogens.
|
1010
|
+
|
1011
|
+
Args:
|
1012
|
+
df (DataFrame): The input DataFrame containing the recruitment data.
|
1013
|
+
df_type (str): The type of DataFrame (e.g., 'train', 'test').
|
1014
|
+
channel_of_interest (str): The channel of interest for plotting.
|
1015
|
+
target (str): The target variable for plotting.
|
1016
|
+
columns (list, optional): Additional columns to plot. Defaults to an empty list.
|
1017
|
+
figuresize (int, optional): The size of the figure. Defaults to 50.
|
1018
|
+
|
1019
|
+
Returns:
|
1020
|
+
None
|
1021
|
+
"""
|
1022
|
+
|
1023
|
+
from .plot import spacrGraph
|
1024
|
+
|
1025
|
+
color_list = [(55/255, 155/255, 155/255),
|
1026
|
+
(155/255, 55/255, 155/255),
|
1027
|
+
(55/255, 155/255, 255/255),
|
1028
|
+
(255/255, 55/255, 155/255)]
|
1029
|
+
|
1030
|
+
sns.set_palette(sns.color_palette(color_list))
|
1031
|
+
font = figuresize/2
|
1032
|
+
width=figuresize
|
1033
|
+
height=figuresize/4
|
1034
|
+
|
1035
|
+
# Create the subplots
|
1036
|
+
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(width, height))
|
1037
|
+
|
1038
|
+
# Plot for 'cell_channel' on axes[0]
|
1039
|
+
plotter_cell = spacrGraph(df,grouping_column='condition', data_column=f'cell_channel_{channel_of_interest}_mean_intensity')
|
1040
|
+
plotter_cell.create_plot(ax=axes[0])
|
1041
|
+
axes[0].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1042
|
+
axes[0].set_ylabel(f'cell_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
1043
|
+
|
1044
|
+
# Plot for 'nucleus_channel' on axes[1]
|
1045
|
+
plotter_nucleus = spacrGraph(df,grouping_column='condition', data_column=f'nucleus_channel_{channel_of_interest}_mean_intensity')
|
1046
|
+
plotter_nucleus.create_plot(ax=axes[1])
|
1047
|
+
axes[1].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1048
|
+
axes[1].set_ylabel(f'nucleus_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
1049
|
+
|
1050
|
+
# Plot for 'cytoplasm_channel' on axes[2]
|
1051
|
+
plotter_cytoplasm = spacrGraph(df, grouping_column='condition', data_column=f'cytoplasm_channel_{channel_of_interest}_mean_intensity')
|
1052
|
+
plotter_cytoplasm.create_plot(ax=axes[2])
|
1053
|
+
axes[2].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1054
|
+
axes[2].set_ylabel(f'cytoplasm_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
1055
|
+
|
1056
|
+
# Plot for 'pathogen_channel' on axes[3]
|
1057
|
+
plotter_pathogen = spacrGraph(df, grouping_column='condition', data_column=f'pathogen_channel_{channel_of_interest}_mean_intensity')
|
1058
|
+
plotter_pathogen.create_plot(ax=axes[3])
|
1059
|
+
axes[3].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1060
|
+
axes[3].set_ylabel(f'pathogen_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
1061
|
+
|
1062
|
+
#axes[0].legend_.remove()
|
1063
|
+
#axes[1].legend_.remove()
|
1064
|
+
#axes[2].legend_.remove()
|
1065
|
+
#axes[3].legend_.remove()
|
1066
|
+
|
1067
|
+
handles, labels = axes[3].get_legend_handles_labels()
|
1068
|
+
axes[3].legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc='center left')
|
1069
|
+
for i in [0,1,2,3]:
|
1070
|
+
axes[i].tick_params(axis='both', which='major', labelsize=font)
|
1071
|
+
axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)
|
1072
|
+
|
1073
|
+
plt.tight_layout()
|
1074
|
+
plt.show()
|
1075
|
+
|
1076
|
+
columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
|
1077
|
+
#columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
|
1078
|
+
|
1079
|
+
width = figuresize*2
|
1080
|
+
columns_per_row = math.ceil(len(columns) / 2)
|
1081
|
+
height = (figuresize*2)/columns_per_row
|
1082
|
+
|
1083
|
+
fig, axes = plt.subplots(nrows=2, ncols=columns_per_row, figsize=(width, height * 2))
|
1084
|
+
axes = axes.flatten()
|
1085
|
+
|
1086
|
+
print(f'{columns}')
|
1087
|
+
for i, col in enumerate(columns):
|
1088
|
+
ax = axes[i]
|
1089
|
+
plotter_col = spacrGraph(df, grouping_column='condition', data_column=col)
|
1090
|
+
plotter_col.create_plot(ax=ax)
|
1091
|
+
ax.set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1092
|
+
ax.set_ylabel(f'{col}', fontsize=int(font * 2))
|
1093
|
+
#ax.legend_.remove()
|
1094
|
+
ax.tick_params(axis='both', which='major', labelsize=font)
|
1095
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
|
1096
|
+
if i <= 5:
|
1097
|
+
ax.set_ylim(1, None)
|
1098
|
+
|
1099
|
+
# Turn off any unused axes
|
1100
|
+
for i in range(len(columns), len(axes)):
|
1101
|
+
axes[i].axis('off')
|
1102
|
+
|
1103
|
+
plt.tight_layout()
|
1104
|
+
plt.show()
|
1005
1105
|
|
1006
|
-
def _plot_recruitment(df, df_type, channel_of_interest,
|
1106
|
+
def _plot_recruitment(df, df_type, channel_of_interest, columns=[], figuresize=10):
|
1007
1107
|
"""
|
1008
1108
|
Plot recruitment data for different conditions and pathogens.
|
1009
1109
|
|
@@ -1156,10 +1256,6 @@ def _plot_controls(df, mask_chans, channel_of_interest, figuresize=5):
|
|
1156
1256
|
plt.tight_layout()
|
1157
1257
|
plt.show()
|
1158
1258
|
|
1159
|
-
###################################################
|
1160
|
-
# Classify
|
1161
|
-
###################################################
|
1162
|
-
|
1163
1259
|
def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
1164
1260
|
"""
|
1165
1261
|
Display multiple images in a grid with corresponding labels.
|
@@ -1997,4 +2093,330 @@ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summ
|
|
1997
2093
|
# Show the plot
|
1998
2094
|
plt.show()
|
1999
2095
|
|
2000
|
-
return plt.gcf(), results_df
|
2096
|
+
return plt.gcf(), results_df
|
2097
|
+
|
2098
|
+
class spacrGraph:
|
2099
|
+
def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
|
2100
|
+
order=None, colors=None, output_dir='./output', save=False, y_axis_start=None,
|
2101
|
+
error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
|
2102
|
+
paired=False, all_to_all=True, compare_group=None):
|
2103
|
+
"""
|
2104
|
+
Class for creating grouped plots with optional statistical tests and data preprocessing.
|
2105
|
+
"""
|
2106
|
+
self.df = df
|
2107
|
+
self.grouping_column = grouping_column
|
2108
|
+
self.data_column = data_column
|
2109
|
+
self.graph_type = graph_type
|
2110
|
+
self.summary_func = summary_func
|
2111
|
+
self.order = order
|
2112
|
+
self.colors = colors
|
2113
|
+
self.output_dir = output_dir
|
2114
|
+
self.save = save
|
2115
|
+
self.y_axis_start = y_axis_start
|
2116
|
+
self.error_bar_type = error_bar_type
|
2117
|
+
self.remove_outliers = remove_outliers
|
2118
|
+
self.theme = theme
|
2119
|
+
self.representation = representation
|
2120
|
+
self.paired = paired
|
2121
|
+
self.all_to_all = all_to_all
|
2122
|
+
self.compare_group = compare_group
|
2123
|
+
|
2124
|
+
self.results_df = pd.DataFrame()
|
2125
|
+
self.sns_palette = None
|
2126
|
+
self.fig = None # To store the generated figure
|
2127
|
+
|
2128
|
+
# Preprocess and set palette
|
2129
|
+
self._set_theme()
|
2130
|
+
self.raw_df = self.df.copy() # Preserve the raw data for n_object count
|
2131
|
+
self.df = self.preprocess_data()
|
2132
|
+
|
2133
|
+
def _set_theme(self):
|
2134
|
+
"""Set the Seaborn theme and reorder colors if necessary."""
|
2135
|
+
integer_list = list(range(1, 81))
|
2136
|
+
color_order = [0, 3, 9, 4, 6, 7, 9, 2] + integer_list
|
2137
|
+
self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
|
2138
|
+
|
2139
|
+
def _set_reordered_theme(self, theme='muted', order=None, n_colors=100, show_theme=False):
|
2140
|
+
"""Set and reorder the Seaborn color palette."""
|
2141
|
+
palette = sns.color_palette(theme, n_colors)
|
2142
|
+
if order:
|
2143
|
+
reordered_palette = [palette[i] for i in order]
|
2144
|
+
else:
|
2145
|
+
reordered_palette = palette
|
2146
|
+
if show_theme:
|
2147
|
+
sns.palplot(reordered_palette)
|
2148
|
+
plt.show()
|
2149
|
+
return reordered_palette
|
2150
|
+
|
2151
|
+
def preprocess_data(self):
|
2152
|
+
"""Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
|
2153
|
+
df = self.df.dropna(subset=[self.grouping_column, self.data_column])
|
2154
|
+
|
2155
|
+
# Group by 'prc' column if representation is 'well'
|
2156
|
+
if self.representation == 'well':
|
2157
|
+
df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
|
2158
|
+
|
2159
|
+
if self.order:
|
2160
|
+
df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
|
2161
|
+
else:
|
2162
|
+
df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=sorted(df[self.grouping_column].unique()), ordered=True)
|
2163
|
+
|
2164
|
+
return df
|
2165
|
+
|
2166
|
+
def remove_outliers_from_plot(self):
|
2167
|
+
"""Remove outliers from the plot but keep them in the data."""
|
2168
|
+
filtered_df = self.df.copy()
|
2169
|
+
unique_groups = filtered_df[self.grouping_column].unique()
|
2170
|
+
for group in unique_groups:
|
2171
|
+
group_data = filtered_df[filtered_df[self.grouping_column] == group][self.data_column]
|
2172
|
+
q1 = group_data.quantile(0.25)
|
2173
|
+
q3 = group_data.quantile(0.75)
|
2174
|
+
iqr = q3 - q1
|
2175
|
+
lower_bound = q1 - 1.5 * iqr
|
2176
|
+
upper_bound = q3 + 1.5 * iqr
|
2177
|
+
filtered_df = filtered_df.drop(filtered_df[(filtered_df[self.grouping_column] == group) & ((filtered_df[self.data_column] < lower_bound) | (filtered_df[self.data_column] > upper_bound))].index)
|
2178
|
+
return filtered_df
|
2179
|
+
|
2180
|
+
def perform_normality_tests(self):
|
2181
|
+
"""Perform normality tests for each group."""
|
2182
|
+
unique_groups = self.df[self.grouping_column].unique()
|
2183
|
+
grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
|
2184
|
+
raw_grouped_data = [self.raw_df.loc[self.raw_df[self.grouping_column] == group, self.data_column] for group in unique_groups]
|
2185
|
+
|
2186
|
+
normal_p_values = [normaltest(data).pvalue for data in grouped_data]
|
2187
|
+
normal_stats = [normaltest(data).statistic for data in grouped_data]
|
2188
|
+
is_normal = all(p > 0.05 for p in normal_p_values)
|
2189
|
+
|
2190
|
+
test_results = []
|
2191
|
+
for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
|
2192
|
+
test_results.append({
|
2193
|
+
'Comparison': f'Normality test for {group}',
|
2194
|
+
'Test Statistic': stat,
|
2195
|
+
'p-value': p_value,
|
2196
|
+
'Test Name': 'Normality test',
|
2197
|
+
'n_object': len(raw_grouped_data[unique_groups.tolist().index(group)]), # Raw sample size (objects/cells)
|
2198
|
+
'n_well': len(grouped_data[unique_groups.tolist().index(group)]) if self.representation == 'well' else np.nan # Summarized size (wells)
|
2199
|
+
})
|
2200
|
+
return is_normal, test_results
|
2201
|
+
|
2202
|
+
def perform_levene_test(self, unique_groups):
|
2203
|
+
"""Perform Levene's test for equal variance."""
|
2204
|
+
grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
|
2205
|
+
stat, p_value = levene(*grouped_data)
|
2206
|
+
return stat, p_value
|
2207
|
+
|
2208
|
+
def perform_statistical_tests(self, unique_groups, is_normal):
|
2209
|
+
"""Perform statistical tests based on the number of groups, normality, and paired flag."""
|
2210
|
+
if len(unique_groups) == 2:
|
2211
|
+
if is_normal:
|
2212
|
+
if self.paired:
|
2213
|
+
stat_test = pg.ttest # Paired T-test
|
2214
|
+
test_name = 'Paired T-test'
|
2215
|
+
else:
|
2216
|
+
stat_test = ttest_ind
|
2217
|
+
test_name = 'T-test'
|
2218
|
+
else:
|
2219
|
+
if self.paired:
|
2220
|
+
stat_test = pg.wilcoxon # Paired Wilcoxon test
|
2221
|
+
test_name = 'Paired Wilcoxon test'
|
2222
|
+
else:
|
2223
|
+
stat_test = mannwhitneyu
|
2224
|
+
test_name = 'Mann-Whitney U test'
|
2225
|
+
else:
|
2226
|
+
if is_normal:
|
2227
|
+
stat_test = f_oneway
|
2228
|
+
test_name = 'One-way ANOVA'
|
2229
|
+
else:
|
2230
|
+
stat_test = kruskal
|
2231
|
+
test_name = 'Kruskal-Wallis test'
|
2232
|
+
|
2233
|
+
comparisons = list(itertools.combinations(unique_groups, 2))
|
2234
|
+
test_results = []
|
2235
|
+
for (group1, group2) in comparisons:
|
2236
|
+
data1 = self.df[self.df[self.grouping_column] == group1][self.data_column]
|
2237
|
+
data2 = self.df[self.df[self.grouping_column] == group2][self.data_column]
|
2238
|
+
raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == group1][self.data_column]
|
2239
|
+
raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == group2][self.data_column]
|
2240
|
+
|
2241
|
+
if self.paired:
|
2242
|
+
stat, p = stat_test(data1, data2, paired=True)
|
2243
|
+
else:
|
2244
|
+
stat, p = stat_test(data1, data2)
|
2245
|
+
|
2246
|
+
test_results.append({
|
2247
|
+
'Comparison': f'{group1} vs {group2}',
|
2248
|
+
'Test Statistic': stat,
|
2249
|
+
'p-value': p,
|
2250
|
+
'Test Name': test_name,
|
2251
|
+
'n_object': len(raw_data1) + len(raw_data2), # Raw sample size (objects/cells)
|
2252
|
+
'n_well': len(data1) + len(data2) if self.representation == 'well' else np.nan # Summarized size (wells)
|
2253
|
+
})
|
2254
|
+
return test_results
|
2255
|
+
|
2256
|
+
def perform_posthoc_tests(self, is_normal, unique_groups):
|
2257
|
+
"""Perform post-hoc tests for multiple groups based on all_to_all flag."""
|
2258
|
+
if is_normal and len(unique_groups) > 2 and self.all_to_all:
|
2259
|
+
# Tukey HSD Post-hoc when comparing all to all
|
2260
|
+
tukey_result = pairwise_tukeyhsd(self.df[self.data_column], self.df[self.grouping_column], alpha=0.05)
|
2261
|
+
posthoc_results = []
|
2262
|
+
for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
|
2263
|
+
raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == comparison[0]][self.data_column]
|
2264
|
+
raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == comparison[1]][self.data_column]
|
2265
|
+
|
2266
|
+
posthoc_results.append({
|
2267
|
+
'Comparison': f'{comparison[0]} vs {comparison[1]}',
|
2268
|
+
'Test Statistic': None, # Tukey does not provide a test statistic
|
2269
|
+
'p-value': p_value,
|
2270
|
+
'Test Name': 'Tukey HSD Post-hoc',
|
2271
|
+
'n_object': len(raw_data1) + len(raw_data2),
|
2272
|
+
'n_well': len(self.df[self.df[self.grouping_column] == comparison[0]]) + len(self.df[self.df[self.grouping_column] == comparison[1]])
|
2273
|
+
})
|
2274
|
+
return posthoc_results
|
2275
|
+
|
2276
|
+
elif len(unique_groups) > 2 and not self.all_to_all and self.compare_group:
|
2277
|
+
# Dunn's post-hoc test using Pingouin
|
2278
|
+
dunn_result = pg.pairwise_tests(data=self.df, dv=self.data_column, between=self.grouping_column, padjust='bonf', test='dunn')
|
2279
|
+
posthoc_results = []
|
2280
|
+
for idx, row in dunn_result.iterrows():
|
2281
|
+
if row['A'] == self.compare_group or row['B'] == self.compare_group:
|
2282
|
+
posthoc_results.append({
|
2283
|
+
'Comparison': f"{row['A']} vs {row['B']}",
|
2284
|
+
'Test Statistic': row['T'], # Test statistic from Dunn's test
|
2285
|
+
'p-value': row['p-val'],
|
2286
|
+
'Test Name': 'Dunn’s Post-hoc',
|
2287
|
+
'n_object': None,
|
2288
|
+
'n_well': None
|
2289
|
+
})
|
2290
|
+
return posthoc_results
|
2291
|
+
return []
|
2292
|
+
|
2293
|
+
def create_plot(self, ax=None):
|
2294
|
+
"""Create and display the plot based on the chosen graph type."""
|
2295
|
+
# Optional: Remove outliers for plotting
|
2296
|
+
if self.remove_outliers:
|
2297
|
+
self.df = self.remove_outliers_from_plot()
|
2298
|
+
|
2299
|
+
# Perform normality tests
|
2300
|
+
is_normal, normality_results = self.perform_normality_tests()
|
2301
|
+
|
2302
|
+
# Perform Levene's test for equal variance
|
2303
|
+
unique_groups = self.df[self.grouping_column].unique()
|
2304
|
+
levene_stat, levene_p = self.perform_levene_test(unique_groups)
|
2305
|
+
levene_result = {
|
2306
|
+
'Comparison': 'Levene’s test for equal variance',
|
2307
|
+
'Test Statistic': levene_stat,
|
2308
|
+
'p-value': levene_p,
|
2309
|
+
'Test Name': 'Levene’s Test'
|
2310
|
+
}
|
2311
|
+
|
2312
|
+
# Perform statistical tests
|
2313
|
+
stat_results = self.perform_statistical_tests(unique_groups, is_normal)
|
2314
|
+
|
2315
|
+
# Perform post-hoc tests if applicable
|
2316
|
+
posthoc_results = self.perform_posthoc_tests(is_normal, unique_groups)
|
2317
|
+
|
2318
|
+
# Combine all test results
|
2319
|
+
self.results_df = pd.DataFrame(normality_results + [levene_result] + stat_results + posthoc_results)
|
2320
|
+
|
2321
|
+
# Add sample size column
|
2322
|
+
sample_sizes = self.df.groupby(self.grouping_column)[self.data_column].count().reset_index(name='n')
|
2323
|
+
self.results_df['n'] = self.results_df['Comparison'].apply(
|
2324
|
+
lambda x: next((sample_sizes[sample_sizes[self.grouping_column] == g]['n'].values[0] for g in sample_sizes[self.grouping_column] if g in x), np.nan)
|
2325
|
+
)
|
2326
|
+
|
2327
|
+
# Dynamically set figure dimensions based on the number of unique groups
|
2328
|
+
num_groups = len(unique_groups)
|
2329
|
+
bar_width = 0.6 # Set the desired thickness of each bar
|
2330
|
+
spacing_between_groups = 0.3 # Set the desired spacing between bars and axis
|
2331
|
+
|
2332
|
+
fig_width = num_groups * (bar_width + spacing_between_groups) # Dynamically calculate the figure width
|
2333
|
+
fig_height = 6 # Fixed height for the plot
|
2334
|
+
|
2335
|
+
if ax is None:
|
2336
|
+
self.fig, ax = plt.subplots(figsize=(fig_width, fig_height)) # Store the figure in self.fig
|
2337
|
+
else:
|
2338
|
+
self.fig = ax.figure # Store the figure if ax is provided
|
2339
|
+
|
2340
|
+
sns.set(style="ticks")
|
2341
|
+
color_palette = self.sns_palette if not self.colors else self.colors
|
2342
|
+
|
2343
|
+
# Calculate x-axis limits to ensure equal space between the bars and the y-axis
|
2344
|
+
xlim_lower = -0.5 # Ensures space between the y-axis and the first category
|
2345
|
+
xlim_upper = num_groups - 0.5 # Ensures space after the last category
|
2346
|
+
ax.set_xlim(xlim_lower, xlim_upper)
|
2347
|
+
|
2348
|
+
if self.summary_func is None:
|
2349
|
+
sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=color_palette, jitter=True, alpha=0.6, ax=ax)
|
2350
|
+
elif self.graph_type == 'bar':
|
2351
|
+
self._create_bar_plot(bar_width, ax)
|
2352
|
+
elif self.graph_type == 'box':
|
2353
|
+
self._create_box_plot(ax)
|
2354
|
+
elif self.graph_type == 'violin':
|
2355
|
+
self._create_violin_plot(ax)
|
2356
|
+
elif self.graph_type == 'jitter':
|
2357
|
+
self._create_jitter_plot(ax)
|
2358
|
+
else:
|
2359
|
+
raise ValueError(f"Invalid graph_type: {self.graph_type}. Choose from 'bar', 'box', 'violin', or 'jitter'.")
|
2360
|
+
|
2361
|
+
# Set y-axis start
|
2362
|
+
if self.y_axis_start is not None:
|
2363
|
+
ax.set_ylim(bottom=self.y_axis_start)
|
2364
|
+
|
2365
|
+
# Add ticks, remove grid, and save plot
|
2366
|
+
ax.minorticks_on()
|
2367
|
+
ax.tick_params(axis='x', which='minor', bottom=False) # Disable minor ticks on x-axis
|
2368
|
+
ax.tick_params(axis='x', which='major', length=6, width=2, direction='out')
|
2369
|
+
ax.tick_params(axis='y', which='major', length=6, width=2, direction='out')
|
2370
|
+
ax.tick_params(axis='y', which='minor', length=4, width=1, direction='out')
|
2371
|
+
sns.despine(ax=ax, top=True, right=True)
|
2372
|
+
|
2373
|
+
if self.save:
|
2374
|
+
self._save_results()
|
2375
|
+
|
2376
|
+
plt.show() # Ensure the plot is shown, but plt.show() doesn't clear the figure context
|
2377
|
+
|
2378
|
+
def get_figure(self):
|
2379
|
+
"""Return the generated figure."""
|
2380
|
+
return self.fig
|
2381
|
+
|
2382
|
+
def _create_bar_plot(self, bar_width, ax):
|
2383
|
+
"""Helper method to create a bar plot with consistent bar thickness and centered error bars."""
|
2384
|
+
summary_df = self.df.groupby(self.grouping_column)[self.data_column].agg([self.summary_func, 'std', 'sem'])
|
2385
|
+
|
2386
|
+
if self.error_bar_type == 'std':
|
2387
|
+
error_bars = summary_df['std']
|
2388
|
+
elif self.error_bar_type == 'sem':
|
2389
|
+
error_bars = summary_df['sem']
|
2390
|
+
else:
|
2391
|
+
raise ValueError(f"Invalid error_bar_type: {self.error_bar_type}. Choose either 'std' or 'sem'.")
|
2392
|
+
|
2393
|
+
sns.barplot(x=self.grouping_column, y=self.summary_func, data=summary_df.reset_index(), ci=None, palette=self.sns_palette, width=bar_width, ax=ax)
|
2394
|
+
|
2395
|
+
# Plot the error bars
|
2396
|
+
ax.errorbar(x=np.arange(len(summary_df)), y=summary_df[self.summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
|
2397
|
+
|
2398
|
+
def _create_jitter_plot(self, ax):
|
2399
|
+
"""Helper method to create a jitter plot (strip plot)."""
|
2400
|
+
sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, jitter=True, alpha=0.6, ax=ax)
|
2401
|
+
|
2402
|
+
def _create_box_plot(self, ax):
|
2403
|
+
"""Helper method to create a box plot."""
|
2404
|
+
sns.boxplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
|
2405
|
+
|
2406
|
+
def _create_violin_plot(self, ax):
|
2407
|
+
"""Helper method to create a violin plot."""
|
2408
|
+
sns.violinplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
|
2409
|
+
|
2410
|
+
def _save_results(self):
|
2411
|
+
"""Helper method to save the plot and results."""
|
2412
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
2413
|
+
plot_path = os.path.join(self.output_dir, 'grouped_plot.png')
|
2414
|
+
self.fig.savefig(plot_path)
|
2415
|
+
results_path = os.path.join(self.output_dir, 'test_results.csv')
|
2416
|
+
self.results_df.to_csv(results_path, index=False)
|
2417
|
+
print(f"Plot saved to {plot_path}")
|
2418
|
+
print(f"Test results saved to {results_path}")
|
2419
|
+
|
2420
|
+
def get_results(self):
|
2421
|
+
"""Return the results dataframe."""
|
2422
|
+
return self.results_df
|
spacr/settings.py
CHANGED
@@ -404,7 +404,7 @@ def deep_spacr_defaults(settings):
|
|
404
404
|
settings.setdefault('sample',None)
|
405
405
|
settings.setdefault('experiment','exp.')
|
406
406
|
settings.setdefault('score_threshold',0.5)
|
407
|
-
settings.setdefault('
|
407
|
+
settings.setdefault('dataset','path')
|
408
408
|
settings.setdefault('model_path','path')
|
409
409
|
settings.setdefault('file_type','cell_png')
|
410
410
|
settings.setdefault('generate_training_dataset', True)
|
@@ -461,7 +461,7 @@ def get_analyze_recruitment_default_settings(settings):
|
|
461
461
|
settings.setdefault('pathogen_mask_dim',6)
|
462
462
|
settings.setdefault('channel_of_interest',2)
|
463
463
|
settings.setdefault('plot',True)
|
464
|
-
settings.setdefault('plot_nr',
|
464
|
+
settings.setdefault('plot_nr',3)
|
465
465
|
settings.setdefault('plot_control',True)
|
466
466
|
settings.setdefault('figuresize',10)
|
467
467
|
settings.setdefault('uninfected',True)
|
@@ -534,6 +534,7 @@ def get_perform_regression_default_settings(settings):
|
|
534
534
|
settings.setdefault('random_row_column_effects',False)
|
535
535
|
settings.setdefault('alpha',1)
|
536
536
|
settings.setdefault('fraction_threshold',0.1)
|
537
|
+
settings.setdefault('location_column','column')
|
537
538
|
settings.setdefault('nc','c1')
|
538
539
|
settings.setdefault('pc','c2')
|
539
540
|
settings.setdefault('other','c3')
|
@@ -855,7 +856,7 @@ expected_types = {
|
|
855
856
|
'reverse_complement':bool,
|
856
857
|
'file_type':str,
|
857
858
|
'model_path':str,
|
858
|
-
'
|
859
|
+
'dataset':str,
|
859
860
|
'score_threshold':float,
|
860
861
|
'sample':None,
|
861
862
|
'file_metadata':None,
|
@@ -881,7 +882,7 @@ expected_types = {
|
|
881
882
|
"train_DL_model":bool,
|
882
883
|
}
|
883
884
|
|
884
|
-
categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "
|
885
|
+
categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset","model_path","grna_csv","row_csv","column_csv"],
|
885
886
|
"General": ["metadata_type", "custom_regex", "experiment", "channels", "magnification", "channel_dims", "apply_model_to_dataset", "generate_training_dataset", "train_DL_model", "segmentation_mode"],
|
886
887
|
"Cellpose":["from_scratch", "n_epochs", "width_height", "model_name", "custom_model", "resample", "rescale", "CP_prob", "flow_threshold", "percentiles", "circular", "invert", "diameter", "grayscale", "background", "Signal_to_noise", "resize", "target_height", "target_width"],
|
887
888
|
"Cell": ["cell_intensity_range", "cell_size_range", "cell_chann_dim", "cell_channel", "cell_background", "cell_Signal_to_noise", "cell_CP_prob", "cell_FT", "remove_background_cell", "cell_min_size", "cell_mask_dim", "cytoplasm", "cytoplasm_min_size", "include_uninfected", "merge_edge_pathogen_cells", "adjust_cells", "cells", "cell_loc"],
|
@@ -1202,7 +1203,7 @@ def generate_fields(variables, scrollable_frame):
|
|
1202
1203
|
"complevel": "int - level of compression (0-9). Higher is slower and yealds smaller files",
|
1203
1204
|
"file_type": "str - type of file to process",
|
1204
1205
|
"model_path": "str - path to the model",
|
1205
|
-
"
|
1206
|
+
"dataset": "str - file name of the tar file with image dataset",
|
1206
1207
|
"score_threshold": "float - threshold for classification",
|
1207
1208
|
"sample": "str - number of images to sample for tar dataset (including both classes). Default: None",
|
1208
1209
|
"file_metadata": "str - string that must be present in image path to be included in the dataset",
|
spacr/submodules.py
CHANGED
@@ -36,7 +36,7 @@ def analyze_recruitment(settings={}):
|
|
36
36
|
sns.color_palette("mako", as_cmap=True)
|
37
37
|
print(f"channel:{settings['channel_of_interest']} = {settings['target']}")
|
38
38
|
|
39
|
-
df, _ = _read_and_merge_data(
|
39
|
+
df, _ = _read_and_merge_data(locs=[settings['src']+'/measurements/measurements.db'],
|
40
40
|
tables=['cell', 'nucleus', 'pathogen','cytoplasm'],
|
41
41
|
verbose=True,
|
42
42
|
nuclei_limit=settings['nuclei_limit'],
|
@@ -89,15 +89,16 @@ def analyze_recruitment(settings={}):
|
|
89
89
|
|
90
90
|
if not settings['cell_chann_dim'] is None:
|
91
91
|
df = _object_filter(df, 'cell', settings['cell_size_range'], settings['cell_intensity_range'], mask_chans, 0)
|
92
|
-
if not settings['target_intensity_min'] is None:
|
93
|
-
df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95
|
92
|
+
if not settings['target_intensity_min'] is None or not settings['target_intensity_min'] is 0:
|
93
|
+
df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95"] > settings['target_intensity_min']]
|
94
94
|
print(f"After channel {settings['channel_of_interest']} filtration", len(df))
|
95
95
|
if not settings['nucleus_chann_dim'] is None:
|
96
96
|
df = _object_filter(df, 'nucleus', settings['nucleus_size_range'], settings['nucleus_intensity_range'], mask_chans, 1)
|
97
97
|
if not settings['pathogen_chann_dim'] is None:
|
98
98
|
df = _object_filter(df, 'pathogen', settings['pathogen_size_range'], settings['pathogen_intensity_range'], mask_chans, 2)
|
99
99
|
|
100
|
-
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity
|
100
|
+
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
101
|
+
|
101
102
|
for chan in settings['channel_dims']:
|
102
103
|
df = _calculate_recruitment(df, channel=chan)
|
103
104
|
print(f'calculated recruitment for: {len(df)} rows')
|
@@ -114,9 +115,9 @@ def analyze_recruitment(settings={}):
|
|
114
115
|
_plot_controls(df, mask_chans, settings['channel_of_interest'], figuresize=5)
|
115
116
|
|
116
117
|
print(f'PV level: {len(df)} rows')
|
117
|
-
_plot_recruitment(df, 'by PV', settings['channel_of_interest'],
|
118
|
+
_plot_recruitment(df, 'by PV', settings['channel_of_interest'], columns=[], figuresize=settings['figuresize'])
|
118
119
|
print(f'well level: {len(df_well)} rows')
|
119
|
-
_plot_recruitment(df_well, 'by well', settings['channel_of_interest'],
|
120
|
+
_plot_recruitment(df_well, 'by well', settings['channel_of_interest'], columns=[], figuresize=settings['figuresize'])
|
120
121
|
cells,wells = _results_to_csv(settings['src'], df, df_well)
|
121
122
|
|
122
123
|
return [cells,wells]
|
spacr/toxo.py
CHANGED
@@ -112,10 +112,15 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
|
|
112
112
|
- Plot the enrichment score vs -log10(p-value).
|
113
113
|
"""
|
114
114
|
|
115
|
-
significant_df['variable'].fillna(significant_df['feature'], inplace=True)
|
116
|
-
split_columns = significant_df['variable'].str.split('_', expand=True)
|
117
|
-
significant_df['gene_nr'] = split_columns[0]
|
118
|
-
gene_list = significant_df['gene_nr'].to_list()
|
115
|
+
#significant_df['variable'].fillna(significant_df['feature'], inplace=True)
|
116
|
+
#split_columns = significant_df['variable'].str.split('_', expand=True)
|
117
|
+
#significant_df['gene_nr'] = split_columns[0]
|
118
|
+
#gene_list = significant_df['gene_nr'].to_list()
|
119
|
+
|
120
|
+
significant_df = significant_df.dropna(subset=['n_gene'])
|
121
|
+
significant_df = significant_df[significant_df['n_gene'] != None]
|
122
|
+
|
123
|
+
gene_list = significant_df['n_gene'].to_list()
|
119
124
|
|
120
125
|
# Load metadata
|
121
126
|
metadata = pd.read_csv(metadata_path)
|