spacr 0.3.72__py3-none-any.whl → 0.3.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/gui_core.py CHANGED
@@ -169,7 +169,7 @@ def display_figure(fig):
169
169
  #flash_feedback("right")
170
170
  show_next_figure()
171
171
 
172
- def zoom(event):
172
+ def zoom_v1(event):
173
173
  nonlocal scale_factor
174
174
 
175
175
  zoom_speed = 0.1 # Adjust the zoom speed for smoother experience
@@ -197,6 +197,70 @@ def display_figure(fig):
197
197
 
198
198
  # Redraw the figure efficiently
199
199
  canvas.draw_idle()
200
+
201
+ def zoom_test(event):
202
+ if event.num == 4: # Scroll up
203
+ print("zoom in")
204
+ elif event.num == 5: # Scroll down
205
+ print("zoom out")
206
+
207
+ def zoom_2(event):
208
+ zoom_speed = 0.1 # Change this to control how fast you zoom
209
+
210
+ # Determine the zoom direction based on the scroll event
211
+ if event.num == 4 or (hasattr(event, 'delta') and event.delta > 0): # Scroll up = zoom in
212
+ factor = 1 - zoom_speed
213
+ elif event.num == 5 or (hasattr(event, 'delta') and event.delta < 0): # Scroll down = zoom out
214
+ factor = 1 + zoom_speed
215
+ else:
216
+ return # No recognized scroll direction
217
+
218
+ for ax in canvas.figure.get_axes():
219
+ xlim = ax.get_xlim()
220
+ ylim = ax.get_ylim()
221
+
222
+ x_center = (xlim[1] + xlim[0]) / 2
223
+ y_center = (ylim[1] + ylim[0]) / 2
224
+
225
+ x_range = (xlim[1] - xlim[0]) * factor
226
+ y_range = (ylim[1] - ylim[0]) * factor
227
+
228
+ # Set the new limits
229
+ ax.set_xlim([x_center - x_range / 2, x_center + x_range / 2])
230
+ ax.set_ylim([y_center - y_range / 2, y_center + y_range / 2])
231
+
232
+ # Redraw the figure efficiently
233
+ canvas.draw_idle()
234
+
235
+ def zoom(event):
236
+ # Fixed zoom factors (adjust these if you want faster or slower zoom)
237
+ zoom_in_factor = 0.9 # When zooming in, ranges shrink by 10%
238
+ zoom_out_factor = 1.1 # When zooming out, ranges increase by 10%
239
+
240
+ # Determine the zoom direction based on the scroll event
241
+ if event.num == 4 or (hasattr(event, 'delta') and event.delta > 0): # Scroll up = zoom in
242
+ factor = zoom_in_factor
243
+ elif event.num == 5 or (hasattr(event, 'delta') and event.delta < 0): # Scroll down = zoom out
244
+ factor = zoom_out_factor
245
+ else:
246
+ return # No recognized scroll direction
247
+
248
+ for ax in canvas.figure.get_axes():
249
+ xlim = ax.get_xlim()
250
+ ylim = ax.get_ylim()
251
+
252
+ x_center = (xlim[1] + xlim[0]) / 2
253
+ y_center = (ylim[1] + ylim[0]) / 2
254
+
255
+ x_range = (xlim[1] - xlim[0]) * factor
256
+ y_range = (ylim[1] - ylim[0]) * factor
257
+
258
+ # Set the new limits
259
+ ax.set_xlim([x_center - x_range / 2, x_center + x_range / 2])
260
+ ax.set_ylim([y_center - y_range / 2, y_center + y_range / 2])
261
+
262
+ # Redraw the figure efficiently
263
+ canvas.draw_idle()
200
264
 
201
265
 
202
266
  # Bind events for hover, click interactions, and zoom
@@ -205,19 +269,20 @@ def display_figure(fig):
205
269
  canvas_widget.bind("<Button-1>", on_click)
206
270
  canvas_widget.bind("<Button-3>", on_right_click)
207
271
 
208
-
209
272
  # Detect the operating system and bind the appropriate mouse wheel events
210
273
  current_os = platform.system()
211
274
 
212
275
  if current_os == "Windows":
213
276
  canvas_widget.bind("<MouseWheel>", zoom) # Windows
214
- elif current_os == "Darwin": # macOS
277
+ elif current_os == "Darwin":
215
278
  canvas_widget.bind("<MouseWheel>", zoom)
216
279
  canvas_widget.bind("<Button-4>", zoom) # Scroll up
217
280
  canvas_widget.bind("<Button-5>", zoom) # Scroll down
218
281
  elif current_os == "Linux":
219
282
  canvas_widget.bind("<Button-4>", zoom) # Linux Scroll up
220
283
  canvas_widget.bind("<Button-5>", zoom) # Linux Scroll down
284
+
285
+ process_fig_queue()
221
286
 
222
287
  def clear_unused_figures():
223
288
  global figures, figure_index
@@ -230,71 +295,97 @@ def clear_unused_figures():
230
295
  figure_index = min(max(figure_index, 0), len(figures) - 1)
231
296
 
232
297
  def show_previous_figure():
