spacr 0.3.80__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +0 -4
- spacr/core.py +27 -13
- spacr/deep_spacr.py +378 -5
- spacr/gui_core.py +82 -20
- spacr/gui_elements.py +192 -3
- spacr/gui_utils.py +1 -1
- spacr/io.py +5 -176
- spacr/measure.py +10 -6
- spacr/ml.py +369 -46
- spacr/plot.py +201 -90
- spacr/settings.py +80 -21
- spacr/submodules.py +282 -1
- spacr/toxo.py +98 -75
- spacr/utils.py +144 -49
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/METADATA +2 -1
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/RECORD +20 -20
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/LICENSE +0 -0
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/WHEEL +0 -0
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/top_level.txt +0 -0
spacr/gui_core.py
CHANGED
@@ -4,6 +4,7 @@ from tkinter import ttk
|
|
4
4
|
from tkinter import filedialog
|
5
5
|
from multiprocessing import Process, Value, Queue, set_start_method
|
6
6
|
from tkinter import ttk
|
7
|
+
import matplotlib
|
7
8
|
from matplotlib.figure import Figure
|
8
9
|
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
9
10
|
import numpy as np
|
@@ -323,52 +324,48 @@ def show_next_figure():
|
|
323
324
|
index_control.set(figure_index)
|
324
325
|
index_control.set_to(len(figures) - 1)
|
325
326
|
display_figure(fig)
|
326
|
-
|
327
|
+
|
327
328
|
def process_fig_queue():
|
328
329
|
global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
|
329
330
|
from .gui_elements import standardize_figure
|
330
331
|
|
331
|
-
#print("process_fig_queue called", flush=True)
|
332
332
|
try:
|
333
|
-
got_new_figure = False
|
334
333
|
while not fig_queue.empty():
|
335
334
|
fig = fig_queue.get_nowait()
|
336
|
-
#print("Got a figure from fig_queue", flush=True)
|
337
|
-
|
338
335
|
if fig is None:
|
339
|
-
print("Warning: Retrieved a None figure from fig_queue."
|
336
|
+
print("Warning: Retrieved a None figure from fig_queue.")
|
340
337
|
continue
|
341
338
|
|
342
339
|
# Standardize the figure appearance before adding it
|
343
340
|
fig = standardize_figure(fig)
|
344
341
|
figures.append(fig)
|
345
342
|
|
343
|
+
# OPTIONAL: Cap the size of the figures deque at 100
|
344
|
+
MAX_FIGURES = 100
|
345
|
+
while len(figures) > MAX_FIGURES:
|
346
|
+
# Discard the oldest figure
|
347
|
+
old_fig = figures.popleft()
|
348
|
+
# If needed, you could also close the figure to free memory:
|
349
|
+
matplotlib.pyplot.close(old_fig)
|
350
|
+
|
346
351
|
# Update slider maximum
|
347
352
|
index_control.set_to(len(figures) - 1)
|
348
|
-
#print("New maximum slider value after adding a figure:", index_control.to, flush=True)
|
349
353
|
|
350
354
|
# If no figure has been displayed yet
|
351
355
|
if figure_index == -1:
|
352
356
|
figure_index = 0
|
353
357
|
display_figure(figures[figure_index])
|
354
358
|
index_control.set(figure_index)
|
355
|
-
#print("Displayed the first figure and set slider value to 0", flush=True)
|
356
|
-
|
357
|
-
#got_new_figure = True
|
358
|
-
|
359
|
-
#if not got_new_figure:
|
360
|
-
# No new figures this time
|
361
|
-
#print("No new figures found in the queue this iteration.", flush=True)
|
362
359
|
|
363
360
|
except Exception as e:
|
364
|
-
print("Exception in process_fig_queue:", e
|
361
|
+
print("Exception in process_fig_queue:", e)
|
365
362
|
traceback.print_exc()
|
366
363
|
|
367
364
|
finally:
|
368
365
|
# Schedule process_fig_queue() to run again
|
369
366
|
after_id = canvas_widget.after(uppdate_frequency, process_fig_queue)
|
370
367
|
parent_frame.after_tasks.append(after_id)
|
371
|
-
|
368
|
+
|
372
369
|
|
373
370
|
def update_figure(value):
|
374
371
|
from .gui_elements import standardize_figure
|
@@ -513,7 +510,7 @@ def import_settings(settings_type='mask'):
|
|
513
510
|
#vars_dict = hide_all_settings(vars_dict, categories=None)
|
514
511
|
csv_settings = read_settings_from_csv(csv_file_path)
|
515
512
|
if settings_type == 'mask':
|
516
|
-
settings = set_default_settings_preprocess_generate_masks(
|
513
|
+
settings = set_default_settings_preprocess_generate_masks(settings={})
|
517
514
|
elif settings_type == 'measure':
|
518
515
|
settings = get_measure_crop_settings(settings={})
|
519
516
|
elif settings_type == 'classify':
|
@@ -565,7 +562,7 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
|
|
565
562
|
settings_frame.grid_columnconfigure(0, weight=1)
|
566
563
|
|
567
564
|
if settings_type == 'mask':
|
568
|
-
settings = set_default_settings_preprocess_generate_masks(
|
565
|
+
settings = set_default_settings_preprocess_generate_masks(settings={})
|
569
566
|
elif settings_type == 'measure':
|
570
567
|
settings = get_measure_crop_settings(settings={})
|
571
568
|
elif settings_type == 'classify':
|
@@ -881,7 +878,7 @@ def start_process(q=None, fig_queue=None, settings_type='mask'):
|
|
881
878
|
q.put(f"Error: {e}")
|
882
879
|
return
|
883
880
|
|
884
|
-
if thread_control.get("run_thread") is not None:
|
881
|
+
if isinstance(thread_control, dict) and thread_control.get("run_thread") is not None:
|
885
882
|
initiate_abort()
|
886
883
|
|
887
884
|
stop_requested = Value('i', 0)
|
@@ -987,6 +984,66 @@ def main_thread_update_function(root, q, fig_queue, canvas_widget):
|
|
987
984
|
print(f"Error updating GUI canvas: {e}")
|
988
985
|
finally:
|
989
986
|
root.after(uppdate_frequency, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
|
987
|
+
|
988
|
+
def cleanup_previous_instance():
|
989
|
+
"""
|
990
|
+
Cleans up resources from the previous application instance.
|
991
|
+
"""
|
992
|
+
global parent_frame, usage_bars, figures, figure_index, thread_control, canvas, q, fig_queue
|
993
|
+
|
994
|
+
# 1. Destroy all widgets in the parent frame
|
995
|
+
if parent_frame is not None:
|
996
|
+
for widget in parent_frame.winfo_children():
|
997
|
+
try:
|
998
|
+
widget.destroy()
|
999
|
+
except Exception as e:
|
1000
|
+
print(f"Error destroying widget: {e}")
|
1001
|
+
parent_frame.update_idletasks()
|
1002
|
+
parent_frame = None
|
1003
|
+
|
1004
|
+
# 2. Cancel all pending `after` tasks
|
1005
|
+
if parent_frame is not None:
|
1006
|
+
parent_window = parent_frame.winfo_toplevel()
|
1007
|
+
if hasattr(parent_window, 'after_tasks'):
|
1008
|
+
for after_id in parent_window.after_tasks:
|
1009
|
+
parent_window.after_cancel(after_id)
|
1010
|
+
parent_window.after_tasks = []
|
1011
|
+
|
1012
|
+
# 3. Clear global queues
|
1013
|
+
if q is not None:
|
1014
|
+
while not q.empty():
|
1015
|
+
q.get()
|
1016
|
+
q = None
|
1017
|
+
|
1018
|
+
if fig_queue is not None:
|
1019
|
+
while not fig_queue.empty():
|
1020
|
+
fig_queue.get()
|
1021
|
+
fig_queue = None
|
1022
|
+
|
1023
|
+
# 4. Stop and reset global thread control
|
1024
|
+
if thread_control is not None:
|
1025
|
+
thread_control['stop'] = True
|
1026
|
+
#thread_control = None
|
1027
|
+
|
1028
|
+
# 5. Reset usage bars, figures, and indices
|
1029
|
+
usage_bars = []
|
1030
|
+
figures = deque()
|
1031
|
+
figure_index = -1
|
1032
|
+
|
1033
|
+
# 6. Clear canvas or other visualizations
|
1034
|
+
if canvas is not None:
|
1035
|
+
try:
|
1036
|
+
if hasattr(canvas, 'figure'): # Check if it's a FigureCanvasTkAgg
|
1037
|
+
canvas.figure.clear() # Clear the Matplotlib figure
|
1038
|
+
canvas.get_tk_widget().destroy() # Destroy the Tkinter widget
|
1039
|
+
else:
|
1040
|
+
# Assume it's a standard Tkinter Canvas
|
1041
|
+
canvas.delete("all")
|
1042
|
+
except Exception as e:
|
1043
|
+
print(f"Error clearing canvas: {e}")
|
1044
|
+
canvas = None
|
1045
|
+
|
1046
|
+
print("Previous instance cleaned up successfully.")
|
990
1047
|
|
991
1048
|
def initiate_root(parent, settings_type='mask'):
|
992
1049
|
"""
|
@@ -1002,7 +1059,11 @@ def initiate_root(parent, settings_type='mask'):
|
|
1002
1059
|
|
1003
1060
|
global q, fig_queue, thread_control, parent_frame, scrollable_frame, button_frame, vars_dict, canvas, canvas_widget, button_scrollable_frame, progress_bar, uppdate_frequency, figures, figure_index, index_control, usage_bars
|
1004
1061
|
|
1005
|
-
|
1062
|
+
# Clean up any previous instance
|
1063
|
+
cleanup_previous_instance()
|
1064
|
+
|
1065
|
+
from .gui_utils import setup_frame
|
1066
|
+
from .gui_elements import create_menu_bar
|
1006
1067
|
from .settings import descriptions
|
1007
1068
|
#from .openai import Chatbot
|
1008
1069
|
|
@@ -1065,6 +1126,7 @@ def initiate_root(parent, settings_type='mask'):
|
|
1065
1126
|
|
1066
1127
|
process_console_queue()
|
1067
1128
|
process_fig_queue()
|
1129
|
+
create_menu_bar(parent)
|
1068
1130
|
after_id = parent_window.after(uppdate_frequency, lambda: main_thread_update_function(parent_window, q, fig_queue, canvas_widget))
|
1069
1131
|
parent_window.after_tasks.append(after_id)
|
1070
1132
|
|
spacr/gui_elements.py
CHANGED
@@ -7,6 +7,7 @@ from tkinter import font
|
|
7
7
|
from queue import Queue
|
8
8
|
from tkinter import Label, Frame, Button
|
9
9
|
import numpy as np
|
10
|
+
import pandas as pd
|
10
11
|
from PIL import Image, ImageOps, ImageTk, ImageDraw, ImageFont, ImageEnhance
|
11
12
|
from concurrent.futures import ThreadPoolExecutor
|
12
13
|
from skimage.exposure import rescale_intensity
|
@@ -17,10 +18,28 @@ from skimage.draw import polygon, line
|
|
17
18
|
from skimage.transform import resize
|
18
19
|
from scipy.ndimage import binary_fill_holes, label
|
19
20
|
from tkinter import ttk, scrolledtext
|
20
|
-
from
|
21
|
+
from sklearn.model_selection import train_test_split
|
22
|
+
from xgboost import XGBClassifier
|
23
|
+
from sklearn.metrics import classification_report, confusion_matrix
|
21
24
|
|
22
25
|
fig = None
|
23
26
|
|
27
|
+
def restart_gui_app(root):
|
28
|
+
"""
|
29
|
+
Restarts the GUI application by destroying the current instance
|
30
|
+
and launching a fresh one.
|
31
|
+
"""
|
32
|
+
try:
|
33
|
+
# Destroy the current root window
|
34
|
+
root.destroy()
|
35
|
+
|
36
|
+
# Import and launch a new instance of the application
|
37
|
+
from spacr.gui import gui_app
|
38
|
+
new_root = tk.Tk() # Create a fresh Tkinter root instance
|
39
|
+
gui_app()
|
40
|
+
except Exception as e:
|
41
|
+
print(f"Error restarting GUI application: {e}")
|
42
|
+
|
24
43
|
def create_menu_bar(root):
|
25
44
|
from .gui import initiate_root
|
26
45
|
gui_apps = {
|
@@ -56,6 +75,7 @@ def create_menu_bar(root):
|
|
56
75
|
|
57
76
|
# Add a separator and an exit option
|
58
77
|
app_menu.add_separator()
|
78
|
+
#app_menu.add_command(label="Home",command=lambda: restart_gui_app(root))
|
59
79
|
app_menu.add_command(label="Help", command=lambda: webbrowser.open("https://spacr.readthedocs.io/en/latest/?badge=latest"))
|
60
80
|
app_menu.add_command(label="Exit", command=root.quit)
|
61
81
|
|
@@ -2201,7 +2221,8 @@ class AnnotateApp:
|
|
2201
2221
|
self.image_size = (image_size, image_size)
|
2202
2222
|
else:
|
2203
2223
|
raise ValueError("Invalid image size")
|
2204
|
-
|
2224
|
+
|
2225
|
+
self.orig_annotation_columns = annotation_column
|
2205
2226
|
self.annotation_column = annotation_column
|
2206
2227
|
self.image_type = image_type
|
2207
2228
|
self.channels = channels
|
@@ -2258,6 +2279,12 @@ class AnnotateApp:
|
|
2258
2279
|
|
2259
2280
|
self.exit_button = Button(self.button_frame, text="Exit", command=self.shutdown, bg=self.bg_color, fg=self.fg_color, highlightbackground=self.fg_color, highlightcolor=self.fg_color, highlightthickness=1)
|
2260
2281
|
self.exit_button.pack(side="right", padx=5)
|
2282
|
+
|
2283
|
+
self.train_button = Button(self.button_frame,text="Train & Classify (beta)",command=self.train_and_classify,bg=self.bg_color,fg=self.fg_color,highlightbackground=self.fg_color,highlightcolor=self.fg_color,highlightthickness=1)
|
2284
|
+
self.train_button.pack(side="right", padx=5)
|
2285
|
+
|
2286
|
+
self.train_button = Button(self.button_frame,text="orig.",command=self.swich_back_annotation_column,bg=self.bg_color,fg=self.fg_color,highlightbackground=self.fg_color,highlightcolor=self.fg_color,highlightthickness=1)
|
2287
|
+
self.train_button.pack(side="right", padx=5)
|
2261
2288
|
|
2262
2289
|
# Calculate grid rows and columns based on the root window size and image size
|
2263
2290
|
self.calculate_grid_dimensions()
|
@@ -2280,7 +2307,12 @@ class AnnotateApp:
|
|
2280
2307
|
self.grid_frame.grid_rowconfigure(row, weight=1)
|
2281
2308
|
for col in range(self.grid_cols):
|
2282
2309
|
self.grid_frame.grid_columnconfigure(col, weight=1)
|
2283
|
-
|
2310
|
+
|
2311
|
+
def swich_back_annotation_column(self):
|
2312
|
+
self.annotation_column = self.orig_annotation_columns
|
2313
|
+
self.prefilter_paths_annotations()
|
2314
|
+
self.update_display()
|
2315
|
+
|
2284
2316
|
def calculate_grid_dimensions(self):
|
2285
2317
|
window_width = self.root.winfo_width()
|
2286
2318
|
window_height = self.root.winfo_height()
|
@@ -2603,6 +2635,163 @@ class AnnotateApp:
|
|
2603
2635
|
print(f'Quit application')
|
2604
2636
|
else:
|
2605
2637
|
print('Waiting for pending updates to finish before quitting')
|
2638
|
+
|
2639
|
+
def train_and_classify(self):
|
2640
|
+
"""
|
2641
|
+
1) Merge data from the relevant DB tables (including png_list).
|
2642
|
+
2) Collect manual annotations from png_list.<annotation_column> => 'manual_annotation'.
|
2643
|
+
- 1 => class=1, 2 => class=0 (for training).
|
2644
|
+
3) If only one class is present, randomly sample unannotated images as the other class.
|
2645
|
+
4) Train an XGBoost model.
|
2646
|
+
5) Classify *all* rows -> fill XGboost_score (prob of class=1) & XGboost_annotation (1 or 2 if high confidence).
|
2647
|
+
6) Write those columns back to sqlite, so every row in png_list has a score (and possibly an annotation).
|
2648
|
+
7) Refresh the UI (prefilter_paths_annotations + load_images).
|
2649
|
+
"""
|
2650
|
+
|
2651
|
+
# Optionally, update your GUI status label
|
2652
|
+
self.update_gui_text("Merging data...")
|
2653
|
+
|
2654
|
+
from spacr.io import _read_and_merge_data # Adapt to your actual import
|
2655
|
+
|
2656
|
+
# (1) Merge data
|
2657
|
+
merged_df, obj_df_ls = _read_and_merge_data(
|
2658
|
+
locs=[self.db_path],
|
2659
|
+
tables=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list'],
|
2660
|
+
verbose=False
|
2661
|
+
)
|
2662
|
+
|
2663
|
+
# (2) Load manual annotations from the DB
|
2664
|
+
conn = sqlite3.connect(self.db_path)
|
2665
|
+
c = conn.cursor()
|
2666
|
+
c.execute(f"SELECT png_path, {self.annotation_column} FROM png_list WHERE {self.annotation_column} IS NOT NULL")
|
2667
|
+
annotated_rows = c.fetchall() # e.g. [(png_path, 1 or 2), ...]
|
2668
|
+
conn.close()
|
2669
|
+
|
2670
|
+
# dict {png_path -> 1 or 2}
|
2671
|
+
annot_dict = dict(annotated_rows)
|
2672
|
+
|
2673
|
+
# Add 'manual_annotation' to merged_df
|
2674
|
+
merged_df['manual_annotation'] = merged_df['png_path'].map(annot_dict)
|
2675
|
+
|
2676
|
+
# Subset with manual labels
|
2677
|
+
annotated_df = merged_df.dropna(subset=['manual_annotation']).copy()
|
2678
|
+
# Convert "2" => "0" for binary classification
|
2679
|
+
annotated_df['manual_annotation'] = annotated_df['manual_annotation'].replace({2: 0}).astype(int)
|
2680
|
+
|
2681
|
+
# (3) Handle single-class scenario
|
2682
|
+
class_counts = annotated_df['manual_annotation'].value_counts()
|
2683
|
+
if len(class_counts) == 1:
|
2684
|
+
single_class = class_counts.index[0] # 0 or 1
|
2685
|
+
needed = class_counts.iloc[0]
|
2686
|
+
other_class = 1 if single_class == 0 else 0
|
2687
|
+
|
2688
|
+
unannotated_df_all = merged_df[merged_df['manual_annotation'].isna()].copy()
|
2689
|
+
if len(unannotated_df_all) == 0:
|
2690
|
+
print("No unannotated rows to sample for the other class. Cannot proceed.")
|
2691
|
+
self.update_gui_text("Not enough data to train (no second class).")
|
2692
|
+
return
|
2693
|
+
|
2694
|
+
sample_size = min(needed, len(unannotated_df_all))
|
2695
|
+
artificially_labeled = unannotated_df_all.sample(n=sample_size, replace=False).copy()
|
2696
|
+
artificially_labeled['manual_annotation'] = other_class
|
2697
|
+
|
2698
|
+
annotated_df = pd.concat([annotated_df, artificially_labeled], ignore_index=True)
|
2699
|
+
print(f"Only one class was present => randomly labeled {sample_size} unannotated rows as {other_class}.")
|
2700
|
+
|
2701
|
+
if len(annotated_df) < 2:
|
2702
|
+
print("Not enough annotated data to train (need at least 2).")
|
2703
|
+
self.update_gui_text("Not enough data to train.")
|
2704
|
+
return
|
2705
|
+
|
2706
|
+
# (4) Train XGBoost
|
2707
|
+
self.update_gui_text("Training XGBoost model...")
|
2708
|
+
|
2709
|
+
# Identify numeric columns
|
2710
|
+
ignore_cols = {'png_path', 'manual_annotation'}
|
2711
|
+
feature_cols = [
|
2712
|
+
col for col in annotated_df.columns
|
2713
|
+
if col not in ignore_cols
|
2714
|
+
and (annotated_df[col].dtype == float or annotated_df[col].dtype == int)
|
2715
|
+
]
|
2716
|
+
|
2717
|
+
X_data = annotated_df[feature_cols].fillna(0).values
|
2718
|
+
y_data = annotated_df['manual_annotation'].values
|
2719
|
+
|
2720
|
+
# standard train/test
|
2721
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
2722
|
+
X_data, y_data, test_size=0.1, random_state=42
|
2723
|
+
)
|
2724
|
+
model = XGBClassifier(use_label_encoder=False, eval_metric='logloss')
|
2725
|
+
model.fit(X_train, y_train)
|
2726
|
+
|
2727
|
+
# Evaluate
|
2728
|
+
preds = model.predict(X_test)
|
2729
|
+
print("=== Classification Report ===")
|
2730
|
+
print(classification_report(y_test, preds))
|
2731
|
+
print("=== Confusion Matrix ===")
|
2732
|
+
print(confusion_matrix(y_test, preds))
|
2733
|
+
|
2734
|
+
# (5) Classify ALL rows
|
2735
|
+
all_df = merged_df.copy()
|
2736
|
+
X_all = all_df[feature_cols].fillna(0).values
|
2737
|
+
probs_all = model.predict_proba(X_all)[:, 1]
|
2738
|
+
# Probability => XGboost_score
|
2739
|
+
all_df['XGboost_score'] = probs_all
|
2740
|
+
|
2741
|
+
# Decide XGboost_annotation
|
2742
|
+
def get_annotation_from_prob(prob):
|
2743
|
+
if prob > 0.9:
|
2744
|
+
return 1 # class=1
|
2745
|
+
elif prob < 0.1:
|
2746
|
+
return 0 # class=0
|
2747
|
+
return None # uncertain
|
2748
|
+
|
2749
|
+
xgb_anno_col = [get_annotation_from_prob(p) for p in probs_all]
|
2750
|
+
# Convert 0 => 2 if your DB uses "2" for the negative class
|
2751
|
+
xgb_anno_col = [2 if x == 0 else x for x in xgb_anno_col]
|
2752
|
+
|
2753
|
+
all_df['XGboost_annotation'] = xgb_anno_col
|
2754
|
+
|
2755
|
+
# (6) Write results back to png_list
|
2756
|
+
self.update_gui_text("Updating the database with XGBoost predictions...")
|
2757
|
+
conn = sqlite3.connect(self.db_path)
|
2758
|
+
c = conn.cursor()
|
2759
|
+
# Ensure columns exist
|
2760
|
+
try:
|
2761
|
+
c.execute("ALTER TABLE png_list ADD COLUMN XGboost_annotation INTEGER")
|
2762
|
+
except sqlite3.OperationalError:
|
2763
|
+
pass
|
2764
|
+
try:
|
2765
|
+
c.execute("ALTER TABLE png_list ADD COLUMN XGboost_score FLOAT")
|
2766
|
+
except sqlite3.OperationalError:
|
2767
|
+
pass
|
2768
|
+
|
2769
|
+
# Update each row
|
2770
|
+
for idx, row in all_df.iterrows():
|
2771
|
+
score_val = float(row['XGboost_score'])
|
2772
|
+
anno_val = row['XGboost_annotation']
|
2773
|
+
the_path = row['png_path']
|
2774
|
+
if pd.isna(the_path):
|
2775
|
+
continue # skip if no path
|
2776
|
+
|
2777
|
+
if pd.isna(anno_val):
|
2778
|
+
# We set annotation=NULL but do set the score
|
2779
|
+
c.execute("""
|
2780
|
+
UPDATE png_list
|
2781
|
+
SET XGboost_annotation = NULL,
|
2782
|
+
XGboost_score = ?
|
2783
|
+
WHERE png_path = ?
|
2784
|
+
""", (score_val, the_path))
|
2785
|
+
else:
|
2786
|
+
# numeric annotation + numeric score
|
2787
|
+
c.execute("""
|
2788
|
+
UPDATE png_list
|
2789
|
+
SET XGboost_annotation = ?,
|
2790
|
+
XGboost_score = ?
|
2791
|
+
WHERE png_path = ?
|
2792
|
+
""", (int(anno_val), score_val, the_path))
|
2793
|
+
|
2794
|
+
self.annotation_column = 'XGboost_annotation'
|
2606
2795
|
|
2607
2796
|
def standardize_figure(fig):
|
2608
2797
|
from .gui_elements import set_dark_style
|
spacr/gui_utils.py
CHANGED
@@ -482,7 +482,7 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
482
482
|
|
483
483
|
if settings_type == 'mask':
|
484
484
|
function = preprocess_generate_masks
|
485
|
-
imports =
|
485
|
+
imports = 1
|
486
486
|
elif settings_type == 'measure':
|
487
487
|
function = measure_crop
|
488
488
|
imports = 1
|
spacr/io.py
CHANGED
@@ -1773,7 +1773,7 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
|
|
1773
1773
|
print(e)
|
1774
1774
|
conn.close()
|
1775
1775
|
if 'png_list' in dataframes:
|
1776
|
-
png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row_name', 'column_name']].copy()
|
1776
|
+
png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row_name', 'column_name', 'field']].copy()
|
1777
1777
|
png_list_df['cell_id'] = png_list_df['cell_id'].str[1:].astype(int)
|
1778
1778
|
png_list_df.rename(columns={'cell_id': 'object_label'}, inplace=True)
|
1779
1779
|
if 'cell' in dataframes:
|
@@ -2275,175 +2275,6 @@ def _read_db(db_loc, tables):
|
|
2275
2275
|
dfs.append(df)
|
2276
2276
|
conn.close() # Close the connection
|
2277
2277
|
return dfs
|
2278
|
-
|
2279
|
-
def _read_and_merge_data_v1(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False):
|
2280
|
-
|
2281
|
-
from .utils import _split_data
|
2282
|
-
|
2283
|
-
#Extract plate DataFrames
|
2284
|
-
all_dfs = []
|
2285
|
-
for loc in locs:
|
2286
|
-
db_dfs = _read_db(loc, tables)
|
2287
|
-
all_dfs.append(db_dfs)
|
2288
|
-
|
2289
|
-
#Extract Tables from DataFrames and concatinate rows
|
2290
|
-
for i, dfs in enumerate(all_dfs):
|
2291
|
-
if 'cell' in tables:
|
2292
|
-
cell = dfs[0]
|
2293
|
-
if verbose:
|
2294
|
-
print(f'plate: {i+1} cells:{len(cell)}')
|
2295
|
-
# see pathogens logic, copy logic to other tables #here
|
2296
|
-
if 'nucleus' in tables:
|
2297
|
-
nucleus = dfs[1]
|
2298
|
-
if verbose:
|
2299
|
-
print(f'plate: {i+1} nucleus:{len(nucleus)} ')
|
2300
|
-
|
2301
|
-
if 'pathogen' in tables:
|
2302
|
-
if len(tables) == 1:
|
2303
|
-
pathogen = dfs[0]
|
2304
|
-
print(len(pathogen))
|
2305
|
-
else:
|
2306
|
-
pathogen = dfs[2]
|
2307
|
-
if verbose:
|
2308
|
-
print(f'plate: {i+1} pathogens:{len(pathogen)}')
|
2309
|
-
|
2310
|
-
if 'cytoplasm' in tables:
|
2311
|
-
if not 'pathogen' in tables:
|
2312
|
-
cytoplasm = dfs[2]
|
2313
|
-
else:
|
2314
|
-
cytoplasm = dfs[3]
|
2315
|
-
if verbose:
|
2316
|
-
print(f'plate: {i+1} cytoplasms: {len(cytoplasm)}')
|
2317
|
-
|
2318
|
-
if i > 0:
|
2319
|
-
if 'cell' in tables:
|
2320
|
-
cells = pd.concat([cells, cell], axis = 0)
|
2321
|
-
if 'nucleus' in tables:
|
2322
|
-
nucleus = pd.concat([nucleus, nucleus], axis = 0)
|
2323
|
-
if 'pathogen' in tables:
|
2324
|
-
pathogens = pd.concat([pathogens, pathogen], axis = 0)
|
2325
|
-
if 'cytoplasm' in tables:
|
2326
|
-
cytoplasms = pd.concat([cytoplasms, cytoplasm], axis = 0)
|
2327
|
-
else:
|
2328
|
-
if 'cell' in tables:
|
2329
|
-
cells = cell.copy()
|
2330
|
-
if 'nucleus' in tables:
|
2331
|
-
nucleus = nucleus.copy()
|
2332
|
-
if 'pathogen' in tables:
|
2333
|
-
pathogens = pathogen.copy()
|
2334
|
-
if 'cytoplasm' in tables:
|
2335
|
-
cytoplasms = cytoplasm.copy()
|
2336
|
-
|
2337
|
-
#Add an o in front of all object and cell lables to convert them to strings
|
2338
|
-
if 'cell' in tables:
|
2339
|
-
cells = cells.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
|
2340
|
-
cells = cells.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
|
2341
|
-
cells_g_df, metadata = _split_data(cells, 'prcfo', 'object_label')
|
2342
|
-
merged_df = cells_g_df.copy()
|
2343
|
-
if verbose:
|
2344
|
-
print(f'cells: {len(cells)}')
|
2345
|
-
print(f'cells grouped: {len(cells_g_df)}')
|
2346
|
-
|
2347
|
-
if 'cytoplasm' in tables:
|
2348
|
-
cytoplasms = cytoplasms.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
|
2349
|
-
cytoplasms = cytoplasms.assign(prcfo = lambda x: x['prcf'] + '_' + x['object_label'])
|
2350
|
-
cytoplasms_g_df, _ = _split_data(cytoplasms, 'prcfo', 'object_label')
|
2351
|
-
merged_df = cells_g_df.merge(cytoplasms_g_df, left_index=True, right_index=True)
|
2352
|
-
if verbose:
|
2353
|
-
print(f'cytoplasms: {len(cytoplasms)}')
|
2354
|
-
print(f'cytoplasms grouped: {len(cytoplasms_g_df)}')
|
2355
|
-
|
2356
|
-
if 'nucleus' in tables:
|
2357
|
-
if not 'cell' in tables:
|
2358
|
-
cells_g_df = pd.DataFrame()
|
2359
|
-
nucleus = nucleus.dropna(subset=['cell_id'])
|
2360
|
-
nucleus = nucleus.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
|
2361
|
-
nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2362
|
-
nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2363
|
-
nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
|
2364
|
-
if nuclei_limit == False:
|
2365
|
-
nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
|
2366
|
-
nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
|
2367
|
-
if verbose:
|
2368
|
-
print(f'nucleus: {len(nucleus)}')
|
2369
|
-
print(f'nucleus grouped: {len(nucleus_g_df)}')
|
2370
|
-
if 'cytoplasm' in tables:
|
2371
|
-
merged_df = merged_df.merge(nucleus_g_df, left_index=True, right_index=True)
|
2372
|
-
else:
|
2373
|
-
merged_df = cells_g_df.merge(nucleus_g_df, left_index=True, right_index=True)
|
2374
|
-
|
2375
|
-
if 'pathogen' in tables:
|
2376
|
-
if not 'cell' in tables:
|
2377
|
-
cells_g_df = pd.DataFrame()
|
2378
|
-
merged_df = []
|
2379
|
-
try:
|
2380
|
-
pathogens = pathogens.dropna(subset=['cell_id'])
|
2381
|
-
|
2382
|
-
except:
|
2383
|
-
pathogens['cell_id'] = pathogens['object_label']
|
2384
|
-
pathogens = pathogens.dropna(subset=['cell_id'])
|
2385
|
-
|
2386
|
-
pathogens = pathogens.assign(object_label=lambda x: 'o' + x['object_label'].astype(int).astype(str))
|
2387
|
-
pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2388
|
-
pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2389
|
-
pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
|
2390
|
-
|
2391
|
-
if isinstance(pathogen_limit, bool):
|
2392
|
-
if pathogen_limit == False:
|
2393
|
-
pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
|
2394
|
-
print(f"after multiinfected Bool: {len(pathogens)}")
|
2395
|
-
if isinstance(pathogen_limit, float):
|
2396
|
-
pathogen_limit = int(pathogen_limit)
|
2397
|
-
if isinstance(pathogen_limit, int):
|
2398
|
-
pathogens = pathogens[pathogens['pathogen_prcfo_count']<=pathogen_limit]
|
2399
|
-
print(f"afer multiinfected Float: {len(pathogens)}")
|
2400
|
-
if not 'cell' in tables:
|
2401
|
-
pathogens_g_df, metadata = _split_data(pathogens, 'prcfo', 'cell_id')
|
2402
|
-
else:
|
2403
|
-
pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
|
2404
|
-
|
2405
|
-
if verbose:
|
2406
|
-
print(f'pathogens: {len(pathogens)}')
|
2407
|
-
print(f'pathogens grouped: {len(pathogens_g_df)}')
|
2408
|
-
|
2409
|
-
if len(merged_df) == 0:
|
2410
|
-
merged_df = pathogens_g_df
|
2411
|
-
else:
|
2412
|
-
merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
|
2413
|
-
|
2414
|
-
#Add prc column (plate row column)
|
2415
|
-
metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name'])
|
2416
|
-
|
2417
|
-
#Count cells per well
|
2418
|
-
cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
|
2419
|
-
|
2420
|
-
cells_well.reset_index(inplace=True)
|
2421
|
-
cells_well.rename(columns={'object_label': 'cells_per_well'}, inplace=True)
|
2422
|
-
metadata = pd.merge(metadata, cells_well, on='prc', how='inner', suffixes=('', '_drop_col'))
|
2423
|
-
object_label_cols = [col for col in metadata.columns if '_drop_col' in col]
|
2424
|
-
metadata.drop(columns=object_label_cols, inplace=True)
|
2425
|
-
|
2426
|
-
#Add prcfo column (plate row column field object)
|
2427
|
-
metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name']+ '_' +x['field']+ '_' +x['object_label'])
|
2428
|
-
metadata.set_index('prcfo', inplace=True)
|
2429
|
-
|
2430
|
-
merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
|
2431
|
-
|
2432
|
-
merged_df = merged_df.dropna(axis=1)
|
2433
|
-
if verbose:
|
2434
|
-
print(f'Generated dataframe with: {len(merged_df.columns)} columns and {len(merged_df)} rows')
|
2435
|
-
|
2436
|
-
obj_df_ls = []
|
2437
|
-
if 'cell' in tables:
|
2438
|
-
obj_df_ls.append(cells)
|
2439
|
-
if 'cytoplasm' in tables:
|
2440
|
-
obj_df_ls.append(cytoplasms)
|
2441
|
-
if 'nucleus' in tables:
|
2442
|
-
obj_df_ls.append(nucleus)
|
2443
|
-
if 'pathogen' in tables:
|
2444
|
-
obj_df_ls.append(pathogens)
|
2445
|
-
|
2446
|
-
return merged_df, obj_df_ls
|
2447
2278
|
|
2448
2279
|
def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=10, pathogen_limit=10, change_plate=False):
|
2449
2280
|
from .io import _read_db
|
@@ -2929,6 +2760,7 @@ def generate_training_dataset(settings):
|
|
2929
2760
|
def get_smallest_class_size(df, settings, dataset_mode):
|
2930
2761
|
if dataset_mode == 'metadata':
|
2931
2762
|
sizes = [len(df[df['condition'] == c]) for c in settings['class_metadata']]
|
2763
|
+
#sizes = [len(df[df['condition'].isin(class_list)]) for class_list in settings['class_metadata']]
|
2932
2764
|
print(f'Class sizes: {sizes}')
|
2933
2765
|
elif dataset_mode == 'annotation':
|
2934
2766
|
sizes = [len(class_paths) for class_paths in df]
|
@@ -2997,16 +2829,12 @@ def generate_training_dataset(settings):
|
|
2997
2829
|
df = df.dropna(subset=['condition'])
|
2998
2830
|
|
2999
2831
|
display(df)
|
3000
|
-
|
3001
|
-
#df['metadata_based_class'] = pd.NA
|
3002
|
-
#for i, class_ in enumerate(settings['classes']):
|
3003
|
-
# ls = settings['class_metadata'][i]
|
3004
|
-
# df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
|
3005
2832
|
|
3006
2833
|
size = get_smallest_class_size(df, settings, 'metadata')
|
3007
2834
|
|
3008
2835
|
for class_ in settings['class_metadata']:
|
3009
2836
|
class_temp_df = df[df['condition'] == class_]
|
2837
|
+
#class_temp_df = df[df['condition'].isin(class_)]
|
3010
2838
|
print(f'Found {len(class_temp_df)} images for class {class_}')
|
3011
2839
|
class_paths_temp = class_temp_df['png_path'].tolist()
|
3012
2840
|
|
@@ -3033,6 +2861,8 @@ def generate_training_dataset(settings):
|
|
3033
2861
|
from .io import _read_and_merge_data, _read_db
|
3034
2862
|
from .utils import get_paths_from_db, annotate_conditions, save_settings
|
3035
2863
|
from .settings import set_generate_training_dataset_defaults
|
2864
|
+
|
2865
|
+
settings = set_generate_training_dataset_defaults(settings)
|
3036
2866
|
|
3037
2867
|
if 'nucleus' not in settings['tables']:
|
3038
2868
|
settings['nuclei_limit'] = False
|
@@ -3041,7 +2871,6 @@ def generate_training_dataset(settings):
|
|
3041
2871
|
settings['pathogen_limit'] = 0
|
3042
2872
|
|
3043
2873
|
# Set default settings and save
|
3044
|
-
settings = set_generate_training_dataset_defaults(settings)
|
3045
2874
|
save_settings(settings, 'cv_dataset', show=True)
|
3046
2875
|
|
3047
2876
|
class_path_list = None
|