spacr 0.3.60__py3-none-any.whl → 0.3.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/app_annotate.py CHANGED
@@ -4,14 +4,6 @@ from .gui import MainApp
4
4
  from .gui_elements import set_dark_style, spacrButton
5
5
 
6
6
  def convert_to_number(value):
7
-
8
- """
9
- Converts a string value to an integer if possible, otherwise converts to a float.
10
- Args:
11
- value (str): The string representation of the number.
12
- Returns:
13
- int or float: The converted number.
14
- """
15
7
  try:
16
8
  return int(value)
17
9
  except ValueError:
spacr/core.py CHANGED
@@ -465,10 +465,8 @@ def generate_image_umap(settings={}):
465
465
  display(settings_df)
466
466
 
467
467
  db_paths = get_db_paths(settings['src'])
468
-
469
468
  tables = settings['tables'] + ['png_list']
470
469
  all_df = pd.DataFrame()
471
- #image_paths = []
472
470
 
473
471
  for i,db_path in enumerate(db_paths):
474
472
  df = _read_and_join_tables(db_path, table_names=tables)
@@ -476,7 +474,7 @@ def generate_image_umap(settings={}):
476
474
  all_df = pd.concat([all_df, df], axis=0)
477
475
  #image_paths.extend(image_paths_tmp)
478
476
 
