spacr 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -16,8 +16,9 @@ from skimage.measure import find_contours, label, regionprops
16
16
 
17
17
  from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
18
18
  from statsmodels.stats.multicomp import pairwise_tukeyhsd
19
+ from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
19
20
  import itertools
20
-
21
+ import pingouin as pg
21
22
 
22
23
  from ipywidgets import IntSlider, interact
23
24
  from IPython.display import Image as ipyimage
@@ -1002,8 +1003,107 @@ def _display_gif(path):
1002
1003
  """
1003
1004
  with open(path, 'rb') as file:
1004
1005
  display(ipyimage(file.read()))
1006
+
1007
+ def _plot_recruitment_v2(df, df_type, channel_of_interest, columns=[], figuresize=10):
1008
+ """
1009
+ Plot recruitment data for different conditions and pathogens.
1010
+
1011
+ Args:
1012
+ df (DataFrame): The input DataFrame containing the recruitment data.
1013
+ df_type (str): The type of DataFrame (e.g., 'train', 'test').
1014
+ channel_of_interest (str): The channel of interest for plotting.
1015
+ target (str): The target variable for plotting.
1016
+ columns (list, optional): Additional columns to plot. Defaults to an empty list.
1017
+ figuresize (int, optional): The size of the figure. Defaults to 50.
1018
+
1019
+ Returns:
1020
+ None
1021
+ """
1022
+
1023
+ from .plot import spacrGraph
1024
+
1025
+ color_list = [(55/255, 155/255, 155/255),
1026
+ (155/255, 55/255, 155/255),
1027
+ (55/255, 155/255, 255/255),
1028
+ (255/255, 55/255, 155/255)]
1029
+
1030
+ sns.set_palette(sns.color_palette(color_list))
1031
+ font = figuresize/2
1032
+ width=figuresize
1033
+ height=figuresize/4
1034
+
1035
+ # Create the subplots
1036
+ fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(width, height))
1037
+
1038
+ # Plot for 'cell_channel' on axes[0]
1039
+ plotter_cell = spacrGraph(df,grouping_column='condition', data_column=f'cell_channel_{channel_of_interest}_mean_intensity')
1040
+ plotter_cell.create_plot(ax=axes[0])
1041
+ axes[0].set_xlabel(f'pathogen {df_type}', fontsize=font)
1042
+ axes[0].set_ylabel(f'cell_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1043
+
1044
+ # Plot for 'nucleus_channel' on axes[1]
1045
+ plotter_nucleus = spacrGraph(df,grouping_column='condition', data_column=f'nucleus_channel_{channel_of_interest}_mean_intensity')
1046
+ plotter_nucleus.create_plot(ax=axes[1])
1047
+ axes[1].set_xlabel(f'pathogen {df_type}', fontsize=font)
1048
+ axes[1].set_ylabel(f'nucleus_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1049
+
1050
+ # Plot for 'cytoplasm_channel' on axes[2]
1051
+ plotter_cytoplasm = spacrGraph(df, grouping_column='condition', data_column=f'cytoplasm_channel_{channel_of_interest}_mean_intensity')
1052
+ plotter_cytoplasm.create_plot(ax=axes[2])
1053
+ axes[2].set_xlabel(f'pathogen {df_type}', fontsize=font)
1054
+ axes[2].set_ylabel(f'cytoplasm_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1055
+
1056
+ # Plot for 'pathogen_channel' on axes[3]
1057
+ plotter_pathogen = spacrGraph(df, grouping_column='condition', data_column=f'pathogen_channel_{channel_of_interest}_mean_intensity')
1058
+ plotter_pathogen.create_plot(ax=axes[3])
1059
+ axes[3].set_xlabel(f'pathogen {df_type}', fontsize=font)
1060
+ axes[3].set_ylabel(f'pathogen_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1061
+
1062
+ #axes[0].legend_.remove()
1063
+ #axes[1].legend_.remove()
1064
+ #axes[2].legend_.remove()
1065
+ #axes[3].legend_.remove()
1066
+
1067
+ handles, labels = axes[3].get_legend_handles_labels()
1068
+ axes[3].legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc='center left')
1069
+ for i in [0,1,2,3]:
1070
+ axes[i].tick_params(axis='both', which='major', labelsize=font)
1071
+ axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)
1072
+
1073
+ plt.tight_layout()
1074
+ plt.show()
1075
+
1076
+ columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
1077
+ #columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
1078
+
1079
+ width = figuresize*2
1080
+ columns_per_row = math.ceil(len(columns) / 2)
1081
+ height = (figuresize*2)/columns_per_row
1082
+
1083
+ fig, axes = plt.subplots(nrows=2, ncols=columns_per_row, figsize=(width, height * 2))
1084
+ axes = axes.flatten()
1085
+
1086
+ print(f'{columns}')
1087
+ for i, col in enumerate(columns):
1088
+ ax = axes[i]
1089
+ plotter_col = spacrGraph(df, grouping_column='condition', data_column=col)
1090
+ plotter_col.create_plot(ax=ax)
1091
+ ax.set_xlabel(f'pathogen {df_type}', fontsize=font)
1092
+ ax.set_ylabel(f'{col}', fontsize=int(font * 2))
1093
+ #ax.legend_.remove()
1094
+ ax.tick_params(axis='both', which='major', labelsize=font)
1095
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
1096
+ if i <= 5:
1097
+ ax.set_ylim(1, None)
1098
+
1099
+ # Turn off any unused axes
1100
+ for i in range(len(columns), len(axes)):
1101
+ axes[i].axis('off')
1102
+
1103
+ plt.tight_layout()
1104
+ plt.show()
1005
1105
 
1006
- def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=10):
1106
+ def _plot_recruitment(df, df_type, channel_of_interest, columns=[], figuresize=10):
1007
1107
  """
