spacr 0.3.80__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/gui_core.py CHANGED
@@ -4,6 +4,7 @@ from tkinter import ttk
4
4
  from tkinter import filedialog
5
5
  from multiprocessing import Process, Value, Queue, set_start_method
6
6
  from tkinter import ttk
7
+ import matplotlib
7
8
  from matplotlib.figure import Figure
8
9
  from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
9
10
  import numpy as np
@@ -323,52 +324,48 @@ def show_next_figure():
323
324
  index_control.set(figure_index)
324
325
  index_control.set_to(len(figures) - 1)
325
326
  display_figure(fig)
326
-
327
+
327
328
  def process_fig_queue():
328
329
  global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
329
330
  from .gui_elements import standardize_figure
330
331
 
331
- #print("process_fig_queue called", flush=True)
332
332
  try:
333
- got_new_figure = False
334
333
  while not fig_queue.empty():
335
334
  fig = fig_queue.get_nowait()
336
- #print("Got a figure from fig_queue", flush=True)
337
-
338
335
  if fig is None:
339
- print("Warning: Retrieved a None figure from fig_queue.", flush=True)
336
+ print("Warning: Retrieved a None figure from fig_queue.")
340
337
  continue
341
338
 
342
339
  # Standardize the figure appearance before adding it
343
340
  fig = standardize_figure(fig)
344
341
  figures.append(fig)
345
342
 
343
+ # OPTIONAL: Cap the size of the figures deque at 100
344
+ MAX_FIGURES = 100
345
+ while len(figures) > MAX_FIGURES:
346
+ # Discard the oldest figure
347
+ old_fig = figures.popleft()
348
+ # If needed, you could also close the figure to free memory:
349
+ matplotlib.pyplot.close(old_fig)
350
+
346
351
  # Update slider maximum
347
352
  index_control.set_to(len(figures) - 1)
348
- #print("New maximum slider value after adding a figure:", index_control.to, flush=True)
349
353
 
350
354
  # If no figure has been displayed yet
351
355
  if figure_index == -1:
352
356
  figure_index = 0
353
357
  display_figure(figures[figure_index])
354
358
  index_control.set(figure_index)
355
- #print("Displayed the first figure and set slider value to 0", flush=True)
356
-
357
- #got_new_figure = True
358
-
359
- #if not got_new_figure:
360
- # No new figures this time
361
- #print("No new figures found in the queue this iteration.", flush=True)
362
359
 
363
360
  except Exception as e:
364
- print("Exception in process_fig_queue:", e, flush=True)
361
+ print("Exception in process_fig_queue:", e)
365
362
  traceback.print_exc()
366
363
 
367
364
  finally:
368
365
  # Schedule process_fig_queue() to run again
369
366
  after_id = canvas_widget.after(uppdate_frequency, process_fig_queue)
370
367
  parent_frame.after_tasks.append(after_id)
371
- #print("process_fig_queue scheduled again", flush=True)
368
+
372
369
 
373
370
  def update_figure(value):
374
371
  from .gui_elements import standardize_figure
@@ -513,7 +510,7 @@ def import_settings(settings_type='mask'):
513
510
  #vars_dict = hide_all_settings(vars_dict, categories=None)
514
511
  csv_settings = read_settings_from_csv(csv_file_path)
515
512
  if settings_type == 'mask':
516
- settings = set_default_settings_preprocess_generate_masks(src='path', settings={})
513
+ settings = set_default_settings_preprocess_generate_masks(settings={})
517
514
  elif settings_type == 'measure':
518
515
  settings = get_measure_crop_settings(settings={})
519
516
  elif settings_type == 'classify':
@@ -565,7 +562,7 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
565
562
  settings_frame.grid_columnconfigure(0, weight=1)
566
563
 
567
564
  if settings_type == 'mask':
568
- settings = set_default_settings_preprocess_generate_masks(src='path', settings={})
565
+ settings = set_default_settings_preprocess_generate_masks(settings={})
569
566
  elif settings_type == 'measure':