479
- all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
477
+ all_df['cond'] = all_df['column_name'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
480
478
 
481
479
  if settings['exclude_conditions']:
482
480
  if isinstance(settings['exclude_conditions'], str):
@@ -495,7 +493,10 @@ def generate_image_umap(settings={}):
495
493
 
496
494
  # Extract and reset the index for the column to compare
497
495
  col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
498
-
496
+
497
+ #if settings['only_top_features']:
498
+ # column_list = None
499
+
499
500
  # Preprocess the data to obtain numeric data
500
501
  numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
501
502
 
@@ -571,7 +572,11 @@ def generate_image_umap(settings={}):
571
572
  print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {grid_path}')
572
573
 
573
574
  # Add cluster labels to the dataframe
574
- all_df['cluster'] = labels
575
+ if len(labels) > 0:
576
+ all_df['cluster'] = labels
577
+ else:
578
+ all_df['cluster'] = 1 # Assign a default cluster label
579
+ print("No clusters found. Consider reducing 'min_samples' or increasing 'eps' for DBSCAN.")
575
580
 
576
581
  # Save the results to a CSV file
577
582
  results_dir = os.path.join(settings['src'][0], 'results')
@@ -653,7 +658,7 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
653
658
  df = _read_and_join_tables(db_path, table_names=tables)
654
659
  all_df = pd.concat([all_df, df], axis=0)
655
660
 
656
- all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
661
+ all_df['cond'] = all_df['column_name'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
657
662
 
658
663
  if settings['exclude_conditions']:
659
664
  if isinstance(settings['exclude_conditions'], str):
@@ -882,7 +887,7 @@ def generate_screen_graphs(settings):
882
887
  db_loc = [os.path.join(src, 'measurements', 'measurements.db')]
883
888
 
884
889
  # Read and merge data from the database
885
- df, _ = _read_and_merge_data(db_loc, settings['tables'], verbose=True, nuclei_limit=settings['nuclei_limit'], pathogen_limit=settings['pathogen_limit'], uninfected=settings['uninfected'])
890
+ df, _ = _read_and_merge_data(db_loc, settings['tables'], verbose=True, nuclei_limit=settings['nuclei_limit'], pathogen_limit=settings['pathogen_limit'])
886
891
 
887
892
  # Annotate the data
888
893
  df = annotate_conditions(df, cells=settings['cells'], cell_loc=None, pathogens=settings['controls'], pathogen_loc=settings['controls_loc'], treatments=None, treatment_loc=None)
spacr/gui_utils.py CHANGED
@@ -225,14 +225,30 @@ def annotate(settings):
225
225
  conn.close()
226
226
 
227
227
  root = tk.Tk()
228
- root.geometry(settings['geom'])
229
- app = AnnotateApp(root, db, src, image_type=settings['image_type'], channels=settings['channels'], image_size=settings['img_size'], grid_rows=settings['rows'], grid_cols=settings['columns'], annotation_column=settings['annotation_column'], normalize=settings['normalize'], percentiles=settings['percentiles'], measurement=settings['measurement'], threshold=settings['threshold'], normalize_channels=settings['normalize_channels'])
230
- next_button = tk.Button(root, text="Next", command=app.next_page)
231
- next_button.grid(row=app.grid_rows, column=app.grid_cols - 1)
232
- back_button = tk.Button(root, text="Back", command=app.previous_page)
233
- back_button.grid(row=app.grid_rows, column=app.grid_cols - 2)
234
- exit_button = tk.Button(root, text="Exit", command=app.shutdown)
235
- exit_button.grid(row=app.grid_rows, column=app.grid_cols - 3)
228
+
229
+ root.geometry(f"{root.winfo_screenwidth()}x{root.winfo_screenheight()}")
230
+
231
+ db_path = os.path.join(settings['src'], 'measurements/measurements.db')
232
+
233
+ app = AnnotateApp(root,
234
+ db_path=db_path,
235
+ src=settings['src'],
236
+ image_type=settings['image_type'],
237
+ channels=settings['channels'],
238
+ image_size=settings['img_size'],
239
+ annotation_column=settings['annotation_column'],
240
+ normalize=settings['normalize'],
241
+ percentiles=settings['percentiles'],
242
+ measurement=settings['measurement'],
243
+ threshold=settings['threshold'],
244
+ normalize_channels=settings['normalize_channels'])
245
+
246
+ #next_button = tk.Button(root, text="Next", command=app.next_page)
247
+ #next_button.grid(row=app.grid_rows, column=app.grid_cols - 1)
248
+ #back_button = tk.Button(root, text="Back", command=app.previous_page)
249
+ #back_button.grid(row=app.grid_rows, column=app.grid_cols - 2)
250
+ #exit_button = tk.Button(root, text="Exit", command=app.shutdown)
251
+ #exit_button.grid(row=app.grid_rows, column=app.grid_cols - 3)
236
252
 
237
253
  app.load_images()
238
254
  root.mainloop()
spacr/io.py CHANGED
@@ -2089,150 +2089,6 @@ def _read_db(db_loc, tables):
2089
2089
  conn.close()
2090
2090
  return dfs
2091
2091
 
2092
- def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
2093
- """
2094
- Read and merge data from SQLite databases and perform data preprocessing.
2095
-
2096
- Parameters:
2097
- - locs (list): A list of file paths to the SQLite database files.
2098
- - tables (list): A list of table names to read from the databases.
2099
- - verbose (bool): Whether to print verbose output. Default is False.
2100
- - nuclei_limit (bool): Whether to include multinucleated cells. Default is False.
2101
- - pathogen_limit (bool): Whether to include cells with multiple infections. Default is False.
2102
- - uninfected (bool): Whether to include non-infected cells. Default is False.
2103
-
2104
- Returns:
2105
- - merged_df (pandas.DataFrame): The merged and preprocessed dataframe.
2106
- - obj_df_ls (list): A list of pandas DataFrames, each containing the data for a specific object type.
2107
- """
2108
-
2109
- from .utils import _split_data
2110
-
2111
- #Extract plate DataFrames
2112
- all_dfs = []
2113
- for loc in locs:
2114
- db_dfs = _read_db(loc, tables)
2115
- all_dfs.append(db_dfs)
2116
-
2117
- #Extract Tables from DataFrames and concatinate rows
2118
- for i, dfs in enumerate(all_dfs):
2119
- if 'cell' in tables:
2120
- cell = dfs[0]
2121
- print(f'plate: {i+1} cells:{len(cell)}')
2122
-
2123
- if 'nucleus' in tables:
2124
- nucleus = dfs[1]
2125
- print(f'plate: {i+1} nucleus:{len(nucleus)} ')
2126
-
2127
- if 'pathogen' in tables:
2128
- pathogen = dfs[2]
2129
-
2130
- print(f'plate: {i+1} pathogens:{len(pathogen)}')
2131
- if 'cytoplasm' in tables:
2132
- if not 'pathogen' in tables:
2133
- cytoplasm = dfs[2]
2134
- else:
2135
- cytoplasm = dfs[3]
2136
- print(f'plate: {i+1} cytoplasms: {len(cytoplasm)}')
2137
-
2138
- if i > 0:
2139
- if 'cell' in tables:
2140
- cells = pd.concat([cells, cell], axis = 0)
2141
- if 'nucleus' in tables:
2142
- nucleus = pd.concat([nucleus, nucleus], axis = 0)
2143
- if 'pathogen' in tables:
2144
- pathogens = pd.concat([pathogens, pathogen], axis = 0)
2145
- if 'cytoplasm' in tables:
2146
- cytoplasms = pd.concat([cytoplasms, cytoplasm], axis = 0)
2147
- else:
2148
- if 'cell' in tables:
2149
- cells = cell.copy()
2150
- if 'nucleus' in tables:
2151
- nucleus = nucleus.copy()
2152
- if 'pathogen' in tables:
2153
- pathogens = pathogen.copy()
2154
- if 'cytoplasm' in tables:
2155
- cytoplasms = cytoplasm.copy()
2156
-
2157
- #Add an o in front of all object and cell lables to convert them to strings
2158
- if 'cell' in tables:
2159
- cells = cells.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2160
- cells = cells.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
2161
- cells_g_df, metadata = _split_data(cells, 'prcfo', 'object_label')
2162
- print(f'cells: {len(cells)}')
2163
- print(f'cells grouped: {len(cells_g_df)}')
2164
- if 'cytoplasm' in tables:
2165
- cytoplasms = cytoplasms.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2166
- cytoplasms = cytoplasms.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
2167
- cytoplasms_g_df, _ = _split_data(cytoplasms, 'prcfo', 'object_label')
2168
- merged_df = cells_g_df.merge(cytoplasms_g_df, left_index=True, right_index=True)
2169
- print(f'cytoplasms: {len(cytoplasms)}')
2170
- print(f'cytoplasms grouped: {len(cytoplasms_g_df)}')
2171
- if 'nucleus' in tables:
2172
- nucleus = nucleus.dropna(subset=['cell_id'])
2173
- nucleus = nucleus.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2174
- nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2175
- nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2176
- nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
2177
- if nuclei_limit == False:
2178
- #nucleus = nucleus[~nucleus['prcfo'].duplicated()]
2179
- nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
2180
- nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
2181
- print(f'nucleus: {len(nucleus)}')
2182
- print(f'nucleus grouped: {len(nucleus_g_df)}')
2183
- if 'cytoplasm' in tables:
2184
- merged_df = merged_df.merge(nucleus_g_df, left_index=True, right_index=True)
2185
- else:
2186
- merged_df = cells_g_df.merge(nucleus_g_df, left_index=True, right_index=True)
2187
- if 'pathogen' in tables:
2188
- pathogens = pathogens.dropna(subset=['cell_id'])
2189
- pathogens = pathogens.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2190
- pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2191
- pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2192
- pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
2193
- if uninfected == False:
2194
- pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
2195
- if pathogen_limit == False:
2196
- pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
2197
- pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
2198
- print(f'pathogens: {len(pathogens)}')
2199
- print(f'pathogens grouped: {len(pathogens_g_df)}')
2200
- merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
2201
-
2202
- #Add prc column (plate row column)
2203
- metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name'])
2204
-
2205
- #Count cells per well
2206
- cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
2207
-
2208
- cells_well.reset_index(inplace=True)
2209
- cells_well.rename(columns={'object_label': 'cells_per_well'}, inplace=True)
2210
- metadata = pd.merge(metadata, cells_well, on='prc', how='inner', suffixes=('', '_drop_col'))
2211
- object_label_cols = [col for col in metadata.columns if '_drop_col' in col]
2212
- metadata.drop(columns=object_label_cols, inplace=True)
2213
-
2214
- #Add prcfo column (plate row column field object)
2215
- metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name']+ '_' +x['field']+ '_' +x['object_label'])
2216
- metadata.set_index('prcfo', inplace=True)
2217
-
2218
- merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
2219
-
2220
- merged_df = merged_df.dropna(axis=1)
2221
-
2222
- print(f'Generated dataframe with: {len(merged_df.columns)} columns and {len(merged_df)} rows')
2223
-
2224
- obj_df_ls = []
2225
- if 'cell' in tables:
2226
- obj_df_ls.append(cells)
2227
- if 'cytoplasm' in tables:
2228
- obj_df_ls.append(cytoplasms)
2229
- if 'nucleus' in tables:
2230
- obj_df_ls.append(nucleus)
2231
- if 'pathogen' in tables:
2232
- obj_df_ls.append(pathogens)
2233
-
2234
- return merged_df, obj_df_ls
2235
-
2236
2092
  def _results_to_csv(src, df, df_well):
2237
2093
  """
2238
2094
  Save the given dataframes as CSV files in the specified directory.
@@ -2420,7 +2276,7 @@ def _read_db(db_loc, tables):
2420
2276
  conn.close() # Close the connection
2421
2277
  return dfs
2422
2278
 
2423
- def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
2279
+ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False):
2424
2280
 
2425
2281
  from .utils import _split_data
2426
2282
 
@@ -2532,11 +2388,6 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
2532
2388
  pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2533
2389
  pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
2534
2390
 
2535
- print(f"before noninfected: {len(pathogens)}")
2536
- if uninfected == False:
2537
- pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
2538
- print(f"after noninfected: {len(pathogens)}")
2539
-
2540
2391
  if isinstance(pathogen_limit, bool):
2541
2392
  if pathogen_limit == False:
2542
2393
  pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
@@ -2929,8 +2780,8 @@ def generate_training_dataset(settings):
2929
2780
  tables=tables,
2930
2781
  verbose=False,
2931
2782
  nuclei_limit=settings['nuclei_limit'],
2932
- pathogen_limit=settings['pathogen_limit'],
2933
- uninfected=settings['uninfected'])
2783
+ pathogen_limit=settings['pathogen_limit'])
2784
+
2934
2785
  [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
2935
2786
  filtered_png_list_df = png_list_df[png_list_df['prcfo'].isin(df.index)]
2936
2787
  return filtered_png_list_df
@@ -2952,8 +2803,7 @@ def generate_training_dataset(settings):
2952
2803
  tables=tables,
2953
2804
  verbose=False,
2954
2805
  nuclei_limit=settings['nuclei_limit'],
2955
- pathogen_limit=settings['pathogen_limit'],
2956
- uninfected=settings['uninfected'])
2806
+ pathogen_limit=settings['pathogen_limit'])
2957
2807
 
2958
2808
  print('length df 1', len(df))
2959
2809
  df = annotate_conditions(df, cells=['HeLa'], pathogens=['pathogen'], treatments=settings['classes'],
@@ -3034,7 +2884,6 @@ def generate_training_dataset(settings):
3034
2884
 
3035
2885
  if 'pathogen' not in settings['tables']:
3036
2886
  settings['pathogen_limit'] = 0
3037
- settings['uninfected'] = True
3038
2887
 
3039
2888
  # Set default settings and save
3040
2889
  settings = set_generate_training_dataset_defaults(settings)
spacr/ml.py CHANGED
@@ -1172,15 +1172,14 @@ def generate_ml_scores(settings):
1172
1172
  db_loc = [src+'/measurements/measurements.db']
1173
1173
  tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1174
1174
 
1175
- nuclei_limit, pathogen_limit, uninfected = settings['nuclei_limit'], settings['pathogen_limit'], settings['uninfected']
1175
+ nuclei_limit, pathogen_limit = settings['nuclei_limit'], settings['pathogen_limit']
1176
1176
 
1177
1177
  df, _ = _read_and_merge_data(db_loc,
1178
1178
  tables,
1179
1179
  settings['verbose'],
1180
1180
  nuclei_limit,
1181
- pathogen_limit,
1182
- uninfected)
1183
-
1181
+ pathogen_limit)
1182
+
1184
1183
  if settings['annotation_column'] is not None:
1185
1184
 
1186
1185
  settings['location_column'] = settings['annotation_column']
spacr/plot.py CHANGED
@@ -909,7 +909,7 @@ def plot_merged(src, settings):
909
909
  path = os.path.join(src, file)
910
910
  stack = np.load(path)
911
911
  print(f'Loaded: {path}')
912
- if not settings['uninfected']:
912
+ if settings['pathogen_limit'] > 0:
913
913
  if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
914
914
  stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
915
915
 
@@ -2198,8 +2198,8 @@ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot',
2198
2198
  tables,
2199
2199
  verbose=True,
2200
2200
  nuclei_limit=True,
2201
- pathogen_limit=True,
2202
- uninfected=True)
2201
+ pathogen_limit=True)
2202
+
2203
2203
  paths_df = _read_db(loc, tables=['png_list'])
2204
2204
  merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
2205
2205
  return merged_df
@@ -2435,7 +2435,9 @@ class spacrGraph:
2435
2435
 
2436
2436
  self.df = df
2437
2437
  self.grouping_column = grouping_column
2438
+ self.order = sorted(df[self.grouping_column].unique().tolist())
2438
2439
  self.data_column = data_column if isinstance(data_column, list) else [data_column]
2440
+
2439
2441
  self.graph_type = graph_type
2440
2442
  self.summary_func = summary_func
2441
2443
  self.order = order
@@ -2909,9 +2911,11 @@ class spacrGraph:
2909
2911
  ax.set_xlim(-0.5, num_groups - 0.5)
2910
2912
 
2911
2913
  # Set ticks to match the group labels in your DataFrame
2912
- group_labels = self.df[self.grouping_column].unique()
2913
- ax.set_xticks(range(len(group_labels)))
2914
- ax.set_xticklabels(group_labels, rotation=45, ha='right')
2914
+ #group_labels = self.df[self.grouping_column].unique()
2915
+ #group_labels = self.order
2916
+ #ax.set_xticks(range(len(group_labels)))
2917
+ #ax.set_xticklabels(group_labels, rotation=45, ha='right')
2918
+ plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
2915
2919
 
2916
2920
  # Customize elements based on the graph type
2917
2921
  if graph_type == 'bar':
@@ -2943,6 +2947,66 @@ class spacrGraph:
2943
2947
 
2944
2948
  # Redraw the figure to apply changes
2945
2949
  ax.figure.canvas.draw()
2950
+
2951
+ def _standerdize_figure_format_v1(self, ax, num_groups, graph_type):
2952
+ """
2953
+ Adjusts the figure layout (size, bar width, jitter, and spacing) based on the number of groups.
2954
+ """
2955
+ if graph_type in ['line', 'line_std']:
2956
+ print("Skipping layout adjustment for line graphs.")
2957
+ return # Skip layout adjustment for line graphs
2958
+
2959
+ correction_factor = 4
2960
+
2961
+ # Set figure size to ensure it remains square with a minimum size
2962
+ fig_size = max(6, num_groups * 2) / correction_factor
2963
+ ax.figure.set_size_inches(fig_size, fig_size)
2964
+
2965
+ # Configure layout based on the number of groups
2966
+ bar_width = min(0.8, 1.5 / num_groups) / correction_factor
2967
+ jitter_amount = min(0.1, 0.2 / num_groups) / correction_factor
2968
+ jitter_size = max(50 / num_groups, 200)
2969
+
2970
+ # Adjust x-axis limits to fit the specified order of groups
2971
+ ax.set_xlim(-0.5, len(self.order) - 0.5) # Use `self.order` length to ensure alignment
2972
+
2973
+ # Use `self.order` as the x-tick labels to maintain consistent ordering
2974
+ ax.set_xticks(range(len(self.order)))
2975
+ #ax.set_xticklabels(self.order, rotation=45, ha='right')
2976
+ plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
2977
+
2978
+ # Customize elements based on the graph type
2979
+ if graph_type == 'bar':
2980
+ # Adjust bars' width and position
2981
+ for bar in ax.patches:
2982
+ bar.set_width(bar_width)
2983
+ bar.set_x(bar.get_x() - bar_width / 2)
2984
+
2985
+ elif graph_type in ['jitter', 'jitter_bar', 'jitter_box']:
2986
+ # Adjust jitter points' position and size
2987
+ for coll in ax.collections:
2988
+ offsets = coll.get_offsets()
2989
+ offsets[:, 0] += jitter_amount # Shift jitter points slightly
2990
+ coll.set_offsets(offsets)
2991
+ coll.set_sizes([jitter_size] * len(offsets)) # Adjust point size dynamically
2992
+
2993
+ elif graph_type in ['box', 'violin']:
2994
+ # Adjust box width for consistent spacing
2995
+ for artist in ax.artists:
2996
+ artist.set_width(bar_width)
2997
+
2998
+ # Adjust legend and axis labels
2999
+ ax.tick_params(axis='x', labelsize=max(10, 15 - num_groups // 2))
3000
+ ax.tick_params(axis='y', labelsize=max(10, 15 - num_groups // 2))
3001
+
3002
+ # Adjust legend placement and size
3003
+ if ax.get_legend():
3004
+ ax.get_legend().set_bbox_to_anchor((1.05, 1))
3005
+ ax.get_legend().prop.set_size(max(8, 12 - num_groups // 3))
3006
+
3007
+ # Redraw the figure to apply changes
3008
+ ax.figure.canvas.draw()
3009
+
2946
3010
 
2947
3011
  def _create_bar_plot(self, ax):
2948
3012
  """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
@@ -2959,7 +3023,7 @@ class spacrGraph:
2959
3023
 
2960
3024
  summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
2961
3025
  error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
2962
- sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
3026
+ sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None, order=self.order)
2963
3027
 
2964
3028
  # Adjust the bar width manually
2965
3029
  if len(self.data_column) > 1:
@@ -2999,7 +3063,7 @@ class spacrGraph:
2999
3063
  hue = None
3000
3064
 
3001
3065
  # Create the jitter plot
3002
- sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax, alpha=0.6, size=16)
3066
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax, alpha=0.6, size=16, order=self.order)
3003
3067
 
