spacr 0.3.2__py3-none-any.whl → 0.3.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -16,8 +16,9 @@ from skimage.measure import find_contours, label, regionprops
16
16
 
17
17
  from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
18
18
  from statsmodels.stats.multicomp import pairwise_tukeyhsd
19
+ from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
19
20
  import itertools
20
-
21
+ import pingouin as pg
21
22
 
22
23
  from ipywidgets import IntSlider, interact
23
24
  from IPython.display import Image as ipyimage
@@ -1002,8 +1003,107 @@ def _display_gif(path):
1002
1003
  """
1003
1004
  with open(path, 'rb') as file:
1004
1005
  display(ipyimage(file.read()))
1006
+
1007
+ def _plot_recruitment_v2(df, df_type, channel_of_interest, columns=[], figuresize=10):
1008
+ """
1009
+ Plot recruitment data for different conditions and pathogens.
1010
+
1011
+ Args:
1012
+ df (DataFrame): The input DataFrame containing the recruitment data.
1013
+ df_type (str): The type of DataFrame (e.g., 'train', 'test').
1014
+ channel_of_interest (str): The channel of interest for plotting.
1015
+ target (str): The target variable for plotting.
1016
+ columns (list, optional): Additional columns to plot. Defaults to an empty list.
1017
+ figuresize (int, optional): The size of the figure. Defaults to 50.
1018
+
1019
+ Returns:
1020
+ None
1021
+ """
1022
+
1023
+ from .plot import spacrGraph
1024
+
1025
+ color_list = [(55/255, 155/255, 155/255),
1026
+ (155/255, 55/255, 155/255),
1027
+ (55/255, 155/255, 255/255),
1028
+ (255/255, 55/255, 155/255)]
1029
+
1030
+ sns.set_palette(sns.color_palette(color_list))
1031
+ font = figuresize/2
1032
+ width=figuresize
1033
+ height=figuresize/4
1034
+
1035
+ # Create the subplots
1036
+ fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(width, height))
1037
+
1038
+ # Plot for 'cell_channel' on axes[0]
1039
+ plotter_cell = spacrGraph(df,grouping_column='condition', data_column=f'cell_channel_{channel_of_interest}_mean_intensity')
1040
+ plotter_cell.create_plot(ax=axes[0])
1041
+ axes[0].set_xlabel(f'pathogen {df_type}', fontsize=font)
1042
+ axes[0].set_ylabel(f'cell_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1043
+
1044
+ # Plot for 'nucleus_channel' on axes[1]
1045
+ plotter_nucleus = spacrGraph(df,grouping_column='condition', data_column=f'nucleus_channel_{channel_of_interest}_mean_intensity')
1046
+ plotter_nucleus.create_plot(ax=axes[1])
1047
+ axes[1].set_xlabel(f'pathogen {df_type}', fontsize=font)
1048
+ axes[1].set_ylabel(f'nucleus_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1049
+
1050
+ # Plot for 'cytoplasm_channel' on axes[2]
1051
+ plotter_cytoplasm = spacrGraph(df, grouping_column='condition', data_column=f'cytoplasm_channel_{channel_of_interest}_mean_intensity')
1052
+ plotter_cytoplasm.create_plot(ax=axes[2])
1053
+ axes[2].set_xlabel(f'pathogen {df_type}', fontsize=font)
1054
+ axes[2].set_ylabel(f'cytoplasm_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1055
+
1056
+ # Plot for 'pathogen_channel' on axes[3]
1057
+ plotter_pathogen = spacrGraph(df, grouping_column='condition', data_column=f'pathogen_channel_{channel_of_interest}_mean_intensity')
1058
+ plotter_pathogen.create_plot(ax=axes[3])
1059
+ axes[3].set_xlabel(f'pathogen {df_type}', fontsize=font)
1060
+ axes[3].set_ylabel(f'pathogen_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1061
+
1062
+ #axes[0].legend_.remove()
1063
+ #axes[1].legend_.remove()
1064
+ #axes[2].legend_.remove()
1065
+ #axes[3].legend_.remove()
1066
+
1067
+ handles, labels = axes[3].get_legend_handles_labels()
1068
+ axes[3].legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc='center left')
1069
+ for i in [0,1,2,3]:
1070
+ axes[i].tick_params(axis='both', which='major', labelsize=font)
1071
+ axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)
1072
+
1073
+ plt.tight_layout()
1074
+ plt.show()
1075
+
1076
+ columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
1077
+ #columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
1078
+
1079
+ width = figuresize*2
1080
+ columns_per_row = math.ceil(len(columns) / 2)
1081
+ height = (figuresize*2)/columns_per_row
1082
+
1083
+ fig, axes = plt.subplots(nrows=2, ncols=columns_per_row, figsize=(width, height * 2))
1084
+ axes = axes.flatten()
1085
+
1086
+ print(f'{columns}')
1087
+ for i, col in enumerate(columns):
1088
+ ax = axes[i]
1089
+ plotter_col = spacrGraph(df, grouping_column='condition', data_column=col)
1090
+ plotter_col.create_plot(ax=ax)
1091
+ ax.set_xlabel(f'pathogen {df_type}', fontsize=font)
1092
+ ax.set_ylabel(f'{col}', fontsize=int(font * 2))
1093
+ #ax.legend_.remove()
1094
+ ax.tick_params(axis='both', which='major', labelsize=font)
1095
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
1096
+ if i <= 5:
1097
+ ax.set_ylim(1, None)
1098
+
1099
+ # Turn off any unused axes
1100
+ for i in range(len(columns), len(axes)):
1101
+ axes[i].axis('off')
1102
+
1103
+ plt.tight_layout()
1104
+ plt.show()
1005
1105
 
1006
- def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=10):
1106
+ def _plot_recruitment(df, df_type, channel_of_interest, columns=[], figuresize=10):
1007
1107
  """
