spacr 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +18 -12
- spacr/annotate_app.py +258 -99
- spacr/annotate_app_v2.py +163 -4
- spacr/app_annotate.py +541 -0
- spacr/app_classify.py +8 -0
- spacr/app_make_masks.py +925 -0
- spacr/app_make_masks_v2.py +686 -0
- spacr/app_mask.py +8 -0
- spacr/app_measure.py +8 -0
- spacr/app_sequencing.py +8 -0
- spacr/app_umap.py +8 -0
- spacr/classify_app.py +201 -0
- spacr/core.py +30 -28
- spacr/deep_spacr.py +9 -7
- spacr/gui.py +50 -31
- spacr/gui_annotate.py +145 -0
- spacr/gui_classify_app.py +20 -6
- spacr/gui_core.py +608 -0
- spacr/gui_elements.py +324 -0
- spacr/gui_make_masks_app.py +927 -0
- spacr/gui_make_masks_app_v2.py +688 -0
- spacr/gui_mask_app.py +8 -4
- spacr/gui_measure_app.py +15 -5
- spacr/gui_run.py +58 -0
- spacr/gui_utils.py +80 -1026
- spacr/gui_wrappers.py +149 -0
- spacr/make_masks_app.py +929 -0
- spacr/make_masks_app_v2.py +688 -0
- spacr/mask_app.py +239 -915
- spacr/measure.py +35 -15
- spacr/measure_app.py +246 -0
- spacr/plot.py +53 -1
- spacr/sequencing.py +1 -17
- spacr/settings.py +502 -9
- spacr/sim_app.py +0 -0
- spacr/utils.py +73 -11
- {spacr-0.1.1.dist-info → spacr-0.1.7.dist-info}/METADATA +13 -22
- spacr-0.1.7.dist-info/RECORD +60 -0
- spacr-0.1.7.dist-info/entry_points.txt +8 -0
- spacr-0.1.1.dist-info/RECORD +0 -40
- spacr-0.1.1.dist-info/entry_points.txt +0 -9
- {spacr-0.1.1.dist-info → spacr-0.1.7.dist-info}/LICENSE +0 -0
- {spacr-0.1.1.dist-info → spacr-0.1.7.dist-info}/WHEEL +0 -0
- {spacr-0.1.1.dist-info → spacr-0.1.7.dist-info}/top_level.txt +0 -0
spacr/app_mask.py
ADDED
spacr/app_measure.py
ADDED
spacr/app_sequencing.py
ADDED
spacr/app_umap.py
ADDED
spacr/classify_app.py
ADDED
@@ -0,0 +1,201 @@
|
|
1
|
+
import sys, ctypes, matplotlib
|
2
|
+
import tkinter as tk
|
3
|
+
from tkinter import ttk, scrolledtext
|
4
|
+
from matplotlib.figure import Figure
|
5
|
+
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
6
|
+
from matplotlib.figure import Figure
|
7
|
+
matplotlib.use('Agg')
|
8
|
+
from tkinter import filedialog
|
9
|
+
from multiprocessing import Process, Queue, Value
|
10
|
+
import traceback
|
11
|
+
|
12
|
+
try:
|
13
|
+
ctypes.windll.shcore.SetProcessDpiAwareness(True)
|
14
|
+
except AttributeError:
|
15
|
+
pass
|
16
|
+
|
17
|
+
from .logger import log_function_call
|
18
|
+
from .gui_utils import ScrollableFrame, StdoutRedirector, CustomButton, set_dark_style, set_default_font, generate_fields, process_stdout_stderr, clear_canvas, main_thread_update_function
|
19
|
+
from .gui_utils import classify_variables, check_classify_gui_settings, train_test_model_wrapper, read_settings_from_csv, update_settings_from_csv, style_text_boxes, create_menu_bar
|
20
|
+
|
21
|
+
thread_control = {"run_thread": None, "stop_requested": False}
|
22
|
+
|
23
|
+
#@log_function_call
|
24
|
+
def initiate_abort():
|
25
|
+
global thread_control
|
26
|
+
if thread_control.get("stop_requested") is not None:
|
27
|
+
thread_control["stop_requested"].value = 1
|
28
|
+
|
29
|
+
if thread_control.get("run_thread") is not None:
|
30
|
+
thread_control["run_thread"].join(timeout=5)
|
31
|
+
if thread_control["run_thread"].is_alive():
|
32
|
+
thread_control["run_thread"].terminate()
|
33
|
+
thread_control["run_thread"] = None
|
34
|
+
|
35
|
+
#@log_function_call
|
36
|
+
def run_classify_gui(q, fig_queue, stop_requested):
|
37
|
+
global vars_dict
|
38
|
+
process_stdout_stderr(q)
|
39
|
+
try:
|
40
|
+
settings = check_classify_gui_settings(vars_dict)
|
41
|
+
for key in settings:
|
42
|
+
value = settings[key]
|
43
|
+
print(key, value, type(value))
|
44
|
+
train_test_model_wrapper(settings['src'], settings)
|
45
|
+
except Exception as e:
|
46
|
+
q.put(f"Error during processing: {e}")
|
47
|
+
traceback.print_exc()
|
48
|
+
finally:
|
49
|
+
stop_requested.value = 1
|
50
|
+
|
51
|
+
#@log_function_call
|
52
|
+
def start_process(q, fig_queue):
|
53
|
+
global thread_control
|
54
|
+
if thread_control.get("run_thread") is not None:
|
55
|
+
initiate_abort()
|
56
|
+
|
57
|
+
stop_requested = Value('i', 0) # multiprocessing shared value for inter-process communication
|
58
|
+
thread_control["stop_requested"] = stop_requested
|
59
|
+
thread_control["run_thread"] = Process(target=run_classify_gui, args=(q, fig_queue, stop_requested))
|
60
|
+
thread_control["run_thread"].start()
|
61
|
+
|
62
|
+
def import_settings(scrollable_frame):
|
63
|
+
global vars_dict
|
64
|
+
|
65
|
+
csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
|
66
|
+
csv_settings = read_settings_from_csv(csv_file_path)
|
67
|
+
variables = classify_variables()
|
68
|
+
new_settings = update_settings_from_csv(variables, csv_settings)
|
69
|
+
vars_dict = generate_fields(new_settings, scrollable_frame)
|
70
|
+
|
71
|
+
#@log_function_call
|
72
|
+
def initiate_classify_root(parent_frame):
|
73
|
+
global vars_dict, q, canvas, fig_queue, canvas_widget, thread_control
|
74
|
+
|
75
|
+
style = ttk.Style(parent_frame)
|
76
|
+
set_dark_style(style)
|
77
|
+
style_text_boxes(style)
|
78
|
+
set_default_font(parent_frame, font_name="Helvetica", size=8)
|
79
|
+
|
80
|
+
parent_frame.configure(bg='#333333')
|
81
|
+
parent_frame.grid_rowconfigure(0, weight=1)
|
82
|
+
parent_frame.grid_columnconfigure(0, weight=1)
|
83
|
+
fig_queue = Queue()
|
84
|
+
|
85
|
+
def _process_fig_queue():
|
86
|
+
global canvas
|
87
|
+
try:
|
88
|
+
while not fig_queue.empty():
|
89
|
+
clear_canvas(canvas)
|
90
|
+
fig = fig_queue.get_nowait()
|
91
|
+
for ax in fig.get_axes():
|
92
|
+
ax.set_xticks([]) # Remove x-axis ticks
|
93
|
+
ax.set_yticks([]) # Remove y-axis ticks
|
94
|
+
ax.xaxis.set_visible(False) # Hide the x-axis
|
95
|
+
ax.yaxis.set_visible(False) # Hide the y-axis
|
96
|
+
fig.tight_layout()
|
97
|
+
fig.set_facecolor('#333333')
|
98
|
+
canvas.figure = fig
|
99
|
+
fig_width, fig_height = canvas_widget.winfo_width(), canvas_widget.winfo_height()
|
100
|
+
fig.set_size_inches(fig_width / fig.dpi, fig_height / fig.dpi, forward=True)
|
101
|
+
canvas.draw_idle()
|
102
|
+
except Exception as e:
|
103
|
+
traceback.print_exc()
|
104
|
+
finally:
|
105
|
+
canvas_widget.after(100, _process_fig_queue)
|
106
|
+
|
107
|
+
def _process_console_queue():
|
108
|
+
while not q.empty():
|
109
|
+
message = q.get_nowait()
|
110
|
+
console_output.insert(tk.END, message)
|
111
|
+
console_output.see(tk.END)
|
112
|
+
console_output.after(100, _process_console_queue)
|
113
|
+
|
114
|
+
vertical_container = tk.PanedWindow(parent_frame, orient=tk.HORIZONTAL)
|
115
|
+
vertical_container.grid(row=0, column=0, sticky=tk.NSEW)
|
116
|
+
parent_frame.grid_rowconfigure(0, weight=1)
|
117
|
+
parent_frame.grid_columnconfigure(0, weight=1)
|
118
|
+
|
119
|
+
# Settings Section
|
120
|
+
settings_frame = tk.Frame(vertical_container, bg='#333333')
|
121
|
+
vertical_container.add(settings_frame, stretch="always")
|
122
|
+
settings_label = ttk.Label(settings_frame, text="Settings", background="#333333", foreground="white")
|
123
|
+
settings_label.grid(row=0, column=0, pady=10, padx=10)
|
124
|
+
scrollable_frame = ScrollableFrame(settings_frame, width=500)
|
125
|
+
scrollable_frame.grid(row=1, column=0, sticky="nsew")
|
126
|
+
settings_frame.grid_rowconfigure(1, weight=1)
|
127
|
+
settings_frame.grid_columnconfigure(0, weight=1)
|
128
|
+
|
129
|
+
# Setup for user input fields (variables)
|
130
|
+
variables = classify_variables()
|
131
|
+
vars_dict = generate_fields(variables, scrollable_frame)
|
132
|
+
|
133
|
+
# Button section
|
134
|
+
import_btn = CustomButton(scrollable_frame.scrollable_frame, text="Import", command=lambda: import_settings(scrollable_frame), font=('Helvetica', 10))
|
135
|
+
import_btn.grid(row=47, column=0, pady=20, padx=20)
|
136
|
+
run_button = CustomButton(scrollable_frame.scrollable_frame, text="Run", command=lambda: start_process(q, fig_queue), font=('Helvetica', 10))
|
137
|
+
run_button.grid(row=45, column=0, pady=20, padx=20)
|
138
|
+
abort_button = CustomButton(scrollable_frame.scrollable_frame, text="Abort", command=initiate_abort, font=('Helvetica', 10))
|
139
|
+
abort_button.grid(row=45, column=1, pady=20, padx=20)
|
140
|
+
progress_label = ttk.Label(scrollable_frame.scrollable_frame, text="Processing: 0%", background="black", foreground="white") # Create progress field
|
141
|
+
progress_label.grid(row=50, column=0, columnspan=2, sticky="ew", pady=(5, 0), padx=10)
|
142
|
+
|
143
|
+
# Plot Canvas Section
|
144
|
+
plot_frame = tk.PanedWindow(vertical_container, orient=tk.VERTICAL)
|
145
|
+
vertical_container.add(plot_frame, stretch="always")
|
146
|
+
figure = Figure(figsize=(30, 4), dpi=100, facecolor='#333333')
|
147
|
+
plot = figure.add_subplot(111)
|
148
|
+
plot.plot([], [])
|
149
|
+
plot.axis('off')
|
150
|
+
canvas = FigureCanvasTkAgg(figure, master=plot_frame)
|
151
|
+
canvas.get_tk_widget().configure(cursor='arrow', background='#333333', highlightthickness=0)
|
152
|
+
canvas_widget = canvas.get_tk_widget()
|
153
|
+
plot_frame.add(canvas_widget, stretch="always")
|
154
|
+
canvas.draw()
|
155
|
+
canvas.figure = figure
|
156
|
+
|
157
|
+
# Console Section
|
158
|
+
console_frame = tk.Frame(vertical_container, bg='#333333')
|
159
|
+
vertical_container.add(console_frame, stretch="always")
|
160
|
+
console_label = ttk.Label(console_frame, text="Console", background="#333333", foreground="white")
|
161
|
+
console_label.grid(row=0, column=0, pady=10, padx=10)
|
162
|
+
console_output = scrolledtext.ScrolledText(console_frame, height=10, bg='#333333', fg='white', insertbackground='white')
|
163
|
+
console_output.grid(row=1, column=0, sticky="nsew")
|
164
|
+
console_frame.grid_rowconfigure(1, weight=1)
|
165
|
+
console_frame.grid_columnconfigure(0, weight=1)
|
166
|
+
|
167
|
+
q = Queue()
|
168
|
+
sys.stdout = StdoutRedirector(console_output)
|
169
|
+
sys.stderr = StdoutRedirector(console_output)
|
170
|
+
|
171
|
+
_process_console_queue()
|
172
|
+
_process_fig_queue()
|
173
|
+
|
174
|
+
parent_frame.after(100, lambda: main_thread_update_function(parent_frame, q, fig_queue, canvas_widget, progress_label))
|
175
|
+
|
176
|
+
return parent_frame, vars_dict
|
177
|
+
|
178
|
+
def gui_classify():
|
179
|
+
root = tk.Tk()
|
180
|
+
width = root.winfo_screenwidth()
|
181
|
+
height = root.winfo_screenheight()
|
182
|
+
root.geometry(f"{width}x{height}")
|
183
|
+
root.title("SpaCr: classify objects")
|
184
|
+
|
185
|
+
# Clear previous content if any
|
186
|
+
if hasattr(root, 'content_frame'):
|
187
|
+
for widget in root.content_frame.winfo_children():
|
188
|
+
widget.destroy()
|
189
|
+
root.content_frame.grid_forget()
|
190
|
+
else:
|
191
|
+
root.content_frame = tk.Frame(root)
|
192
|
+
root.content_frame.grid(row=1, column=0, sticky="nsew")
|
193
|
+
root.grid_rowconfigure(1, weight=1)
|
194
|
+
root.grid_columnconfigure(0, weight=1)
|
195
|
+
|
196
|
+
initiate_classify_root(root.content_frame)
|
197
|
+
create_menu_bar(root)
|
198
|
+
root.mainloop()
|
199
|
+
|
200
|
+
if __name__ == "__main__":
|
201
|
+
gui_classify()
|
spacr/core.py
CHANGED
@@ -971,7 +971,7 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
|
|
971
971
|
shutil.rmtree(temp_dir)
|
972
972
|
print(f"\nSaved {total_images} images to {tar_name}")
|
973
973
|
|
974
|
-
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images',
|
974
|
+
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_job=10, threshold=0.5, verbose=False):
|
975
975
|
|
976
976
|
from .io import TarImageDataset
|
977
977
|
from .utils import process_vision_results
|
@@ -994,7 +994,7 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
|
|
994
994
|
model = torch.load(model_path)
|
995
995
|
|
996
996
|
dataset = TarImageDataset(tar_path, transform=transform)
|
997
|
-
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
997
|
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, n_job=n_job, pin_memory=True)
|
998
998
|
|
999
999
|
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
1000
1000
|
dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
|
@@ -1034,7 +1034,7 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
|
|
1034
1034
|
torch.cuda.memory.empty_cache()
|
1035
1035
|
return df
|
1036
1036
|
|
1037
|
-
def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
|
1037
|
+
def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_job=10):
|
1038
1038
|
|
1039
1039
|
from .io import NoClassDataset
|
1040
1040
|
|
@@ -1055,7 +1055,7 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
|
|
1055
1055
|
|
1056
1056
|
print(f'Loading dataset in {src} with {len(src)} images')
|
1057
1057
|
dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
|
1058
|
-
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
1058
|
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, n_job=n_job)
|
1059
1059
|
print(f'Loaded {len(src)} images')
|
1060
1060
|
|
1061
1061
|
result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
|
@@ -1302,7 +1302,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1302
1302
|
|
1303
1303
|
return
|
1304
1304
|
|
1305
|
-
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'],
|
1305
|
+
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_job=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
|
1306
1306
|
|
1307
1307
|
"""
|
1308
1308
|
Generate data loaders for training and validation/test datasets.
|
@@ -1314,7 +1314,7 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1314
1314
|
- image_size (int): The size of the input images.
|
1315
1315
|
- batch_size (int): The batch size for the data loaders.
|
1316
1316
|
- classes (list): The list of classes to consider.
|
1317
|
-
-
|
1317
|
+
- n_job (int): The number of worker threads for data loading.
|
1318
1318
|
- validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
|
1319
1319
|
- max_show (int): The maximum number of images to show when verbose is True.
|
1320
1320
|
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
@@ -1404,10 +1404,10 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1404
1404
|
#val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
|
1405
1405
|
print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
|
1406
1406
|
|
1407
|
-
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle,
|
1408
|
-
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle,
|
1407
|
+
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
|
1408
|
+
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
|
1409
1409
|
else:
|
1410
|
-
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle,
|
1410
|
+
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
|
1411
1411
|
|
1412
1412
|
elif train_mode == 'irm':
|
1413
1413
|
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
@@ -1436,13 +1436,13 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1436
1436
|
#val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
|
1437
1437
|
print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
|
1438
1438
|
|
1439
|
-
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle,
|
1440
|
-
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle,
|
1439
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
|
1440
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
|
1441
1441
|
|
1442
1442
|
train_loaders.append(train_loader)
|
1443
1443
|
val_loaders.append(val_loader)
|
1444
1444
|
else:
|
1445
|
-
train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle,
|
1445
|
+
train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
|
1446
1446
|
train_loaders.append(train_loader)
|
1447
1447
|
val_loaders.append(None)
|
1448
1448
|
|
@@ -1668,7 +1668,7 @@ def preprocess_generate_masks(src, settings={}):
|
|
1668
1668
|
|
1669
1669
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1670
1670
|
from .plot import plot_merged, plot_arrays
|
1671
|
-
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks
|
1671
|
+
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks
|
1672
1672
|
from .settings import set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings
|
1673
1673
|
|
1674
1674
|
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
@@ -2078,11 +2078,11 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
2078
2078
|
else:
|
2079
2079
|
radius = 100
|
2080
2080
|
|
2081
|
-
|
2082
|
-
if
|
2083
|
-
|
2081
|
+
n_job = os.cpu_count()-2
|
2082
|
+
if n_job < 1:
|
2083
|
+
n_job = 1
|
2084
2084
|
|
2085
|
-
mask_stack = _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, masks_3D=masks, mode=timelapse_mode, timelapse_remove_transient=timelapse_remove_transient, radius=radius,
|
2085
|
+
mask_stack = _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, masks_3D=masks, mode=timelapse_mode, timelapse_remove_transient=timelapse_remove_transient, radius=radius, n_job=n_job)
|
2086
2086
|
if timelapse_mode == 'trackpy':
|
2087
2087
|
mask_stack = _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelapse_displacement, timelapse_memory, timelapse_remove_transient, plot, save, timelapse_mode)
|
2088
2088
|
|
@@ -2303,9 +2303,9 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2303
2303
|
else:
|
2304
2304
|
radius = 100
|
2305
2305
|
|
2306
|
-
|
2307
|
-
if
|
2308
|
-
|
2306
|
+
n_job = os.cpu_count()-2
|
2307
|
+
if n_job < 1:
|
2308
|
+
n_job = 1
|
2309
2309
|
|
2310
2310
|
mask_stack = _btrack_track_cells(src=src,
|
2311
2311
|
name=name,
|
@@ -2317,7 +2317,7 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2317
2317
|
mode=timelapse_mode,
|
2318
2318
|
timelapse_remove_transient=timelapse_remove_transient,
|
2319
2319
|
radius=radius,
|
2320
|
-
|
2320
|
+
n_job=n_job)
|
2321
2321
|
if timelapse_mode == 'trackpy':
|
2322
2322
|
mask_stack = _trackpy_track_cells(src=src,
|
2323
2323
|
name=name,
|
@@ -2551,7 +2551,7 @@ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
|
|
2551
2551
|
common_files.intersection_update(os.listdir(d))
|
2552
2552
|
common_files = list(common_files)
|
2553
2553
|
|
2554
|
-
# Create a pool of
|
2554
|
+
# Create a pool of n_job
|
2555
2555
|
with Pool(processes=processes) as pool:
|
2556
2556
|
args = [(src, filename, dirs, conditions) for filename in common_files]
|
2557
2557
|
results = pool.map(compare_mask, args)
|
@@ -3021,9 +3021,9 @@ def generate_image_umap(settings={}):
|
|
3021
3021
|
"""
|
3022
3022
|
|
3023
3023
|
from .io import _read_and_join_tables
|
3024
|
-
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis
|
3025
|
-
from .settings import
|
3026
|
-
settings =
|
3024
|
+
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis #, generate_umap_from_images
|
3025
|
+
from .settings import set_default_umap_image_settings
|
3026
|
+
settings = set_default_umap_image_settings(settings)
|
3027
3027
|
|
3028
3028
|
if isinstance(settings['src'], str):
|
3029
3029
|
settings['src'] = [settings['src']]
|
@@ -3109,7 +3109,9 @@ def generate_image_umap(settings={}):
|
|
3109
3109
|
|
3110
3110
|
else:
|
3111
3111
|
if settings['resnet_features']:
|
3112
|
-
|
3112
|
+
# placeholder for resnet features, not implemented yet
|
3113
|
+
pass
|
3114
|
+
#numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
|
3113
3115
|
else:
|
3114
3116
|
# Apply the trained reducer to the entire dataset
|
3115
3117
|
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
@@ -3205,9 +3207,9 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
|
|
3205
3207
|
|
3206
3208
|
from .io import _read_and_join_tables
|
3207
3209
|
from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors
|
3208
|
-
from .settings import
|
3210
|
+
from .settings import set_default_umap_image_settings
|
3209
3211
|
|
3210
|
-
settings =
|
3212
|
+
settings = set_default_umap_image_settings(settings)
|
3211
3213
|
pointsize = settings['dot_size']
|
3212
3214
|
if isinstance(dbscan_params, dict):
|
3213
3215
|
dbscan_params = [dbscan_params]
|
spacr/deep_spacr.py
CHANGED
@@ -10,6 +10,9 @@ import matplotlib.pyplot as plt
|
|
10
10
|
from PIL import Image
|
11
11
|
|
12
12
|
from .logger import log_function_call
|
13
|
+
from .utils import close_multiprocessing_processes, reset_mp
|
14
|
+
#reset_mp()
|
15
|
+
#close_multiprocessing_processes()
|
13
16
|
|
14
17
|
def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
|
15
18
|
"""
|
@@ -42,7 +45,6 @@ def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
|
|
42
45
|
for batch_idx, (data, target, _) in enumerate(loader, start=1):
|
43
46
|
start_time = time.time()
|
44
47
|
data, target = data.to(device), target.to(device).float()
|
45
|
-
#data, target = data.to(torch.float).to(device), target.to(device).float()
|
46
48
|
output = model(data)
|
47
49
|
loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
|
48
50
|
loss = calculate_loss(output, target, loss_type=loss_type)
|
@@ -228,7 +230,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
228
230
|
image_size=settings['image_size'],
|
229
231
|
batch_size=settings['batch_size'],
|
230
232
|
classes=settings['classes'],
|
231
|
-
|
233
|
+
n_job=settings['n_job'],
|
232
234
|
validation_split=settings['val_split'],
|
233
235
|
pin_memory=settings['pin_memory'],
|
234
236
|
normalize=settings['normalize'],
|
@@ -253,7 +255,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
253
255
|
optimizer_type = settings['optimizer_type'],
|
254
256
|
use_checkpoint = settings['use_checkpoint'],
|
255
257
|
dropout_rate = settings['dropout_rate'],
|
256
|
-
|
258
|
+
n_job = settings['n_job'],
|
257
259
|
val_loaders = val,
|
258
260
|
test_loaders = None,
|
259
261
|
intermedeate_save = settings['intermedeate_save'],
|
@@ -274,7 +276,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
274
276
|
image_size=settings['image_size'],
|
275
277
|
batch_size=settings['batch_size'],
|
276
278
|
classes=settings['classes'],
|
277
|
-
|
279
|
+
n_job=settings['n_job'],
|
278
280
|
validation_split=0.0,
|
279
281
|
pin_memory=settings['pin_memory'],
|
280
282
|
normalize=settings['normalize'],
|
@@ -313,7 +315,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
313
315
|
torch.cuda.memory.empty_cache()
|
314
316
|
gc.collect()
|
315
317
|
|
316
|
-
def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0,
|
318
|
+
def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_job=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
|
317
319
|
"""
|
318
320
|
Trains a model using the specified parameters.
|
319
321
|
|
@@ -330,7 +332,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
330
332
|
optimizer_type (str, optional): The type of optimizer to use. Defaults to 'adamw'.
|
331
333
|
use_checkpoint (bool, optional): Whether to use checkpointing during training. Defaults to False.
|
332
334
|
dropout_rate (float, optional): The dropout rate for the model. Defaults to 0.
|
333
|
-
|
335
|
+
n_job (int, optional): The number of n_job for data loading. Defaults to 20.
|
334
336
|
val_loaders (list, optional): A list of validation data loaders. Defaults to None.
|
335
337
|
test_loaders (list, optional): A list of test data loaders. Defaults to None.
|
336
338
|
init_weights (str, optional): The initialization weights for the model. Defaults to 'imagenet'.
|
@@ -355,7 +357,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
355
357
|
|
356
358
|
use_cuda = torch.cuda.is_available()
|
357
359
|
device = torch.device("cuda" if use_cuda else "cpu")
|
358
|
-
kwargs = {'
|
360
|
+
kwargs = {'n_job': n_job, 'pin_memory': True} if use_cuda else {}
|
359
361
|
|
360
362
|
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
361
363
|
batch, chans, height, width = images.shape
|
spacr/gui.py
CHANGED
@@ -1,52 +1,67 @@
|
|
1
1
|
import tkinter as tk
|
2
2
|
from tkinter import ttk
|
3
|
-
from tkinter import font as tkFont
|
4
3
|
from PIL import Image, ImageTk
|
5
|
-
import os
|
6
|
-
import
|
7
|
-
|
8
|
-
|
9
|
-
from .
|
10
|
-
from .
|
11
|
-
from .annotate_app import initiate_annotation_app_root
|
12
|
-
from .mask_app import initiate_mask_app_root
|
13
|
-
from .gui_classify_app import initiate_classify_root
|
14
|
-
|
15
|
-
from .gui_utils import CustomButton, style_text_boxes
|
4
|
+
import os, requests
|
5
|
+
from multiprocessing import set_start_method
|
6
|
+
from .gui_elements import spacrButton, create_menu_bar, set_dark_style
|
7
|
+
from .gui_core import initiate_root
|
8
|
+
from .app_annotate import initiate_annotation_app_root
|
9
|
+
from .app_make_masks import initiate_mask_app_root
|
16
10
|
|
17
11
|
class MainApp(tk.Tk):
|
18
|
-
def __init__(self):
|
12
|
+
def __init__(self, default_app=None):
|
19
13
|
super().__init__()
|
14
|
+
width = self.winfo_screenwidth()
|
15
|
+
height = self.winfo_screenheight()
|
16
|
+
self.geometry(f"{width}x{height}")
|
20
17
|
self.title("SpaCr GUI Collection")
|
21
|
-
self.geometry("1100x1500")
|
22
18
|
self.configure(bg="black")
|
23
|
-
#self.attributes('-fullscreen', True)
|
24
|
-
|
25
19
|
style = ttk.Style()
|
26
|
-
|
20
|
+
set_dark_style(style)
|
27
21
|
|
28
22
|
self.gui_apps = {
|
29
|
-
"Mask": (
|
30
|
-
"Measure": (
|
23
|
+
"Mask": (lambda frame: initiate_root(frame, 'mask'), "Generate cellpose masks for cells, nuclei and pathogen images."),
|
24
|
+
"Measure": (lambda frame: initiate_root(frame, 'measure'), "Measure single object intensity and morphological feature. Crop and save single object image"),
|
31
25
|
"Annotate": (initiate_annotation_app_root, "Annotation single object images on a grid. Annotations are saved to database."),
|
32
26
|
"Make Masks": (initiate_mask_app_root, "Adjust pre-existing Cellpose models to your specific dataset for improved performance"),
|
33
|
-
"Classify": (
|
27
|
+
"Classify": (lambda frame: initiate_root(frame, 'classify'), "Train Torch Convolutional Neural Networks (CNNs) or Transformers to classify single object images."),
|
28
|
+
"Sequencing": (lambda frame: initiate_root(frame, 'sequencing'), "Analyze sequensing data."),
|
29
|
+
"Umap": (lambda frame: initiate_root(frame, 'umap'), "Generate UMAP embedings with datapoints represented as images.")
|
34
30
|
}
|
35
31
|
|
36
32
|
self.selected_app = tk.StringVar()
|
37
33
|
self.create_widgets()
|
38
34
|
|
35
|
+
|
36
|
+
if default_app == "Mask":
|
37
|
+
self.load_app(default_app, self.gui_apps[default_app][0])
|
38
|
+
elif default_app == "Measure":
|
39
|
+
self.load_app(default_app, self.gui_apps[default_app][1])
|
40
|
+
elif default_app == "Annotate":
|
41
|
+
self.load_app(default_app, self.gui_apps[default_app][2])
|
42
|
+
elif default_app == "Make Masks":
|
43
|
+
self.load_app(default_app, self.gui_apps[default_app][3])
|
44
|
+
elif default_app == "Classify":
|
45
|
+
self.load_app(default_app, self.gui_apps[default_app][4])
|
46
|
+
elif default_app == "Sequencing":
|
47
|
+
self.load_app(default_app, self.gui_apps[default_app][5])
|
48
|
+
elif default_app == "Umap":
|
49
|
+
self.load_app(default_app, self.gui_apps[default_app][6])
|
50
|
+
|
39
51
|
def create_widgets(self):
|
40
52
|
# Create the menu bar
|
41
|
-
|
53
|
+
create_menu_bar(self)
|
54
|
+
|
42
55
|
# Create a canvas to hold the selected app and other elements
|
43
|
-
self.canvas = tk.Canvas(self, bg="black", highlightthickness=0
|
56
|
+
self.canvas = tk.Canvas(self, bg="black", highlightthickness=0)
|
44
57
|
self.canvas.grid(row=0, column=0, sticky="nsew")
|
45
58
|
self.grid_rowconfigure(0, weight=1)
|
46
59
|
self.grid_columnconfigure(0, weight=1)
|
60
|
+
|
47
61
|
# Create a frame inside the canvas to hold the main content
|
48
62
|
self.content_frame = tk.Frame(self.canvas, bg="black")
|
49
63
|
self.content_frame.pack(fill=tk.BOTH, expand=True)
|
64
|
+
|
50
65
|
# Create startup screen with buttons for each GUI app
|
51
66
|
self.create_startup_screen()
|
52
67
|
|
@@ -59,10 +74,10 @@ class MainApp(tk.Tk):
|
|
59
74
|
|
60
75
|
# Load the logo image
|
61
76
|
if not self.load_logo(logo_frame):
|
62
|
-
tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Helvetica', 24
|
77
|
+
tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Helvetica', 24)).pack(padx=10, pady=10)
|
63
78
|
|
64
79
|
# Add SpaCr text below the logo with padding for sharper text
|
65
|
-
tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Helvetica', 24
|
80
|
+
tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Helvetica', 24)).pack(padx=10, pady=10)
|
66
81
|
|
67
82
|
# Create a frame for the buttons and descriptions
|
68
83
|
buttons_frame = tk.Frame(self.content_frame, bg="black")
|
@@ -72,10 +87,11 @@ class MainApp(tk.Tk):
|
|
72
87
|
app_func, app_desc = app_data
|
73
88
|
|
74
89
|
# Create custom button with text
|
75
|
-
button =
|
90
|
+
button = spacrButton(buttons_frame, text=app_name, command=lambda app_name=app_name, app_func=app_func: self.load_app(app_name, app_func), font=('Helvetica', 12))
|
91
|
+
#button = ttk.Button(buttons_frame, text=app_name, command=lambda app_name=app_name, app_func=app_func: self.load_app(app_name, app_func), style='Custom.TButton')
|
76
92
|
button.grid(row=i, column=0, pady=10, padx=10, sticky="w")
|
77
93
|
|
78
|
-
description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Helvetica',
|
94
|
+
description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Helvetica', 12))
|
79
95
|
description_label.grid(row=i, column=1, pady=10, padx=10, sticky="w")
|
80
96
|
|
81
97
|
# Ensure buttons have a fixed width
|
@@ -98,7 +114,6 @@ class MainApp(tk.Tk):
|
|
98
114
|
|
99
115
|
try:
|
100
116
|
img_path = os.path.join(os.path.dirname(__file__), 'logo_spacr.png')
|
101
|
-
print(f"Trying to load logo from {img_path}")
|
102
117
|
logo_image = Image.open(img_path)
|
103
118
|
except (FileNotFoundError, Image.UnidentifiedImageError):
|
104
119
|
print(f"File {img_path} not found or is not a valid image. Attempting to download from GitHub.")
|
@@ -115,7 +130,9 @@ class MainApp(tk.Tk):
|
|
115
130
|
print(f"An error occurred while loading the logo: {e}")
|
116
131
|
return False
|
117
132
|
try:
|
118
|
-
|
133
|
+
screen_height = frame.winfo_screenheight()
|
134
|
+
new_height = int(screen_height // 4)
|
135
|
+
logo_image = logo_image.resize((new_height, new_height), Image.Resampling.LANCZOS)
|
119
136
|
logo_photo = ImageTk.PhotoImage(logo_image)
|
120
137
|
logo_label = tk.Label(frame, image=logo_photo, bg="black")
|
121
138
|
logo_label.image = logo_photo # Keep a reference to avoid garbage collection
|
@@ -125,13 +142,14 @@ class MainApp(tk.Tk):
|
|
125
142
|
print(f"An error occurred while processing the logo image: {e}")
|
126
143
|
return False
|
127
144
|
|
128
|
-
def load_app(self, app_name):
|
129
|
-
|
145
|
+
def load_app(self, app_name, app_func):
|
146
|
+
# Clear the current content frame
|
130
147
|
self.clear_frame(self.content_frame)
|
131
148
|
|
149
|
+
# Initialize the selected app
|
132
150
|
app_frame = tk.Frame(self.content_frame, bg="black")
|
133
151
|
app_frame.pack(fill=tk.BOTH, expand=True)
|
134
|
-
|
152
|
+
app_func(app_frame)
|
135
153
|
|
136
154
|
def clear_frame(self, frame):
|
137
155
|
for widget in frame.winfo_children():
|
@@ -142,4 +160,5 @@ def gui_app():
|
|
142
160
|
app.mainloop()
|
143
161
|
|
144
162
|
if __name__ == "__main__":
|
163
|
+
set_start_method('spawn', force=True)
|
145
164
|
gui_app()
|