spacr 0.3.60__py3-none-any.whl → 0.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -909,7 +909,7 @@ def plot_merged(src, settings):
909
909
  path = os.path.join(src, file)
910
910
  stack = np.load(path)
911
911
  print(f'Loaded: {path}')
912
- if not settings['uninfected']:
912
+ if settings['pathogen_limit'] > 0:
913
913
  if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
914
914
  stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
915
915
 
@@ -2198,8 +2198,8 @@ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot',
2198
2198
  tables,
2199
2199
  verbose=True,
2200
2200
  nuclei_limit=True,
2201
- pathogen_limit=True,
2202
- uninfected=True)
2201
+ pathogen_limit=True)
2202
+
2203
2203
  paths_df = _read_db(loc, tables=['png_list'])
2204
2204
  merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
2205
2205
  return merged_df
@@ -2435,7 +2435,9 @@ class spacrGraph:
2435
2435
 
2436
2436
  self.df = df
2437
2437
  self.grouping_column = grouping_column
2438
+ self.order = sorted(df[self.grouping_column].unique().tolist())
2438
2439
  self.data_column = data_column if isinstance(data_column, list) else [data_column]
2440
+
2439
2441
  self.graph_type = graph_type
2440
2442
  self.summary_func = summary_func
2441
2443
  self.order = order
@@ -2909,9 +2911,11 @@ class spacrGraph:
2909
2911
  ax.set_xlim(-0.5, num_groups - 0.5)
2910
2912
 
2911
2913
  # Set ticks to match the group labels in your DataFrame
2912
- group_labels = self.df[self.grouping_column].unique()
2913
- ax.set_xticks(range(len(group_labels)))
2914
- ax.set_xticklabels(group_labels, rotation=45, ha='right')
2914
+ #group_labels = self.df[self.grouping_column].unique()
2915
+ #group_labels = self.order
2916
+ #ax.set_xticks(range(len(group_labels)))
2917
+ #ax.set_xticklabels(group_labels, rotation=45, ha='right')
2918
+ plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
2915
2919
 
2916
2920
  # Customize elements based on the graph type
2917
2921
  if graph_type == 'bar':
@@ -2943,6 +2947,66 @@ class spacrGraph:
2943
2947
 
2944
2948
  # Redraw the figure to apply changes
2945
2949
  ax.figure.canvas.draw()