233
- global figure_index, figures, fig_queue
298
+ from .gui_elements import standardize_figure
299
+ global figure_index, figures, fig_queue, index_control
234
300
 
235
301
  if figure_index is not None and figure_index > 0:
236
302
  figure_index -= 1
303
+ index_control.set(figure_index)
304
+ figures[figure_index] = standardize_figure(figures[figure_index])
237
305
  display_figure(figures[figure_index])
238
- clear_unused_figures()
306
+ #clear_unused_figures()
239
307
 
240
308
  def show_next_figure():
241
- global figure_index, figures, fig_queue
309
+ from .gui_elements import standardize_figure
310
+ global figure_index, figures, fig_queue, index_control
242
311
  if figure_index is not None and figure_index < len(figures) - 1:
243
312
  figure_index += 1
313
+ index_control.set(figure_index)
314
+ index_control.set_to(len(figures) - 1)
315
+ figures[figure_index] = standardize_figure(figures[figure_index])
244
316
  display_figure(figures[figure_index])
245
- clear_unused_figures()
317
+ #clear_unused_figures()
318
+
246
319
  elif figure_index == len(figures) - 1 and not fig_queue.empty():
247
320
  fig = fig_queue.get_nowait()
248
321
  figures.append(fig)
249
322
  figure_index += 1
323
+ index_control.set(figure_index)
324
+ index_control.set_to(len(figures) - 1)
250
325
  display_figure(fig)
251
-
326
+
252
327
  def process_fig_queue():
253
328
  global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
254
-
255
329
  from .gui_elements import standardize_figure
330
+
331
+ #print("process_fig_queue called", flush=True)
256
332
  try:
333
+ got_new_figure = False
257
334
  while not fig_queue.empty():
258
335
  fig = fig_queue.get_nowait()
336
+ #print("Got a figure from fig_queue", flush=True)
259
337
 
260
338
  if fig is None:
261
- print("Warning: Retrieved a None figure from fig_queue.")
262
- continue # Skip processing if the figure is None
339
+ print("Warning: Retrieved a None figure from fig_queue.", flush=True)
340
+ continue
263
341
 
264
- # Standardize the figure appearance before adding it to the list
342
+ # Standardize the figure appearance before adding it
265
343
  fig = standardize_figure(fig)
266
-
267
344
  figures.append(fig)
268
345
 
269
- # Update the slider range and set the value to the latest figure index
346
+ # Update slider maximum
270
347
  index_control.set_to(len(figures) - 1)
348
+ #print("New maximum slider value after adding a figure:", index_control.to, flush=True)
271
349
 
350
+ # If no figure has been displayed yet
272
351
  if figure_index == -1:
273
- figure_index += 1
352
+ figure_index = 0
274
353
  display_figure(figures[figure_index])
275
354
  index_control.set(figure_index)
276
-
355
+ #print("Displayed the first figure and set slider value to 0", flush=True)
356
+
357
+ #got_new_figure = True
358
+
359
+ #if not got_new_figure:
360
+ # No new figures this time
361
+ #print("No new figures found in the queue this iteration.", flush=True)
362
+
277
363
  except Exception as e:
364
+ print("Exception in process_fig_queue:", e, flush=True)
278
365
  traceback.print_exc()
366
+
279
367
  finally:
368
+ # Schedule process_fig_queue() to run again
280
369
  after_id = canvas_widget.after(uppdate_frequency, process_fig_queue)
281
370
  parent_frame.after_tasks.append(after_id)
371
+ #print("process_fig_queue scheduled again", flush=True)
282
372
 
283
373
  def update_figure(value):
284
- global figure_index, figures
285
-
374
+ from .gui_elements import standardize_figure
375
+ global figure_index, figures, index_control
376
+
286
377
  # Convert the value to an integer
287
378
  index = int(value)
288
379
 
289
380
  # Check if the index is valid
290
381
  if 0 <= index < len(figures):
291
382
  figure_index = index
383
+ figures[figure_index] = standardize_figure(figures[figure_index])
292
384
  display_figure(figures[figure_index])
293
-
294
- # Update the index control widget's range and value
295
- index_control.set_to(len(figures) - 1)
296
- index_control.set(figure_index)
297
-
385
+ index_control.set(figure_index)
386
+ print("update_figure called with value:", figure_index)
387
+ index_control.set_to(len(figures) - 1)
388
+
298
389
  def setup_plot_section(vertical_container, settings_type):
299
390
  global canvas, canvas_widget, figures, figure_index, index_control
300
391
  from .gui_utils import display_media_in_plot_frame
@@ -305,29 +396,29 @@ def setup_plot_section(vertical_container, settings_type):
305
396
 
306
397
  # Initialize deque for storing figures and the current index
307
398
  figures = deque()
399
+ figure_index = -1 # Start with no figure displayed
308
400
 
309
401
  # Create a frame for the plot section
310
402
  plot_frame = tk.Frame(vertical_container)
311
403
  plot_frame.configure(bg=bg)
312
404
  vertical_container.add(plot_frame, stretch="always")
313
405
 
314
- # Clear the plot_frame (optional, to handle cases where it may already have content)
406
+ # Clear the plot_frame (optional)
315
407
  for widget in plot_frame.winfo_children():