1008
1108
  Plot recruitment data for different conditions and pathogens.
1009
1109
 
@@ -1156,10 +1256,6 @@ def _plot_controls(df, mask_chans, channel_of_interest, figuresize=5):
1156
1256
  plt.tight_layout()
1157
1257
  plt.show()
1158
1258
 
1159
- ###################################################
1160
- # Classify
1161
- ###################################################
1162
-
1163
1259
  def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1164
1260
  """
1165
1261
  Display multiple images in a grid with corresponding labels.
@@ -1997,4 +2093,330 @@ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summ
1997
2093
  # Show the plot
1998
2094
  plt.show()
1999
2095
 
2000
- return plt.gcf(), results_df
2096
+ return plt.gcf(), results_df
2097
+
2098
+ class spacrGraph:
2099
+ def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
2100
+ order=None, colors=None, output_dir='./output', save=False, y_axis_start=None,
2101
+ error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
2102
+ paired=False, all_to_all=True, compare_group=None):
2103
+ """
2104
+ Class for creating grouped plots with optional statistical tests and data preprocessing.
2105
+ """
2106
+ self.df = df
2107
+ self.grouping_column = grouping_column
2108
+ self.data_column = data_column
2109
+ self.graph_type = graph_type
2110
+ self.summary_func = summary_func
2111
+ self.order = order
2112
+ self.colors = colors
2113
+ self.output_dir = output_dir
2114
+ self.save = save
2115
+ self.y_axis_start = y_axis_start
2116
+ self.error_bar_type = error_bar_type
2117
+ self.remove_outliers = remove_outliers
2118
+ self.theme = theme
2119
+ self.representation = representation
2120
+ self.paired = paired
2121
+ self.all_to_all = all_to_all
2122
+ self.compare_group = compare_group
2123
+
2124
+ self.results_df = pd.DataFrame()
2125
+ self.sns_palette = None
2126
+ self.fig = None # To store the generated figure
2127
+
2128
+ # Preprocess and set palette
2129
+ self._set_theme()
2130
+ self.raw_df = self.df.copy() # Preserve the raw data for n_object count
2131
+ self.df = self.preprocess_data()
2132
+
2133
+ def _set_theme(self):
2134
+ """Set the Seaborn theme and reorder colors if necessary."""
2135
+ integer_list = list(range(1, 81))
2136
+ color_order = [0, 3, 9, 4, 6, 7, 9, 2] + integer_list
2137
+ self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
2138
+
2139
+ def _set_reordered_theme(self, theme='muted', order=None, n_colors=100, show_theme=False):
2140
+ """Set and reorder the Seaborn color palette."""
2141
+ palette = sns.color_palette(theme, n_colors)
2142
+ if order:
2143
+ reordered_palette = [palette[i] for i in order]
2144
+ else:
2145
+ reordered_palette = palette
2146
+ if show_theme:
2147
+ sns.palplot(reordered_palette)
2148
+ plt.show()
2149
+ return reordered_palette
2150
+
2151
+ def preprocess_data(self):
2152
+ """Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
2153
+ df = self.df.dropna(subset=[self.grouping_column, self.data_column])
2154
+
2155
+ # Group by 'prc' column if representation is 'well'
2156
+ if self.representation == 'well':
2157
+ df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
2158
+
2159
+ if self.order:
2160
+ df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
2161
+ else:
2162
+ df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=sorted(df[self.grouping_column].unique()), ordered=True)
2163
+
2164
+ return df
2165
+
2166
+ def remove_outliers_from_plot(self):
2167
+ """Remove outliers from the plot but keep them in the data."""
2168
+ filtered_df = self.df.copy()
2169
+ unique_groups = filtered_df[self.grouping_column].unique()
2170
+ for group in unique_groups:
2171
+ group_data = filtered_df[filtered_df[self.grouping_column] == group][self.data_column]
2172
+ q1 = group_data.quantile(0.25)
2173
+ q3 = group_data.quantile(0.75)
2174
+ iqr = q3 - q1
2175
+ lower_bound = q1 - 1.5 * iqr
2176
+ upper_bound = q3 + 1.5 * iqr
2177
+ filtered_df = filtered_df.drop(filtered_df[(filtered_df[self.grouping_column] == group) & ((filtered_df[self.data_column] < lower_bound) | (filtered_df[self.data_column] > upper_bound))].index)
2178
+ return filtered_df
2179
+
2180
+ def perform_normality_tests(self):
2181
+ """Perform normality tests for each group."""
2182
+ unique_groups = self.df[self.grouping_column].unique()
2183
+ grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2184
+ raw_grouped_data = [self.raw_df.loc[self.raw_df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2185
+
2186
+ normal_p_values = [normaltest(data).pvalue for data in grouped_data]
2187
+ normal_stats = [normaltest(data).statistic for data in grouped_data]
2188
+ is_normal = all(p > 0.05 for p in normal_p_values)
2189
+
2190
+ test_results = []
2191
+ for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
2192
+ test_results.append({
2193
+ 'Comparison': f'Normality test for {group}',
2194
+ 'Test Statistic': stat,
2195
+ 'p-value': p_value,
2196
+ 'Test Name': 'Normality test',
2197
+ 'n_object': len(raw_grouped_data[unique_groups.tolist().index(group)]), # Raw sample size (objects/cells)
2198
+ 'n_well': len(grouped_data[unique_groups.tolist().index(group)]) if self.representation == 'well' else np.nan # Summarized size (wells)
2199
+ })
2200
+ return is_normal, test_results
2201
+
2202
+ def perform_levene_test(self, unique_groups):
2203
+ """Perform Levene's test for equal variance."""
2204
+ grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2205
+ stat, p_value = levene(*grouped_data)
2206
+ return stat, p_value
2207
+
2208
+ def perform_statistical_tests(self, unique_groups, is_normal):
2209
+ """Perform statistical tests based on the number of groups, normality, and paired flag."""
2210
+ if len(unique_groups) == 2:
2211
+ if is_normal:
2212
+ if self.paired:
2213
+ stat_test = pg.ttest # Paired T-test
2214
+ test_name = 'Paired T-test'
2215
+ else:
2216
+ stat_test = ttest_ind
2217
+ test_name = 'T-test'
2218
+ else:
2219
+ if self.paired:
2220
+ stat_test = pg.wilcoxon # Paired Wilcoxon test
2221
+ test_name = 'Paired Wilcoxon test'
2222
+ else:
2223
+ stat_test = mannwhitneyu
2224
+ test_name = 'Mann-Whitney U test'
2225
+ else:
2226
+ if is_normal:
2227
+ stat_test = f_oneway
2228
+ test_name = 'One-way ANOVA'
2229
+ else:
2230
+ stat_test = kruskal
2231
+ test_name = 'Kruskal-Wallis test'
2232
+
2233
+ comparisons = list(itertools.combinations(unique_groups, 2))
2234
+ test_results = []
2235
+ for (group1, group2) in comparisons:
2236
+ data1 = self.df[self.df[self.grouping_column] == group1][self.data_column]
2237
+ data2 = self.df[self.df[self.grouping_column] == group2][self.data_column]
2238
+ raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == group1][self.data_column]
2239
+ raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == group2][self.data_column]
2240
+
2241
+ if self.paired:
2242
+ stat, p = stat_test(data1, data2, paired=True)
2243
+ else:
2244
+ stat, p = stat_test(data1, data2)
2245
+
2246
+ test_results.append({
2247
+ 'Comparison': f'{group1} vs {group2}',
2248
+ 'Test Statistic': stat,
2249
+ 'p-value': p,
2250
+ 'Test Name': test_name,
2251
+ 'n_object': len(raw_data1) + len(raw_data2), # Raw sample size (objects/cells)
2252
+ 'n_well': len(data1) + len(data2) if self.representation == 'well' else np.nan # Summarized size (wells)
2253
+ })
2254
+ return test_results
2255
+
2256
+ def perform_posthoc_tests(self, is_normal, unique_groups):
2257
+ """Perform post-hoc tests for multiple groups based on all_to_all flag."""
2258
+ if is_normal and len(unique_groups) > 2 and self.all_to_all:
2259
+ # Tukey HSD Post-hoc when comparing all to all
2260
+ tukey_result = pairwise_tukeyhsd(self.df[self.data_column], self.df[self.grouping_column], alpha=0.05)
2261
+ posthoc_results = []
2262
+ for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
2263
+ raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == comparison[0]][self.data_column]
2264
+ raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == comparison[1]][self.data_column]
2265
+
2266
+ posthoc_results.append({
2267
+ 'Comparison': f'{comparison[0]} vs {comparison[1]}',
2268
+ 'Test Statistic': None, # Tukey does not provide a test statistic
2269
+ 'p-value': p_value,
2270
+ 'Test Name': 'Tukey HSD Post-hoc',
2271
+ 'n_object': len(raw_data1) + len(raw_data2),
2272
+ 'n_well': len(self.df[self.df[self.grouping_column] == comparison[0]]) + len(self.df[self.df[self.grouping_column] == comparison[1]])
2273
+ })
2274
+ return posthoc_results
2275
+
2276
+ elif len(unique_groups) > 2 and not self.all_to_all and self.compare_group:
2277
+ # Dunn's post-hoc test using Pingouin
2278
+ dunn_result = pg.pairwise_tests(data=self.df, dv=self.data_column, between=self.grouping_column, padjust='bonf', test='dunn')
2279
+ posthoc_results = []
2280
+ for idx, row in dunn_result.iterrows():
2281
+ if row['A'] == self.compare_group or row['B'] == self.compare_group:
2282
+ posthoc_results.append({
2283
+ 'Comparison': f"{row['A']} vs {row['B']}",
2284
+ 'Test Statistic': row['T'], # Test statistic from Dunn's test
2285
+ 'p-value': row['p-val'],
2286
+ 'Test Name': 'Dunn’s Post-hoc',
2287
+ 'n_object': None,
2288
+ 'n_well': None
2289
+ })
2290
+ return posthoc_results
2291
+ return []
2292
+
2293
+ def create_plot(self, ax=None):
2294
+ """Create and display the plot based on the chosen graph type."""
2295
+ # Optional: Remove outliers for plotting
2296
+ if self.remove_outliers:
2297
+ self.df = self.remove_outliers_from_plot()
2298
+
2299
+ # Perform normality tests
2300
+ is_normal, normality_results = self.perform_normality_tests()
2301
+
2302
+ # Perform Levene's test for equal variance
2303
+ unique_groups = self.df[self.grouping_column].unique()
2304
+ levene_stat, levene_p = self.perform_levene_test(unique_groups)
2305
+ levene_result = {
2306
+ 'Comparison': 'Levene’s test for equal variance',
2307
+ 'Test Statistic': levene_stat,
2308
+ 'p-value': levene_p,
2309
+ 'Test Name': 'Levene’s Test'
2310
+ }
2311
+
2312
+ # Perform statistical tests
2313
+ stat_results = self.perform_statistical_tests(unique_groups, is_normal)
2314
+
2315
+ # Perform post-hoc tests if applicable
2316
+ posthoc_results = self.perform_posthoc_tests(is_normal, unique_groups)
2317
+
2318
+ # Combine all test results
2319
+ self.results_df = pd.DataFrame(normality_results + [levene_result] + stat_results + posthoc_results)
2320
+
2321
+ # Add sample size column
2322
+ sample_sizes = self.df.groupby(self.grouping_column)[self.data_column].count().reset_index(name='n')
2323
+ self.results_df['n'] = self.results_df['Comparison'].apply(
2324
+ lambda x: next((sample_sizes[sample_sizes[self.grouping_column] == g]['n'].values[0] for g in sample_sizes[self.grouping_column] if g in x), np.nan)
2325
+ )
2326
+
2327
+ # Dynamically set figure dimensions based on the number of unique groups
2328
+ num_groups = len(unique_groups)
2329
+ bar_width = 0.6 # Set the desired thickness of each bar
2330
+ spacing_between_groups = 0.3 # Set the desired spacing between bars and axis
2331
+
2332
+ fig_width = num_groups * (bar_width + spacing_between_groups) # Dynamically calculate the figure width
2333
+ fig_height = 6 # Fixed height for the plot
2334
+
2335
+ if ax is None:
2336
+ self.fig, ax = plt.subplots(figsize=(fig_width, fig_height)) # Store the figure in self.fig
2337
+ else:
2338
+ self.fig = ax.figure # Store the figure if ax is provided
2339
+
2340
+ sns.set(style="ticks")
2341
+ color_palette = self.sns_palette if not self.colors else self.colors
2342
+
2343
+ # Calculate x-axis limits to ensure equal space between the bars and the y-axis
2344
+ xlim_lower = -0.5 # Ensures space between the y-axis and the first category
2345
+ xlim_upper = num_groups - 0.5 # Ensures space after the last category
2346
+ ax.set_xlim(xlim_lower, xlim_upper)
2347
+
2348
+ if self.summary_func is None:
2349
+ sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=color_palette, jitter=True, alpha=0.6, ax=ax)
2350
+ elif self.graph_type == 'bar':
2351
+ self._create_bar_plot(bar_width, ax)
2352
+ elif self.graph_type == 'box':
2353
+ self._create_box_plot(ax)
2354
+ elif self.graph_type == 'violin':
2355
+ self._create_violin_plot(ax)
2356
+ elif self.graph_type == 'jitter':
2357
+ self._create_jitter_plot(ax)
2358
+ else:
2359
+ raise ValueError(f"Invalid graph_type: {self.graph_type}. Choose from 'bar', 'box', 'violin', or 'jitter'.")
2360
+
2361
+ # Set y-axis start
2362
+ if self.y_axis_start is not None:
2363
+ ax.set_ylim(bottom=self.y_axis_start)
2364
+
2365
+ # Add ticks, remove grid, and save plot
2366
+ ax.minorticks_on()
2367
+ ax.tick_params(axis='x', which='minor', bottom=False) # Disable minor ticks on x-axis
2368
+ ax.tick_params(axis='x', which='major', length=6, width=2, direction='out')
2369
+ ax.tick_params(axis='y', which='major', length=6, width=2, direction='out')
2370
+ ax.tick_params(axis='y', which='minor', length=4, width=1, direction='out')
2371
+ sns.despine(ax=ax, top=True, right=True)
2372
+
2373
+ if self.save:
2374
+ self._save_results()
2375
+
2376
+ plt.show() # Ensure the plot is shown, but plt.show() doesn't clear the figure context
2377
+
2378
+ def get_figure(self):
2379
+ """Return the generated figure."""
2380
+ return self.fig
2381
+
2382
+ def _create_bar_plot(self, bar_width, ax):
2383
+ """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
2384
+ summary_df = self.df.groupby(self.grouping_column)[self.data_column].agg([self.summary_func, 'std', 'sem'])
2385
+
2386
+ if self.error_bar_type == 'std':
2387
+ error_bars = summary_df['std']
2388
+ elif self.error_bar_type == 'sem':
2389
+ error_bars = summary_df['sem']
2390
+ else:
2391
+ raise ValueError(f"Invalid error_bar_type: {self.error_bar_type}. Choose either 'std' or 'sem'.")
2392
+
2393
+ sns.barplot(x=self.grouping_column, y=self.summary_func, data=summary_df.reset_index(), ci=None, palette=self.sns_palette, width=bar_width, ax=ax)
2394
+
2395
+ # Plot the error bars
2396
+ ax.errorbar(x=np.arange(len(summary_df)), y=summary_df[self.summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
2397
+
2398
+ def _create_jitter_plot(self, ax):
2399
+ """Helper method to create a jitter plot (strip plot)."""
2400
+ sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, jitter=True, alpha=0.6, ax=ax)
2401
+
2402
+ def _create_box_plot(self, ax):
2403
+ """Helper method to create a box plot."""
2404
+ sns.boxplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
2405
+
2406
+ def _create_violin_plot(self, ax):
2407
+ """Helper method to create a violin plot."""
2408
+ sns.violinplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
2409
+
2410
+ def _save_results(self):
2411
+ """Helper method to save the plot and results."""
2412
+ os.makedirs(self.output_dir, exist_ok=True)
2413
+ plot_path = os.path.join(self.output_dir, 'grouped_plot.png')
2414
+ self.fig.savefig(plot_path)
2415
+ results_path = os.path.join(self.output_dir, 'test_results.csv')
2416
+ self.results_df.to_csv(results_path, index=False)
2417
+ print(f"Plot saved to {plot_path}")
2418
+ print(f"Test results saved to {results_path}")
2419
+
2420
+ def get_results(self):
2421
+ """Return the results dataframe."""
2422
+ return self.results_df
spacr/settings.py CHANGED
@@ -404,7 +404,7 @@ def deep_spacr_defaults(settings):
404
404
  settings.setdefault('sample',None)
405
405
  settings.setdefault('experiment','exp.')
406
406
  settings.setdefault('score_threshold',0.5)
407
- settings.setdefault('tar_path','path')
407
+ settings.setdefault('dataset','path')
408
408
  settings.setdefault('model_path','path')
409
409
  settings.setdefault('file_type','cell_png')
410
410
  settings.setdefault('generate_training_dataset', True)
@@ -461,7 +461,7 @@ def get_analyze_recruitment_default_settings(settings):
461
461
  settings.setdefault('pathogen_mask_dim',6)
462
462
  settings.setdefault('channel_of_interest',2)
463
463
  settings.setdefault('plot',True)
464
- settings.setdefault('plot_nr',10)
464
+ settings.setdefault('plot_nr',3)
465
465
  settings.setdefault('plot_control',True)
466
466
  settings.setdefault('figuresize',10)
467
467
  settings.setdefault('uninfected',True)
@@ -534,6 +534,7 @@ def get_perform_regression_default_settings(settings):
534
534
  settings.setdefault('random_row_column_effects',False)
535
535
  settings.setdefault('alpha',1)
536
536
  settings.setdefault('fraction_threshold',0.1)
537
+ settings.setdefault('location_column','column')
537
538
  settings.setdefault('nc','c1')
538
539
  settings.setdefault('pc','c2')
539
540
  settings.setdefault('other','c3')
@@ -855,7 +856,7 @@ expected_types = {
855
856
  'reverse_complement':bool,
856
857
  'file_type':str,
857
858
  'model_path':str,
858
- 'tar_path':str,
859
+ 'dataset':str,
859
860
  'score_threshold':float,
860
861
  'sample':None,
861
862
  'file_metadata':None,
@@ -881,7 +882,7 @@ expected_types = {
881
882
  "train_DL_model":bool,
882
883
  }
883
884
 
884
- categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "tar_path","model_path","grna_csv","row_csv","column_csv"],
885
+ categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset","model_path","grna_csv","row_csv","column_csv"],
885
886
  "General": ["metadata_type", "custom_regex", "experiment", "channels", "magnification", "channel_dims", "apply_model_to_dataset", "generate_training_dataset", "train_DL_model", "segmentation_mode"],
886
887
  "Cellpose":["from_scratch", "n_epochs", "width_height", "model_name", "custom_model", "resample", "rescale", "CP_prob", "flow_threshold", "percentiles", "circular", "invert", "diameter", "grayscale", "background", "Signal_to_noise", "resize", "target_height", "target_width"],
887
888
  "Cell": ["cell_intensity_range", "cell_size_range", "cell_chann_dim", "cell_channel", "cell_background", "cell_Signal_to_noise", "cell_CP_prob", "cell_FT", "remove_background_cell", "cell_min_size", "cell_mask_dim", "cytoplasm", "cytoplasm_min_size", "include_uninfected", "merge_edge_pathogen_cells", "adjust_cells", "cells", "cell_loc"],
@@ -1202,7 +1203,7 @@ def generate_fields(variables, scrollable_frame):
1202
1203
  "complevel": "int - level of compression (0-9). Higher is slower and yealds smaller files",
1203
1204
  "file_type": "str - type of file to process",
1204
1205
  "model_path": "str - path to the model",
1205
- "tar_path": "str - path to the tar file with image dataset",
1206
+ "dataset": "str - file name of the tar file with image dataset",
1206
1207
  "score_threshold": "float - threshold for classification",
1207
1208
  "sample": "str - number of images to sample for tar dataset (including both classes). Default: None",
1208
1209
  "file_metadata": "str - string that must be present in image path to be included in the dataset",
spacr/submodules.py CHANGED
@@ -36,7 +36,7 @@ def analyze_recruitment(settings={}):
36
36
  sns.color_palette("mako", as_cmap=True)
37
37
  print(f"channel:{settings['channel_of_interest']} = {settings['target']}")
38
38
 
39
- df, _ = _read_and_merge_data(db_loc=[settings['src']+'/measurements/measurements.db'],
39
+ df, _ = _read_and_merge_data(locs=[settings['src']+'/measurements/measurements.db'],
40
40
  tables=['cell', 'nucleus', 'pathogen','cytoplasm'],
41
41
  verbose=True,
42
42
  nuclei_limit=settings['nuclei_limit'],
@@ -89,15 +89,16 @@ def analyze_recruitment(settings={}):
89
89
 
90
90
  if not settings['cell_chann_dim'] is None:
91
91
  df = _object_filter(df, 'cell', settings['cell_size_range'], settings['cell_intensity_range'], mask_chans, 0)
92
- if not settings['target_intensity_min'] is None:
93
- df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95'] > settings['target_intensity_min"]]
92
+ if not settings['target_intensity_min'] is None or not settings['target_intensity_min'] is 0:
93
+ df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95"] > settings['target_intensity_min']]
94
94
  print(f"After channel {settings['channel_of_interest']} filtration", len(df))
95
95
  if not settings['nucleus_chann_dim'] is None:
96
96
  df = _object_filter(df, 'nucleus', settings['nucleus_size_range'], settings['nucleus_intensity_range'], mask_chans, 1)
97
97
  if not settings['pathogen_chann_dim'] is None:
98
98
  df = _object_filter(df, 'pathogen', settings['pathogen_size_range'], settings['pathogen_intensity_range'], mask_chans, 2)
99
99
 
100
- df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
100
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
101
+
101
102
  for chan in settings['channel_dims']:
102
103
  df = _calculate_recruitment(df, channel=chan)
103
104
  print(f'calculated recruitment for: {len(df)} rows')
@@ -114,9 +115,9 @@ def analyze_recruitment(settings={}):
114
115
  _plot_controls(df, mask_chans, settings['channel_of_interest'], figuresize=5)
115
116
 
116
117
  print(f'PV level: {len(df)} rows')
117
- _plot_recruitment(df, 'by PV', settings['channel_of_interest'], settings['target'], settings['figuresize'])
118
+ _plot_recruitment(df, 'by PV', settings['channel_of_interest'], columns=[], figuresize=settings['figuresize'])
118
119
  print(f'well level: {len(df_well)} rows')
119
- _plot_recruitment(df_well, 'by well', settings['channel_of_interest'], settings['target'], settings['figuresize'])
120
+ _plot_recruitment(df_well, 'by well', settings['channel_of_interest'], columns=[], figuresize=settings['figuresize'])
120
121
  cells,wells = _results_to_csv(settings['src'], df, df_well)
121
122
 
122
123
  return [cells,wells]
spacr/toxo.py CHANGED
@@ -112,10 +112,15 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
112
112
  - Plot the enrichment score vs -log10(p-value).
113
113
  """
114
114
 
115
- significant_df['variable'].fillna(significant_df['feature'], inplace=True)
116
- split_columns = significant_df['variable'].str.split('_', expand=True)
117
- significant_df['gene_nr'] = split_columns[0]
118
- gene_list = significant_df['gene_nr'].to_list()
115
+ #significant_df['variable'].fillna(significant_df['feature'], inplace=True)
116
+ #split_columns = significant_df['variable'].str.split('_', expand=True)
117
+ #significant_df['gene_nr'] = split_columns[0]
118
+ #gene_list = significant_df['gene_nr'].to_list()
119
+
120
+ significant_df = significant_df.dropna(subset=['n_gene'])
121
+ significant_df = significant_df[significant_df['n_gene'] != None]
122
+
123
+ gene_list = significant_df['n_gene'].to_list()
119
124
 
120
125
  # Load metadata
121
126
  metadata = pd.read_csv(metadata_path)