2950
+
2951
+ def _standerdize_figure_format_v1(self, ax, num_groups, graph_type):
2952
+ """
2953
+ Adjusts the figure layout (size, bar width, jitter, and spacing) based on the number of groups.
2954
+ """
2955
+ if graph_type in ['line', 'line_std']:
2956
+ print("Skipping layout adjustment for line graphs.")
2957
+ return # Skip layout adjustment for line graphs
2958
+
2959
+ correction_factor = 4
2960
+
2961
+ # Set figure size to ensure it remains square with a minimum size
2962
+ fig_size = max(6, num_groups * 2) / correction_factor
2963
+ ax.figure.set_size_inches(fig_size, fig_size)
2964
+
2965
+ # Configure layout based on the number of groups
2966
+ bar_width = min(0.8, 1.5 / num_groups) / correction_factor
2967
+ jitter_amount = min(0.1, 0.2 / num_groups) / correction_factor
2968
+ jitter_size = max(50 / num_groups, 200)
2969
+
2970
+ # Adjust x-axis limits to fit the specified order of groups
2971
+ ax.set_xlim(-0.5, len(self.order) - 0.5) # Use `self.order` length to ensure alignment
2972
+
2973
+ # Use `self.order` as the x-tick labels to maintain consistent ordering
2974
+ ax.set_xticks(range(len(self.order)))
2975
+ #ax.set_xticklabels(self.order, rotation=45, ha='right')
2976
+ plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
2977
+
2978
+ # Customize elements based on the graph type
2979
+ if graph_type == 'bar':
2980
+ # Adjust bars' width and position
2981
+ for bar in ax.patches:
2982
+ bar.set_width(bar_width)
2983
+ bar.set_x(bar.get_x() - bar_width / 2)
2984
+
2985
+ elif graph_type in ['jitter', 'jitter_bar', 'jitter_box']:
2986
+ # Adjust jitter points' position and size
2987
+ for coll in ax.collections:
2988
+ offsets = coll.get_offsets()
2989
+ offsets[:, 0] += jitter_amount # Shift jitter points slightly
2990
+ coll.set_offsets(offsets)
2991
+ coll.set_sizes([jitter_size] * len(offsets)) # Adjust point size dynamically
2992
+
2993
+ elif graph_type in ['box', 'violin']:
2994
+ # Adjust box width for consistent spacing
2995
+ for artist in ax.artists:
2996
+ artist.set_width(bar_width)
2997
+
2998
+ # Adjust legend and axis labels
2999
+ ax.tick_params(axis='x', labelsize=max(10, 15 - num_groups // 2))
3000
+ ax.tick_params(axis='y', labelsize=max(10, 15 - num_groups // 2))
3001
+
3002
+ # Adjust legend placement and size
3003
+ if ax.get_legend():
3004
+ ax.get_legend().set_bbox_to_anchor((1.05, 1))
3005
+ ax.get_legend().prop.set_size(max(8, 12 - num_groups // 3))
3006
+
3007
+ # Redraw the figure to apply changes
3008
+ ax.figure.canvas.draw()
3009
+
2946
3010
 
2947
3011
  def _create_bar_plot(self, ax):
2948
3012
  """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
@@ -2959,7 +3023,7 @@ class spacrGraph:
2959
3023
 
2960
3024
  summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
2961
3025
  error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
2962
- sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
3026
+ sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None, order=self.order)
2963
3027
 
2964
3028
  # Adjust the bar width manually
2965
3029
  if len(self.data_column) > 1:
@@ -2999,7 +3063,7 @@ class spacrGraph:
2999
3063
  hue = None
3000
3064
 
3001
3065
  # Create the jitter plot
3002
- sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax, alpha=0.6, size=16)
3066
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax, alpha=0.6, size=16, order=self.order)
3003
3067
 
3004
3068
  # Adjust legend and labels
3005
3069
  ax.set_xlabel(self.grouping_column)
@@ -3088,7 +3152,7 @@ class spacrGraph:
3088
3152
  hue = None
3089
3153
 
3090
3154
  # Create the box plot
3091
- sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
3155
+ sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
3092
3156
 
3093
3157
  # Adjust legend and labels
3094
3158
  ax.set_xlabel(self.grouping_column)
@@ -3117,7 +3181,7 @@ class spacrGraph:
3117
3181
  hue = None
3118
3182
 
3119
3183
  # Create the violin plot
3120
- sns.violinplot(data=self.df_melted,x=x_axis_column,y='Value', hue=self.hue,palette=self.sns_palette,ax=ax)
3184
+ sns.violinplot(data=self.df_melted,x=x_axis_column,y='Value', hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
3121
3185
 
3122
3186
  # Adjust legend and labels
3123
3187
  ax.set_xlabel(self.grouping_column)
@@ -3148,8 +3212,8 @@ class spacrGraph:
3148
3212
 
3149
3213
  summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
3150
3214
  error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
3151
- sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
3152
- sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16)
3215
+ sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None, order=self.order)
3216
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16, order=self.order)
3153
3217
 
3154
3218
  # Adjust the bar width manually
3155
3219
  if len(self.data_column) > 1:
@@ -3189,8 +3253,8 @@ class spacrGraph:
3189
3253
  hue = None
3190
3254
 
3191
3255
  # Create the box plot
3192
- sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
3193
- sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12)
3256
+ sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
3257
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12, order=self.order)
3194
3258
 
3195
3259
  # Adjust legend and labels
3196
3260
  ax.set_xlabel(self.grouping_column)
@@ -3264,12 +3328,11 @@ def plot_data_from_db(settings):
3264
3328
  [df1] = _read_db(db_loc, tables=[settings['table_names']])
3265
3329
  else:
3266
3330
  df1, _ = _read_and_merge_data(locs=[db_loc],
3267
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm'],
3331
+ tables = settings['tables'],
3268
3332
  verbose=settings['verbose'],
3269
3333
  nuclei_limit=settings['nuclei_limit'],
3270
- pathogen_limit=settings['pathogen_limit'],
3271
- uninfected=settings['uninfected'])
3272
-
3334
+ pathogen_limit=settings['pathogen_limit'])
3335
+
3273
3336
  dft = annotate_conditions(df1,
3274
3337
  cells=settings['cell_types'],
3275
3338
  cell_loc=settings['cell_plate_metadata'],
@@ -3281,10 +3344,7 @@ def plot_data_from_db(settings):
3281
3344
 
3282
3345
  df = pd.concat(dfs, axis=0)
3283
3346
  df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
3284
- #df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3285
- #df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3286
- df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
3287
-
3347
+
3288
3348
  if settings['cell_plate_metadata'] != None:
3289
3349
  df = df.dropna(subset='host_cell')
3290
3350
 
@@ -3297,7 +3357,6 @@ def plot_data_from_db(settings):
3297
3357
  df = df.dropna(subset=settings['data_column'])
3298
3358
  df = df.dropna(subset=settings['grouping_column'])
3299
3359
 
3300
-
3301
3360
  src = srcs[0]
3302
3361
  dst = os.path.join(src, 'results', settings['graph_name'])
3303
3362
  os.makedirs(dst, exist_ok=True)
spacr/settings.py CHANGED
@@ -2,7 +2,6 @@ import os, ast
2
2
 
3
3
  def set_default_plot_merge_settings():
4
4
  settings = {}
5
- settings.setdefault('uninfected', True)
6
5
  settings.setdefault('pathogen_limit', 10)
7
6
  settings.setdefault('nuclei_limit', 1)
8
7
  settings.setdefault('remove_background', False)
@@ -181,8 +180,8 @@ def set_default_umap_image_settings(settings={}):
181
180
  settings.setdefault('n_neighbors', 1000)
182
181
  settings.setdefault('min_dist', 0.1)
183
182
  settings.setdefault('metric', 'euclidean')
184
- settings.setdefault('eps', 0.5)
185
- settings.setdefault('min_samples', 1000)
183
+ settings.setdefault('eps', 0.9)
184
+ settings.setdefault('min_samples', 100)
186
185
  settings.setdefault('filter_by', 'channel_0')
187
186
  settings.setdefault('img_zoom', 0.5)
188
187
  settings.setdefault('plot_by_cluster', True)
@@ -201,16 +200,13 @@ def set_default_umap_image_settings(settings={}):
201
200
  settings.setdefault('col_to_compare', 'column_name')
202
201
  settings.setdefault('pos', 'c1')
203
202
  settings.setdefault('neg', 'c2')
203
+ settings.setdefault('mix', 'c3')
204
204
  settings.setdefault('embedding_by_controls', False)
205
205
  settings.setdefault('plot_images', True)
206
206
  settings.setdefault('reduction_method','umap')
207
207
  settings.setdefault('save_figure', False)
208
208
  settings.setdefault('n_jobs', -1)
209
209
  settings.setdefault('color_by', None)
210
- settings.setdefault('neg', 'c1')
211
- settings.setdefault('pos', 'c2')
212
- settings.setdefault('mix', 'c3')
213
- settings.setdefault('mix', 'c3')
214
210
  settings.setdefault('exclude_conditions', None)
215
211
  settings.setdefault('analyze_clusters', False)
216
212
  settings.setdefault('resnet_features', False)
@@ -295,7 +291,6 @@ def set_default_analyze_screen(settings):
295
291
  settings.setdefault('exclude',None)
296
292
  settings.setdefault('nuclei_limit',True)
297
293
  settings.setdefault('pathogen_limit',3)
298
- settings.setdefault('uninfected',True)
299
294
  settings.setdefault('n_repeats',10)
300
295
  settings.setdefault('top_features',30)
301
296
  settings.setdefault('remove_low_variance_features',True)
@@ -353,7 +348,6 @@ def set_generate_training_dataset_defaults(settings):
353
348
  settings.setdefault('tables',None)
354
349
  settings.setdefault('nuclei_limit',True)
355
350
  settings.setdefault('pathogen_limit',True)
356
- settings.setdefault('uninfected',True)
357
351
  settings.setdefault('png_type','cell_png')
358
352
 
359
353
  return settings
@@ -467,7 +461,6 @@ def get_analyze_recruitment_default_settings(settings):
467
461
  settings.setdefault('plot_nr',3)
468
462
  settings.setdefault('plot_control',True)
469
463
  settings.setdefault('figuresize',10)
470
- settings.setdefault('uninfected',True)
471
464
  settings.setdefault('pathogen_limit',10)
472
465
  settings.setdefault('nuclei_limit',1)
473
466
  settings.setdefault('cells_per_well',0)
@@ -691,7 +684,6 @@ expected_types = {
691
684
  "measurement": str,
692
685
  "nr_imgs": int,
693
686
  "um_per_pixel": (int, float),
694
- "uninfected": bool,
695
687
  "pathogen_limit": int,
696
688
  "nuclei_limit": int,
697
689
  "filter_min_max": (list, type(None)),
@@ -898,7 +890,7 @@ categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset
898
890
  "Plot": ["plot", "plot_control", "plot_nr", "examples_to_plot", "normalize_plots", "cmap", "figuresize", "plot_cluster_grids", "img_zoom", "row_limit", "color_by", "plot_images", "smooth_lines", "plot_points", "plot_outlines", "black_background", "plot_by_cluster", "heatmap_feature","grouping","min_max","cmap","save_figure"],
899
891
  "Test": ["test_mode", "test_images", "random_test", "test_nr", "test", "test_split"],
900
892
  "Timelapse": ["timelapse", "fps", "timelapse_displacement", "timelapse_memory", "timelapse_frame_limits", "timelapse_remove_transient", "timelapse_mode", "timelapse_objects", "compartments"],
901
- "Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "uninfected", "background", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
893
+ "Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "background", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
902
894
  "Miscellaneous": ["all_to_mip", "pick_slice", "skip_mode", "upscale", "upscale_factor"]
903
895
  }
904
896
 
@@ -1080,7 +1072,6 @@ def generate_fields(variables, scrollable_frame):
1080
1072
  "img_zoom": "(float) - Zoom factor for the images in plots.",
1081
1073
  "nuclei_limit": "(int) - Whether to include multinucleated cells in the analysis.",
1082
1074
  "pathogen_limit": "(int) - Whether to include multi-infected cells in the analysis.",
1083
- "uninfected": "(bool) - Whether to include non-infected cells in the analysis.",
1084
1075
  "uninfected": "(bool) - Whether to include uninfected cells in the analysis.",
1085
1076
  "init_weights": "(bool) - Whether to initialize weights for the model.",
1086
1077
  "src": "(str) - Path to the folder containing the images.",
spacr/submodules.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import seaborn as sns
2
- import os, random, sqlite3
2
+ import os, random, sqlite3, re, shap
3
3
  import pandas as pd
4
4
  import numpy as np
5
5
  import cellpose
@@ -7,6 +7,9 @@ from skimage.measure import regionprops, label
7
7
  from cellpose import models as cp_models
8
8
  from cellpose import train as train_cp
9
9
  from IPython.display import display
10
+ from sklearn.ensemble import RandomForestClassifier
11
+ from sklearn.inspection import permutation_importance
12
+ from math import pi
10
13
 
11
14
  import matplotlib.pyplot as plt
12
15
  from natsort import natsorted
@@ -43,9 +46,8 @@ def analyze_recruitment(settings={}):
43
46
  tables=['cell', 'nucleus', 'pathogen','cytoplasm'],
44
47
  verbose=True,
45
48
  nuclei_limit=settings['nuclei_limit'],
46
- pathogen_limit=settings['pathogen_limit'],
47
- uninfected=settings['uninfected'])
48
-
49
+ pathogen_limit=settings['pathogen_limit'])
50
+
49
51
  df = annotate_conditions(df,
50
52
  cells=settings['cell_types'],
51
53
  cell_loc=settings['cell_plate_metadata'],
@@ -550,4 +552,296 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
550
552
  fig_1 = plot_line(df, x_column = 'pc_fraction', y_columns=y_columns, group_column=None, xlabel=None, ylabel='Fraction', title=None, figsize=(10, 6), save_path=save_paths[0])
551
553
  fig_2 = plot_line(df, x_column = 'nc_fraction', y_columns=y_columns, group_column=None, xlabel=None, ylabel='Fraction', title=None, figsize=(10, 6), save_path=save_paths[1])
552
554
 
553
- return [fig_1, fig_2]
555
+ return [fig_1, fig_2]
556
+
557
+ def interperate_vision_model(settings={}):
558
+
559
+ from .io import _read_and_merge_data
560
+
561
+ def generate_comparison_columns(df, compartments=['cell', 'nucleus', 'pathogen', 'cytoplasm']):
562
+
563
+ comparison_dict = {}
564
+
565
+ # Get columns by compartment
566
+ compartment_columns = {comp: [col for col in df.columns if col.startswith(comp)] for comp in compartments}
567
+
568
+ for comp0, comp0_columns in compartment_columns.items():
569
+ for comp0_col in comp0_columns:
570
+ related_cols = []
571
+ base_col_name = comp0_col.replace(comp0, '') # Base feature name without compartment prefix
572
+
573
+ # Look for matching columns in other compartments
574
+ for prefix, prefix_columns in compartment_columns.items():
575
+ if prefix == comp0: # Skip same-compartment comparisons
576
+ continue
577
+ # Check if related column exists in other compartment
578
+ related_col = prefix + base_col_name
579
+ if related_col in df.columns:
580
+ related_cols.append(related_col)
581
+ new_col_name = f"{prefix}_{comp0}{base_col_name}" # Format: prefix_comp0_base
582
+
583
+ # Calculate ratio and handle infinite or NaN values
584
+ df[new_col_name] = df[related_col] / df[comp0_col]
585
+ df[new_col_name].replace([float('inf'), -float('inf')], pd.NA, inplace=True) # Replace inf values with NA
586
+ df[new_col_name].fillna(0, inplace=True) # Replace NaN values with 0 for ease of further calculations
587
+
588
+ # Generate all-to-all comparisons
589
+ if related_cols:
590
+ comparison_dict[comp0_col] = related_cols
591
+ for i, rel_col_1 in enumerate(related_cols):
592
+ for rel_col_2 in related_cols[i + 1:]:
593
+ # Create a new column name for each pairwise comparison
594
+ comp1, comp2 = rel_col_1.split('_')[0], rel_col_2.split('_')[0]
595
+ new_col_name_all = f"{comp1}_{comp2}{base_col_name}"
596
+
597
+ # Calculate pairwise ratio and handle infinite or NaN values
598
+ df[new_col_name_all] = df[rel_col_1] / df[rel_col_2]
599
+ df[new_col_name_all].replace([float('inf'), -float('inf')], pd.NA, inplace=True) # Replace inf with NA
600
+ df[new_col_name_all].fillna(0, inplace=True) # Replace NaN with 0
601
+
602
+ return df, comparison_dict
603
+
604
+ def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pathogen'], name='compartment', include_all=False):
605
+
606
+ # Function to determine compartment based on multiple matches
607
+ def find_feature_class(feature, compartments):
608
+ matches = [compartment for compartment in compartments if re.search(compartment, feature)]
609
+ if len(matches) > 1:
610
+ return '-'.join(matches)
611
+ elif matches:
612
+ return matches[0]
613
+ else:
614
+ return None
615
+
616
+ from spacr.plot import spacrGraph
617
+
618
+ df[name] = df['feature'].apply(lambda x: find_feature_class(x, feature_groups))
619
+
620
+ if name == 'channel':
621
+ df['channel'].fillna('morphology', inplace=True)
622
+
623
+ # Create new DataFrame with summed importance for each compartment and channel
624
+ importance_sum = df.groupby(name)['importance'].sum().reset_index(name=f'{name}_importance_sum')
625
+
626
+ if include_all:
627
+ total_compartment_importance = importance_sum[f'{name}_importance_sum'].sum()
628
+ importance_sum = pd.concat(
629
+ [importance_sum,
630
+ pd.DataFrame(
631
+ [{name: 'all', f'{name}_importance_sum': total_compartment_importance}])]
632
+ , ignore_index=True)
633
+
634
+ return importance_sum
635
+
636
+ # Function to create radar plot for individual and combined values
637
+ def create_extended_radar_plot(values, labels, title):
638
+ values = list(values) + [values[0]] # Close the loop for radar chart
639
+ angles = [n / float(len(labels)) * 2 * pi for n in range(len(labels))]
640
+ angles += angles[:1]
641
+
642
+ fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
643
+ ax.plot(angles, values, linewidth=2, linestyle='solid')
644
+ ax.fill(angles, values, alpha=0.25)
645
+
646
+ ax.set_xticks(angles[:-1])
647
+ ax.set_xticklabels(labels, fontsize=10, rotation=45, ha='right')
648
+ plt.title(title, pad=20)
649
+ plt.show()
650
+
651
+ def extract_compartment_channel(feature_name):
652
+ # Identify compartment as the first part before an underscore
653
+ compartment = feature_name.split('_')[0]
654
+
655
+ if compartment == 'cells':
656
+ compartment = 'cell'
657
+
658
+ # Identify channels based on substring presence
659
+ channels = []
660
+ if 'channel_0' in feature_name:
661
+ channels.append('channel_0')
662
+ if 'channel_1' in feature_name:
663
+ channels.append('channel_1')
664
+ if 'channel_2' in feature_name:
665
+ channels.append('channel_2')
666
+ if 'channel_3' in feature_name:
667
+ channels.append('channel_3')
668
+
669
+ # If multiple channels are found, join them with a '+'
670
+ if channels:
671
+ channel = ' + '.join(channels)
672
+ else:
673
+ channel = 'morphology' # Use 'morphology' if no channel identifier is found
674
+
675
+ return (compartment, channel)
676
+
677
+ def read_and_preprocess_data(settings):
678
+
679
+ df, _ = _read_and_merge_data(
680
+ locs=[settings['src']+'/measurements/measurements.db'],
681
+ tables=settings['tables'],
682
+ verbose=True,
683
+ nuclei_limit=settings['nuclei_limit'],
684
+ pathogen_limit=settings['pathogen_limit']
685
+ )
686
+
687
+ df, _dict = generate_comparison_columns(df, compartments=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
688
+ print(f"Expanded dataframe to {len(df.columns)} columns with relative features")
689
+ scores_df = pd.read_csv(settings['scores'])
690
+
691
+ # Clean and align columns for merging
692
+ df['object_label'] = df['object_label'].str.replace('o', '')
693
+
694
+ if 'row_name' not in scores_df.columns:
695
+ scores_df['row_name'] = scores_df['row']
696
+
697
+ if 'column_name' not in scores_df.columns:
698
+ scores_df['column_name'] = scores_df['col']
699
+
700
+ if 'object_label' not in scores_df.columns:
701
+ scores_df['object_label'] = scores_df['object']
702
+
703
+ # Remove the 'o' prefix from 'object_label' in df, ensuring it is a string type
704
+ df['object_label'] = df['object_label'].str.replace('o', '').astype(str)
705
+
706
+ # Ensure 'object_label' in scores_df is also a string
707
+ scores_df['object_label'] = scores_df['object'].astype(str)
708
+
709
+ # Ensure all join columns have the same data type in both DataFrames
710
+ df[['plate', 'row_name', 'column_name', 'field', 'object_label']] = df[['plate', 'row_name', 'column_name', 'field', 'object_label']].astype(str)
711
+ scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label']] = scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label']].astype(str)
712
+
713
+ # Select only the necessary columns from scores_df for merging
714
+ scores_df = scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label', settings['score_column']]]
715
+
716
+ # Now merge DataFrames
717
+ merged_df = pd.merge(df, scores_df, on=['plate', 'row_name', 'column_name', 'field', 'object_label'], how='inner')
718
+
719
+ # Separate numerical features and the score column
720
+ X = merged_df.select_dtypes(include='number').drop(columns=[settings['score_column']])
721
+ y = merged_df[settings['score_column']]
722
+
723
+ return X, y, merged_df
724
+
725
+ X, y, merged_df = read_and_preprocess_data(settings)
726
+
727
+ output = {}
728
+
729
+ # Step 1: Feature Importance using Random Forest
730
+ if settings['feature_importance'] or settings['feature_importance']:
731
+ model = RandomForestClassifier(random_state=42, n_jobs=settings['n_jobs'])
732
+ model.fit(X, y)
733
+
734
+ if settings['feature_importance']:
735
+ print(f"Feature Importance ...")
736
+ feature_importances = model.feature_importances_
737
+ feature_importance_df = pd.DataFrame({'feature': X.columns, 'importance': feature_importances})
738
+ feature_importance_df = feature_importance_df.sort_values(by='importance', ascending=False)
739
+ top_feature_importance_df = feature_importance_df.head(settings['top_features'])
740
+
741
+ # Plot Feature Importance
742
+ plt.figure(figsize=(10, 6))
743
+ plt.barh(top_feature_importance_df['feature'], top_feature_importance_df['importance'])
744
+ plt.xlabel('Importance')
745
+ plt.title(f"Top {settings['top_features']} Features - Feature Importance")
746
+ plt.gca().invert_yaxis()
747
+ plt.show()
748
+
749
+ output['feature_importance'] = feature_importance_df
750
+ fi_compartment_df = group_feature_class(feature_importance_df, feature_groups=settings['tables'], name='compartment', include_all=settings['include_all'])
751
+ fi_channel_df = group_feature_class(feature_importance_df, feature_groups=settings['channels'], name='channel', include_all=settings['include_all'])
752
+
753
+ output['feature_importance_compartment'] = fi_compartment_df
754
+ output['feature_importance_channel'] = fi_channel_df
755
+
756
+ # Step 2: Permutation Importance
757
+ if settings['permutation_importance']:
758
+ print(f"Permutation Importance ...")
759
+ perm_importance = permutation_importance(model, X, y, n_repeats=10, random_state=42, n_jobs=settings['n_jobs'])
760
+ perm_importance_df = pd.DataFrame({'feature': X.columns, 'importance': perm_importance.importances_mean})
761
+ perm_importance_df = perm_importance_df.sort_values(by='importance', ascending=False)
762
+ top_perm_importance_df = perm_importance_df.head(settings['top_features'])
763
+
764
+ # Plot Permutation Importance
765
+ plt.figure(figsize=(10, 6))
766
+ plt.barh(top_perm_importance_df['feature'], top_perm_importance_df['importance'])
767
+ plt.xlabel('Importance')
768
+ plt.title(f"Top {settings['top_features']} Features - Permutation Importance")
769
+ plt.gca().invert_yaxis()
770
+ plt.show()
771
+
772
+ output['permutation_importance'] = perm_importance_df
773
+
774
+ # Step 3: SHAP Analysis
775
+ if settings['shap']:
776
+ print(f"SHAP Analysis ...")
777
+
778
+ # Select top N features based on Random Forest importance and fit the model on these features only
779
+ top_features = feature_importance_df.head(settings['top_features'])['feature']
780
+ X_top = X[top_features]
781
+
782
+ # Refit the model on this subset of features
783
+ model = RandomForestClassifier(random_state=42, n_jobs=settings['n_jobs'])
784
+ model.fit(X_top, y)
785
+
786
+ # Sample a smaller subset of rows to speed up SHAP
787
+ if settings['shap_sample']:
788
+ sample = int(len(X_top) / 100)
789
+ X_sample = X_top.sample(min(sample, len(X_top)), random_state=42)
790
+ else:
791
+ X_sample = X_top
792
+
793
+ # Initialize SHAP explainer with the same subset of features
794
+ explainer = shap.Explainer(model.predict, X_sample)
795
+ shap_values = explainer(X_sample, max_evals=1500)
796
+
797
+ # Plot SHAP summary for the selected sample and top features
798
+ shap.summary_plot(shap_values, X_sample, max_display=settings['top_features'])
799
+
800
+ # Convert SHAP values to a DataFrame for easier manipulation
801
+ shap_df = pd.DataFrame(shap_values.values, columns=X_sample.columns)
802
+
803
+ # Apply the function to create MultiIndex columns with compartment and channel
804
+ shap_df.columns = pd.MultiIndex.from_tuples(
805
+ [extract_compartment_channel(feat) for feat in shap_df.columns],
806
+ names=['compartment', 'channel']
807
+ )
808
+
809
+ # Aggregate SHAP values by compartment and channel
810
+ compartment_mean = shap_df.abs().groupby(level='compartment', axis=1).mean().mean(axis=0)
811
+ channel_mean = shap_df.abs().groupby(level='channel', axis=1).mean().mean(axis=0)
812
+
813
+ # Calculate combined importance for each pair of compartments and channels
814
+ combined_compartment = {}
815
+ for i, comp1 in enumerate(compartment_mean.index):
816
+ for comp2 in compartment_mean.index[i+1:]:
817
+ combined_compartment[f"{comp1} + {comp2}"] = shap_df.loc[:, (comp1, slice(None))].abs().mean().mean() + \
818
+ shap_df.loc[:, (comp2, slice(None))].abs().mean().mean()
819
+
820
+ combined_channel = {}
821
+ for i, chan1 in enumerate(channel_mean.index):
822
+ for chan2 in channel_mean.index[i+1:]:
823
+ combined_channel[f"{chan1} + {chan2}"] = shap_df.loc[:, (slice(None), chan1)].abs().mean().mean() + \
824
+ shap_df.loc[:, (slice(None), chan2)].abs().mean().mean()
825
+
826
+ # Prepare values and labels for radar charts
827
+ all_compartment_importance = list(compartment_mean.values) + list(combined_compartment.values())
828
+ all_compartment_labels = list(compartment_mean.index) + list(combined_compartment.keys())
829
+
830
+ all_channel_importance = list(channel_mean.values) + list(combined_channel.values())
831
+ all_channel_labels = list(channel_mean.index) + list(combined_channel.keys())
832
+
833
+ # Create radar plots for compartments and channels
834
+ #create_extended_radar_plot(all_compartment_importance, all_compartment_labels, "SHAP Importance by Compartment (Individual and Combined)")
835
+ #create_extended_radar_plot(all_channel_importance, all_channel_labels, "SHAP Importance by Channel (Individual and Combined)")
836
+
837
+ output['shap'] = shap_df
838
+
839
+ if settings['save']:
840
+ dst = os.path.join(settings['src'], 'results')
841
+ os.makedirs(dst, exist_ok=True)
842
+ for key, df in output.items():
843
+ save_path = os.path.join(dst, f"{key}.csv")
844
+ df.to_csv(save_path)
845
+ print(f"Saved {save_path}")
846
+
847
+ return output