spacr 0.1.11__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/app_classify.py ADDED
@@ -0,0 +1,206 @@
1
+ import sys, ctypes, matplotlib
2
+ import tkinter as tk
3
+ from tkinter import ttk, scrolledtext
4
+ from matplotlib.figure import Figure
5
+ from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
6
+ from matplotlib.figure import Figure
7
+ matplotlib.use('Agg')
8
+ from tkinter import filedialog
9
+ from multiprocessing import Process, Queue, Value
10
+ import traceback
11
+
12
+ try:
13
+ ctypes.windll.shcore.SetProcessDpiAwareness(True)
14
+ except AttributeError:
15
+ pass
16
+
17
+ from .logger import log_function_call
18
+ from .settings import set_default_train_test_model
19
+ from .gui_utils import ScrollableFrame, StdoutRedirector, CustomButton, set_dark_style, set_default_font, generate_fields, process_stdout_stderr, clear_canvas, main_thread_update_function, convert_settings_dict_for_gui
20
+ from .gui_utils import check_classify_gui_settings, train_test_model_wrapper, read_settings_from_csv, update_settings_from_csv, set_dark_style, create_menu_bar
21
+
22
+ thread_control = {"run_thread": None, "stop_requested": False}
23
+
24
+ #@log_function_call
25
+ def initiate_abort():
26
+ global thread_control
27
+ if thread_control.get("stop_requested") is not None:
28
+ thread_control["stop_requested"].value = 1
29
+
30
+ if thread_control.get("run_thread") is not None:
31
+ thread_control["run_thread"].join(timeout=5)
32
+ if thread_control["run_thread"].is_alive():
33
+ thread_control["run_thread"].terminate()
34
+ thread_control["run_thread"] = None
35
+
36
+ #@log_function_call
37
+ def run_classify_gui(q, fig_queue, stop_requested):
38
+ global vars_dict
39
+ process_stdout_stderr(q)
40
+ try:
41
+ settings = check_classify_gui_settings(vars_dict)
42
+ for key in settings:
43
+ value = settings[key]
44
+ print(key, value, type(value))
45
+ train_test_model_wrapper(settings['src'], settings)
46
+ except Exception as e:
47
+ q.put(f"Error during processing: {e}")
48
+ traceback.print_exc()
49
+ finally:
50
+ stop_requested.value = 1
51
+
52
+ #@log_function_call
53
+ def start_process(q, fig_queue):
54
+ global thread_control
55
+ if thread_control.get("run_thread") is not None:
56
+ initiate_abort()
57
+
58
+ stop_requested = Value('i', 0) # multiprocessing shared value for inter-process communication
59
+ thread_control["stop_requested"] = stop_requested
60
+ thread_control["run_thread"] = Process(target=run_classify_gui, args=(q, fig_queue, stop_requested))
61
+ thread_control["run_thread"].start()
62
+
63
+ def import_settings(scrollable_frame):
64
+ global vars_dict
65
+
66
+ csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
67
+ csv_settings = read_settings_from_csv(csv_file_path)
68
+ settings = set_default_train_test_model({})
69
+ variables = convert_settings_dict_for_gui(settings)
70
+ new_settings = update_settings_from_csv(variables, csv_settings)
71
+ vars_dict = generate_fields(new_settings, scrollable_frame)
72
+
73
+ #@log_function_call
74
+ def initiate_classify_root(parent_frame):
75
+ global vars_dict, q, canvas, fig_queue, canvas_widget, thread_control
76
+
77
+ style = ttk.Style(parent_frame)
78
+ set_dark_style(style)
79
+ set_default_font(parent_frame, font_name="Helvetica", size=8)
80
+
81
+ parent_frame.configure(bg='black')
82
+ parent_frame.grid_rowconfigure(0, weight=1)
83
+ parent_frame.grid_columnconfigure(0, weight=1)
84
+ fig_queue = Queue()
85
+
86
+ def _process_fig_queue():
87
+ global canvas
88
+ try:
89
+ while not fig_queue.empty():
90
+ clear_canvas(canvas)
91
+ fig = fig_queue.get_nowait()
92
+ for ax in fig.get_axes():
93
+ ax.set_xticks([]) # Remove x-axis ticks
94
+ ax.set_yticks([]) # Remove y-axis ticks
95
+ ax.xaxis.set_visible(False) # Hide the x-axis
96
+ ax.yaxis.set_visible(False) # Hide the y-axis
97
+ fig.tight_layout()
98
+ fig.set_facecolor('black')
99
+ canvas.figure = fig
100
+ fig_width, fig_height = canvas_widget.winfo_width(), canvas_widget.winfo_height()
101
+ fig.set_size_inches(fig_width / fig.dpi, fig_height / fig.dpi, forward=True)
102
+ canvas.draw_idle()
103
+ except Exception as e:
104
+ traceback.print_exc()
105
+ finally:
106
+ canvas_widget.after(100, _process_fig_queue)
107
+
108
+ def _process_console_queue():
109
+ while not q.empty():
110
+ message = q.get_nowait()
111
+ console_output.insert(tk.END, message)
112
+ console_output.see(tk.END)
113
+ console_output.after(100, _process_console_queue)
114
+
115
+ vertical_container = tk.PanedWindow(parent_frame, orient=tk.HORIZONTAL)
116
+ vertical_container.grid(row=0, column=0, sticky=tk.NSEW)
117
+ parent_frame.grid_rowconfigure(0, weight=1)
118
+ parent_frame.grid_columnconfigure(0, weight=1)
119
+
120
+ # Settings Section
121
+ settings_frame = tk.Frame(vertical_container, bg='black')
122
+ vertical_container.add(settings_frame, stretch="always")
123
+ settings_label = ttk.Label(settings_frame, text="Settings", background="black", foreground="white")
124
+ settings_label.grid(row=0, column=0, pady=10, padx=10)
125
+ scrollable_frame = ScrollableFrame(settings_frame, bg='black')
126
+ scrollable_frame.grid(row=1, column=0, sticky="nsew")
127
+ settings_frame.grid_rowconfigure(1, weight=1)
128
+ settings_frame.grid_columnconfigure(0, weight=1)
129
+
130
+ # Setup for user input fields (variables)
131
+ settings = set_default_train_test_model({})
132
+ variables = convert_settings_dict_for_gui(settings)
133
+ vars_dict = generate_fields(variables, scrollable_frame)
134
+
135
+ # Button section
136
+ btn_row = 1
137
+ run_button = CustomButton(scrollable_frame.scrollable_frame, text="Run", command=lambda: start_process(q, fig_queue), font=('Helvetica', 10))
138
+ run_button.grid(row=btn_row, column=0, pady=20, padx=20)
139
+ abort_button = CustomButton(scrollable_frame.scrollable_frame, text="Abort", command=initiate_abort, font=('Helvetica', 10))
140
+ abort_button.grid(row=btn_row, column=1, pady=20, padx=20)
141
+ btn_row += 1
142
+ import_btn = CustomButton(scrollable_frame.scrollable_frame, text="Import", command=lambda: import_settings(scrollable_frame), font=('Helvetica', 10))
143
+ import_btn.grid(row=btn_row, column=0, pady=20, padx=20)
144
+ btn_row += 1
145
+ progress_label = ttk.Label(scrollable_frame.scrollable_frame, text="Processing: 0%", background="black", foreground="white") # Create progress field
146
+ progress_label.grid(row=btn_row, column=0, columnspan=2, sticky="ew", pady=(5, 0), padx=10)
147
+
148
+ # Plot Canvas Section
149
+ plot_frame = tk.PanedWindow(vertical_container, orient=tk.VERTICAL)
150
+ vertical_container.add(plot_frame, stretch="always")
151
+ figure = Figure(figsize=(30, 4), dpi=100, facecolor='black')
152
+ plot = figure.add_subplot(111)
153
+ plot.plot([], [])
154
+ plot.axis('off')
155
+ canvas = FigureCanvasTkAgg(figure, master=plot_frame)
156
+ canvas.get_tk_widget().configure(cursor='arrow', background='black', highlightthickness=0)
157
+ canvas_widget = canvas.get_tk_widget()
158
+ plot_frame.add(canvas_widget, stretch="always")
159
+ canvas.draw()
160
+ canvas.figure = figure
161
+
162
+ # Console Section
163
+ console_frame = tk.Frame(vertical_container, bg='black')
164
+ vertical_container.add(console_frame, stretch="always")
165
+ console_label = ttk.Label(console_frame, text="Console", background="black", foreground="white")
166
+ console_label.grid(row=0, column=0, pady=10, padx=10)
167
+ console_output = scrolledtext.ScrolledText(console_frame, height=10, bg='black', fg='white', insertbackground='white')
168
+ console_output.grid(row=1, column=0, sticky="nsew")
169
+ console_frame.grid_rowconfigure(1, weight=1)
170
+ console_frame.grid_columnconfigure(0, weight=1)
171
+
172
+ q = Queue()
173
+ sys.stdout = StdoutRedirector(console_output)
174
+ sys.stderr = StdoutRedirector(console_output)
175
+
176
+ _process_console_queue()
177
+ _process_fig_queue()
178
+
179
+ parent_frame.after(100, lambda: main_thread_update_function(parent_frame, q, fig_queue, canvas_widget, progress_label))
180
+
181
+ return parent_frame, vars_dict
182
+
183
+ def gui_classify():
184
+ root = tk.Tk()
185
+ width = root.winfo_screenwidth()
186
+ height = root.winfo_screenheight()
187
+ root.geometry(f"{width}x{height}")
188
+ root.title("SpaCr: classify objects")
189
+
190
+ # Clear previous content if any
191
+ if hasattr(root, 'content_frame'):
192
+ for widget in root.content_frame.winfo_children():
193
+ widget.destroy()
194
+ root.content_frame.grid_forget()
195
+ else:
196
+ root.content_frame = tk.Frame(root)
197
+ root.content_frame.grid(row=1, column=0, sticky="nsew")
198
+ root.grid_rowconfigure(1, weight=1)
199
+ root.grid_columnconfigure(0, weight=1)
200
+
201
+ initiate_classify_root(root.content_frame)
202
+ create_menu_bar(root)
203
+ root.mainloop()
204
+
205
+ if __name__ == "__main__":
206
+ gui_classify()