3004
3068
  # Adjust legend and labels
3005
3069
  ax.set_xlabel(self.grouping_column)
@@ -3088,7 +3152,7 @@ class spacrGraph:
3088
3152
  hue = None
3089
3153
 
3090
3154
  # Create the box plot
3091
- sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
3155
+ sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
3092
3156
 
3093
3157
  # Adjust legend and labels
3094
3158
  ax.set_xlabel(self.grouping_column)
@@ -3117,7 +3181,7 @@ class spacrGraph:
3117
3181
  hue = None
3118
3182
 
3119
3183
  # Create the violin plot
3120
- sns.violinplot(data=self.df_melted,x=x_axis_column,y='Value', hue=self.hue,palette=self.sns_palette,ax=ax)
3184
+ sns.violinplot(data=self.df_melted,x=x_axis_column,y='Value', hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
3121
3185
 
3122
3186
  # Adjust legend and labels
3123
3187
  ax.set_xlabel(self.grouping_column)
@@ -3148,8 +3212,8 @@ class spacrGraph:
3148
3212
 
3149
3213
  summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
3150
3214
  error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
3151
- sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
3152
- sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16)
3215
+ sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None, order=self.order)
3216
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16, order=self.order)
3153
3217
 
3154
3218
  # Adjust the bar width manually
