spacr 0.3.60__py3-none-any.whl → 0.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/app_annotate.py +0 -8
- spacr/core.py +12 -7
- spacr/gui_utils.py +24 -8
- spacr/io.py +134 -157
- spacr/ml.py +3 -4
- spacr/plot.py +82 -23
- spacr/settings.py +4 -13
- spacr/submodules.py +299 -5
- spacr/utils.py +96 -3
- {spacr-0.3.60.dist-info → spacr-0.3.62.dist-info}/METADATA +1 -1
- {spacr-0.3.60.dist-info → spacr-0.3.62.dist-info}/RECORD +15 -15
- {spacr-0.3.60.dist-info → spacr-0.3.62.dist-info}/LICENSE +0 -0
- {spacr-0.3.60.dist-info → spacr-0.3.62.dist-info}/WHEEL +0 -0
- {spacr-0.3.60.dist-info → spacr-0.3.62.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.60.dist-info → spacr-0.3.62.dist-info}/top_level.txt +0 -0
spacr/plot.py
CHANGED
@@ -909,7 +909,7 @@ def plot_merged(src, settings):
|
|
909
909
|
path = os.path.join(src, file)
|
910
910
|
stack = np.load(path)
|
911
911
|
print(f'Loaded: {path}')
|
912
|
-
if
|
912
|
+
if settings['pathogen_limit'] > 0:
|
913
913
|
if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
|
914
914
|
stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
|
915
915
|
|
@@ -2198,8 +2198,8 @@ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot',
|
|
2198
2198
|
tables,
|
2199
2199
|
verbose=True,
|
2200
2200
|
nuclei_limit=True,
|
2201
|
-
pathogen_limit=True
|
2202
|
-
|
2201
|
+
pathogen_limit=True)
|
2202
|
+
|
2203
2203
|
paths_df = _read_db(loc, tables=['png_list'])
|
2204
2204
|
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
2205
2205
|
return merged_df
|
@@ -2435,7 +2435,9 @@ class spacrGraph:
|
|
2435
2435
|
|
2436
2436
|
self.df = df
|
2437
2437
|
self.grouping_column = grouping_column
|
2438
|
+
self.order = sorted(df[self.grouping_column].unique().tolist())
|
2438
2439
|
self.data_column = data_column if isinstance(data_column, list) else [data_column]
|
2440
|
+
|
2439
2441
|
self.graph_type = graph_type
|
2440
2442
|
self.summary_func = summary_func
|
2441
2443
|
self.order = order
|
@@ -2909,9 +2911,11 @@ class spacrGraph:
|
|
2909
2911
|
ax.set_xlim(-0.5, num_groups - 0.5)
|
2910
2912
|
|
2911
2913
|
# Set ticks to match the group labels in your DataFrame
|
2912
|
-
group_labels = self.df[self.grouping_column].unique()
|
2913
|
-
|
2914
|
-
ax.
|
2914
|
+
#group_labels = self.df[self.grouping_column].unique()
|
2915
|
+
#group_labels = self.order
|
2916
|
+
#ax.set_xticks(range(len(group_labels)))
|
2917
|
+
#ax.set_xticklabels(group_labels, rotation=45, ha='right')
|
2918
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
|
2915
2919
|
|
2916
2920
|
# Customize elements based on the graph type
|
2917
2921
|
if graph_type == 'bar':
|
@@ -2943,6 +2947,66 @@ class spacrGraph:
|
|
2943
2947
|
|
2944
2948
|
# Redraw the figure to apply changes
|
2945
2949
|
ax.figure.canvas.draw()
|
2950
|
+
|
2951
|
+
def _standerdize_figure_format_v1(self, ax, num_groups, graph_type):
|
2952
|
+
"""
|
2953
|
+
Adjusts the figure layout (size, bar width, jitter, and spacing) based on the number of groups.
|
2954
|
+
"""
|
2955
|
+
if graph_type in ['line', 'line_std']:
|
2956
|
+
print("Skipping layout adjustment for line graphs.")
|
2957
|
+
return # Skip layout adjustment for line graphs
|
2958
|
+
|
2959
|
+
correction_factor = 4
|
2960
|
+
|
2961
|
+
# Set figure size to ensure it remains square with a minimum size
|
2962
|
+
fig_size = max(6, num_groups * 2) / correction_factor
|
2963
|
+
ax.figure.set_size_inches(fig_size, fig_size)
|
2964
|
+
|
2965
|
+
# Configure layout based on the number of groups
|
2966
|
+
bar_width = min(0.8, 1.5 / num_groups) / correction_factor
|
2967
|
+
jitter_amount = min(0.1, 0.2 / num_groups) / correction_factor
|
2968
|
+
jitter_size = max(50 / num_groups, 200)
|
2969
|
+
|
2970
|
+
# Adjust x-axis limits to fit the specified order of groups
|
2971
|
+
ax.set_xlim(-0.5, len(self.order) - 0.5) # Use `self.order` length to ensure alignment
|
2972
|
+
|
2973
|
+
# Use `self.order` as the x-tick labels to maintain consistent ordering
|
2974
|
+
ax.set_xticks(range(len(self.order)))
|
2975
|
+
#ax.set_xticklabels(self.order, rotation=45, ha='right')
|
2976
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
|
2977
|
+
|
2978
|
+
# Customize elements based on the graph type
|
2979
|
+
if graph_type == 'bar':
|
2980
|
+
# Adjust bars' width and position
|
2981
|
+
for bar in ax.patches:
|
2982
|
+
bar.set_width(bar_width)
|
2983
|
+
bar.set_x(bar.get_x() - bar_width / 2)
|
2984
|
+
|
2985
|
+
elif graph_type in ['jitter', 'jitter_bar', 'jitter_box']:
|
2986
|
+
# Adjust jitter points' position and size
|
2987
|
+
for coll in ax.collections:
|
2988
|
+
offsets = coll.get_offsets()
|
2989
|
+
offsets[:, 0] += jitter_amount # Shift jitter points slightly
|
2990
|
+
coll.set_offsets(offsets)
|
2991
|
+
coll.set_sizes([jitter_size] * len(offsets)) # Adjust point size dynamically
|
2992
|
+
|
2993
|
+
elif graph_type in ['box', 'violin']:
|
2994
|
+
# Adjust box width for consistent spacing
|
2995
|
+
for artist in ax.artists:
|
2996
|
+
artist.set_width(bar_width)
|
2997
|
+
|
2998
|
+
# Adjust legend and axis labels
|
2999
|
+
ax.tick_params(axis='x', labelsize=max(10, 15 - num_groups // 2))
|
3000
|
+
ax.tick_params(axis='y', labelsize=max(10, 15 - num_groups // 2))
|
3001
|
+
|
3002
|
+
# Adjust legend placement and size
|
3003
|
+
if ax.get_legend():
|
3004
|
+
ax.get_legend().set_bbox_to_anchor((1.05, 1))
|
3005
|
+
ax.get_legend().prop.set_size(max(8, 12 - num_groups // 3))
|
3006
|
+
|
3007
|
+
# Redraw the figure to apply changes
|
3008
|
+
ax.figure.canvas.draw()
|
3009
|
+
|
2946
3010
|
|
2947
3011
|
def _create_bar_plot(self, ax):
|
2948
3012
|
"""Helper method to create a bar plot with consistent bar thickness and centered error bars."""
|
@@ -2959,7 +3023,7 @@ class spacrGraph:
|
|
2959
3023
|
|
2960
3024
|
summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
|
2961
3025
|
error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
|
2962
|
-
sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
|
3026
|
+
sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None, order=self.order)
|
2963
3027
|
|
2964
3028
|
# Adjust the bar width manually
|
2965
3029
|
if len(self.data_column) > 1:
|
@@ -2999,7 +3063,7 @@ class spacrGraph:
|
|
2999
3063
|
hue = None
|
3000
3064
|
|
3001
3065
|
# Create the jitter plot
|
3002
|
-
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax, alpha=0.6, size=16)
|
3066
|
+
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax, alpha=0.6, size=16, order=self.order)
|
3003
3067
|
|
3004
3068
|
# Adjust legend and labels
|
3005
3069
|
ax.set_xlabel(self.grouping_column)
|
@@ -3088,7 +3152,7 @@ class spacrGraph:
|
|
3088
3152
|
hue = None
|
3089
3153
|
|
3090
3154
|
# Create the box plot
|
3091
|
-
sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
|
3155
|
+
sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
|
3092
3156
|
|
3093
3157
|
# Adjust legend and labels
|
3094
3158
|
ax.set_xlabel(self.grouping_column)
|
@@ -3117,7 +3181,7 @@ class spacrGraph:
|
|
3117
3181
|
hue = None
|
3118
3182
|
|
3119
3183
|
# Create the violin plot
|
3120
|
-
sns.violinplot(data=self.df_melted,x=x_axis_column,y='Value', hue=self.hue,palette=self.sns_palette,ax=ax)
|
3184
|
+
sns.violinplot(data=self.df_melted,x=x_axis_column,y='Value', hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
|
3121
3185
|
|
3122
3186
|
# Adjust legend and labels
|
3123
3187
|
ax.set_xlabel(self.grouping_column)
|
@@ -3148,8 +3212,8 @@ class spacrGraph:
|
|
3148
3212
|
|
3149
3213
|
summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
|
3150
3214
|
error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
|
3151
|
-
sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
|
3152
|
-
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16)
|
3215
|
+
sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None, order=self.order)
|
3216
|
+
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16, order=self.order)
|
3153
3217
|
|
3154
3218
|
# Adjust the bar width manually
|
3155
3219
|
if len(self.data_column) > 1:
|
@@ -3189,8 +3253,8 @@ class spacrGraph:
|
|
3189
3253
|
hue = None
|
3190
3254
|
|
3191
3255
|
# Create the box plot
|
3192
|
-
sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
|
3193
|
-
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12)
|
3256
|
+
sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
|
3257
|
+
sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12, order=self.order)
|
3194
3258
|
|
3195
3259
|
# Adjust legend and labels
|
3196
3260
|
ax.set_xlabel(self.grouping_column)
|
@@ -3264,12 +3328,11 @@ def plot_data_from_db(settings):
|
|
3264
3328
|
[df1] = _read_db(db_loc, tables=[settings['table_names']])
|
3265
3329
|
else:
|
3266
3330
|
df1, _ = _read_and_merge_data(locs=[db_loc],
|
3267
|
-
tables = ['
|
3331
|
+
tables = settings['tables'],
|
3268
3332
|
verbose=settings['verbose'],
|
3269
3333
|
nuclei_limit=settings['nuclei_limit'],
|
3270
|
-
pathogen_limit=settings['pathogen_limit']
|
3271
|
-
|
3272
|
-
|
3334
|
+
pathogen_limit=settings['pathogen_limit'])
|
3335
|
+
|
3273
3336
|
dft = annotate_conditions(df1,
|
3274
3337
|
cells=settings['cell_types'],
|
3275
3338
|
cell_loc=settings['cell_plate_metadata'],
|
@@ -3281,10 +3344,7 @@ def plot_data_from_db(settings):
|
|
3281
3344
|
|
3282
3345
|
df = pd.concat(dfs, axis=0)
|
3283
3346
|
df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
|
3284
|
-
|
3285
|
-
#df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
|
3286
|
-
df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
|
3287
|
-
|
3347
|
+
|
3288
3348
|
if settings['cell_plate_metadata'] != None:
|
3289
3349
|
df = df.dropna(subset='host_cell')
|
3290
3350
|
|
@@ -3297,7 +3357,6 @@ def plot_data_from_db(settings):
|
|
3297
3357
|
df = df.dropna(subset=settings['data_column'])
|
3298
3358
|
df = df.dropna(subset=settings['grouping_column'])
|
3299
3359
|
|
3300
|
-
|
3301
3360
|
src = srcs[0]
|
3302
3361
|
dst = os.path.join(src, 'results', settings['graph_name'])
|
3303
3362
|
os.makedirs(dst, exist_ok=True)
|
spacr/settings.py
CHANGED
@@ -2,7 +2,6 @@ import os, ast
|
|
2
2
|
|
3
3
|
def set_default_plot_merge_settings():
|
4
4
|
settings = {}
|
5
|
-
settings.setdefault('uninfected', True)
|
6
5
|
settings.setdefault('pathogen_limit', 10)
|
7
6
|
settings.setdefault('nuclei_limit', 1)
|
8
7
|
settings.setdefault('remove_background', False)
|
@@ -181,8 +180,8 @@ def set_default_umap_image_settings(settings={}):
|
|
181
180
|
settings.setdefault('n_neighbors', 1000)
|
182
181
|
settings.setdefault('min_dist', 0.1)
|
183
182
|
settings.setdefault('metric', 'euclidean')
|
184
|
-
settings.setdefault('eps', 0.
|
185
|
-
settings.setdefault('min_samples',
|
183
|
+
settings.setdefault('eps', 0.9)
|
184
|
+
settings.setdefault('min_samples', 100)
|
186
185
|
settings.setdefault('filter_by', 'channel_0')
|
187
186
|
settings.setdefault('img_zoom', 0.5)
|
188
187
|
settings.setdefault('plot_by_cluster', True)
|
@@ -201,16 +200,13 @@ def set_default_umap_image_settings(settings={}):
|
|
201
200
|
settings.setdefault('col_to_compare', 'column_name')
|
202
201
|
settings.setdefault('pos', 'c1')
|
203
202
|
settings.setdefault('neg', 'c2')
|
203
|
+
settings.setdefault('mix', 'c3')
|
204
204
|
settings.setdefault('embedding_by_controls', False)
|
205
205
|
settings.setdefault('plot_images', True)
|
206
206
|
settings.setdefault('reduction_method','umap')
|
207
207
|
settings.setdefault('save_figure', False)
|
208
208
|
settings.setdefault('n_jobs', -1)
|
209
209
|
settings.setdefault('color_by', None)
|
210
|
-
settings.setdefault('neg', 'c1')
|
211
|
-
settings.setdefault('pos', 'c2')
|
212
|
-
settings.setdefault('mix', 'c3')
|
213
|
-
settings.setdefault('mix', 'c3')
|
214
210
|
settings.setdefault('exclude_conditions', None)
|
215
211
|
settings.setdefault('analyze_clusters', False)
|
216
212
|
settings.setdefault('resnet_features', False)
|
@@ -295,7 +291,6 @@ def set_default_analyze_screen(settings):
|
|
295
291
|
settings.setdefault('exclude',None)
|
296
292
|
settings.setdefault('nuclei_limit',True)
|
297
293
|
settings.setdefault('pathogen_limit',3)
|
298
|
-
settings.setdefault('uninfected',True)
|
299
294
|
settings.setdefault('n_repeats',10)
|
300
295
|
settings.setdefault('top_features',30)
|
301
296
|
settings.setdefault('remove_low_variance_features',True)
|
@@ -353,7 +348,6 @@ def set_generate_training_dataset_defaults(settings):
|
|
353
348
|
settings.setdefault('tables',None)
|
354
349
|
settings.setdefault('nuclei_limit',True)
|
355
350
|
settings.setdefault('pathogen_limit',True)
|
356
|
-
settings.setdefault('uninfected',True)
|
357
351
|
settings.setdefault('png_type','cell_png')
|
358
352
|
|
359
353
|
return settings
|
@@ -467,7 +461,6 @@ def get_analyze_recruitment_default_settings(settings):
|
|
467
461
|
settings.setdefault('plot_nr',3)
|
468
462
|
settings.setdefault('plot_control',True)
|
469
463
|
settings.setdefault('figuresize',10)
|
470
|
-
settings.setdefault('uninfected',True)
|
471
464
|
settings.setdefault('pathogen_limit',10)
|
472
465
|
settings.setdefault('nuclei_limit',1)
|
473
466
|
settings.setdefault('cells_per_well',0)
|
@@ -691,7 +684,6 @@ expected_types = {
|
|
691
684
|
"measurement": str,
|
692
685
|
"nr_imgs": int,
|
693
686
|
"um_per_pixel": (int, float),
|
694
|
-
"uninfected": bool,
|
695
687
|
"pathogen_limit": int,
|
696
688
|
"nuclei_limit": int,
|
697
689
|
"filter_min_max": (list, type(None)),
|
@@ -898,7 +890,7 @@ categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset
|
|
898
890
|
"Plot": ["plot", "plot_control", "plot_nr", "examples_to_plot", "normalize_plots", "cmap", "figuresize", "plot_cluster_grids", "img_zoom", "row_limit", "color_by", "plot_images", "smooth_lines", "plot_points", "plot_outlines", "black_background", "plot_by_cluster", "heatmap_feature","grouping","min_max","cmap","save_figure"],
|
899
891
|
"Test": ["test_mode", "test_images", "random_test", "test_nr", "test", "test_split"],
|
900
892
|
"Timelapse": ["timelapse", "fps", "timelapse_displacement", "timelapse_memory", "timelapse_frame_limits", "timelapse_remove_transient", "timelapse_mode", "timelapse_objects", "compartments"],
|
901
|
-
"Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "
|
893
|
+
"Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "background", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
|
902
894
|
"Miscellaneous": ["all_to_mip", "pick_slice", "skip_mode", "upscale", "upscale_factor"]
|
903
895
|
}
|
904
896
|
|
@@ -1080,7 +1072,6 @@ def generate_fields(variables, scrollable_frame):
|
|
1080
1072
|
"img_zoom": "(float) - Zoom factor for the images in plots.",
|
1081
1073
|
"nuclei_limit": "(int) - Whether to include multinucleated cells in the analysis.",
|
1082
1074
|
"pathogen_limit": "(int) - Whether to include multi-infected cells in the analysis.",
|
1083
|
-
"uninfected": "(bool) - Whether to include non-infected cells in the analysis.",
|
1084
1075
|
"uninfected": "(bool) - Whether to include uninfected cells in the analysis.",
|
1085
1076
|
"init_weights": "(bool) - Whether to initialize weights for the model.",
|
1086
1077
|
"src": "(str) - Path to the folder containing the images.",
|
spacr/submodules.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import seaborn as sns
|
2
|
-
import os, random, sqlite3
|
2
|
+
import os, random, sqlite3, re, shap
|
3
3
|
import pandas as pd
|
4
4
|
import numpy as np
|
5
5
|
import cellpose
|
@@ -7,6 +7,9 @@ from skimage.measure import regionprops, label
|
|
7
7
|
from cellpose import models as cp_models
|
8
8
|
from cellpose import train as train_cp
|
9
9
|
from IPython.display import display
|
10
|
+
from sklearn.ensemble import RandomForestClassifier
|
11
|
+
from sklearn.inspection import permutation_importance
|
12
|
+
from math import pi
|
10
13
|
|
11
14
|
import matplotlib.pyplot as plt
|
12
15
|
from natsort import natsorted
|
@@ -43,9 +46,8 @@ def analyze_recruitment(settings={}):
|
|
43
46
|
tables=['cell', 'nucleus', 'pathogen','cytoplasm'],
|
44
47
|
verbose=True,
|
45
48
|
nuclei_limit=settings['nuclei_limit'],
|
46
|
-
pathogen_limit=settings['pathogen_limit']
|
47
|
-
|
48
|
-
|
49
|
+
pathogen_limit=settings['pathogen_limit'])
|
50
|
+
|
49
51
|
df = annotate_conditions(df,
|
50
52
|
cells=settings['cell_types'],
|
51
53
|
cell_loc=settings['cell_plate_metadata'],
|
@@ -550,4 +552,296 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
550
552
|
fig_1 = plot_line(df, x_column = 'pc_fraction', y_columns=y_columns, group_column=None, xlabel=None, ylabel='Fraction', title=None, figsize=(10, 6), save_path=save_paths[0])
|
551
553
|
fig_2 = plot_line(df, x_column = 'nc_fraction', y_columns=y_columns, group_column=None, xlabel=None, ylabel='Fraction', title=None, figsize=(10, 6), save_path=save_paths[1])
|
552
554
|
|
553
|
-
return [fig_1, fig_2]
|
555
|
+
return [fig_1, fig_2]
|
556
|
+
|
557
|
+
def interperate_vision_model(settings={}):
|
558
|
+
|
559
|
+
from .io import _read_and_merge_data
|
560
|
+
|
561
|
+
def generate_comparison_columns(df, compartments=['cell', 'nucleus', 'pathogen', 'cytoplasm']):
|
562
|
+
|
563
|
+
comparison_dict = {}
|
564
|
+
|
565
|
+
# Get columns by compartment
|
566
|
+
compartment_columns = {comp: [col for col in df.columns if col.startswith(comp)] for comp in compartments}
|
567
|
+
|
568
|
+
for comp0, comp0_columns in compartment_columns.items():
|
569
|
+
for comp0_col in comp0_columns:
|
570
|
+
related_cols = []
|
571
|
+
base_col_name = comp0_col.replace(comp0, '') # Base feature name without compartment prefix
|
572
|
+
|
573
|
+
# Look for matching columns in other compartments
|
574
|
+
for prefix, prefix_columns in compartment_columns.items():
|
575
|
+
if prefix == comp0: # Skip same-compartment comparisons
|
576
|
+
continue
|
577
|
+
# Check if related column exists in other compartment
|
578
|
+
related_col = prefix + base_col_name
|
579
|
+
if related_col in df.columns:
|
580
|
+
related_cols.append(related_col)
|
581
|
+
new_col_name = f"{prefix}_{comp0}{base_col_name}" # Format: prefix_comp0_base
|
582
|
+
|
583
|
+
# Calculate ratio and handle infinite or NaN values
|
584
|
+
df[new_col_name] = df[related_col] / df[comp0_col]
|
585
|
+
df[new_col_name].replace([float('inf'), -float('inf')], pd.NA, inplace=True) # Replace inf values with NA
|
586
|
+
df[new_col_name].fillna(0, inplace=True) # Replace NaN values with 0 for ease of further calculations
|
587
|
+
|
588
|
+
# Generate all-to-all comparisons
|
589
|
+
if related_cols:
|
590
|
+
comparison_dict[comp0_col] = related_cols
|
591
|
+
for i, rel_col_1 in enumerate(related_cols):
|
592
|
+
for rel_col_2 in related_cols[i + 1:]:
|
593
|
+
# Create a new column name for each pairwise comparison
|
594
|
+
comp1, comp2 = rel_col_1.split('_')[0], rel_col_2.split('_')[0]
|
595
|
+
new_col_name_all = f"{comp1}_{comp2}{base_col_name}"
|
596
|
+
|
597
|
+
# Calculate pairwise ratio and handle infinite or NaN values
|
598
|
+
df[new_col_name_all] = df[rel_col_1] / df[rel_col_2]
|
599
|
+
df[new_col_name_all].replace([float('inf'), -float('inf')], pd.NA, inplace=True) # Replace inf with NA
|
600
|
+
df[new_col_name_all].fillna(0, inplace=True) # Replace NaN with 0
|
601
|
+
|
602
|
+
return df, comparison_dict
|
603
|
+
|
604
|
+
def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pathogen'], name='compartment', include_all=False):
|
605
|
+
|
606
|
+
# Function to determine compartment based on multiple matches
|
607
|
+
def find_feature_class(feature, compartments):
|
608
|
+
matches = [compartment for compartment in compartments if re.search(compartment, feature)]
|
609
|
+
if len(matches) > 1:
|
610
|
+
return '-'.join(matches)
|
611
|
+
elif matches:
|
612
|
+
return matches[0]
|
613
|
+
else:
|
614
|
+
return None
|
615
|
+
|
616
|
+
from spacr.plot import spacrGraph
|
617
|
+
|
618
|
+
df[name] = df['feature'].apply(lambda x: find_feature_class(x, feature_groups))
|
619
|
+
|
620
|
+
if name == 'channel':
|
621
|
+
df['channel'].fillna('morphology', inplace=True)
|
622
|
+
|
623
|
+
# Create new DataFrame with summed importance for each compartment and channel
|
624
|
+
importance_sum = df.groupby(name)['importance'].sum().reset_index(name=f'{name}_importance_sum')
|
625
|
+
|
626
|
+
if include_all:
|
627
|
+
total_compartment_importance = importance_sum[f'{name}_importance_sum'].sum()
|
628
|
+
importance_sum = pd.concat(
|
629
|
+
[importance_sum,
|
630
|
+
pd.DataFrame(
|
631
|
+
[{name: 'all', f'{name}_importance_sum': total_compartment_importance}])]
|
632
|
+
, ignore_index=True)
|
633
|
+
|
634
|
+
return importance_sum
|
635
|
+
|
636
|
+
# Function to create radar plot for individual and combined values
|
637
|
+
def create_extended_radar_plot(values, labels, title):
|
638
|
+
values = list(values) + [values[0]] # Close the loop for radar chart
|
639
|
+
angles = [n / float(len(labels)) * 2 * pi for n in range(len(labels))]
|
640
|
+
angles += angles[:1]
|
641
|
+
|
642
|
+
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
|
643
|
+
ax.plot(angles, values, linewidth=2, linestyle='solid')
|
644
|
+
ax.fill(angles, values, alpha=0.25)
|
645
|
+
|
646
|
+
ax.set_xticks(angles[:-1])
|
647
|
+
ax.set_xticklabels(labels, fontsize=10, rotation=45, ha='right')
|
648
|
+
plt.title(title, pad=20)
|
649
|
+
plt.show()
|
650
|
+
|
651
|
+
def extract_compartment_channel(feature_name):
|
652
|
+
# Identify compartment as the first part before an underscore
|
653
|
+
compartment = feature_name.split('_')[0]
|
654
|
+
|
655
|
+
if compartment == 'cells':
|
656
|
+
compartment = 'cell'
|
657
|
+
|
658
|
+
# Identify channels based on substring presence
|
659
|
+
channels = []
|
660
|
+
if 'channel_0' in feature_name:
|
661
|
+
channels.append('channel_0')
|
662
|
+
if 'channel_1' in feature_name:
|
663
|
+
channels.append('channel_1')
|
664
|
+
if 'channel_2' in feature_name:
|
665
|
+
channels.append('channel_2')
|
666
|
+
if 'channel_3' in feature_name:
|
667
|
+
channels.append('channel_3')
|
668
|
+
|
669
|
+
# If multiple channels are found, join them with a '+'
|
670
|
+
if channels:
|
671
|
+
channel = ' + '.join(channels)
|
672
|
+
else:
|
673
|
+
channel = 'morphology' # Use 'morphology' if no channel identifier is found
|
674
|
+
|
675
|
+
return (compartment, channel)
|
676
|
+
|
677
|
+
def read_and_preprocess_data(settings):
|
678
|
+
|
679
|
+
df, _ = _read_and_merge_data(
|
680
|
+
locs=[settings['src']+'/measurements/measurements.db'],
|
681
|
+
tables=settings['tables'],
|
682
|
+
verbose=True,
|
683
|
+
nuclei_limit=settings['nuclei_limit'],
|
684
|
+
pathogen_limit=settings['pathogen_limit']
|
685
|
+
)
|
686
|
+
|
687
|
+
df, _dict = generate_comparison_columns(df, compartments=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
688
|
+
print(f"Expanded dataframe to {len(df.columns)} columns with relative features")
|
689
|
+
scores_df = pd.read_csv(settings['scores'])
|
690
|
+
|
691
|
+
# Clean and align columns for merging
|
692
|
+
df['object_label'] = df['object_label'].str.replace('o', '')
|
693
|
+
|
694
|
+
if 'row_name' not in scores_df.columns:
|
695
|
+
scores_df['row_name'] = scores_df['row']
|
696
|
+
|
697
|
+
if 'column_name' not in scores_df.columns:
|
698
|
+
scores_df['column_name'] = scores_df['col']
|
699
|
+
|
700
|
+
if 'object_label' not in scores_df.columns:
|
701
|
+
scores_df['object_label'] = scores_df['object']
|
702
|
+
|
703
|
+
# Remove the 'o' prefix from 'object_label' in df, ensuring it is a string type
|
704
|
+
df['object_label'] = df['object_label'].str.replace('o', '').astype(str)
|
705
|
+
|
706
|
+
# Ensure 'object_label' in scores_df is also a string
|
707
|
+
scores_df['object_label'] = scores_df['object'].astype(str)
|
708
|
+
|
709
|
+
# Ensure all join columns have the same data type in both DataFrames
|
710
|
+
df[['plate', 'row_name', 'column_name', 'field', 'object_label']] = df[['plate', 'row_name', 'column_name', 'field', 'object_label']].astype(str)
|
711
|
+
scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label']] = scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label']].astype(str)
|
712
|
+
|
713
|
+
# Select only the necessary columns from scores_df for merging
|
714
|
+
scores_df = scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label', settings['score_column']]]
|
715
|
+
|
716
|
+
# Now merge DataFrames
|
717
|
+
merged_df = pd.merge(df, scores_df, on=['plate', 'row_name', 'column_name', 'field', 'object_label'], how='inner')
|
718
|
+
|
719
|
+
# Separate numerical features and the score column
|
720
|
+
X = merged_df.select_dtypes(include='number').drop(columns=[settings['score_column']])
|
721
|
+
y = merged_df[settings['score_column']]
|
722
|
+
|
723
|
+
return X, y, merged_df
|
724
|
+
|
725
|
+
X, y, merged_df = read_and_preprocess_data(settings)
|
726
|
+
|
727
|
+
output = {}
|
728
|
+
|
729
|
+
# Step 1: Feature Importance using Random Forest
|
730
|
+
if settings['feature_importance'] or settings['feature_importance']:
|
731
|
+
model = RandomForestClassifier(random_state=42, n_jobs=settings['n_jobs'])
|
732
|
+
model.fit(X, y)
|
733
|
+
|
734
|
+
if settings['feature_importance']:
|
735
|
+
print(f"Feature Importance ...")
|
736
|
+
feature_importances = model.feature_importances_
|
737
|
+
feature_importance_df = pd.DataFrame({'feature': X.columns, 'importance': feature_importances})
|
738
|
+
feature_importance_df = feature_importance_df.sort_values(by='importance', ascending=False)
|
739
|
+
top_feature_importance_df = feature_importance_df.head(settings['top_features'])
|
740
|
+
|
741
|
+
# Plot Feature Importance
|
742
|
+
plt.figure(figsize=(10, 6))
|
743
|
+
plt.barh(top_feature_importance_df['feature'], top_feature_importance_df['importance'])
|
744
|
+
plt.xlabel('Importance')
|
745
|
+
plt.title(f"Top {settings['top_features']} Features - Feature Importance")
|
746
|
+
plt.gca().invert_yaxis()
|
747
|
+
plt.show()
|
748
|
+
|
749
|
+
output['feature_importance'] = feature_importance_df
|
750
|
+
fi_compartment_df = group_feature_class(feature_importance_df, feature_groups=settings['tables'], name='compartment', include_all=settings['include_all'])
|
751
|
+
fi_channel_df = group_feature_class(feature_importance_df, feature_groups=settings['channels'], name='channel', include_all=settings['include_all'])
|
752
|
+
|
753
|
+
output['feature_importance_compartment'] = fi_compartment_df
|
754
|
+
output['feature_importance_channel'] = fi_channel_df
|
755
|
+
|
756
|
+
# Step 2: Permutation Importance
|
757
|
+
if settings['permutation_importance']:
|
758
|
+
print(f"Permutation Importance ...")
|
759
|
+
perm_importance = permutation_importance(model, X, y, n_repeats=10, random_state=42, n_jobs=settings['n_jobs'])
|
760
|
+
perm_importance_df = pd.DataFrame({'feature': X.columns, 'importance': perm_importance.importances_mean})
|
761
|
+
perm_importance_df = perm_importance_df.sort_values(by='importance', ascending=False)
|
762
|
+
top_perm_importance_df = perm_importance_df.head(settings['top_features'])
|
763
|
+
|
764
|
+
# Plot Permutation Importance
|
765
|
+
plt.figure(figsize=(10, 6))
|
766
|
+
plt.barh(top_perm_importance_df['feature'], top_perm_importance_df['importance'])
|
767
|
+
plt.xlabel('Importance')
|
768
|
+
plt.title(f"Top {settings['top_features']} Features - Permutation Importance")
|
769
|
+
plt.gca().invert_yaxis()
|
770
|
+
plt.show()
|
771
|
+
|
772
|
+
output['permutation_importance'] = perm_importance_df
|
773
|
+
|
774
|
+
# Step 3: SHAP Analysis
|
775
|
+
if settings['shap']:
|
776
|
+
print(f"SHAP Analysis ...")
|
777
|
+
|
778
|
+
# Select top N features based on Random Forest importance and fit the model on these features only
|
779
|
+
top_features = feature_importance_df.head(settings['top_features'])['feature']
|
780
|
+
X_top = X[top_features]
|
781
|
+
|
782
|
+
# Refit the model on this subset of features
|
783
|
+
model = RandomForestClassifier(random_state=42, n_jobs=settings['n_jobs'])
|
784
|
+
model.fit(X_top, y)
|
785
|
+
|
786
|
+
# Sample a smaller subset of rows to speed up SHAP
|
787
|
+
if settings['shap_sample']:
|
788
|
+
sample = int(len(X_top) / 100)
|
789
|
+
X_sample = X_top.sample(min(sample, len(X_top)), random_state=42)
|
790
|
+
else:
|
791
|
+
X_sample = X_top
|
792
|
+
|
793
|
+
# Initialize SHAP explainer with the same subset of features
|
794
|
+
explainer = shap.Explainer(model.predict, X_sample)
|
795
|
+
shap_values = explainer(X_sample, max_evals=1500)
|
796
|
+
|
797
|
+
# Plot SHAP summary for the selected sample and top features
|
798
|
+
shap.summary_plot(shap_values, X_sample, max_display=settings['top_features'])
|
799
|
+
|
800
|
+
# Convert SHAP values to a DataFrame for easier manipulation
|
801
|
+
shap_df = pd.DataFrame(shap_values.values, columns=X_sample.columns)
|
802
|
+
|
803
|
+
# Apply the function to create MultiIndex columns with compartment and channel
|
804
|
+
shap_df.columns = pd.MultiIndex.from_tuples(
|
805
|
+
[extract_compartment_channel(feat) for feat in shap_df.columns],
|
806
|
+
names=['compartment', 'channel']
|
807
|
+
)
|
808
|
+
|
809
|
+
# Aggregate SHAP values by compartment and channel
|
810
|
+
compartment_mean = shap_df.abs().groupby(level='compartment', axis=1).mean().mean(axis=0)
|
811
|
+
channel_mean = shap_df.abs().groupby(level='channel', axis=1).mean().mean(axis=0)
|
812
|
+
|
813
|
+
# Calculate combined importance for each pair of compartments and channels
|
814
|
+
combined_compartment = {}
|
815
|
+
for i, comp1 in enumerate(compartment_mean.index):
|
816
|
+
for comp2 in compartment_mean.index[i+1:]:
|
817
|
+
combined_compartment[f"{comp1} + {comp2}"] = shap_df.loc[:, (comp1, slice(None))].abs().mean().mean() + \
|
818
|
+
shap_df.loc[:, (comp2, slice(None))].abs().mean().mean()
|
819
|
+
|
820
|
+
combined_channel = {}
|
821
|
+
for i, chan1 in enumerate(channel_mean.index):
|
822
|
+
for chan2 in channel_mean.index[i+1:]:
|
823
|
+
combined_channel[f"{chan1} + {chan2}"] = shap_df.loc[:, (slice(None), chan1)].abs().mean().mean() + \
|
824
|
+
shap_df.loc[:, (slice(None), chan2)].abs().mean().mean()
|
825
|
+
|
826
|
+
# Prepare values and labels for radar charts
|
827
|
+
all_compartment_importance = list(compartment_mean.values) + list(combined_compartment.values())
|
828
|
+
all_compartment_labels = list(compartment_mean.index) + list(combined_compartment.keys())
|
829
|
+
|
830
|
+
all_channel_importance = list(channel_mean.values) + list(combined_channel.values())
|
831
|
+
all_channel_labels = list(channel_mean.index) + list(combined_channel.keys())
|
832
|
+
|
833
|
+
# Create radar plots for compartments and channels
|
834
|
+
#create_extended_radar_plot(all_compartment_importance, all_compartment_labels, "SHAP Importance by Compartment (Individual and Combined)")
|
835
|
+
#create_extended_radar_plot(all_channel_importance, all_channel_labels, "SHAP Importance by Channel (Individual and Combined)")
|
836
|
+
|
837
|
+
output['shap'] = shap_df
|
838
|
+
|
839
|
+
if settings['save']:
|
840
|
+
dst = os.path.join(settings['src'], 'results')
|
841
|
+
os.makedirs(dst, exist_ok=True)
|
842
|
+
for key, df in output.items():
|
843
|
+
save_path = os.path.join(dst, f"{key}.csv")
|
844
|
+
df.to_csv(save_path)
|
845
|
+
print(f"Saved {save_path}")
|
846
|
+
|
847
|
+
return output
|