316
408
  widget.destroy()
317
409
 
318
- # Create a figure and plot
410
+ # Create a figure and plot (initial figure)
319
411
  figure = Figure(figsize=(30, 4), dpi=100)
320
412
  plot = figure.add_subplot(111)
321
413
  plot.plot([], [])
322
414
  plot.axis('off')
323
415
 
324
416
  if settings_type == 'map_barcodes':
325
- # Load and display GIF
326
417
  current_dir = os.path.dirname(__file__)
327
418
  resources_path = os.path.join(current_dir, 'resources', 'icons')
328
419
  gif_path = os.path.join(resources_path, 'dna_matrix.mp4')
329
-
330
420
  display_media_in_plot_frame(gif_path, plot_frame)
421
+
331
422
  canvas = FigureCanvasTkAgg(figure, master=plot_frame)
332
423
  canvas.get_tk_widget().configure(cursor='arrow', highlightthickness=0)
333
424
  canvas_widget = canvas.get_tk_widget()
@@ -348,10 +439,11 @@ def setup_plot_section(vertical_container, settings_type):
348
439
  # Create slider
349
440
  control_frame = tk.Frame(plot_frame, height=15*2, bg=bg)
350
441
  control_frame.grid(row=1, column=0, sticky="ew", padx=10, pady=5)
351
- control_frame.grid_propagate(False)
442
+ control_frame.grid_propagate(False)
352
443
 
353
- # Pass the update_figure function as the command to spacrSlider
354
- index_control = spacrSlider(control_frame, from_=0, to=0, value=0, thickness=2, knob_radius=10, position="center", show_index=True, command=update_figure)
444
+ index_control = spacrSlider(control_frame, from_=0, to=0, value=0, thickness=2, knob_radius=10,
445
+ position="center", show_index=True, command=update_figure)
446
+
355
447
  index_control.grid(row=0, column=0, sticky="ew")
356
448
  control_frame.grid_columnconfigure(0, weight=1)
357
449
 
@@ -359,10 +451,17 @@ def setup_plot_section(vertical_container, settings_type):
359
451
  style = ttk.Style(vertical_container)
360
452
  _ = set_dark_style(style, containers=containers, widgets=widgets)
361
453
 
454
+ # Now ensure the first figure is displayed and recognized:
455
+ figures.append(figure)
456
+ figure_index = 0
457
+ display_figure(figures[figure_index])
458
+ index_control.set_to(len(figures) - 1) # Slider max = 0 in this case, since there's only one figure
459
+ index_control.set(figure_index) # Set slider to 0 to indicate the first figure
460
+
362
461
  return canvas, canvas_widget
363
462
 
364
- def set_globals(thread_control_var, q_var, console_output_var, parent_frame_var, vars_dict_var, canvas_var, canvas_widget_var, scrollable_frame_var, fig_queue_var, figures_var, figure_index_var, index_control_var, progress_bar_var, usage_bars_var):
365
- global thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue, figures, figure_index, progress_bar, usage_bars, index_control
463
+ def set_globals(thread_control_var, q_var, console_output_var, parent_frame_var, vars_dict_var, canvas_var, canvas_widget_var, scrollable_frame_var, fig_queue_var, progress_bar_var, usage_bars_var):
464
+ global thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue, progress_bar, usage_bars
366
465
  thread_control = thread_control_var
367
466
  q = q_var
368
467
  console_output = console_output_var