570
567
  settings = get_measure_crop_settings(settings={})
571
568
  elif settings_type == 'classify':
@@ -881,7 +878,7 @@ def start_process(q=None, fig_queue=None, settings_type='mask'):
881
878
  q.put(f"Error: {e}")
882
879
  return
883
880
 
884
- if thread_control.get("run_thread") is not None:
881
+ if isinstance(thread_control, dict) and thread_control.get("run_thread") is not None:
885
882
  initiate_abort()
886
883
 
887
884
  stop_requested = Value('i', 0)
@@ -987,6 +984,66 @@ def main_thread_update_function(root, q, fig_queue, canvas_widget):
987
984
  print(f"Error updating GUI canvas: {e}")
988
985
  finally:
989
986
  root.after(uppdate_frequency, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
987
+
988
+ def cleanup_previous_instance():
989
+ """
990
+ Cleans up resources from the previous application instance.
991
+ """
992
+ global parent_frame, usage_bars, figures, figure_index, thread_control, canvas, q, fig_queue
993
+
994
+ # 1. Destroy all widgets in the parent frame
995
+ if parent_frame is not None:
996
+ for widget in parent_frame.winfo_children():
997
+ try:
998
+ widget.destroy()
999
+ except Exception as e:
1000
+ print(f"Error destroying widget: {e}")
1001
+ parent_frame.update_idletasks()
1002
+ parent_frame = None
1003
+
1004
+ # 2. Cancel all pending `after` tasks
1005
+ if parent_frame is not None:
1006
+ parent_window = parent_frame.winfo_toplevel()
1007
+ if hasattr(parent_window, 'after_tasks'):
1008
+ for after_id in parent_window.after_tasks:
1009
+ parent_window.after_cancel(after_id)
1010
+ parent_window.after_tasks = []
1011
+
1012
+ # 3. Clear global queues
1013
+ if q is not None:
1014
+ while not q.empty():
1015
+ q.get()
1016
+ q = None
1017
+
1018
+ if fig_queue is not None:
1019
+ while not fig_queue.empty():
1020
+ fig_queue.get()
1021
+ fig_queue = None
1022
+
1023
+ # 4. Stop and reset global thread control
1024
+ if thread_control is not None:
1025
+ thread_control['stop'] = True
1026
+ #thread_control = None
1027
+
1028
+ # 5. Reset usage bars, figures, and indices
1029
+ usage_bars = []
1030
+ figures = deque()
1031
+ figure_index = -1
1032
+
1033
+ # 6. Clear canvas or other visualizations
1034
+ if canvas is not None:
1035
+ try:
1036
+ if hasattr(canvas, 'figure'): # Check if it's a FigureCanvasTkAgg
1037
+ canvas.figure.clear() # Clear the Matplotlib figure
1038
+ canvas.get_tk_widget().destroy() # Destroy the Tkinter widget
1039
+ else:
1040
+ # Assume it's a standard Tkinter Canvas
1041
+ canvas.delete("all")
1042
+ except Exception as e:
1043
+ print(f"Error clearing canvas: {e}")
1044
+ canvas = None
1045
+
1046
+ print("Previous instance cleaned up successfully.")
990
1047
 
991
1048
  def initiate_root(parent, settings_type='mask'):
992
1049
  """
@@ -1002,7 +1059,11 @@ def initiate_root(parent, settings_type='mask'):
1002
1059
 
1003
1060
  global q, fig_queue, thread_control, parent_frame, scrollable_frame, button_frame, vars_dict, canvas, canvas_widget, button_scrollable_frame, progress_bar, uppdate_frequency, figures, figure_index, index_control, usage_bars
1004
1061
 
1005
- from .gui_utils import setup_frame, get_screen_dimensions
1062
+ # Clean up any previous instance
1063
+ cleanup_previous_instance()
1064
+
1065
+ from .gui_utils import setup_frame
1066
+ from .gui_elements import create_menu_bar
1006
1067
  from .settings import descriptions
1007
1068
  #from .openai import Chatbot
1008
1069
 
@@ -1065,6 +1126,7 @@ def initiate_root(parent, settings_type='mask'):
1065
1126
 
1066
1127
  process_console_queue()
1067
1128
  process_fig_queue()
1129
+ create_menu_bar(parent)
1068
1130
  after_id = parent_window.after(uppdate_frequency, lambda: main_thread_update_function(parent_window, q, fig_queue, canvas_widget))
1069
1131
  parent_window.after_tasks.append(after_id)
1070
1132
 
spacr/gui_elements.py CHANGED
@@ -7,6 +7,7 @@ from tkinter import font
7
7
  from queue import Queue
8
8
  from tkinter import Label, Frame, Button
9
9
  import numpy as np
10
+ import pandas as pd
10
11
  from PIL import Image, ImageOps, ImageTk, ImageDraw, ImageFont, ImageEnhance
11
12
  from concurrent.futures import ThreadPoolExecutor
12
13
  from skimage.exposure import rescale_intensity
@@ -17,10 +18,28 @@ from skimage.draw import polygon, line
17
18
  from skimage.transform import resize
18
19
  from scipy.ndimage import binary_fill_holes, label
19
20
  from tkinter import ttk, scrolledtext
20
- from skimage.color import rgb2gray
21
+ from sklearn.model_selection import train_test_split
22
+ from xgboost import XGBClassifier
23
+ from sklearn.metrics import classification_report, confusion_matrix
21
24
 
22
25
  fig = None
23
26
 
27
+ def restart_gui_app(root):
28
+ """
29
+ Restarts the GUI application by destroying the current instance
30
+ and launching a fresh one.
31
+ """
32
+ try:
33
+ # Destroy the current root window
34
+ root.destroy()
35
+
36
+ # Import and launch a new instance of the application
37
+ from spacr.gui import gui_app
38
+ new_root = tk.Tk() # Create a fresh Tkinter root instance
39
+ gui_app()
40
+ except Exception as e:
41
+ print(f"Error restarting GUI application: {e}")
42
+
24
43
  def create_menu_bar(root):
25
44
  from .gui import initiate_root
26
45
  gui_apps = {
@@ -56,6 +75,7 @@ def create_menu_bar(root):
56
75
 
57
76
  # Add a separator and an exit option
58
77
  app_menu.add_separator()
78
+ #app_menu.add_command(label="Home",command=lambda: restart_gui_app(root))
59
79
  app_menu.add_command(label="Help", command=lambda: webbrowser.open("https://spacr.readthedocs.io/en/latest/?badge=latest"))
60
80
  app_menu.add_command(label="Exit", command=root.quit)
61
81
 
@@ -2201,7 +2221,8 @@ class AnnotateApp:
2201
2221
  self.image_size = (image_size, image_size)
2202
2222
  else:
2203
2223
  raise ValueError("Invalid image size")
2204
-
2224
+
2225
+ self.orig_annotation_columns = annotation_column
2205
2226
  self.annotation_column = annotation_column
2206
2227
  self.image_type = image_type
2207
2228
  self.channels = channels
@@ -2258,6 +2279,12 @@ class AnnotateApp:
2258
2279
 
2259
2280
  self.exit_button = Button(self.button_frame, text="Exit", command=self.shutdown, bg=self.bg_color, fg=self.fg_color, highlightbackground=self.fg_color, highlightcolor=self.fg_color, highlightthickness=1)
2260
2281
  self.exit_button.pack(side="right", padx=5)
2282
+
2283
+ self.train_button = Button(self.button_frame,text="Train & Classify (beta)",command=self.train_and_classify,bg=self.bg_color,fg=self.fg_color,highlightbackground=self.fg_color,highlightcolor=self.fg_color,highlightthickness=1)
2284
+ self.train_button.pack(side="right", padx=5)
2285
+
2286
+ self.train_button = Button(self.button_frame,text="orig.",command=self.swich_back_annotation_column,bg=self.bg_color,fg=self.fg_color,highlightbackground=self.fg_color,highlightcolor=self.fg_color,highlightthickness=1)
2287
+ self.train_button.pack(side="right", padx=5)
2261
2288
 
2262
2289
  # Calculate grid rows and columns based on the root window size and image size
2263
2290
  self.calculate_grid_dimensions()
@@ -2280,7 +2307,12 @@ class AnnotateApp:
2280
2307
  self.grid_frame.grid_rowconfigure(row, weight=1)
2281
2308
  for col in range(self.grid_cols):
2282
2309
  self.grid_frame.grid_columnconfigure(col, weight=1)
2283
-
2310
+
2311
+ def swich_back_annotation_column(self):
2312
+ self.annotation_column = self.orig_annotation_columns
2313
+ self.prefilter_paths_annotations()
2314
+ self.update_display()
2315
+
2284
2316
  def calculate_grid_dimensions(self):
2285
2317
  window_width = self.root.winfo_width()
2286
2318
  window_height = self.root.winfo_height()
@@ -2603,6 +2635,163 @@ class AnnotateApp:
2603
2635
  print(f'Quit application')
2604
2636
  else:
2605
2637
  print('Waiting for pending updates to finish before quitting')
2638
+
2639
+ def train_and_classify(self):
2640
+ """
2641
+ 1) Merge data from the relevant DB tables (including png_list).
2642
+ 2) Collect manual annotations from png_list.<annotation_column> => 'manual_annotation'.
2643
+ - 1 => class=1, 2 => class=0 (for training).
2644
+ 3) If only one class is present, randomly sample unannotated images as the other class.
2645
+ 4) Train an XGBoost model.
2646
+ 5) Classify *all* rows -> fill XGboost_score (prob of class=1) & XGboost_annotation (1 or 2 if high confidence).
2647
+ 6) Write those columns back to sqlite, so every row in png_list has a score (and possibly an annotation).
2648
+ 7) Refresh the UI (prefilter_paths_annotations + load_images).
2649
+ """
2650
+
2651
+ # Optionally, update your GUI status label
2652
+ self.update_gui_text("Merging data...")
2653
+
2654
+ from spacr.io import _read_and_merge_data # Adapt to your actual import
2655
+
2656
+ # (1) Merge data
2657
+ merged_df, obj_df_ls = _read_and_merge_data(
2658
+ locs=[self.db_path],
2659
+ tables=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list'],
2660
+ verbose=False
2661
+ )
2662
+
2663
+ # (2) Load manual annotations from the DB
2664
+ conn = sqlite3.connect(self.db_path)
2665
+ c = conn.cursor()
2666
+ c.execute(f"SELECT png_path, {self.annotation_column} FROM png_list WHERE {self.annotation_column} IS NOT NULL")
2667
+ annotated_rows = c.fetchall() # e.g. [(png_path, 1 or 2), ...]
2668
+ conn.close()
2669
+
2670
+ # dict {png_path -> 1 or 2}
2671
+ annot_dict = dict(annotated_rows)
2672
+
2673
+ # Add 'manual_annotation' to merged_df
2674
+ merged_df['manual_annotation'] = merged_df['png_path'].map(annot_dict)
2675
+
2676
+ # Subset with manual labels
2677
+ annotated_df = merged_df.dropna(subset=['manual_annotation']).copy()
2678
+ # Convert "2" => "0" for binary classification
2679
+ annotated_df['manual_annotation'] = annotated_df['manual_annotation'].replace({2: 0}).astype(int)
2680
+
2681
+ # (3) Handle single-class scenario
2682
+ class_counts = annotated_df['manual_annotation'].value_counts()
2683
+ if len(class_counts) == 1:
2684
+ single_class = class_counts.index[0] # 0 or 1
2685
+ needed = class_counts.iloc[0]
2686
+ other_class = 1 if single_class == 0 else 0
2687
+
2688
+ unannotated_df_all = merged_df[merged_df['manual_annotation'].isna()].copy()
2689
+ if len(unannotated_df_all) == 0:
2690
+ print("No unannotated rows to sample for the other class. Cannot proceed.")
2691
+ self.update_gui_text("Not enough data to train (no second class).")
2692
+ return
2693
+
2694
+ sample_size = min(needed, len(unannotated_df_all))
2695
+ artificially_labeled = unannotated_df_all.sample(n=sample_size, replace=False).copy()
2696
+ artificially_labeled['manual_annotation'] = other_class
2697
+
2698
+ annotated_df = pd.concat([annotated_df, artificially_labeled], ignore_index=True)
2699
+ print(f"Only one class was present => randomly labeled {sample_size} unannotated rows as {other_class}.")
2700
+
2701
+ if len(annotated_df) < 2:
2702
+ print("Not enough annotated data to train (need at least 2).")
2703
+ self.update_gui_text("Not enough data to train.")
2704
+ return
2705
+
2706
+ # (4) Train XGBoost
2707
+ self.update_gui_text("Training XGBoost model...")
2708
+
2709
+ # Identify numeric columns
2710
+ ignore_cols = {'png_path', 'manual_annotation'}
2711
+ feature_cols = [
2712
+ col for col in annotated_df.columns
2713
+ if col not in ignore_cols
2714
+ and (annotated_df[col].dtype == float or annotated_df[col].dtype == int)
2715
+ ]
2716
+
2717
+ X_data = annotated_df[feature_cols].fillna(0).values
2718
+ y_data = annotated_df['manual_annotation'].values
2719
+
2720
+ # standard train/test
2721
+ X_train, X_test, y_train, y_test = train_test_split(
2722
+ X_data, y_data, test_size=0.1, random_state=42
2723
+ )
2724
+ model = XGBClassifier(use_label_encoder=False, eval_metric='logloss')
2725
+ model.fit(X_train, y_train)
2726
+
2727
+ # Evaluate
2728
+ preds = model.predict(X_test)
2729
+ print("=== Classification Report ===")
2730
+ print(classification_report(y_test, preds))
2731
+ print("=== Confusion Matrix ===")
2732
+ print(confusion_matrix(y_test, preds))
2733
+
2734
+ # (5) Classify ALL rows
2735
+ all_df = merged_df.copy()
2736
+ X_all = all_df[feature_cols].fillna(0).values
2737
+ probs_all = model.predict_proba(X_all)[:, 1]
2738
+ # Probability => XGboost_score
2739
+ all_df['XGboost_score'] = probs_all
2740
+
2741
+ # Decide XGboost_annotation
2742
+ def get_annotation_from_prob(prob):
2743
+ if prob > 0.9:
2744
+ return 1 # class=1
2745
+ elif prob < 0.1:
2746
+ return 0 # class=0
2747
+ return None # uncertain
2748
+
2749
+ xgb_anno_col = [get_annotation_from_prob(p) for p in probs_all]
2750
+ # Convert 0 => 2 if your DB uses "2" for the negative class
2751
+ xgb_anno_col = [2 if x == 0 else x for x in xgb_anno_col]
2752
+
2753
+ all_df['XGboost_annotation'] = xgb_anno_col
2754
+
2755
+ # (6) Write results back to png_list
2756
+ self.update_gui_text("Updating the database with XGBoost predictions...")
2757
+ conn = sqlite3.connect(self.db_path)
2758
+ c = conn.cursor()
2759
+ # Ensure columns exist
2760
+ try:
2761
+ c.execute("ALTER TABLE png_list ADD COLUMN XGboost_annotation INTEGER")
2762
+ except sqlite3.OperationalError:
2763
+ pass
2764
+ try:
2765
+ c.execute("ALTER TABLE png_list ADD COLUMN XGboost_score FLOAT")
2766
+ except sqlite3.OperationalError:
2767
+ pass
2768
+
2769
+ # Update each row
2770
+ for idx, row in all_df.iterrows():
2771
+ score_val = float(row['XGboost_score'])
2772
+ anno_val = row['XGboost_annotation']
2773
+ the_path = row['png_path']
2774
+ if pd.isna(the_path):
2775
+ continue # skip if no path
2776
+
2777
+ if pd.isna(anno_val):
2778
+ # We set annotation=NULL but do set the score
2779
+ c.execute("""
2780
+ UPDATE png_list
2781
+ SET XGboost_annotation = NULL,
2782
+ XGboost_score = ?
2783
+ WHERE png_path = ?
2784
+ """, (score_val, the_path))
2785
+ else:
2786
+ # numeric annotation + numeric score
2787
+ c.execute("""
2788
+ UPDATE png_list
2789
+ SET XGboost_annotation = ?,
2790
+ XGboost_score = ?
2791
+ WHERE png_path = ?
2792
+ """, (int(anno_val), score_val, the_path))
2793
+
2794
+ self.annotation_column = 'XGboost_annotation'
2606
2795
 
2607
2796
  def standardize_figure(fig):
2608
2797
  from .gui_elements import set_dark_style
spacr/gui_utils.py CHANGED
@@ -482,7 +482,7 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
482
482
 
483
483
  if settings_type == 'mask':
484
484
  function = preprocess_generate_masks
485
- imports = 2
485
+ imports = 1
486
486
  elif settings_type == 'measure':
487
487
  function = measure_crop
488
488
  imports = 1
spacr/io.py CHANGED
@@ -1773,7 +1773,7 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
1773
1773
  print(e)
1774
1774
  conn.close()
1775
1775
  if 'png_list' in dataframes:
1776
- png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row_name', 'column_name']].copy()
1776
+ png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row_name', 'column_name', 'field']].copy()
1777
1777
  png_list_df['cell_id'] = png_list_df['cell_id'].str[1:].astype(int)
1778
1778
  png_list_df.rename(columns={'cell_id': 'object_label'}, inplace=True)
1779
1779
  if 'cell' in dataframes:
@@ -2275,175 +2275,6 @@ def _read_db(db_loc, tables):
2275
2275
  dfs.append(df)
2276
2276
  conn.close() # Close the connection
2277
2277
  return dfs
2278
-
2279
- def _read_and_merge_data_v1(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False):
2280
-
2281
- from .utils import _split_data
2282
-
2283
- #Extract plate DataFrames
2284
- all_dfs = []
2285
- for loc in locs:
2286
- db_dfs = _read_db(loc, tables)
2287
- all_dfs.append(db_dfs)
2288
-
2289
- #Extract Tables from DataFrames and concatinate rows
2290
- for i, dfs in enumerate(all_dfs):
2291
- if 'cell' in tables:
2292
- cell = dfs[0]
2293
- if verbose:
2294
- print(f'plate: {i+1} cells:{len(cell)}')
2295
- # see pathogens logic, copy logic to other tables #here
2296
- if 'nucleus' in tables:
2297
- nucleus = dfs[1]
2298
- if verbose:
2299
- print(f'plate: {i+1} nucleus:{len(nucleus)} ')
2300
-
2301
- if 'pathogen' in tables:
2302
- if len(tables) == 1:
2303
- pathogen = dfs[0]
2304
- print(len(pathogen))
2305
- else:
2306
- pathogen = dfs[2]
2307
- if verbose:
2308
- print(f'plate: {i+1} pathogens:{len(pathogen)}')
2309
-
2310
- if 'cytoplasm' in tables:
2311
- if not 'pathogen' in tables:
2312
- cytoplasm = dfs[2]
2313
- else:
2314
- cytoplasm = dfs[3]
2315
- if verbose:
2316
- print(f'plate: {i+1} cytoplasms: {len(cytoplasm)}')
2317
-
2318
- if i > 0:
2319
- if 'cell' in tables:
2320
- cells = pd.concat([cells, cell], axis = 0)
2321
- if 'nucleus' in tables:
2322
- nucleus = pd.concat([nucleus, nucleus], axis = 0)
2323
- if 'pathogen' in tables:
2324
- pathogens = pd.concat([pathogens, pathogen], axis = 0)
2325
- if 'cytoplasm' in tables:
2326
- cytoplasms = pd.concat([cytoplasms, cytoplasm], axis = 0)
2327
- else:
2328
- if 'cell' in tables:
2329
- cells = cell.copy()
2330
- if 'nucleus' in tables:
2331
- nucleus = nucleus.copy()
2332
- if 'pathogen' in tables:
2333
- pathogens = pathogen.copy()
2334
- if 'cytoplasm' in tables:
2335
- cytoplasms = cytoplasm.copy()
2336
-
2337
- #Add an o in front of all object and cell lables to convert them to strings
2338
- if 'cell' in tables:
2339
- cells = cells.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2340
- cells = cells.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
2341
- cells_g_df, metadata = _split_data(cells, 'prcfo', 'object_label')
2342
- merged_df = cells_g_df.copy()
2343
- if verbose:
2344
- print(f'cells: {len(cells)}')
2345
- print(f'cells grouped: {len(cells_g_df)}')
2346
-
2347
- if 'cytoplasm' in tables:
2348
- cytoplasms = cytoplasms.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2349
- cytoplasms = cytoplasms.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
2350
- cytoplasms_g_df, _ = _split_data(cytoplasms, 'prcfo', 'object_label')
2351
- merged_df = cells_g_df.merge(cytoplasms_g_df, left_index=True, right_index=True)
2352
- if verbose:
2353
- print(f'cytoplasms: {len(cytoplasms)}')
2354
- print(f'cytoplasms grouped: {len(cytoplasms_g_df)}')
2355
-
2356
- if 'nucleus' in tables:
2357
- if not 'cell' in tables:
2358
- cells_g_df = pd.DataFrame()
2359
- nucleus = nucleus.dropna(subset=['cell_id'])
2360
- nucleus = nucleus.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2361
- nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2362
- nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2363
- nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
2364
- if nuclei_limit == False:
2365
- nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
2366
- nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
2367
- if verbose:
2368
- print(f'nucleus: {len(nucleus)}')
2369
- print(f'nucleus grouped: {len(nucleus_g_df)}')
2370
- if 'cytoplasm' in tables:
2371
- merged_df = merged_df.merge(nucleus_g_df, left_index=True, right_index=True)
2372
- else:
2373
- merged_df = cells_g_df.merge(nucleus_g_df, left_index=True, right_index=True)
2374
-
2375
- if 'pathogen' in tables:
2376
- if not 'cell' in tables:
2377
- cells_g_df = pd.DataFrame()
2378
- merged_df = []
2379
- try:
2380
- pathogens = pathogens.dropna(subset=['cell_id'])
2381
-
2382
- except:
2383
- pathogens['cell_id'] = pathogens['object_label']
2384
- pathogens = pathogens.dropna(subset=['cell_id'])
2385
-
2386
- pathogens = pathogens.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
2387
- pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2388
- pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2389
- pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
2390
-
2391
- if isinstance(pathogen_limit, bool):
2392
- if pathogen_limit == False:
2393
- pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
2394
- print(f"after multiinfected Bool: {len(pathogens)}")
2395
- if isinstance(pathogen_limit, float):
2396
- pathogen_limit = int(pathogen_limit)
2397
- if isinstance(pathogen_limit, int):
2398
- pathogens = pathogens[pathogens['pathogen_prcfo_count']<=pathogen_limit]
2399
- print(f"afer multiinfected Float: {len(pathogens)}")
2400
- if not 'cell' in tables:
2401
- pathogens_g_df, metadata = _split_data(pathogens, 'prcfo', 'cell_id')
2402
- else:
2403
- pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
2404
-
2405
- if verbose:
2406
- print(f'pathogens: {len(pathogens)}')
2407
- print(f'pathogens grouped: {len(pathogens_g_df)}')
2408
-
2409
- if len(merged_df) == 0:
2410
- merged_df = pathogens_g_df
2411
- else:
2412
- merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
2413
-
2414
- #Add prc column (plate row column)
2415
- metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name'])
2416
-
2417
- #Count cells per well
2418
- cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
2419
-
2420
- cells_well.reset_index(inplace=True)
2421
- cells_well.rename(columns={'object_label': 'cells_per_well'}, inplace=True)
2422
- metadata = pd.merge(metadata, cells_well, on='prc', how='inner', suffixes=('', '_drop_col'))
2423
- object_label_cols = [col for col in metadata.columns if '_drop_col' in col]
2424
- metadata.drop(columns=object_label_cols, inplace=True)
2425
-
2426
- #Add prcfo column (plate row column field object)
2427
- metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name']+ '_' +x['field']+ '_' +x['object_label'])
2428
- metadata.set_index('prcfo', inplace=True)
2429
-
2430
- merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
2431
-
2432
- merged_df = merged_df.dropna(axis=1)
2433
- if verbose:
2434
- print(f'Generated dataframe with: {len(merged_df.columns)} columns and {len(merged_df)} rows')
2435
-
2436
- obj_df_ls = []
2437
- if 'cell' in tables:
2438
- obj_df_ls.append(cells)
2439
- if 'cytoplasm' in tables:
2440
- obj_df_ls.append(cytoplasms)
2441
- if 'nucleus' in tables:
2442
- obj_df_ls.append(nucleus)
2443
- if 'pathogen' in tables:
2444
- obj_df_ls.append(pathogens)
2445
-
2446
- return merged_df, obj_df_ls
2447
2278
 
2448
2279
  def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=10, pathogen_limit=10, change_plate=False):
2449
2280
  from .io import _read_db
@@ -2929,6 +2760,7 @@ def generate_training_dataset(settings):
2929
2760
  def get_smallest_class_size(df, settings, dataset_mode):
2930
2761
  if dataset_mode == 'metadata':
2931
2762
  sizes = [len(df[df['condition'] == c]) for c in settings['class_metadata']]
2763
+ #sizes = [len(df[df['condition'].isin(class_list)]) for class_list in settings['class_metadata']]
2932
2764
  print(f'Class sizes: {sizes}')
2933
2765
  elif dataset_mode == 'annotation':
2934
2766
  sizes = [len(class_paths) for class_paths in df]
@@ -2997,16 +2829,12 @@ def generate_training_dataset(settings):
2997
2829
  df = df.dropna(subset=['condition'])
2998
2830
 
2999
2831
  display(df)
3000
-
3001
- #df['metadata_based_class'] = pd.NA
3002
- #for i, class_ in enumerate(settings['classes']):
3003
- # ls = settings['class_metadata'][i]
3004
- # df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
3005
2832
 
3006
2833
  size = get_smallest_class_size(df, settings, 'metadata')
3007
2834
 
3008
2835
  for class_ in settings['class_metadata']:
3009
2836
  class_temp_df = df[df['condition'] == class_]
2837
+ #class_temp_df = df[df['condition'].isin(class_)]
3010
2838
  print(f'Found {len(class_temp_df)} images for class {class_}')
3011
2839
  class_paths_temp = class_temp_df['png_path'].tolist()
3012
2840
 
@@ -3033,6 +2861,8 @@ def generate_training_dataset(settings):
3033
2861
  from .io import _read_and_merge_data, _read_db
3034
2862
  from .utils import get_paths_from_db, annotate_conditions, save_settings
3035
2863
  from .settings import set_generate_training_dataset_defaults
2864
+
2865
+ settings = set_generate_training_dataset_defaults(settings)
3036
2866
 
3037
2867
  if 'nucleus' not in settings['tables']:
3038
2868
  settings['nuclei_limit'] = False
@@ -3041,7 +2871,6 @@ def generate_training_dataset(settings):
3041
2871
  settings['pathogen_limit'] = 0
3042
2872
 
3043
2873
  # Set default settings and save
3044
- settings = set_generate_training_dataset_defaults(settings)
3045
2874
  save_settings(settings, 'cv_dataset', show=True)
3046
2875
 
3047
2876
  class_path_list = None