3155
3219
  if len(self.data_column) > 1:
@@ -3189,8 +3253,8 @@ class spacrGraph:
3189
3253
  hue = None
3190
3254
 
3191
3255
  # Create the box plot
3192
- sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
3193
- sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12)
3256
+ sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax, order=self.order)
3257
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12, order=self.order)
3194
3258
 
3195
3259
  # Adjust legend and labels
3196
3260
  ax.set_xlabel(self.grouping_column)
@@ -3264,12 +3328,11 @@ def plot_data_from_db(settings):
3264
3328
  [df1] = _read_db(db_loc, tables=[settings['table_names']])
3265
3329
  else:
3266
3330
  df1, _ = _read_and_merge_data(locs=[db_loc],
3267
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm'],
3331
+ tables = settings['tables'],
3268
3332
  verbose=settings['verbose'],
3269
3333
  nuclei_limit=settings['nuclei_limit'],
3270
- pathogen_limit=settings['pathogen_limit'],
3271
- uninfected=settings['uninfected'])
3272
-
3334
+ pathogen_limit=settings['pathogen_limit'])
3335
+
3273
3336
  dft = annotate_conditions(df1,
3274
3337
  cells=settings['cell_types'],
3275
3338
  cell_loc=settings['cell_plate_metadata'],
@@ -3281,10 +3344,7 @@ def plot_data_from_db(settings):
3281
3344
 
3282
3345
  df = pd.concat(dfs, axis=0)
3283
3346
  df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
3284
- #df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3285
- #df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3286
- df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
3287
-
3347
+
3288
3348
  if settings['cell_plate_metadata'] != None:
3289
3349
  df = df.dropna(subset='host_cell')
3290
3350
 
@@ -3297,7 +3357,6 @@ def plot_data_from_db(settings):
3297
3357
  df = df.dropna(subset=settings['data_column'])
3298
3358
  df = df.dropna(subset=settings['grouping_column'])
3299
3359
 
3300
-
3301
3360
  src = srcs[0]
3302
3361
  dst = os.path.join(src, 'results', settings['graph_name'])
3303
3362
  os.makedirs(dst, exist_ok=True)
spacr/settings.py CHANGED
@@ -2,7 +2,6 @@ import os, ast
2
2
 
3
3
  def set_default_plot_merge_settings():
4
4
  settings = {}
5
- settings.setdefault('uninfected', True)
6
5
  settings.setdefault('pathogen_limit', 10)
7
6
  settings.setdefault('nuclei_limit', 1)
8
7
  settings.setdefault('remove_background', False)
@@ -181,8 +180,8 @@ def set_default_umap_image_settings(settings={}):
181
180
  settings.setdefault('n_neighbors', 1000)
182
181
  settings.setdefault('min_dist', 0.1)
183
182
  settings.setdefault('metric', 'euclidean')
184
- settings.setdefault('eps', 0.5)
185
- settings.setdefault('min_samples', 1000)
183
+ settings.setdefault('eps', 0.9)
184
+ settings.setdefault('min_samples', 100)
186
185
  settings.setdefault('filter_by', 'channel_0')
187
186
  settings.setdefault('img_zoom', 0.5)
188
187
  settings.setdefault('plot_by_cluster', True)
@@ -201,16 +200,13 @@ def set_default_umap_image_settings(settings={}):
201
200
  settings.setdefault('col_to_compare', 'column_name')
202
201
  settings.setdefault('pos', 'c1')
203
202
  settings.setdefault('neg', 'c2')
203
+ settings.setdefault('mix', 'c3')
204
204
  settings.setdefault('embedding_by_controls', False)
205
205
  settings.setdefault('plot_images', True)
206
206
  settings.setdefault('reduction_method','umap')
207
207
  settings.setdefault('save_figure', False)
208
208
  settings.setdefault('n_jobs', -1)
209
209
  settings.setdefault('color_by', None)
210
- settings.setdefault('neg', 'c1')
211
- settings.setdefault('pos', 'c2')
212
- settings.setdefault('mix', 'c3')
213
- settings.setdefault('mix', 'c3')
214
210
  settings.setdefault('exclude_conditions', None)
215
211
  settings.setdefault('analyze_clusters', False)
216
212
  settings.setdefault('resnet_features', False)
@@ -295,7 +291,6 @@ def set_default_analyze_screen(settings):
295
291
  settings.setdefault('exclude',None)
296
292
  settings.setdefault('nuclei_limit',True)
297
293
  settings.setdefault('pathogen_limit',3)
298
- settings.setdefault('uninfected',True)
299
294
  settings.setdefault('n_repeats',10)
300
295
  settings.setdefault('top_features',30)
301
296
  settings.setdefault('remove_low_variance_features',True)
@@ -353,7 +348,6 @@ def set_generate_training_dataset_defaults(settings):
353
348
  settings.setdefault('tables',None)
354
349
  settings.setdefault('nuclei_limit',True)
355
350
  settings.setdefault('pathogen_limit',True)
356
- settings.setdefault('uninfected',True)
357
351
  settings.setdefault('png_type','cell_png')
358
352
 
359
353
  return settings
@@ -467,7 +461,6 @@ def get_analyze_recruitment_default_settings(settings):
467
461
  settings.setdefault('plot_nr',3)
468
462
  settings.setdefault('plot_control',True)
469
463
  settings.setdefault('figuresize',10)
470
- settings.setdefault('uninfected',True)
471
464
  settings.setdefault('pathogen_limit',10)
472
465
  settings.setdefault('nuclei_limit',1)
473
466
  settings.setdefault('cells_per_well',0)
@@ -691,7 +684,6 @@ expected_types = {
691
684
  "measurement": str,
692
685
  "nr_imgs": int,
693
686
  "um_per_pixel": (int, float),
694
- "uninfected": bool,
695
687
  "pathogen_limit": int,
696
688
  "nuclei_limit": int,
697
689
  "filter_min_max": (list, type(None)),
@@ -898,7 +890,7 @@ categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset
898
890
  "Plot": ["plot", "plot_control", "plot_nr", "examples_to_plot", "normalize_plots", "cmap", "figuresize", "plot_cluster_grids", "img_zoom", "row_limit", "color_by", "plot_images", "smooth_lines", "plot_points", "plot_outlines", "black_background", "plot_by_cluster", "heatmap_feature","grouping","min_max","cmap","save_figure"],
899
891
  "Test": ["test_mode", "test_images", "random_test", "test_nr", "test", "test_split"],
900
892
  "Timelapse": ["timelapse", "fps", "timelapse_displacement", "timelapse_memory", "timelapse_frame_limits", "timelapse_remove_transient", "timelapse_mode", "timelapse_objects", "compartments"],
901
- "Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "uninfected", "background", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
893
+ "Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "background", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
902
894
  "Miscellaneous": ["all_to_mip", "pick_slice", "skip_mode", "upscale", "upscale_factor"]
903
895
  }
904
896
 
@@ -1080,7 +1072,6 @@ def generate_fields(variables, scrollable_frame):
1080
1072
  "img_zoom": "(float) - Zoom factor for the images in plots.",
1081
1073
  "nuclei_limit": "(int) - Whether to include multinucleated cells in the analysis.",
1082
1074
  "pathogen_limit": "(int) - Whether to include multi-infected cells in the analysis.",
1083
- "uninfected": "(bool) - Whether to include non-infected cells in the analysis.",
1084
1075
  "uninfected": "(bool) - Whether to include uninfected cells in the analysis.",
1085
1076
  "init_weights": "(bool) - Whether to initialize weights for the model.",
1086
1077
  "src": "(str) - Path to the folder containing the images.",
spacr/submodules.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import seaborn as sns
2
- import os, random, sqlite3
2
+ import os, random, sqlite3, re, shap
3
3
  import pandas as pd
4
4
  import numpy as np
5
5
  import cellpose
@@ -7,6 +7,9 @@ from skimage.measure import regionprops, label
7
7
  from cellpose import models as cp_models
8
8
  from cellpose import train as train_cp
9
9
  from IPython.display import display
10
+ from sklearn.ensemble import RandomForestClassifier
11
+ from sklearn.inspection import permutation_importance
12
+ from math import pi
10
13
 
11
14
  import matplotlib.pyplot as plt
12
15
  from natsort import natsorted
@@ -43,9 +46,8 @@ def analyze_recruitment(settings={}):
43
46
  tables=['cell', 'nucleus', 'pathogen','cytoplasm'],
44
47
  verbose=True,
45
48
  nuclei_limit=settings['nuclei_limit'],
46
- pathogen_limit=settings['pathogen_limit'],
47
- uninfected=settings['uninfected'])
48
-
49
+ pathogen_limit=settings['pathogen_limit'])
50
+
49
51
  df = annotate_conditions(df,
50
52
  cells=settings['cell_types'],
51
53
  cell_loc=settings['cell_plate_metadata'],
@@ -550,4 +552,296 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
550
552
  fig_1 = plot_line(df, x_column = 'pc_fraction', y_columns=y_columns, group_column=None, xlabel=None, ylabel='Fraction', title=None, figsize=(10, 6), save_path=save_paths[0])
551
553
  fig_2 = plot_line(df, x_column = 'nc_fraction', y_columns=y_columns, group_column=None, xlabel=None, ylabel='Fraction', title=None, figsize=(10, 6), save_path=save_paths[1])
552
554
 
553
- return [fig_1, fig_2]
555
+ return [fig_1, fig_2]
556
+
557
+ def interperate_vision_model(settings={}):
558
+
559
+ from .io import _read_and_merge_data
560
+
561
+ def generate_comparison_columns(df, compartments=['cell', 'nucleus', 'pathogen', 'cytoplasm']):
562
+
563
+ comparison_dict = {}
564
+
565
+ # Get columns by compartment
566
+ compartment_columns = {comp: [col for col in df.columns if col.startswith(comp)] for comp in compartments}
567
+
568
+ for comp0, comp0_columns in compartment_columns.items():
569
+ for comp0_col in comp0_columns:
570
+ related_cols = []
571
+ base_col_name = comp0_col.replace(comp0, '') # Base feature name without compartment prefix
572
+
573
+ # Look for matching columns in other compartments
574
+ for prefix, prefix_columns in compartment_columns.items():
575
+ if prefix == comp0: # Skip same-compartment comparisons
576
+ continue
577
+ # Check if related column exists in other compartment
578
+ related_col = prefix + base_col_name
579
+ if related_col in df.columns:
580
+ related_cols.append(related_col)
581
+ new_col_name = f"{prefix}_{comp0}{base_col_name}" # Format: prefix_comp0_base
582
+
583
+ # Calculate ratio and handle infinite or NaN values
584
+ df[new_col_name] = df[related_col] / df[comp0_col]
585
+ df[new_col_name].replace([float('inf'), -float('inf')], pd.NA, inplace=True) # Replace inf values with NA
586
+ df[new_col_name].fillna(0, inplace=True) # Replace NaN values with 0 for ease of further calculations
587
+
588
+ # Generate all-to-all comparisons
589
+ if related_cols:
590
+ comparison_dict[comp0_col] = related_cols
591
+ for i, rel_col_1 in enumerate(related_cols):
592
+ for rel_col_2 in related_cols[i + 1:]:
593
+ # Create a new column name for each pairwise comparison
594
+ comp1, comp2 = rel_col_1.split('_')[0], rel_col_2.split('_')[0]
595
+ new_col_name_all = f"{comp1}_{comp2}{base_col_name}"
596
+
597
+ # Calculate pairwise ratio and handle infinite or NaN values
598
+ df[new_col_name_all] = df[rel_col_1] / df[rel_col_2]
599
+ df[new_col_name_all].replace([float('inf'), -float('inf')], pd.NA, inplace=True) # Replace inf with NA
600
+ df[new_col_name_all].fillna(0, inplace=True) # Replace NaN with 0
601
+
602
+ return df, comparison_dict
603
+
604
+ def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pathogen'], name='compartment', include_all=False):
605
+
606
+ # Function to determine compartment based on multiple matches
607
+ def find_feature_class(feature, compartments):
608
+ matches = [compartment for compartment in compartments if re.search(compartment, feature)]
609
+ if len(matches) > 1:
610
+ return '-'.join(matches)
611
+ elif matches:
612
+ return matches[0]
613
+ else:
614
+ return None
615
+
616
+ from spacr.plot import spacrGraph
617
+
618
+ df[name] = df['feature'].apply(lambda x: find_feature_class(x, feature_groups))
619
+
620
+ if name == 'channel':
621
+ df['channel'].fillna('morphology', inplace=True)
622
+
623
+ # Create new DataFrame with summed importance for each compartment and channel
624
+ importance_sum = df.groupby(name)['importance'].sum().reset_index(name=f'{name}_importance_sum')
625
+
626
+ if include_all:
627
+ total_compartment_importance = importance_sum[f'{name}_importance_sum'].sum()
628
+ importance_sum = pd.concat(
629
+ [importance_sum,
630
+ pd.DataFrame(
631
+ [{name: 'all', f'{name}_importance_sum': total_compartment_importance}])]
632
+ , ignore_index=True)
633
+
634
+ return importance_sum
635
+
636
+ # Function to create radar plot for individual and combined values
637
+ def create_extended_radar_plot(values, labels, title):
638
+ values = list(values) + [values[0]] # Close the loop for radar chart
639
+ angles = [n / float(len(labels)) * 2 * pi for n in range(len(labels))]
640
+ angles += angles[:1]
641
+
642
+ fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
643
+ ax.plot(angles, values, linewidth=2, linestyle='solid')
644
+ ax.fill(angles, values, alpha=0.25)
645
+
646
+ ax.set_xticks(angles[:-1])
647
+ ax.set_xticklabels(labels, fontsize=10, rotation=45, ha='right')
648
+ plt.title(title, pad=20)
649
+ plt.show()
650
+
651
+ def extract_compartment_channel(feature_name):
652
+ # Identify compartment as the first part before an underscore
653
+ compartment = feature_name.split('_')[0]
654
+
655
+ if compartment == 'cells':
656
+ compartment = 'cell'
657
+
658
+ # Identify channels based on substring presence
659
+ channels = []
660
+ if 'channel_0' in feature_name:
661
+ channels.append('channel_0')
662
+ if 'channel_1' in feature_name:
663
+ channels.append('channel_1')
664
+ if 'channel_2' in feature_name:
665
+ channels.append('channel_2')
666
+ if 'channel_3' in feature_name:
667
+ channels.append('channel_3')
668
+
669
+ # If multiple channels are found, join them with a '+'
670
+ if channels:
671
+ channel = ' + '.join(channels)
672
+ else:
673
+ channel = 'morphology' # Use 'morphology' if no channel identifier is found
674
+
675
+ return (compartment, channel)
676
+
677
+ def read_and_preprocess_data(settings):
678
+
679
+ df, _ = _read_and_merge_data(
680
+ locs=[settings['src']+'/measurements/measurements.db'],
681
+ tables=settings['tables'],
682
+ verbose=True,
683
+ nuclei_limit=settings['nuclei_limit'],
684
+ pathogen_limit=settings['pathogen_limit']
685
+ )
686
+
687
+ df, _dict = generate_comparison_columns(df, compartments=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
688
+ print(f"Expanded dataframe to {len(df.columns)} columns with relative features")
689
+ scores_df = pd.read_csv(settings['scores'])
690
+
691
+ # Clean and align columns for merging
692
+ df['object_label'] = df['object_label'].str.replace('o', '')
693
+
694
+ if 'row_name' not in scores_df.columns:
695
+ scores_df['row_name'] = scores_df['row']
696
+
697
+ if 'column_name' not in scores_df.columns:
698
+ scores_df['column_name'] = scores_df['col']
699
+
700
+ if 'object_label' not in scores_df.columns:
701
+ scores_df['object_label'] = scores_df['object']
702
+
703
+ # Remove the 'o' prefix from 'object_label' in df, ensuring it is a string type
704
+ df['object_label'] = df['object_label'].str.replace('o', '').astype(str)
705
+
706
+ # Ensure 'object_label' in scores_df is also a string
707
+ scores_df['object_label'] = scores_df['object'].astype(str)
708
+
709
+ # Ensure all join columns have the same data type in both DataFrames
710
+ df[['plate', 'row_name', 'column_name', 'field', 'object_label']] = df[['plate', 'row_name', 'column_name', 'field', 'object_label']].astype(str)
711
+ scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label']] = scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label']].astype(str)
712
+
713
+ # Select only the necessary columns from scores_df for merging
714
+ scores_df = scores_df[['plate', 'row_name', 'column_name', 'field', 'object_label', settings['score_column']]]
715
+
716
+ # Now merge DataFrames
717
+ merged_df = pd.merge(df, scores_df, on=['plate', 'row_name', 'column_name', 'field', 'object_label'], how='inner')
718
+
719
+ # Separate numerical features and the score column
720
+ X = merged_df.select_dtypes(include='number').drop(columns=[settings['score_column']])
721
+ y = merged_df[settings['score_column']]
722
+
723
+ return X, y, merged_df
724
+
725
+ X, y, merged_df = read_and_preprocess_data(settings)
726
+
727
+ output = {}
728
+
729
+ # Step 1: Feature Importance using Random Forest
730
+ if settings['feature_importance'] or settings['feature_importance']:
731
+ model = RandomForestClassifier(random_state=42, n_jobs=settings['n_jobs'])
732
+ model.fit(X, y)
733
+
734
+ if settings['feature_importance']:
735
+ print(f"Feature Importance ...")
736
+ feature_importances = model.feature_importances_
737
+ feature_importance_df = pd.DataFrame({'feature': X.columns, 'importance': feature_importances})
738
+ feature_importance_df = feature_importance_df.sort_values(by='importance', ascending=False)
739
+ top_feature_importance_df = feature_importance_df.head(settings['top_features'])
740
+
741
+ # Plot Feature Importance
742
+ plt.figure(figsize=(10, 6))
743
+ plt.barh(top_feature_importance_df['feature'], top_feature_importance_df['importance'])
744
+ plt.xlabel('Importance')
745
+ plt.title(f"Top {settings['top_features']} Features - Feature Importance")
746
+ plt.gca().invert_yaxis()
747
+ plt.show()
748
+
749
+ output['feature_importance'] = feature_importance_df
750
+ fi_compartment_df = group_feature_class(feature_importance_df, feature_groups=settings['tables'], name='compartment', include_all=settings['include_all'])
751
+ fi_channel_df = group_feature_class(feature_importance_df, feature_groups=settings['channels'], name='channel', include_all=settings['include_all'])
752
+
753
+ output['feature_importance_compartment'] = fi_compartment_df
754
+ output['feature_importance_channel'] = fi_channel_df
755
+
756
+ # Step 2: Permutation Importance
757
+ if settings['permutation_importance']:
758
+ print(f"Permutation Importance ...")
759
+ perm_importance = permutation_importance(model, X, y, n_repeats=10, random_state=42, n_jobs=settings['n_jobs'])
760
+ perm_importance_df = pd.DataFrame({'feature': X.columns, 'importance': perm_importance.importances_mean})
761
+ perm_importance_df = perm_importance_df.sort_values(by='importance', ascending=False)
762
+ top_perm_importance_df = perm_importance_df.head(settings['top_features'])
763
+
764
+ # Plot Permutation Importance
765
+ plt.figure(figsize=(10, 6))
766
+ plt.barh(top_perm_importance_df['feature'], top_perm_importance_df['importance'])
767
+ plt.xlabel('Importance')
768
+ plt.title(f"Top {settings['top_features']} Features - Permutation Importance")
769
+ plt.gca().invert_yaxis()
770
+ plt.show()
771
+
772
+ output['permutation_importance'] = perm_importance_df
773
+
774
+ # Step 3: SHAP Analysis
775
+ if settings['shap']:
776
+ print(f"SHAP Analysis ...")
777
+
778
+ # Select top N features based on Random Forest importance and fit the model on these features only
779
+ top_features = feature_importance_df.head(settings['top_features'])['feature']
780
+ X_top = X[top_features]
781
+
782
+ # Refit the model on this subset of features
783
+ model = RandomForestClassifier(random_state=42, n_jobs=settings['n_jobs'])
784
+ model.fit(X_top, y)
785
+
786
+ # Sample a smaller subset of rows to speed up SHAP
787
+ if settings['shap_sample']:
788
+ sample = int(len(X_top) / 100)
789
+ X_sample = X_top.sample(min(sample, len(X_top)), random_state=42)
790
+ else:
791
+ X_sample = X_top
792
+
793
+ # Initialize SHAP explainer with the same subset of features
794
+ explainer = shap.Explainer(model.predict, X_sample)
795
+ shap_values = explainer(X_sample, max_evals=1500)
796
+
797
+ # Plot SHAP summary for the selected sample and top features
798
+ shap.summary_plot(shap_values, X_sample, max_display=settings['top_features'])
799
+
800
+ # Convert SHAP values to a DataFrame for easier manipulation
801
+ shap_df = pd.DataFrame(shap_values.values, columns=X_sample.columns)
802
+
803
+ # Apply the function to create MultiIndex columns with compartment and channel
804
+ shap_df.columns = pd.MultiIndex.from_tuples(
805
+ [extract_compartment_channel(feat) for feat in shap_df.columns],
806
+ names=['compartment', 'channel']
807
+ )
808
+
809
+ # Aggregate SHAP values by compartment and channel
810
+ compartment_mean = shap_df.abs().groupby(level='compartment', axis=1).mean().mean(axis=0)
811
+ channel_mean = shap_df.abs().groupby(level='channel', axis=1).mean().mean(axis=0)
812
+
813
+ # Calculate combined importance for each pair of compartments and channels
814
+ combined_compartment = {}
815
+ for i, comp1 in enumerate(compartment_mean.index):
816
+ for comp2 in compartment_mean.index[i+1:]:
817
+ combined_compartment[f"{comp1} + {comp2}"] = shap_df.loc[:, (comp1, slice(None))].abs().mean().mean() + \
818
+ shap_df.loc[:, (comp2, slice(None))].abs().mean().mean()
819
+
820
+ combined_channel = {}
821
+ for i, chan1 in enumerate(channel_mean.index):
822
+ for chan2 in channel_mean.index[i+1:]:
823
+ combined_channel[f"{chan1} + {chan2}"] = shap_df.loc[:, (slice(None), chan1)].abs().mean().mean() + \
824
+ shap_df.loc[:, (slice(None), chan2)].abs().mean().mean()
825
+
826
+ # Prepare values and labels for radar charts
827
+ all_compartment_importance = list(compartment_mean.values) + list(combined_compartment.values())
828
+ all_compartment_labels = list(compartment_mean.index) + list(combined_compartment.keys())
829
+
830
+ all_channel_importance = list(channel_mean.values) + list(combined_channel.values())
831
+ all_channel_labels = list(channel_mean.index) + list(combined_channel.keys())
832
+
833
+ # Create radar plots for compartments and channels
834
+ #create_extended_radar_plot(all_compartment_importance, all_compartment_labels, "SHAP Importance by Compartment (Individual and Combined)")
835
+ #create_extended_radar_plot(all_channel_importance, all_channel_labels, "SHAP Importance by Channel (Individual and Combined)")
836
+
837
+ output['shap'] = shap_df
838
+
839
+ if settings['save']:
840
+ dst = os.path.join(settings['src'], 'results')
841
+ os.makedirs(dst, exist_ok=True)
842
+ for key, df in output.items():
843
+ save_path = os.path.join(dst, f"{key}.csv")
844
+ df.to_csv(save_path)
845
+ print(f"Saved {save_path}")
846
+
847
+ return output
spacr/utils.py CHANGED
@@ -4052,7 +4052,7 @@ def measure_test_mode(settings):
4052
4052
 