@@ -372,11 +471,11 @@ def set_globals(thread_control_var, q_var, console_output_var, parent_frame_var,
372
471
  canvas_widget = canvas_widget_var
373
472
  scrollable_frame = scrollable_frame_var
374
473
  fig_queue = fig_queue_var
375
- figures = figures_var
376
- figure_index = figure_index_var
474
+ #figures = figures_var
475
+ #figure_index = figure_index_var
476
+ #index_control = index_control_var
377
477
  progress_bar = progress_bar_var
378
478
  usage_bars = usage_bars_var
379
- index_control = index_control_var
380
479
 
381
480
  def import_settings(settings_type='mask'):
382
481
  global vars_dict, scrollable_frame, button_scrollable_frame
@@ -606,6 +705,7 @@ def setup_button_section(horizontal_container, settings_type='mask', run=True, a
606
705
  widgets.append(import_button)
607
706
  btn_row += 1
608
707
 
708
+ btn_row += 1
609
709
  # Add the batch progress bar
610
710
  progress_bar = spacrProgressBar(button_scrollable_frame.scrollable_frame, orient='horizontal', mode='determinate')
611
711
  progress_bar.grid(row=btn_row, column=0, columnspan=7, pady=5, padx=5, sticky='ew')
@@ -853,7 +953,8 @@ def process_console_queue():
853
953
  if progress_bar:
854
954
  progress_bar['maximum'] = total_progress
855
955
  progress_bar['value'] = unique_progress_count
856
-
956
+ #print("Current progress bar value:", progress_bar['value']) # Debugg
957
+
857
958
  # Store operation type and additional info
858
959
  if operation_type:
859
960
  progress_bar.operation_type = operation_type
@@ -955,7 +1056,7 @@ def initiate_root(parent, settings_type='mask'):
955
1056
  else:
956
1057
  usage_bars = []
957
1058
 
958
- set_globals(thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue, figures, figure_index, index_control, progress_bar, usage_bars)
1059
+ set_globals(thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue, progress_bar, usage_bars)
959
1060
  description_text = descriptions.get(settings_type, "No description available for this module.")
960
1061
 
961
1062
  q.put(f"Console")
spacr/gui_elements.py CHANGED
@@ -667,7 +667,7 @@ class spacrProgressBar(ttk.Progressbar):
667
667
  # Remove any borders and ensure the active color fills the entire space
668
668
  self.style.configure(
669
669
  "spacr.Horizontal.TProgressbar",
670
- troughcolor=self.inactive_color, # Set the trough to bg color
670
+ troughcolor=self.inactive_color, # Set the trough to bg color
671
671
  background=self.active_color, # Active part is the active color
672
672
  borderwidth=0, # Remove border width
673
673
  pbarrelief="flat", # Flat relief for the progress bar
spacr/ml.py CHANGED
@@ -27,6 +27,9 @@ from sklearn.linear_model import Lasso, Ridge
27
27
  from sklearn.preprocessing import FunctionTransformer
28
28
  from patsy import dmatrices
29
29
 
30
+ from sklearn.metrics import classification_report, accuracy_score
31
+ from sklearn.model_selection import StratifiedKFold, cross_val_score
32
+ from sklearn.feature_selection import SelectKBest, f_classif
30
33
  from sklearn.model_selection import train_test_split
31
34
  from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
32
35
  from sklearn.linear_model import LogisticRegression
@@ -1165,21 +1168,29 @@ def generate_ml_scores(settings):
1165
1168
 
1166
1169
  settings = set_default_analyze_screen(settings)
1167
1170
 
1168
- src = settings['src']
1171
+ srcs = settings['src']
1169
1172
 
1170
1173
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1171
1174
  display(settings_df)
1172
-
1173
- db_loc = [src+'/measurements/measurements.db']
1174
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1175
-
1176
- nuclei_limit, pathogen_limit = settings['nuclei_limit'], settings['pathogen_limit']
1177
1175
 
1178
- df, _ = _read_and_merge_data(db_loc,
1179
- tables,
1180
- settings['verbose'],
1181
- nuclei_limit,
1182
- pathogen_limit)
1176
+ if isinstance(srcs, str):
1177
+ srcs = [srcs]
1178
+
1179
+ df = pd.DataFrame()
1180
+ for idx, src in enumerate(srcs):
1181
+
1182
+ if idx == 0:
1183
+ src1 = src
1184
+
1185
+ db_loc = [src+'/measurements/measurements.db']
1186
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1187
+
1188
+ dft, _ = _read_and_merge_data(db_loc,
1189
+ tables,
1190
+ settings['verbose'],
1191
+ nuclei_limit=settings['nuclei_limit'],
1192
+ pathogen_limit=settings['pathogen_limit'])
1193
+ df = pd.concat([df, dft])
1183
1194
 
1184
1195
  if settings['annotation_column'] is not None:
1185
1196
 
@@ -1191,6 +1202,7 @@ def generate_ml_scores(settings):
1191
1202
  annotated_df = png_list_df[['prcfo', settings['annotation_column']]].set_index('prcfo')
1192
1203
  df = annotated_df.merge(df, left_index=True, right_index=True)
1193
1204
  unique_values = df[settings['annotation_column']].dropna().unique()
1205
+
1194
1206
  if len(unique_values) == 1:
1195
1207
  unannotated_rows = df[df[settings['annotation_column']].isna()].index
1196
1208
  existing_value = unique_values[0]
@@ -1213,8 +1225,8 @@ def generate_ml_scores(settings):
1213
1225
  df[settings['annotation_column']] = df[settings['annotation_column']].apply(str)
1214
1226
 
1215
1227
  if settings['channel_of_interest'] in [0,1,2,3]:
1216
-
1217
- df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
1228
+ if f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity" and f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity" in df.columns:
1229
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
1218
1230
 
1219
1231
  output, figs = ml_analysis(df,
1220
1232
  settings['channel_of_interest'],
@@ -1224,18 +1236,24 @@ def generate_ml_scores(settings):
1224
1236
  settings['exclude'],
1225
1237
  settings['n_repeats'],
1226
1238
  settings['top_features'],
1239
+ settings['reg_alpha'],
1240
+ settings['reg_lambda'],
1241
+ settings['learning_rate'],
1227
1242
  settings['n_estimators'],
1228
1243
  settings['test_size'],
1229
1244
  settings['model_type_ml'],
1230
1245
  settings['n_jobs'],
1231
1246
  settings['remove_low_variance_features'],
1232
1247
  settings['remove_highly_correlated_features'],
1248
+ settings['prune_features'],
1249
+ settings['cross_validation'],
1233
1250
  settings['verbose'])
1234
1251
 
1235
1252
  shap_fig = shap_analysis(output[3], output[4], output[5])
1236
1253
 
1237
1254
  features = output[0].select_dtypes(include=[np.number]).columns.tolist()
1238
-
1255
+ train_features_df = pd.DataFrame(output[9], columns=['feature'])
1256
+
1239
1257
  if not settings['heatmap_feature'] in features:
1240
1258
  raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
1241
1259
 
@@ -1247,15 +1265,16 @@ def generate_ml_scores(settings):
1247
1265
  min_count=settings['minimum_cell_count'],
1248
1266
  verbose=settings['verbose'])
1249
1267
 
1250
- data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
1251
- df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
1268
+ data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv, ml_features = get_ml_results_paths(src1, settings['model_type_ml'], settings['channel_of_interest'])
1269
+ df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df, _ = output
1252
1270
 
1253
1271
  settings_df.to_csv(settings_csv, index=False)
1254
1272
  df.to_csv(data_path, mode='w', encoding='utf-8')
1255
1273
  permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
1256
1274
  feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
1275
+ train_features_df.to_csv(ml_features, mode='w', encoding='utf-8')
1257
1276
  metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
1258
-
1277
+
1259
1278
  plate_heatmap.savefig(plate_heatmap_path, format='pdf')
1260
1279
  figs[0].savefig(permutation_fig_path, format='pdf')
1261
1280
  figs[1].savefig(feature_importance_fig_path, format='pdf')
@@ -1263,7 +1282,7 @@ def generate_ml_scores(settings):
1263
1282
 
1264
1283
  if settings['save_to_db']:
1265
1284
  settings['csv_path'] = data_path
1266
- settings['db_path'] = os.path.join(src, 'measurements', 'measurements.db')
1285
+ settings['db_path'] = os.path.join(src1, 'measurements', 'measurements.db')
1267
1286
  settings['table_name'] = 'png_list'
1268
1287
  settings['update_column'] = 'predictions'
1269
1288
  settings['match_column'] = 'prcfo'
@@ -1271,7 +1290,7 @@ def generate_ml_scores(settings):
1271
1290
 
1272
1291
  return [output, plate_heatmap]
1273
1292
 
1274
- def ml_analysis(df, channel_of_interest=3, location_column='column_name', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=100, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, verbose=False):
1293
+ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, reg_alpha=0.1, reg_lambda=1.0, learning_rate=0.00001, n_estimators=1000, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, prune_features=False, cross_validation=False, verbose=False):
1275
1294
 
1276
1295
  """
1277
1296
  Calculates permutation importance for numerical features in the dataframe,
@@ -1313,7 +1332,8 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
1313
1332
  if verbose:
1314
1333
  print(f'Found {len(features)} numerical features in the dataframe')
1315
1334
  print(f'Features used in training: {features}')
1316
-
1335
+ print(f'Features: {features}')
1336
+
1317
1337
  df = pd.concat([df, df_metadata[location_column]], axis=1)
1318
1338
 
1319
1339
  # Subset the dataframe based on specified column values
@@ -1327,14 +1347,26 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
1327
1347
  # Combine the subsets for analysis
1328
1348
  combined_df = pd.concat([df1, df2])
1329
1349
  combined_df = combined_df.drop(columns=[location_column])
1350
+
1330
1351
  if verbose:
1331
1352
  print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
1332
-
1353
+
1333
1354
  X = combined_df[features]
1334
1355
  y = combined_df['target']
1335
-
1336
- print(X)
1337
- print(y)
1356
+
1357
+ if prune_features:
1358
+ before_pruning = len(X.columns)
1359
+ selector = SelectKBest(score_func=f_classif, k=top_features)
1360
+ X_selected = selector.fit_transform(X, y)
1361
+
1362
+ # Get the selected feature names
1363
+ selected_features = X.columns[selector.get_support()]
1364
+ X = pd.DataFrame(X_selected, columns=selected_features, index=X.index)
1365
+
1366
+ features = selected_features.tolist()
1367
+
1368
+ after_pruning = len(X.columns)
1369
+ print(f"Removed {before_pruning - after_pruning} features using SelectKBest")
1338
1370
 
1339
1371
  # Split the data into training and testing sets
1340
1372
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
@@ -1353,12 +1385,102 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
1353
1385
  elif model_type == 'gradient_boosting':
1354
1386
  model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
1355
1387
  elif model_type == 'xgboost':
1356
- model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
1388
+ model = XGBClassifier(reg_alpha=reg_alpha, reg_lambda=reg_lambda, learning_rate=learning_rate, n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
1389
+
1357
1390
  else:
1358
1391
  raise ValueError(f"Unsupported model_type: {model_type}")
1359
1392
 
1360
- model.fit(X_train, y_train)
1393
+ # Perform k-fold cross-validation
1394
+ if cross_validation:
1395
+
1396
+ # Cross-validation setup
1397
+ kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
1398
+ fold_metrics = []
1399
+
1400
+ for fold_idx, (train_index, test_index) in enumerate(kfold.split(X, y), start=1):
1401
+ X_train, X_test = X.iloc[train_index], X.iloc[test_index]
1402
+ y_train, y_test = y.iloc[train_index], y.iloc[test_index]
1403
+
1404
+ # Train the model
1405
+ model.fit(X_train, y_train)
1406
+
1407
+ # Predict for the current test set
1408
+ predictions_test = model.predict(X_test)
1409
+ combined_df.loc[X_test.index, 'predictions'] = predictions_test
1410
+
1411
+ # Get prediction probabilities for the test set
1412
+ prediction_probabilities_test = model.predict_proba(X_test)
1361
1413
 
1414
+ # Find the optimal threshold
1415
+ optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
1416
+ if verbose:
1417
+ print(f'Fold {fold_idx} - Optimal threshold: {optimal_threshold}')
1418
+
1419
+ # Assign predictions and probabilities to the test set in the DataFrame
1420
+ df.loc[X_test.index, 'predictions'] = predictions_test
1421
+ for i in range(prediction_probabilities_test.shape[1]):
1422
+ df.loc[X_test.index, f'prediction_probability_class_{i}'] = prediction_probabilities_test[:, i]
1423
+
1424
+ # Evaluate performance for the current fold
1425
+ fold_report = classification_report(y_test, predictions_test, output_dict=True)
1426
+ fold_metrics.append(pd.DataFrame(fold_report).transpose())
1427
+
1428
+ if verbose:
1429
+ print(f"Fold {fold_idx} Classification Report:")
1430
+ print(classification_report(y_test, predictions_test))
1431
+
1432
+ # Aggregate metrics across all folds
1433
+ metrics_df = pd.concat(fold_metrics).groupby(level=0).mean()
1434
+
1435
+ # Re-train on full data (X, y) and then apply to entire df
1436
+ model.fit(X, y)
1437
+ all_predictions = model.predict(df[features]) # Predict on entire df
1438
+ df['predictions'] = all_predictions
1439
+
1440
+ # Get prediction probabilities for all rows in df
1441
+ prediction_probabilities = model.predict_proba(df[features])
1442
+ for i in range(prediction_probabilities.shape[1]):
1443
+ df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
1444
+
1445
+ if verbose:
1446
+ print("\nFinal Classification Report on Full Dataset:")
1447
+ print(classification_report(y, all_predictions))
1448
+
1449
+ # Generate metrics DataFrame
1450
+ final_report_dict = classification_report(y, all_predictions, output_dict=True)
1451
+ metrics_df = pd.DataFrame(final_report_dict).transpose()
1452
+
1453
+ else:
1454
+ model.fit(X_train, y_train)
1455
+ # Predicting the target variable for the test set
1456
+ predictions_test = model.predict(X_test)
1457
+ combined_df.loc[X_test.index, 'predictions'] = predictions_test
1458
+
1459
+ # Get prediction probabilities for the test set
1460
+ prediction_probabilities_test = model.predict_proba(X_test)
1461
+
1462
+ # Find the optimal threshold
1463
+ optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
1464
+ if verbose:
1465
+ print(f'Optimal threshold: {optimal_threshold}')
1466
+
1467
+ # Predicting the target variable for all other rows in the dataframe
1468
+ X_all = df[features]
1469
+ all_predictions = model.predict(X_all)
1470
+ df['predictions'] = all_predictions
1471
+
1472
+ # Get prediction probabilities for all rows in the dataframe
1473
+ prediction_probabilities = model.predict_proba(X_all)
1474
+ for i in range(prediction_probabilities.shape[1]):
1475
+ df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
1476
+
1477
+ if verbose:
1478
+ print("\nClassification Report:")
1479
+ print(classification_report(y_test, predictions_test))
1480
+
1481
+ report_dict = classification_report(y_test, predictions_test, output_dict=True)
1482
+ metrics_df = pd.DataFrame(report_dict).transpose()
1483
+
1362
1484
  perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
1363
1485
 
1364
1486
  # Create a DataFrame for permutation importances
@@ -1387,40 +1509,13 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
1387
1509
  else:
1388
1510
  feature_importance_df = pd.DataFrame()
1389
1511
 
1390
- # Predicting the target variable for the test set
1391
- predictions_test = model.predict(X_test)
1392
- combined_df.loc[X_test.index, 'predictions'] = predictions_test
1393
-
1394
- # Get prediction probabilities for the test set
1395
- prediction_probabilities_test = model.predict_proba(X_test)
1396
-
1397
- # Find the optimal threshold
1398
- optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
1399
- if verbose:
1400
- print(f'Optimal threshold: {optimal_threshold}')
1401
-
1402
- # Predicting the target variable for all other rows in the dataframe
1403
- X_all = df[features]
1404
- all_predictions = model.predict(X_all)
1405
- df['predictions'] = all_predictions
1406
-
1407
- # Get prediction probabilities for all rows in the dataframe
1408
- prediction_probabilities = model.predict_proba(X_all)
1409
- for i in range(prediction_probabilities.shape[1]):
1410
- df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
1411
- if verbose:
1412
- print("\nClassification Report:")
1413
- print(classification_report(y_test, predictions_test))
1414
- report_dict = classification_report(y_test, predictions_test, output_dict=True)
1415
- metrics_df = pd.DataFrame(report_dict).transpose()
1416
-
1417
1512
  df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
1418
1513
 
1419
1514
  df['prcfo'] = df.index.astype(str)
1420
1515
  df[['plate', 'row_name', 'column_name', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
1421
1516
  df['prc'] = df['plate'] + '_' + df['row_name'] + '_' + df['column_name']
1422
1517
 
1423
- return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
1518
+ return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df, features], [permutation_fig, feature_importance_fig]
1424
1519
 
1425
1520
  def shap_analysis(model, X_train, X_test):
1426
1521
 
@@ -1495,9 +1590,9 @@ def _calculate_similarity(df, features, col_to_compare, val1, val2):
1495
1590
  inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
1496
1591
 
1497
1592
  # Calculate similarity scores
1498
- def safe_similarity(func, row, control):
1593
+ def safe_similarity(func, row, control, *args, **kwargs):
1499
1594
  try:
1500
- return func(row, control)
1595
+ return func(row, control, *args, **kwargs)
1501
1596
  except Exception:
1502
1597
  return np.nan
1503
1598
 
spacr/settings.py CHANGED
@@ -283,7 +283,10 @@ def set_default_analyze_screen(settings):
283
283
  settings.setdefault('cmap','viridis')
284
284
  settings.setdefault('channel_of_interest',3)
285
285
  settings.setdefault('minimum_cell_count',25)
286
- settings.setdefault('n_estimators',100)
286
+ settings.setdefault('reg_alpha',0.1)
287
+ settings.setdefault('reg_lambda',1.0)
288
+ settings.setdefault('learning_rate',0.001)
289
+ settings.setdefault('n_estimators',1000)
287
290
  settings.setdefault('test_size',0.2)
288
291
  settings.setdefault('location_column','column_name')
289
292
  settings.setdefault('positive_control','c2')
@@ -296,6 +299,8 @@ def set_default_analyze_screen(settings):
296
299
  settings.setdefault('remove_low_variance_features',True)
297
300
  settings.setdefault('remove_highly_correlated_features',True)
298
301
  settings.setdefault('n_jobs',-1)
302
+ settings.setdefault('prune_features',False)
303
+ settings.setdefault('cross_validation',True)
299
304
  settings.setdefault('verbose',True)
300
305
  return settings
301
306
 
@@ -872,6 +877,13 @@ expected_types = {
872
877
  "target_layer":str,
873
878
  "save_to_db":bool,
874
879
  "test_mode":bool,
880
+ "test_images":int,
881
+ "remove_background_cell":bool,
882
+ "remove_background_nucleus":bool,
883
+ "remove_background_pathogen":bool,
884
+ "figuresize":int,
885
+ "cmap":str,
886
+ "pathogen_model":str,
875
887
  "normalize_input":bool,
876
888
  }
877
889
 
spacr/utils.py CHANGED
@@ -4277,6 +4277,12 @@ def filter_dataframe_features(df, channel_of_interest, exclude=None, remove_low_
4277
4277
 
4278
4278
  if remove_highly_correlated_features:
4279
4279
  df = remove_highly_correlated_columns(df, threshold=0.95, verbose=verbose)
4280
+
4281
+ # Remove columns with NaN values
4282
+ before_drop_NaN = len(df.columns)
4283
+ df = df.dropna(axis=1)
4284
+ after_drop_NaN = len(df.columns)
4285
+ print(f"Dropped {before_drop_NaN - after_drop_NaN} columns with NaN values")
4280
4286
 
4281
4287
  # Select numerical features
4282
4288
  features = df.select_dtypes(include=[np.number]).columns.tolist()
@@ -4759,7 +4765,8 @@ def get_ml_results_paths(src, model_type='xgboost', channel_of_interest=1):
4759
4765
  shap_fig_path = os.path.join(res_fldr, 'shap.pdf')
4760
4766
  plate_heatmap_path = os.path.join(res_fldr, 'plate_heatmap.pdf')
4761
4767
  settings_csv = os.path.join(res_fldr, 'ml_settings.csv')
4762
- return data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv
4768
+ ml_features = os.path.join(res_fldr, 'ml_features.csv')
4769
+ return data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv, ml_features
4763
4770
 
4764
4771
  def augment_image(image):
4765
4772
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spacr
3
- Version: 0.3.72
3
+ Version: 0.3.80
4
4
  Summary: Spatial phenotype analysis of crisp screens (SpaCr)
5
5
  Home-page: https://github.com/EinarOlafsson/spacr
6
6
  Author: Einar Birnir Olafsson
@@ -12,24 +12,24 @@ spacr/chat_bot.py,sha256=n3Fhqg3qofVXHmh3H9sUcmfYy9MmgRnr48663MVdY9E,1244
12
12
  spacr/core.py,sha256=3u2qKmPmTlswvE1uKTF4gi7KQ3sJBHV9No_ysgk7JCU,48487
13
13
  spacr/deep_spacr.py,sha256=V3diLyxX-0_F5UxhX_b94ROOvL9eoLvnoUmF3nMBqPQ,43250
14
14
  spacr/gui.py,sha256=ARyn9Q_g8HoP-cXh1nzMLVFCKqthY4v2u9yORyaQqQE,8230
15
- spacr/gui_core.py,sha256=N7R7yvfK_dJhOReM_kW3Ci8Bokhi1OzsxeKqvSGdvV4,41460
16
- spacr/gui_elements.py,sha256=EKlvEg_4_je7jciEdR3NTgPrcTraowa2e2RUt-xqd6M,138254
15
+ spacr/gui_core.py,sha256=6NKv8ebqC9Zuior4f2-L1By_Pjtt-RPCrEgnRuE9P54,45576
16
+ spacr/gui_elements.py,sha256=I_eSYF1RkAG0zsa-ZiQT0EaaVvUpucULCuWCowO6t4E,138248
17
17
  spacr/gui_utils.py,sha256=u9RoIOWpAXFEOnUlLpMQZrc1pWSg6omZsJMIhJdRv_g,41211
18
18
  spacr/io.py,sha256=LF6lpphw7GSeuoHQijPykjKNF56wNTFEWFZuDQp3O6Q,145739
19
19
  spacr/logger.py,sha256=lJhTqt-_wfAunCPl93xE65Wr9Y1oIHJWaZMjunHUeIw,1538
20
20
  spacr/measure.py,sha256=2lK-ZcTxLM-MpXV1oZnucRD9iz5aprwahRKw9IEqshg,55085
21
21
  spacr/mediar.py,sha256=FwLvbLQW5LQzPgvJZG8Lw7GniA2vbZx6Jv6vIKu7I5c,14743
22
- spacr/ml.py,sha256=h0IrXoNnyNzZLPYbtZPFI6c4Qeu1gH8R3iUz_O7-ar0,78114
22
+ spacr/ml.py,sha256=x19S8OsR5omb8e6MU9I99Nz95J_QvM5siyk-zaAU3p8,82866
23
23
  spacr/openai.py,sha256=5vBZ3Jl2llYcW3oaTEXgdyCB2aJujMUIO5K038z7w_A,1246
24
24
  spacr/plot.py,sha256=gXC7y3uT4sx8KRODeSFWQG_A1CylsuJ5B7HYe_un6so,165177
25
25
  spacr/sequencing.py,sha256=ClUfwPPK6rNUbUuiEkzcwakzVyDKKUMv9ricrxT8qQY,25227
26
- spacr/settings.py,sha256=14PFxw3YK9tUqbaC6BqfbrWk3sN7gyTZAAI8KNy5KBA,80461
26
+ spacr/settings.py,sha256=xTFTD04H8uXRJ5m4Pnr4Znhx0f_FxdgStMPXol3apxM,80888
27
27
  spacr/sim.py,sha256=1xKhXimNU3ukzIw-3l9cF3Znc_brW8h20yv8fSTzvss,71173
28
28
  spacr/stats.py,sha256=mbhwsyIqt5upsSD346qGjdCw7CFBa0tIS7zHU9e0jNI,9536
29
29
  spacr/submodules.py,sha256=SK8YEs850LAx30YAiwap7ecLpp1_p-bci6H-Or0GLoA,55500
30
30
  spacr/timelapse.py,sha256=KGfG4L4-QnFfgbF7L6C5wL_3gd_rqr05Foje6RsoTBg,39603
31
31
  spacr/toxo.py,sha256=z2nT5aAze3NUIlwnBQcnkARihDwoPfqOgQIVoUluyK0,25087
32
- spacr/utils.py,sha256=LX2Hu6QC-yG9ZVBiM2dkSN9yytCB0eTTRGfExiZzYzE,221940
32
+ spacr/utils.py,sha256=SiUcctyUETEX_GZ-Nflba5whZiEjJynncaH-xcZPK1k,222242
33
33
  spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
34
34
  spacr/resources/MEDIAR/.gitignore,sha256=Ff1q9Nme14JUd-4Q3jZ65aeQ5X4uttptssVDgBVHYo8,152
35
35
  spacr/resources/MEDIAR/LICENSE,sha256=yEj_TRDLUfDpHDNM0StALXIt6mLqSgaV2hcCwa6_TcY,1065
@@ -152,9 +152,9 @@ spacr/resources/icons/umap.png,sha256=dOLF3DeLYy9k0nkUybiZMe1wzHQwLJFRmgccppw-8b
152
152
  spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif,sha256=Tl0ZUfZ_AYAbu0up_nO0tPRtF1BxXhWQ3T3pURBCCRo,7958528
153
153
  spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif,sha256=m8N-V71rA1TT4dFlENNg8s0Q0YEXXs8slIn7yObmZJQ,7958528
154
154
  spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif,sha256=Pbhk7xn-KUP6RSIhJsxQcrHFImBm3GEpLkzx7WOc-5M,7958528
155
- spacr-0.3.72.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
156
- spacr-0.3.72.dist-info/METADATA,sha256=Kt166mcmw6Hb0u47_tZVq1EiZuK3Z_aDC0T7jE41dnI,6032
157
- spacr-0.3.72.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
158
- spacr-0.3.72.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
159
- spacr-0.3.72.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
160
- spacr-0.3.72.dist-info/RECORD,,
155
+ spacr-0.3.80.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
156
+ spacr-0.3.80.dist-info/METADATA,sha256=Q0YV4N-C8XyUHH8HFW_k9ryAftcU8v9oMxNhgzvU8cA,6032
157
+ spacr-0.3.80.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
158
+ spacr-0.3.80.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
159
+ spacr-0.3.80.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
160
+ spacr-0.3.80.dist-info/RECORD,,
File without changes