1008
1108
  Plot recruitment data for different conditions and pathogens.
1009
1109
 
@@ -1156,10 +1256,6 @@ def _plot_controls(df, mask_chans, channel_of_interest, figuresize=5):
1156
1256
  plt.tight_layout()
1157
1257
  plt.show()
1158
1258
 
1159
- ###################################################
1160
- # Classify
1161
- ###################################################
1162
-
1163
1259
  def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1164
1260
  """
1165
1261
  Display multiple images in a grid with corresponding labels.
@@ -1997,4 +2093,330 @@ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summ
1997
2093
  # Show the plot
1998
2094
  plt.show()
1999
2095
 
2000
- return plt.gcf(), results_df
2096
+ return plt.gcf(), results_df
2097
+
2098
+ class spacrGraph:
2099
+ def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
2100
+ order=None, colors=None, output_dir='./output', save=False, y_axis_start=None,
2101
+ error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
2102
+ paired=False, all_to_all=True, compare_group=None):
2103
+ """
2104
+ Class for creating grouped plots with optional statistical tests and data preprocessing.
2105
+ """
2106
+ self.df = df
2107
+ self.grouping_column = grouping_column
2108
+ self.data_column = data_column
2109
+ self.graph_type = graph_type
2110
+ self.summary_func = summary_func
2111
+ self.order = order
2112
+ self.colors = colors
2113
+ self.output_dir = output_dir
2114
+ self.save = save
2115
+ self.y_axis_start = y_axis_start
2116
+ self.error_bar_type = error_bar_type
2117
+ self.remove_outliers = remove_outliers
2118
+ self.theme = theme
2119
+ self.representation = representation
2120
+ self.paired = paired
2121
+ self.all_to_all = all_to_all
2122
+ self.compare_group = compare_group
2123
+
2124
+ self.results_df = pd.DataFrame()
2125
+ self.sns_palette = None
2126
+ self.fig = None # To store the generated figure
2127
+
2128
+ # Preprocess and set palette
2129
+ self._set_theme()
2130
+ self.raw_df = self.df.copy() # Preserve the raw data for n_object count
2131
+ self.df = self.preprocess_data()
2132
+
2133
+ def _set_theme(self):
2134
+ """Set the Seaborn theme and reorder colors if necessary."""
2135
+ integer_list = list(range(1, 81))
2136
+ color_order = [0, 3, 9, 4, 6, 7, 9, 2] + integer_list
2137
+ self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
2138
+
2139
+ def _set_reordered_theme(self, theme='muted', order=None, n_colors=100, show_theme=False):
2140
+ """Set and reorder the Seaborn color palette."""
2141
+ palette = sns.color_palette(theme, n_colors)
2142
+ if order:
2143
+ reordered_palette = [palette[i] for i in order]
2144
+ else:
2145
+ reordered_palette = palette
2146
+ if show_theme:
2147
+ sns.palplot(reordered_palette)
2148
+ plt.show()
2149
+ return reordered_palette
2150
+
2151
+ def preprocess_data(self):
2152
+ """Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
2153
+ df = self.df.dropna(subset=[self.grouping_column, self.data_column])
2154
+
2155
+ # Group by 'prc' column if representation is 'well'
2156
+ if self.representation == 'well':
2157
+ df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
2158
+
2159
+ if self.order:
2160
+ df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
2161
+ else:
2162
+ df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=sorted(df[self.grouping_column].unique()), ordered=True)
2163
+
2164
+ return df
2165
+
2166
+ def remove_outliers_from_plot(self):
2167
+ """Remove outliers from the plot but keep them in the data."""
2168
+ filtered_df = self.df.copy()
2169
+ unique_groups = filtered_df[self.grouping_column].unique()
2170
+ for group in unique_groups:
2171
+ group_data = filtered_df[filtered_df[self.grouping_column] == group][self.data_column]
2172
+ q1 = group_data.quantile(0.25)
2173
+ q3 = group_data.quantile(0.75)
2174
+ iqr = q3 - q1
2175
+ lower_bound = q1 - 1.5 * iqr
2176
+ upper_bound = q3 + 1.5 * iqr
2177
+ filtered_df = filtered_df.drop(filtered_df[(filtered_df[self.grouping_column] == group) & ((filtered_df[self.data_column] < lower_bound) | (filtered_df[self.data_column] > upper_bound))].index)
2178
+ return filtered_df
2179
+
2180
+ def perform_normality_tests(self):
2181
+ """Perform normality tests for each group."""
2182
+ unique_groups = self.df[self.grouping_column].unique()
2183
+ grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2184
+ raw_grouped_data = [self.raw_df.loc[self.raw_df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2185
+
2186
+ normal_p_values = [normaltest(data).pvalue for data in grouped_data]
2187
+ normal_stats = [normaltest(data).statistic for data in grouped_data]
2188
+ is_normal = all(p > 0.05 for p in normal_p_values)
2189
+
2190
+ test_results = []
2191
+ for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
2192
+ test_results.append({
2193
+ 'Comparison': f'Normality test for {group}',
2194
+ 'Test Statistic': stat,
2195
+ 'p-value': p_value,
2196
+ 'Test Name': 'Normality test',
2197
+ 'n_object': len(raw_grouped_data[unique_groups.tolist().index(group)]), # Raw sample size (objects/cells)
2198
+ 'n_well': len(grouped_data[unique_groups.tolist().index(group)]) if self.representation == 'well' else np.nan # Summarized size (wells)
2199
+ })
2200
+ return is_normal, test_results
2201
+
2202
+ def perform_levene_test(self, unique_groups):
2203
+ """Perform Levene's test for equal variance."""
2204
+ grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2205
+ stat, p_value = levene(*grouped_data)
2206
+ return stat, p_value
2207
+
2208
+ def perform_statistical_tests(self, unique_groups, is_normal):
2209
+ """Perform statistical tests based on the number of groups, normality, and paired flag."""
2210
+ if len(unique_groups) == 2:
2211
+ if is_normal:
2212
+ if self.paired:
2213
+ stat_test = pg.ttest # Paired T-test
2214
+ test_name = 'Paired T-test'
2215
+ else:
2216
+ stat_test = ttest_ind
2217
+ test_name = 'T-test'
2218
+ else:
2219
+ if self.paired:
2220
+ stat_test = pg.wilcoxon # Paired Wilcoxon test
2221
+ test_name = 'Paired Wilcoxon test'
2222
+ else:
2223
+ stat_test = mannwhitneyu
2224
+ test_name = 'Mann-Whitney U test'
2225
+ else:
2226
+ if is_normal:
2227
+ stat_test = f_oneway
2228
+ test_name = 'One-way ANOVA'
2229
+ else:
2230
+ stat_test = kruskal
2231
+ test_name = 'Kruskal-Wallis test'
2232
+
2233
+ comparisons = list(itertools.combinations(unique_groups, 2))
2234
+ test_results = []
2235
+ for (group1, group2) in comparisons:
2236
+ data1 = self.df[self.df[self.grouping_column] == group1][self.data_column]
2237
+ data2 = self.df[self.df[self.grouping_column] == group2][self.data_column]
2238
+ raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == group1][self.data_column]
2239
+ raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == group2][self.data_column]
2240
+
2241
+ if self.paired:
2242
+ stat, p = stat_test(data1, data2, paired=True)
2243
+ else:
2244
+ stat, p = stat_test(data1, data2)
2245
+
2246
+ test_results.append({
2247
+ 'Comparison': f'{group1} vs {group2}',
2248
+ 'Test Statistic': stat,
2249
+ 'p-value': p,
2250
+ 'Test Name': test_name,
2251
+ 'n_object': len(raw_data1) + len(raw_data2), # Raw sample size (objects/cells)
2252
+ 'n_well': len(data1) + len(data2) if self.representation == 'well' else np.nan # Summarized size (wells)
2253
+ })
2254
+ return test_results
2255
+
2256
+ def perform_posthoc_tests(self, is_normal, unique_groups):
2257
+ """Perform post-hoc tests for multiple groups based on all_to_all flag."""
2258
+ if is_normal and len(unique_groups) > 2 and self.all_to_all:
2259
+ # Tukey HSD Post-hoc when comparing all to all
2260
+ tukey_result = pairwise_tukeyhsd(self.df[self.data_column], self.df[self.grouping_column], alpha=0.05)
2261
+ posthoc_results = []
2262
+ for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
2263
+ raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == comparison[0]][self.data_column]
2264
+ raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == comparison[1]][self.data_column]
2265
+
2266
+ posthoc_results.append({
2267
+ 'Comparison': f'{comparison[0]} vs {comparison[1]}',
2268
+ 'Test Statistic': None, # Tukey does not provide a test statistic
2269
+ 'p-value': p_value,
2270
+ 'Test Name': 'Tukey HSD Post-hoc',
2271
+ 'n_object': len(raw_data1) + len(raw_data2),
2272
+ 'n_well': len(self.df[self.df[self.grouping_column] == comparison[0]]) + len(self.df[self.df[self.grouping_column] == comparison[1]])
2273
+ })
2274
+ return posthoc_results
2275
+
2276
+ elif len(unique_groups) > 2 and not self.all_to_all and self.compare_group:
2277
+ # Dunn's post-hoc test using Pingouin
2278
+ dunn_result = pg.pairwise_tests(data=self.df, dv=self.data_column, between=self.grouping_column, padjust='bonf', test='dunn')
2279
+ posthoc_results = []
2280
+ for idx, row in dunn_result.iterrows():
2281
+ if row['A'] == self.compare_group or row['B'] == self.compare_group:
2282
+ posthoc_results.append({
2283
+ 'Comparison': f"{row['A']} vs {row['B']}",
2284
+ 'Test Statistic': row['T'], # Test statistic from Dunn's test
2285
+ 'p-value': row['p-val'],
2286
+ 'Test Name': 'Dunn’s Post-hoc',
2287
+ 'n_object': None,
2288
+ 'n_well': None
2289
+ })
2290
+ return posthoc_results
2291
+ return []
2292
+
2293
+ def create_plot(self, ax=None):
2294
+ """Create and display the plot based on the chosen graph type."""
2295
+ # Optional: Remove outliers for plotting
2296
+ if self.remove_outliers:
2297
+ self.df = self.remove_outliers_from_plot()
2298
+
2299
+ # Perform normality tests
2300
+ is_normal, normality_results = self.perform_normality_tests()
2301
+
2302
+ # Perform Levene's test for equal variance
2303
+ unique_groups = self.df[self.grouping_column].unique()
2304
+ levene_stat, levene_p = self.perform_levene_test(unique_groups)
2305
+ levene_result = {
2306
+ 'Comparison': 'Levene’s test for equal variance',
2307
+ 'Test Statistic': levene_stat,
2308
+ 'p-value': levene_p,
2309
+ 'Test Name': 'Levene’s Test'
2310
+ }
2311
+
2312
+ # Perform statistical tests
2313
+ stat_results = self.perform_statistical_tests(unique_groups, is_normal)
2314
+
2315
+ # Perform post-hoc tests if applicable
2316
+ posthoc_results = self.perform_posthoc_tests(is_normal, unique_groups)
2317
+
2318
+ # Combine all test results
2319
+ self.results_df = pd.DataFrame(normality_results + [levene_result] + stat_results + posthoc_results)
2320
+
2321
+ # Add sample size column
2322
+ sample_sizes = self.df.groupby(self.grouping_column)[self.data_column].count().reset_index(name='n')
2323
+ self.results_df['n'] = self.results_df['Comparison'].apply(
2324
+ lambda x: next((sample_sizes[sample_sizes[self.grouping_column] == g]['n'].values[0] for g in sample_sizes[self.grouping_column] if g in x), np.nan)
2325
+ )
2326
+
2327
+ # Dynamically set figure dimensions based on the number of unique groups
2328
+ num_groups = len(unique_groups)
2329
+ bar_width = 0.6 # Set the desired thickness of each bar
2330
+ spacing_between_groups = 0.3 # Set the desired spacing between bars and axis
2331
+
2332
+ fig_width = num_groups * (bar_width + spacing_between_groups) # Dynamically calculate the figure width
2333
+ fig_height = 6 # Fixed height for the plot
2334
+
2335
+ if ax is None:
2336
+ self.fig, ax = plt.subplots(figsize=(fig_width, fig_height)) # Store the figure in self.fig
2337
+ else:
2338
+ self.fig = ax.figure # Store the figure if ax is provided
2339
+
2340
+ sns.set(style="ticks")
2341
+ color_palette = self.sns_palette if not self.colors else self.colors
2342
+
2343
+ # Calculate x-axis limits to ensure equal space between the bars and the y-axis
2344
+ xlim_lower = -0.5 # Ensures space between the y-axis and the first category
2345
+ xlim_upper = num_groups - 0.5 # Ensures space after the last category
2346
+ ax.set_xlim(xlim_lower, xlim_upper)
2347
+
2348
+ if self.summary_func is None:
2349
+ sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=color_palette, jitter=True, alpha=0.6, ax=ax)
2350
+ elif self.graph_type == 'bar':
2351
+ self._create_bar_plot(bar_width, ax)
2352
+ elif self.graph_type == 'box':
2353
+ self._create_box_plot(ax)
2354
+ elif self.graph_type == 'violin':
2355
+ self._create_violin_plot(ax)
2356
+ elif self.graph_type == 'jitter':
2357
+ self._create_jitter_plot(ax)
2358
+ else:
2359
+ raise ValueError(f"Invalid graph_type: {self.graph_type}. Choose from 'bar', 'box', 'violin', or 'jitter'.")
2360
+
2361
+ # Set y-axis start
2362
+ if self.y_axis_start is not None:
2363
+ ax.set_ylim(bottom=self.y_axis_start)
2364
+
2365
+ # Add ticks, remove grid, and save plot
2366
+ ax.minorticks_on()
2367
+ ax.tick_params(axis='x', which='minor', bottom=False) # Disable minor ticks on x-axis
2368
+ ax.tick_params(axis='x', which='major', length=6, width=2, direction='out')
2369
+ ax.tick_params(axis='y', which='major', length=6, width=2, direction='out')
2370
+ ax.tick_params(axis='y', which='minor', length=4, width=1, direction='out')
2371
+ sns.despine(ax=ax, top=True, right=True)
2372
+
2373
+ if self.save:
2374
+ self._save_results()
2375
+
2376
+ plt.show() # Ensure the plot is shown, but plt.show() doesn't clear the figure context
2377
+
2378
+ def get_figure(self):
2379
+ """Return the generated figure."""
2380
+ return self.fig
2381
+
2382
+ def _create_bar_plot(self, bar_width, ax):
2383
+ """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
2384
+ summary_df = self.df.groupby(self.grouping_column)[self.data_column].agg([self.summary_func, 'std', 'sem'])
2385
+
2386
+ if self.error_bar_type == 'std':
2387
+ error_bars = summary_df['std']
2388
+ elif self.error_bar_type == 'sem':
2389
+ error_bars = summary_df['sem']
2390
+ else:
2391
+ raise ValueError(f"Invalid error_bar_type: {self.error_bar_type}. Choose either 'std' or 'sem'.")
2392
+
2393
+ sns.barplot(x=self.grouping_column, y=self.summary_func, data=summary_df.reset_index(), ci=None, palette=self.sns_palette, width=bar_width, ax=ax)
2394
+
2395
+ # Plot the error bars
2396
+ ax.errorbar(x=np.arange(len(summary_df)), y=summary_df[self.summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
2397
+
2398
+ def _create_jitter_plot(self, ax):
2399
+ """Helper method to create a jitter plot (strip plot)."""
2400
+ sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, jitter=True, alpha=0.6, ax=ax)
2401
+
2402
+ def _create_box_plot(self, ax):
2403
+ """Helper method to create a box plot."""
2404
+ sns.boxplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
2405
+
2406
+ def _create_violin_plot(self, ax):
2407
+ """Helper method to create a violin plot."""
2408
+ sns.violinplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
2409
+
2410
+ def _save_results(self):
2411
+ """Helper method to save the plot and results."""
2412
+ os.makedirs(self.output_dir, exist_ok=True)
2413
+ plot_path = os.path.join(self.output_dir, 'grouped_plot.png')
2414
+ self.fig.savefig(plot_path)
2415
+ results_path = os.path.join(self.output_dir, 'test_results.csv')
2416
+ self.results_df.to_csv(results_path, index=False)
2417
+ print(f"Plot saved to {plot_path}")
2418
+ print(f"Test results saved to {results_path}")
2419
+
2420
+ def get_results(self):
2421
+ """Return the results dataframe."""
2422
+ return self.results_df
spacr/settings.py CHANGED
@@ -246,7 +246,7 @@ def get_measure_crop_settings(settings={}):
246
246
  settings.setdefault('normalize_by','png')
247
247
  settings.setdefault('crop_mode',['cell'])
248
248
  settings.setdefault('dialate_pngs', False)
249
- settings.setdefault('dialate_png_ratios', [0.2, 0,2])
249
+ settings.setdefault('dialate_png_ratios', [0.2,0.2])
250
250
 
251
251
  # Timelapsed settings
252
252
  settings.setdefault('timelapse', False)
@@ -404,7 +404,7 @@ def deep_spacr_defaults(settings):
404
404
  settings.setdefault('sample',None)
405
405
  settings.setdefault('experiment','exp.')
406
406
  settings.setdefault('score_threshold',0.5)
407
- settings.setdefault('tar_path','path')
407
+ settings.setdefault('dataset','path')
408
408
  settings.setdefault('model_path','path')
409
409
  settings.setdefault('file_type','cell_png')
410
410
  settings.setdefault('generate_training_dataset', True)
@@ -461,7 +461,7 @@ def get_analyze_recruitment_default_settings(settings):
461
461
  settings.setdefault('pathogen_mask_dim',6)
462
462
  settings.setdefault('channel_of_interest',2)
463
463
  settings.setdefault('plot',True)
464
- settings.setdefault('plot_nr',10)
464
+ settings.setdefault('plot_nr',3)
465
465
  settings.setdefault('plot_control',True)
466
466
  settings.setdefault('figuresize',10)
467
467
  settings.setdefault('uninfected',True)
@@ -534,6 +534,7 @@ def get_perform_regression_default_settings(settings):
534
534
  settings.setdefault('random_row_column_effects',False)
535
535
  settings.setdefault('alpha',1)
536
536
  settings.setdefault('fraction_threshold',0.1)
537
+ settings.setdefault('location_column','column')
537
538
  settings.setdefault('nc','c1')
538
539
  settings.setdefault('pc','c2')
539
540
  settings.setdefault('other','c3')
@@ -855,10 +856,10 @@ expected_types = {
855
856
  'reverse_complement':bool,
856
857
  'file_type':str,
857
858
  'model_path':str,
858
- 'tar_path':str,
859
+ 'dataset':str,
859
860
  'score_threshold':float,
860
861
  'sample':None,
861
- 'file_metadata':None,
862
+ 'file_metadata':(str, type(None), list),
862
863
  'apply_model_to_dataset':False,
863
864
  "train":bool,
864
865
  "test":bool,
@@ -879,27 +880,33 @@ expected_types = {
879
880
  "generate_training_dataset":bool,
880
881
  "segmentation_mode":str,
881
882
  "train_DL_model":bool,
883
+ "normalize":bool,
884
+ "overlay":bool,
885
+ "correlate":bool,
886
+ "target_layer":str,
887
+ "normalize_input":bool,
882
888
  }
883
889
 
884
- categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "tar_path","model_path","grna_csv","row_csv","column_csv"],
890
+ categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset","model_path","grna_csv","row_csv","column_csv"],
885
891
  "General": ["metadata_type", "custom_regex", "experiment", "channels", "magnification", "channel_dims", "apply_model_to_dataset", "generate_training_dataset", "train_DL_model", "segmentation_mode"],
886
892
  "Cellpose":["from_scratch", "n_epochs", "width_height", "model_name", "custom_model", "resample", "rescale", "CP_prob", "flow_threshold", "percentiles", "circular", "invert", "diameter", "grayscale", "background", "Signal_to_noise", "resize", "target_height", "target_width"],
887
893
  "Cell": ["cell_intensity_range", "cell_size_range", "cell_chann_dim", "cell_channel", "cell_background", "cell_Signal_to_noise", "cell_CP_prob", "cell_FT", "remove_background_cell", "cell_min_size", "cell_mask_dim", "cytoplasm", "cytoplasm_min_size", "include_uninfected", "merge_edge_pathogen_cells", "adjust_cells", "cells", "cell_loc"],
888
894
  "Nucleus": ["nucleus_intensity_range", "nucleus_size_range", "nucleus_chann_dim", "nucleus_channel", "nucleus_background", "nucleus_Signal_to_noise", "nucleus_CP_prob", "nucleus_FT", "remove_background_nucleus", "nucleus_min_size", "nucleus_mask_dim", "nucleus_loc"],
889
895
  "Pathogen": ["pathogen_intensity_range", "pathogen_size_range", "pathogen_chann_dim", "pathogen_channel", "pathogen_background", "pathogen_Signal_to_noise", "pathogen_CP_prob", "pathogen_FT", "pathogen_model", "remove_background_pathogen", "pathogen_min_size", "pathogen_mask_dim", "pathogens", "pathogen_loc", "pathogen_types", "pathogen_plate_metadata", ],
890
896
  "Measurements": ["remove_image_canvas", "remove_highly_correlated", "homogeneity", "homogeneity_distances", "radial_dist", "calculate_correlation", "manders_thresholds", "save_measurements", "tables", "image_nr", "dot_size", "filter_by", "remove_highly_correlated_features", "remove_low_variance_features", "channel_of_interest"],
891
- "Object Image": ["save_png", "dialate_pngs", "dialate_png_ratios", "png_size", "png_dims", "save_arrays", "normalize_by", "dialate_png_ratios", "crop_mode", "dialate_pngs", "normalize", "use_bounding_box"],
897
+ "Object Image": ["save_png", "dialate_pngs", "dialate_png_ratios", "png_size", "png_dims", "save_arrays", "normalize_by", "crop_mode", "dialate_pngs", "normalize", "use_bounding_box"],
892
898
  "Sequencing": ["signal_direction","mode","comp_level","comp_type","save_h5","expected_end","offset","target_sequence","regex", "highlight"],
893
899
  "Generate Dataset":["file_metadata","class_metadata", "annotation_column","annotated_classes", "dataset_mode", "metadata_type_by","custom_measurement", "sample", "size"],
894
900
  "Hyperparamiters (Training)": ["png_type", "score_threshold","file_type", "train_channels", "epochs", "loss_type", "optimizer_type","image_size","val_split","learning_rate","weight_decay","dropout_rate", "init_weights", "train", "classes", "augment", "amsgrad","use_checkpoint","gradient_accumulation","gradient_accumulation_steps","intermedeate_save","pin_memory"],
895
901
  "Hyperparamiters (Embedding)": ["visualize","n_neighbors","min_dist","metric","resnet_features","reduction_method","embedding_by_controls","col_to_compare","log_data"],
896
902
  "Hyperparamiters (Clustering)": ["eps","min_samples","analyze_clusters","clustering","remove_cluster_noise"],
897
903
  "Hyperparamiters (Regression)":["cov_type", "class_1_threshold", "plate", "other", "fraction_threshold", "alpha", "random_row_column_effects", "regression_type", "min_cell_count", "agg_type", "transform", "dependent_variable"],
904
+ "Hyperparamiters (Activation)":["cam_type", "normalize", "overlay", "correlation", "target_layer", "normalize_input"],
898
905
  "Annotation": ["nc_loc", "pc_loc", "nc", "pc", "cell_plate_metadata","treatment_plate_metadata", "metadata_types", "cell_types", "target","positive_control","negative_control", "location_column", "treatment_loc", "channel_of_interest", "measurement", "treatments", "um_per_pixel", "nr_imgs", "exclude", "exclude_conditions", "mix", "pos", "neg"],
899
906
  "Plot": ["plot", "plot_control", "plot_nr", "examples_to_plot", "normalize_plots", "cmap", "figuresize", "plot_cluster_grids", "img_zoom", "row_limit", "color_by", "plot_images", "smooth_lines", "plot_points", "plot_outlines", "black_background", "plot_by_cluster", "heatmap_feature","grouping","min_max","cmap","save_figure"],
900
907
  "Test": ["test_mode", "test_images", "random_test", "test_nr", "test", "test_split"],
901
908
  "Timelapse": ["timelapse", "fps", "timelapse_displacement", "timelapse_memory", "timelapse_frame_limits", "timelapse_remove_transient", "timelapse_mode", "timelapse_objects", "compartments"],
902
- "Advanced": ["target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "uninfected", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
909
+ "Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "uninfected", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
903
910
  "Miscellaneous": ["all_to_mip", "pick_slice", "skip_mode", "upscale", "upscale_factor"]
904
911
  }
905
912
 
@@ -948,6 +955,14 @@ def check_settings(vars_dict, expected_types, q=None):
948
955
  settings[key] = float(value) if '.' in value else int(value)
949
956
  elif expected_type == (str, type(None)):
950
957
  settings[key] = str(value) if value else None
958
+ elif expected_type == (str, type(None), list):
959
+ if isinstance(value, list):
960
+ settings[key] = parse_list(value) if value else None
961
+ elif isinstance(value, str):
962
+ settings[key] = str(value)
963
+ else:
964
+ settings[key] = None
965
+
951
966
  elif expected_type == dict:
952
967
  try:
953
968
  # Ensure that the value is a string that can be converted to a dictionary
@@ -1202,10 +1217,10 @@ def generate_fields(variables, scrollable_frame):
1202
1217
  "complevel": "int - level of compression (0-9). Higher is slower and yealds smaller files",
1203
1218
  "file_type": "str - type of file to process",
1204
1219
  "model_path": "str - path to the model",
1205
- "tar_path": "str - path to the tar file with image dataset",
1220
+ "dataset": "str - file name of the tar file with image dataset",
1206
1221
  "score_threshold": "float - threshold for classification",
1207
1222
  "sample": "str - number of images to sample for tar dataset (including both classes). Default: None",
1208
- "file_metadata": "str - string that must be present in image path to be included in the dataset",
1223
+ "file_metadata": "str or list of strings - string(s) that must be present in image path to be included in the dataset",
1209
1224
  "apply_model_to_dataset": "bool - whether to apply model to the dataset",
1210
1225
  "train_channels": "list - channels to use for training",
1211
1226
  "dataset_mode": "str - How to generate train/test dataset.",
@@ -1246,6 +1261,13 @@ def generate_fields(variables, scrollable_frame):
1246
1261
  "mode": "(str) - Mode to use for sequence analysis (either single for R1 or R2 fastq files or paired for the combination of R1 and R2).",
1247
1262
  "signal_direction": "(str) - Direction of fastq file (R1 or R2). only relevent when mode is single.",
1248
1263
  "custom_model_path": "(str) - Path to the custom model to finetune.",
1264
+ "cam_type": "(str) - Choose between: gradcam, gradcam_pp, saliency_image, saliency_channel to generate activateion maps of DL models",
1265
+ "target_layer": "(str) - Only used for gradcam and gradcam_pp. The layer to use for the activation map.",
1266
+ "normalize": "(bool) - Normalize images before overlayng the activation maps.",
1267
+ "overlay": "(bool) - Overlay activation maps on the images.",
1268
+ "shuffle": "(bool) - Shuffle the dataset bufore generating the activation maps",
1269
+ "correlation": "(bool) - Calculate correlation between image channels and activation maps. Data is saved to .db.",
1270
+ "normalize_input": "(bool) - Normalize the input images before passing them to the model.",
1249
1271
  }
1250
1272
 
1251
1273
  for key, (var_type, options, default_value) in variables.items():
@@ -1281,6 +1303,8 @@ descriptions = {
1281
1303
 
1282
1304
  'regression': "Perform regression analysis on your data. Function: regression_tools from spacr.analysis.\n\nKey Features:\n- Statistical Analysis: Conduct various types of regression analysis to identify relationships within your data.\n- Flexible Options: Supports multiple regression models and configurations.\n- Data Insight: Gain deeper insights into your dataset through advanced regression techniques.",
1283
1305
 
1306
+ 'activation': "",
1307
+
1284
1308
  'recruitment': "Analyze recruitment data to understand sample recruitment dynamics. Function: recruitment_analysis_tools from spacr.analysis.\n\nKey Features:\n- Recruitment Analysis: Investigate and analyze the recruitment of samples over time or conditions.\n- Visualization: Generate visualizations to represent recruitment trends and patterns.\n- Integration: Utilize data from various sources for a comprehensive recruitment analysis."
1285
1309
  }
1286
1310
 
@@ -1313,4 +1337,25 @@ def set_default_generate_barecode_mapping(settings={}):
1313
1337
  settings.setdefault('mode', 'paired')
1314
1338
  settings.setdefault('single_direction', 'R1')
1315
1339
  settings.setdefault('test', False)
1340
+ return settings
1341
+
1342
+ def get_default_generate_activation_map_settings(settings):
1343
+ settings.setdefault('dataset', 'path')
1344
+ settings.setdefault('model_type', 'maxvit')
1345
+ settings.setdefault('model_path', 'path')
1346
+ settings.setdefault('image_size', 224)
1347
+ settings.setdefault('batch_size', 64)
1348
+ settings.setdefault('normalize', True)
1349
+ settings.setdefault('cam_type', 'gradcam')
1350
+ settings.setdefault('target_layer', None)
1351
+ settings.setdefault('plot', False)
1352
+ settings.setdefault('save', True)
1353
+ settings.setdefault('normalize_input', True)
1354
+ settings.setdefault('channels', [1,2,3])
1355
+ settings.setdefault('overlay', True)
1356
+ settings.setdefault('shuffle', True)
1357
+ settings.setdefault('correlation', True)
1358
+ settings.setdefault('manders_thresholds', [15,50, 75])
1359
+ settings.setdefault('n_jobs', None)
1360
+
1316
1361
  return settings