4053
4053
  return settings
4054
4054
 
4055
- def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude):
4055
+ def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude, column_list=False):
4056
4056
  """
4057
4057
  Preprocesses the given dataframe by applying filtering, removing highly correlated columns,
4058
4058
  applying log transformation, filling NaN values, and scaling the numeric data.
@@ -4076,7 +4076,10 @@ def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude):
4076
4076
  # Apply filtering based on the `filter_by` parameter
4077
4077
  if filter_by is not None:
4078
4078
  df, _ = filter_dataframe_features(df, channel_of_interest=filter_by, exclude=exclude)
4079
-
4079
+
4080
+ if column_list:
4081
+ df = df[column_list]
4082
+
4080
4083
  # Select numerical features
4081
4084
  numeric_data = df.select_dtypes(include=['number'])
4082
4085
 
@@ -4181,6 +4184,7 @@ def filter_dataframe_features(df, channel_of_interest, exclude=None, remove_low_
4181
4184
 
4182
4185
  if verbose:
4183
4186
  print("Columns to remove:", count_and_id_columns)
4187
+
4184
4188
  df = df.drop(columns=count_and_id_columns)
4185
4189
 
4186
4190
  if not channel_of_interest is None:
@@ -4189,6 +4193,9 @@ def filter_dataframe_features(df, channel_of_interest, exclude=None, remove_low_
4189
4193
  if isinstance(channel_of_interest, list):
4190
4194
  feature_strings = [f"channel_{channel}" for channel in channel_of_interest]
4191
4195
 
4196
+ elif isinstance(channel_of_interest, str):
4197
+ feature_strings = [channel_of_interest]
4198
+
4192
4199
  elif isinstance(channel_of_interest, int):
4193
4200
  feature_string = f"channel_{channel_of_interest}"
4194
4201
  feature_strings = [feature_string]
@@ -5164,3 +5171,33 @@ def rename_columns_in_db(db_path):
5164
5171
  # After closing the 'with' block, run VACUUM outside of any transaction
5165
5172
  with sqlite3.connect(db_path) as conn:
5166
5173
  conn.execute("VACUUM;")
5174
+
5175
+ def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pathogen'], name='compartment'):
5176
+
5177
+ # Function to determine compartment based on multiple matches
5178
+ def find_feature_class(feature, compartments):
5179
+ matches = [compartment for compartment in compartments if re.search(compartment, feature)]
5180
+ if len(matches) > 1:
5181
+ return '-'.join(matches)
5182
+ elif matches:
5183
+ return matches[0]
5184
+ else:
5185
+ return None
5186
+
5187
+ from spacr.plot import spacrGraph
5188
+
5189
+ df[name] = df['feature'].apply(lambda x: find_feature_class(x, feature_groups))
5190
+
5191
+ if name == 'channel':
5192
+ df['channel'].fillna('morphology', inplace=True)
5193
+
5194
+ # Create new DataFrame with summed importance for each compartment and channel
5195
+ importance_sum = df.groupby(name)['importance'].sum().reset_index(name=f'{name}_importance_sum')
5196
+ total_compartment_importance = importance_sum[f'{name}_importance_sum'].sum()
5197
+ importance_sum = pd.concat(
5198
+ [importance_sum,
5199
+ pd.DataFrame(
5200
+ [{name: 'all', '{name}_importance_sum': total_compartment_importance}])]
5201
+ , ignore_index=True)
5202
+
5203
+ return df
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spacr
3
- Version: 0.3.60
3
+ Version: 0.3.61
4
4
  Summary: Spatial phenotype analysis of crisp screens (SpaCr)
5
5
  Home-page: https://github.com/EinarOlafsson/spacr
6
6
  Author: Einar Birnir Olafsson
@@ -1,6 +1,6 @@
1
1
  spacr/__init__.py,sha256=CZtAdU5etLcb9dVmz-4Y7Hjhw3ubjMzfjG0L5ybyFVA,1592
2
2
  spacr/__main__.py,sha256=bkAJJD2kjIqOP-u1kLvct9jQQCeUXzlEjdgitwi1Lm8,75
3
- spacr/app_annotate.py,sha256=zGmAJplDOckhaUZijkHgbFH9LJNbd6TolU2hamplOBc,2769
3
+ spacr/app_annotate.py,sha256=W9eLPa_LZIvXsXx_-0iDFEU938LBDvRy6prXo0qF4KQ,2533
4
4
  spacr/app_classify.py,sha256=urTP_wlZ58hSyM5a19slYlBxN0PdC-9-ga0hvq8CGWc,165
5
5
  spacr/app_make_masks.py,sha256=pqDhRpluiHZz-kPX2Zh_KbYe4TsU43qYBa_7f-rsjpw,1694
6
6
  spacr/app_mask.py,sha256=l-dBY8ftzCMdDe6-pXc2Nh_u-idNL9G7UOARiLJBtds,153
@@ -9,26 +9,26 @@ spacr/app_sequencing.py,sha256=DjG26jy4cpddnV8WOOAIiExtOe9MleVMY4MFa5uTo5w,157
9
9
  spacr/app_umap.py,sha256=ZWAmf_OsIKbYvolYuWPMYhdlVe-n2CADoJulAizMiEo,153
10
10
  spacr/cellpose.py,sha256=RBHMs2vwXcfkj0xqAULpALyzJYXddSRycgZSzmwI7v0,14755
11
11
  spacr/chat_bot.py,sha256=n3Fhqg3qofVXHmh3H9sUcmfYy9MmgRnr48663MVdY9E,1244
12
- spacr/core.py,sha256=dW9RrAKFLfVsFhX0-kaVMc2T7b47Ky0pTXK-CEVOeWQ,48235
12
+ spacr/core.py,sha256=3u2qKmPmTlswvE1uKTF4gi7KQ3sJBHV9No_ysgk7JCU,48487
13
13
  spacr/deep_spacr.py,sha256=HdOcNU8cHcE_19nP7_5uTz-ih3E169ffr2Hm--NvMvA,43255
14
14
  spacr/gui.py,sha256=ARyn9Q_g8HoP-cXh1nzMLVFCKqthY4v2u9yORyaQqQE,8230
15
15
  spacr/gui_core.py,sha256=N7R7yvfK_dJhOReM_kW3Ci8Bokhi1OzsxeKqvSGdvV4,41460
16
16
  spacr/gui_elements.py,sha256=EKlvEg_4_je7jciEdR3NTgPrcTraowa2e2RUt-xqd6M,138254
17
- spacr/gui_utils.py,sha256=Ud6hRRPhombKjeGUhlleEr9I75SNnFj8UD11yKfp9Wo,40860
18
- spacr/io.py,sha256=VHs6h8o0gBEyKxfdNqEhpzjQXPrj7UGG47DwHeUyUDw,143390
17
+ spacr/gui_utils.py,sha256=u9RoIOWpAXFEOnUlLpMQZrc1pWSg6omZsJMIhJdRv_g,41211
18
+ spacr/io.py,sha256=p-ky3yjtoSSvdsktPXVy_dx8dHgMeWqUZOtOwwfrk2o,136108
19
19
  spacr/logger.py,sha256=lJhTqt-_wfAunCPl93xE65Wr9Y1oIHJWaZMjunHUeIw,1538
20
20
  spacr/measure.py,sha256=2lK-ZcTxLM-MpXV1oZnucRD9iz5aprwahRKw9IEqshg,55085
21
21
  spacr/mediar.py,sha256=FwLvbLQW5LQzPgvJZG8Lw7GniA2vbZx6Jv6vIKu7I5c,14743
22
- spacr/ml.py,sha256=aberLbvUM9F6uNpEOFHzn8_w-fiW0sDG3jVb6TDxakI,68275
22
+ spacr/ml.py,sha256=aLDeeaAl0d4-RP1CzFHPqz5br2HrFbJhvPexEm9lvSI,68198
23
23
  spacr/openai.py,sha256=5vBZ3Jl2llYcW3oaTEXgdyCB2aJujMUIO5K038z7w_A,1246
24
- spacr/plot.py,sha256=Y5_VuRHNsIH7iezK8kWXHg9fwh5sW3S34ncIFshbBco,157893
24
+ spacr/plot.py,sha256=zITe54dzQRz-gk_ZT0qJyARuUWJivIBKW8V4rjUH8SE,160320
25
25
  spacr/sequencing.py,sha256=ClUfwPPK6rNUbUuiEkzcwakzVyDKKUMv9ricrxT8qQY,25227
26
- spacr/settings.py,sha256=6_GB1QQw_w_4yq8dH-Ypc4rJw__Cgs6g_BnR9bIjdZI,77669
26
+ spacr/settings.py,sha256=zANLspVmllDZeYjQWIfrHN3VkVgicnYGTduv30MmQ18,77257
27
27
  spacr/sim.py,sha256=1xKhXimNU3ukzIw-3l9cF3Znc_brW8h20yv8fSTzvss,71173
28
- spacr/submodules.py,sha256=dn-QSKX6ZqyyEr8_v69jVGpB-wd3KbaMRacIA8DXONU,28155
28
+ spacr/submodules.py,sha256=Xq4gjvooHN8S7cTk5PIAkd7XD2c7CMVqNpeo8GCvtHc,42489
29
29
  spacr/timelapse.py,sha256=KGfG4L4-QnFfgbF7L6C5wL_3gd_rqr05Foje6RsoTBg,39603
30
30
  spacr/toxo.py,sha256=z2nT5aAze3NUIlwnBQcnkARihDwoPfqOgQIVoUluyK0,25087
31
- spacr/utils.py,sha256=5XGA0aPray3DzCAgwJjPRlsaxsuSRJyTTTZ7rNDTRTg,219202
31
+ spacr/utils.py,sha256=tqIKiSc30xEX0IlfSpoctFJQDVnGHDAX7l1VakRCBuY,220601
32
32
  spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
33
33
  spacr/resources/MEDIAR/.gitignore,sha256=Ff1q9Nme14JUd-4Q3jZ65aeQ5X4uttptssVDgBVHYo8,152
34
34
  spacr/resources/MEDIAR/LICENSE,sha256=yEj_TRDLUfDpHDNM0StALXIt6mLqSgaV2hcCwa6_TcY,1065
@@ -151,9 +151,9 @@ spacr/resources/icons/umap.png,sha256=dOLF3DeLYy9k0nkUybiZMe1wzHQwLJFRmgccppw-8b
151
151
  spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif,sha256=Tl0ZUfZ_AYAbu0up_nO0tPRtF1BxXhWQ3T3pURBCCRo,7958528
152
152
  spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif,sha256=m8N-V71rA1TT4dFlENNg8s0Q0YEXXs8slIn7yObmZJQ,7958528
153
153
  spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif,sha256=Pbhk7xn-KUP6RSIhJsxQcrHFImBm3GEpLkzx7WOc-5M,7958528
154
- spacr-0.3.60.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
155
- spacr-0.3.60.dist-info/METADATA,sha256=UF63-vN6-XEslhGhnotkQz6JanIajbV56bKcSEaEIjE,6032
156
- spacr-0.3.60.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
157
- spacr-0.3.60.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
158
- spacr-0.3.60.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
159
- spacr-0.3.60.dist-info/RECORD,,
154
+ spacr-0.3.61.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
155
+ spacr-0.3.61.dist-info/METADATA,sha256=2jlzT9lkaXx01IWlYMYrpf24p48qDHvrRLZm-YUUl-0,6032
156
+ spacr-0.3.61.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
157
+ spacr-0.3.61.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
158
+ spacr-0.3.61.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
159
+ spacr-0.3.61.dist-info/RECORD,,
File without changes