spacr 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/app_mask.py ADDED
@@ -0,0 +1,8 @@
1
+ from .gui import MainApp
2
+
3
+ def start_mask_app():
4
+ app = MainApp(default_app="Mask")
5
+ app.mainloop()
6
+
7
+ if __name__ == "__main__":
8
+ start_mask_app()
spacr/app_measure.py ADDED
@@ -0,0 +1,8 @@
1
+ from .gui import MainApp
2
+
3
+ def start_measure_app():
4
+ app = MainApp(default_app="Measure")
5
+ app.mainloop()
6
+
7
+ if __name__ == "__main__":
8
+ start_measure_app()
@@ -0,0 +1,8 @@
1
+ from .gui import MainApp
2
+
3
+ def start_seq_app():
4
+ app = MainApp(default_app="Sequencing")
5
+ app.mainloop()
6
+
7
+ if __name__ == "__main__":
8
+ start_seq_app()
spacr/app_umap.py ADDED
@@ -0,0 +1,8 @@
1
+ from .gui import MainApp
2
+
3
+ def start_umap_app():
4
+ app = MainApp(default_app="Umap")
5
+ app.mainloop()
6
+
7
+ if __name__ == "__main__":
8
+ start_umap_app()
spacr/classify_app.py ADDED
@@ -0,0 +1,201 @@
1
+ import sys, ctypes, matplotlib
2
+ import tkinter as tk
3
+ from tkinter import ttk, scrolledtext
4
+ from matplotlib.figure import Figure
5
+ from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
6
+ from matplotlib.figure import Figure
7
+ matplotlib.use('Agg')
8
+ from tkinter import filedialog
9
+ from multiprocessing import Process, Queue, Value
10
+ import traceback
11
+
12
+ try:
13
+ ctypes.windll.shcore.SetProcessDpiAwareness(True)
14
+ except AttributeError:
15
+ pass
16
+
17
+ from .logger import log_function_call
18
+ from .gui_utils import ScrollableFrame, StdoutRedirector, CustomButton, set_dark_style, set_default_font, generate_fields, process_stdout_stderr, clear_canvas, main_thread_update_function
19
+ from .gui_utils import classify_variables, check_classify_gui_settings, train_test_model_wrapper, read_settings_from_csv, update_settings_from_csv, style_text_boxes, create_menu_bar
20
+
21
+ thread_control = {"run_thread": None, "stop_requested": False}
22
+
23
+ #@log_function_call
24
+ def initiate_abort():
25
+ global thread_control
26
+ if thread_control.get("stop_requested") is not None:
27
+ thread_control["stop_requested"].value = 1
28
+
29
+ if thread_control.get("run_thread") is not None:
30
+ thread_control["run_thread"].join(timeout=5)
31
+ if thread_control["run_thread"].is_alive():
32
+ thread_control["run_thread"].terminate()
33
+ thread_control["run_thread"] = None
34
+
35
+ #@log_function_call
36
+ def run_classify_gui(q, fig_queue, stop_requested):
37
+ global vars_dict
38
+ process_stdout_stderr(q)
39
+ try:
40
+ settings = check_classify_gui_settings(vars_dict)
41
+ for key in settings:
42
+ value = settings[key]
43
+ print(key, value, type(value))
44
+ train_test_model_wrapper(settings['src'], settings)
45
+ except Exception as e:
46
+ q.put(f"Error during processing: {e}")
47
+ traceback.print_exc()
48
+ finally:
49
+ stop_requested.value = 1
50
+
51
+ #@log_function_call
52
+ def start_process(q, fig_queue):
53
+ global thread_control
54
+ if thread_control.get("run_thread") is not None:
55
+ initiate_abort()
56
+
57
+ stop_requested = Value('i', 0) # multiprocessing shared value for inter-process communication
58
+ thread_control["stop_requested"] = stop_requested
59
+ thread_control["run_thread"] = Process(target=run_classify_gui, args=(q, fig_queue, stop_requested))
60
+ thread_control["run_thread"].start()
61
+
62
+ def import_settings(scrollable_frame):
63
+ global vars_dict
64
+
65
+ csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
66
+ csv_settings = read_settings_from_csv(csv_file_path)
67
+ variables = classify_variables()
68
+ new_settings = update_settings_from_csv(variables, csv_settings)
69
+ vars_dict = generate_fields(new_settings, scrollable_frame)
70
+
71
+ #@log_function_call
72
+ def initiate_classify_root(parent_frame):
73
+ global vars_dict, q, canvas, fig_queue, canvas_widget, thread_control
74
+
75
+ style = ttk.Style(parent_frame)
76
+ set_dark_style(style)
77
+ style_text_boxes(style)
78
+ set_default_font(parent_frame, font_name="Helvetica", size=8)
79
+
80
+ parent_frame.configure(bg='#333333')
81
+ parent_frame.grid_rowconfigure(0, weight=1)
82
+ parent_frame.grid_columnconfigure(0, weight=1)
83
+ fig_queue = Queue()
84
+
85
+ def _process_fig_queue():
86
+ global canvas
87
+ try:
88
+ while not fig_queue.empty():
89
+ clear_canvas(canvas)
90
+ fig = fig_queue.get_nowait()
91
+ for ax in fig.get_axes():
92
+ ax.set_xticks([]) # Remove x-axis ticks
93
+ ax.set_yticks([]) # Remove y-axis ticks
94
+ ax.xaxis.set_visible(False) # Hide the x-axis
95
+ ax.yaxis.set_visible(False) # Hide the y-axis
96
+ fig.tight_layout()
97
+ fig.set_facecolor('#333333')
98
+ canvas.figure = fig
99
+ fig_width, fig_height = canvas_widget.winfo_width(), canvas_widget.winfo_height()
100
+ fig.set_size_inches(fig_width / fig.dpi, fig_height / fig.dpi, forward=True)
101
+ canvas.draw_idle()
102
+ except Exception as e:
103
+ traceback.print_exc()
104
+ finally:
105
+ canvas_widget.after(100, _process_fig_queue)
106
+
107
+ def _process_console_queue():
108
+ while not q.empty():
109
+ message = q.get_nowait()
110
+ console_output.insert(tk.END, message)
111
+ console_output.see(tk.END)
112
+ console_output.after(100, _process_console_queue)
113
+
114
+ vertical_container = tk.PanedWindow(parent_frame, orient=tk.HORIZONTAL)
115
+ vertical_container.grid(row=0, column=0, sticky=tk.NSEW)
116
+ parent_frame.grid_rowconfigure(0, weight=1)
117
+ parent_frame.grid_columnconfigure(0, weight=1)
118
+
119
+ # Settings Section
120
+ settings_frame = tk.Frame(vertical_container, bg='#333333')
121
+ vertical_container.add(settings_frame, stretch="always")
122
+ settings_label = ttk.Label(settings_frame, text="Settings", background="#333333", foreground="white")
123
+ settings_label.grid(row=0, column=0, pady=10, padx=10)
124
+ scrollable_frame = ScrollableFrame(settings_frame, width=500)
125
+ scrollable_frame.grid(row=1, column=0, sticky="nsew")
126
+ settings_frame.grid_rowconfigure(1, weight=1)
127
+ settings_frame.grid_columnconfigure(0, weight=1)
128
+
129
+ # Setup for user input fields (variables)
130
+ variables = classify_variables()
131
+ vars_dict = generate_fields(variables, scrollable_frame)
132
+
133
+ # Button section
134
+ import_btn = CustomButton(scrollable_frame.scrollable_frame, text="Import", command=lambda: import_settings(scrollable_frame), font=('Helvetica', 10))
135
+ import_btn.grid(row=47, column=0, pady=20, padx=20)
136
+ run_button = CustomButton(scrollable_frame.scrollable_frame, text="Run", command=lambda: start_process(q, fig_queue), font=('Helvetica', 10))
137
+ run_button.grid(row=45, column=0, pady=20, padx=20)
138
+ abort_button = CustomButton(scrollable_frame.scrollable_frame, text="Abort", command=initiate_abort, font=('Helvetica', 10))
139
+ abort_button.grid(row=45, column=1, pady=20, padx=20)
140
+ progress_label = ttk.Label(scrollable_frame.scrollable_frame, text="Processing: 0%", background="black", foreground="white") # Create progress field
141
+ progress_label.grid(row=50, column=0, columnspan=2, sticky="ew", pady=(5, 0), padx=10)
142
+
143
+ # Plot Canvas Section
144
+ plot_frame = tk.PanedWindow(vertical_container, orient=tk.VERTICAL)
145
+ vertical_container.add(plot_frame, stretch="always")
146
+ figure = Figure(figsize=(30, 4), dpi=100, facecolor='#333333')
147
+ plot = figure.add_subplot(111)
148
+ plot.plot([], [])
149
+ plot.axis('off')
150
+ canvas = FigureCanvasTkAgg(figure, master=plot_frame)
151
+ canvas.get_tk_widget().configure(cursor='arrow', background='#333333', highlightthickness=0)
152
+ canvas_widget = canvas.get_tk_widget()
153
+ plot_frame.add(canvas_widget, stretch="always")
154
+ canvas.draw()
155
+ canvas.figure = figure
156
+
157
+ # Console Section
158
+ console_frame = tk.Frame(vertical_container, bg='#333333')
159
+ vertical_container.add(console_frame, stretch="always")
160
+ console_label = ttk.Label(console_frame, text="Console", background="#333333", foreground="white")
161
+ console_label.grid(row=0, column=0, pady=10, padx=10)
162
+ console_output = scrolledtext.ScrolledText(console_frame, height=10, bg='#333333', fg='white', insertbackground='white')
163
+ console_output.grid(row=1, column=0, sticky="nsew")
164
+ console_frame.grid_rowconfigure(1, weight=1)
165
+ console_frame.grid_columnconfigure(0, weight=1)
166
+
167
+ q = Queue()
168
+ sys.stdout = StdoutRedirector(console_output)
169
+ sys.stderr = StdoutRedirector(console_output)
170
+
171
+ _process_console_queue()
172
+ _process_fig_queue()
173
+
174
+ parent_frame.after(100, lambda: main_thread_update_function(parent_frame, q, fig_queue, canvas_widget, progress_label))
175
+
176
+ return parent_frame, vars_dict
177
+
178
+ def gui_classify():
179
+ root = tk.Tk()
180
+ width = root.winfo_screenwidth()
181
+ height = root.winfo_screenheight()
182
+ root.geometry(f"{width}x{height}")
183
+ root.title("SpaCr: classify objects")
184
+
185
+ # Clear previous content if any
186
+ if hasattr(root, 'content_frame'):
187
+ for widget in root.content_frame.winfo_children():
188
+ widget.destroy()
189
+ root.content_frame.grid_forget()
190
+ else:
191
+ root.content_frame = tk.Frame(root)
192
+ root.content_frame.grid(row=1, column=0, sticky="nsew")
193
+ root.grid_rowconfigure(1, weight=1)
194
+ root.grid_columnconfigure(0, weight=1)
195
+
196
+ initiate_classify_root(root.content_frame)
197
+ create_menu_bar(root)
198
+ root.mainloop()
199
+
200
+ if __name__ == "__main__":
201
+ gui_classify()
spacr/core.py CHANGED
@@ -971,7 +971,7 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
971
971
  shutil.rmtree(temp_dir)
972
972
  print(f"\nSaved {total_images} images to {tar_name}")
973
973
 
974
- def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, threshold=0.5, verbose=False):
974
+ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_job=10, threshold=0.5, verbose=False):
975
975
 
976
976
  from .io import TarImageDataset
977
977
  from .utils import process_vision_results
@@ -994,7 +994,7 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
994
994
  model = torch.load(model_path)
995
995
 
996
996
  dataset = TarImageDataset(tar_path, transform=transform)
997
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
997
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, n_job=n_job, pin_memory=True)
998
998
 
999
999
  model_name = os.path.splitext(os.path.basename(model_path))[0]
1000
1000
  dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
@@ -1034,7 +1034,7 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
1034
1034
  torch.cuda.memory.empty_cache()
1035
1035
  return df
1036
1036
 
1037
- def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, num_workers=10):
1037
+ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_job=10):
1038
1038
 
1039
1039
  from .io import NoClassDataset
1040
1040
 
@@ -1055,7 +1055,7 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
1055
1055
 
1056
1056
  print(f'Loading dataset in {src} with {len(src)} images')
1057
1057
  dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
1058
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
1058
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, n_job=n_job)
1059
1059
  print(f'Loaded {len(src)} images')
1060
1060
 
1061
1061
  result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
@@ -1302,7 +1302,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1302
1302
 
1303
1303
  return
1304
1304
 
1305
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
1305
+ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_job=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
1306
1306
 
1307
1307
  """
1308
1308
  Generate data loaders for training and validation/test datasets.
@@ -1314,7 +1314,7 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1314
1314
  - image_size (int): The size of the input images.
1315
1315
  - batch_size (int): The batch size for the data loaders.
1316
1316
  - classes (list): The list of classes to consider.
1317
- - num_workers (int): The number of worker threads for data loading.
1317
+ - n_job (int): The number of worker threads for data loading.
1318
1318
  - validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
1319
1319
  - max_show (int): The maximum number of images to show when verbose is True.
1320
1320
  - pin_memory (bool): Whether to pin memory for faster data transfer.
@@ -1404,10 +1404,10 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1404
1404
  #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1405
1405
  print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1406
1406
 
1407
- train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1408
- val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1407
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
1408
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
1409
1409
  else:
1410
- train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1410
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
1411
1411
 
1412
1412
  elif train_mode == 'irm':
1413
1413
  data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
@@ -1436,13 +1436,13 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1436
1436
  #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1437
1437
  print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1438
1438
 
1439
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1440
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1439
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
1440
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
1441
1441
 
1442
1442
  train_loaders.append(train_loader)
1443
1443
  val_loaders.append(val_loader)
1444
1444
  else:
1445
- train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1445
+ train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, n_job=n_job if n_job is not None else 0, pin_memory=pin_memory)
1446
1446
  train_loaders.append(train_loader)
1447
1447
  val_loaders.append(None)
1448
1448
 
@@ -1668,7 +1668,7 @@ def preprocess_generate_masks(src, settings={}):
1668
1668
 
1669
1669
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1670
1670
  from .plot import plot_merged, plot_arrays
1671
- from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, _merge_cells_based_on_parasite_overlap, process_masks
1671
+ from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks
1672
1672
  from .settings import set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings
1673
1673
 
1674
1674
  settings = set_default_settings_preprocess_generate_masks(src, settings)
@@ -2078,11 +2078,11 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2078
2078
  else:
2079
2079
  radius = 100
2080
2080
 
2081
- workers = os.cpu_count()-2
2082
- if workers < 1:
2083
- workers = 1
2081
+ n_job = os.cpu_count()-2
2082
+ if n_job < 1:
2083
+ n_job = 1
2084
2084
 
2085
- mask_stack = _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, masks_3D=masks, mode=timelapse_mode, timelapse_remove_transient=timelapse_remove_transient, radius=radius, workers=workers)
2085
+ mask_stack = _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, masks_3D=masks, mode=timelapse_mode, timelapse_remove_transient=timelapse_remove_transient, radius=radius, n_job=n_job)
2086
2086
  if timelapse_mode == 'trackpy':
2087
2087
  mask_stack = _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelapse_displacement, timelapse_memory, timelapse_remove_transient, plot, save, timelapse_mode)
2088
2088
 
@@ -2303,9 +2303,9 @@ def generate_cellpose_masks(src, settings, object_type):
2303
2303
  else:
2304
2304
  radius = 100
2305
2305
 
2306
- workers = os.cpu_count()-2
2307
- if workers < 1:
2308
- workers = 1
2306
+ n_job = os.cpu_count()-2
2307
+ if n_job < 1:
2308
+ n_job = 1
2309
2309
 
2310
2310
  mask_stack = _btrack_track_cells(src=src,
2311
2311
  name=name,
@@ -2317,7 +2317,7 @@ def generate_cellpose_masks(src, settings, object_type):
2317
2317
  mode=timelapse_mode,
2318
2318
  timelapse_remove_transient=timelapse_remove_transient,
2319
2319
  radius=radius,
2320
- workers=workers)
2320
+ n_job=n_job)
2321
2321
  if timelapse_mode == 'trackpy':
2322
2322
  mask_stack = _trackpy_track_cells(src=src,
2323
2323
  name=name,
@@ -2551,7 +2551,7 @@ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
2551
2551
  common_files.intersection_update(os.listdir(d))
2552
2552
  common_files = list(common_files)
2553
2553
 
2554
- # Create a pool of workers
2554
+ # Create a pool of n_job
2555
2555
  with Pool(processes=processes) as pool:
2556
2556
  args = [(src, filename, dirs, conditions) for filename in common_files]
2557
2557
  results = pool.map(compare_mask, args)
@@ -3021,9 +3021,9 @@ def generate_image_umap(settings={}):
3021
3021
  """
3022
3022
 
3023
3023
  from .io import _read_and_join_tables
3024
- from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis, generate_umap_from_images
3025
- from .settings import get_umap_image_settings
3026
- settings = get_umap_image_settings(settings)
3024
+ from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis #, generate_umap_from_images
3025
+ from .settings import set_default_umap_image_settings
3026
+ settings = set_default_umap_image_settings(settings)
3027
3027
 
3028
3028
  if isinstance(settings['src'], str):
3029
3029
  settings['src'] = [settings['src']]
@@ -3109,7 +3109,9 @@ def generate_image_umap(settings={}):
3109
3109
 
3110
3110
  else:
3111
3111
  if settings['resnet_features']:
3112
- numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
3112
+ # placeholder for resnet features, not implemented yet
3113
+ pass
3114
+ #numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
3113
3115
  else:
3114
3116
  # Apply the trained reducer to the entire dataset
3115
3117
  numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
@@ -3205,9 +3207,9 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
3205
3207
 
3206
3208
  from .io import _read_and_join_tables
3207
3209
  from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors
3208
- from .settings import get_umap_image_settings
3210
+ from .settings import set_default_umap_image_settings
3209
3211
 
3210
- settings = get_umap_image_settings(settings)
3212
+ settings = set_default_umap_image_settings(settings)
3211
3213
  pointsize = settings['dot_size']
3212
3214
  if isinstance(dbscan_params, dict):
3213
3215
  dbscan_params = [dbscan_params]
spacr/deep_spacr.py CHANGED
@@ -10,6 +10,9 @@ import matplotlib.pyplot as plt
10
10
  from PIL import Image
11
11
 
12
12
  from .logger import log_function_call
13
+ from .utils import close_multiprocessing_processes, reset_mp
14
+ #reset_mp()
15
+ #close_multiprocessing_processes()
13
16
 
14
17
  def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
15
18
  """
@@ -42,7 +45,6 @@ def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
42
45
  for batch_idx, (data, target, _) in enumerate(loader, start=1):
43
46
  start_time = time.time()
44
47
  data, target = data.to(device), target.to(device).float()
45
- #data, target = data.to(torch.float).to(device), target.to(device).float()
46
48
  output = model(data)
47
49
  loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
48
50
  loss = calculate_loss(output, target, loss_type=loss_type)
@@ -228,7 +230,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
228
230
  image_size=settings['image_size'],
229
231
  batch_size=settings['batch_size'],
230
232
  classes=settings['classes'],
231
- num_workers=settings['num_workers'],
233
+ n_job=settings['n_job'],
232
234
  validation_split=settings['val_split'],
233
235
  pin_memory=settings['pin_memory'],
234
236
  normalize=settings['normalize'],
@@ -253,7 +255,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
253
255
  optimizer_type = settings['optimizer_type'],
254
256
  use_checkpoint = settings['use_checkpoint'],
255
257
  dropout_rate = settings['dropout_rate'],
256
- num_workers = settings['num_workers'],
258
+ n_job = settings['n_job'],
257
259
  val_loaders = val,
258
260
  test_loaders = None,
259
261
  intermedeate_save = settings['intermedeate_save'],
@@ -274,7 +276,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
274
276
  image_size=settings['image_size'],
275
277
  batch_size=settings['batch_size'],
276
278
  classes=settings['classes'],
277
- num_workers=settings['num_workers'],
279
+ n_job=settings['n_job'],
278
280
  validation_split=0.0,
279
281
  pin_memory=settings['pin_memory'],
280
282
  normalize=settings['normalize'],
@@ -313,7 +315,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
313
315
  torch.cuda.memory.empty_cache()
314
316
  gc.collect()
315
317
 
316
- def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
318
+ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_job=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
317
319
  """
318
320
  Trains a model using the specified parameters.
319
321
 
@@ -330,7 +332,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
330
332
  optimizer_type (str, optional): The type of optimizer to use. Defaults to 'adamw'.
331
333
  use_checkpoint (bool, optional): Whether to use checkpointing during training. Defaults to False.
332
334
  dropout_rate (float, optional): The dropout rate for the model. Defaults to 0.
333
- num_workers (int, optional): The number of workers for data loading. Defaults to 20.
335
+ n_job (int, optional): The number of n_job for data loading. Defaults to 20.
334
336
  val_loaders (list, optional): A list of validation data loaders. Defaults to None.
335
337
  test_loaders (list, optional): A list of test data loaders. Defaults to None.
336
338
  init_weights (str, optional): The initialization weights for the model. Defaults to 'imagenet'.
@@ -355,7 +357,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
355
357
 
356
358
  use_cuda = torch.cuda.is_available()
357
359
  device = torch.device("cuda" if use_cuda else "cpu")
358
- kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
360
+ kwargs = {'n_job': n_job, 'pin_memory': True} if use_cuda else {}
359
361
 
360
362
  for idx, (images, labels, filenames) in enumerate(train_loaders):
361
363
  batch, chans, height, width = images.shape
spacr/gui.py CHANGED
@@ -1,52 +1,67 @@
1
1
  import tkinter as tk
2
2
  from tkinter import ttk
3
- from tkinter import font as tkFont
4
3
  from PIL import Image, ImageTk
5
- import os
6
- import requests
7
-
8
- # Import your GUI apps
9
- from .gui_mask_app import initiate_mask_root
10
- from .gui_measure_app import initiate_measure_root
11
- from .annotate_app import initiate_annotation_app_root
12
- from .mask_app import initiate_mask_app_root
13
- from .gui_classify_app import initiate_classify_root
14
-
15
- from .gui_utils import CustomButton, style_text_boxes
4
+ import os, requests
5
+ from multiprocessing import set_start_method
6
+ from .gui_elements import spacrButton, create_menu_bar, set_dark_style
7
+ from .gui_core import initiate_root
8
+ from .app_annotate import initiate_annotation_app_root
9
+ from .app_make_masks import initiate_mask_app_root
16
10
 
17
11
  class MainApp(tk.Tk):
18
- def __init__(self):
12
+ def __init__(self, default_app=None):
19
13
  super().__init__()
14
+ width = self.winfo_screenwidth()
15
+ height = self.winfo_screenheight()
16
+ self.geometry(f"{width}x{height}")
20
17
  self.title("SpaCr GUI Collection")
21
- self.geometry("1100x1500")
22
18
  self.configure(bg="black")
23
- #self.attributes('-fullscreen', True)
24
-
25
19
  style = ttk.Style()
26
- style_text_boxes(style)
20
+ set_dark_style(style)
27
21
 
28
22
  self.gui_apps = {
29
- "Mask": (initiate_mask_root, "Generate cellpose masks for cells, nuclei and pathogen images."),
30
- "Measure": (initiate_measure_root, "Measure single object intensity and morphological feature. Crop and save single object image"),
23
+ "Mask": (lambda frame: initiate_root(frame, 'mask'), "Generate cellpose masks for cells, nuclei and pathogen images."),
24
+ "Measure": (lambda frame: initiate_root(frame, 'measure'), "Measure single object intensity and morphological feature. Crop and save single object image"),
31
25
  "Annotate": (initiate_annotation_app_root, "Annotation single object images on a grid. Annotations are saved to database."),
32
26
  "Make Masks": (initiate_mask_app_root, "Adjust pre-existing Cellpose models to your specific dataset for improved performance"),
33
- "Classify": (initiate_classify_root, "Train Torch Convolutional Neural Networks (CNNs) or Transformers to classify single object images.")
27
+ "Classify": (lambda frame: initiate_root(frame, 'classify'), "Train Torch Convolutional Neural Networks (CNNs) or Transformers to classify single object images."),
28
+ "Sequencing": (lambda frame: initiate_root(frame, 'sequencing'), "Analyze sequensing data."),
29
+ "Umap": (lambda frame: initiate_root(frame, 'umap'), "Generate UMAP embedings with datapoints represented as images.")
34
30
  }
35
31
 
36
32
  self.selected_app = tk.StringVar()
37
33
  self.create_widgets()
38
34
 
35
+
36
+ if default_app == "Mask":
37
+ self.load_app(default_app, self.gui_apps[default_app][0])
38
+ elif default_app == "Measure":
39
+ self.load_app(default_app, self.gui_apps[default_app][1])
40
+ elif default_app == "Annotate":
41
+ self.load_app(default_app, self.gui_apps[default_app][2])
42
+ elif default_app == "Make Masks":
43
+ self.load_app(default_app, self.gui_apps[default_app][3])
44
+ elif default_app == "Classify":
45
+ self.load_app(default_app, self.gui_apps[default_app][4])
46
+ elif default_app == "Sequencing":
47
+ self.load_app(default_app, self.gui_apps[default_app][5])
48
+ elif default_app == "Umap":
49
+ self.load_app(default_app, self.gui_apps[default_app][6])
50
+
39
51
  def create_widgets(self):
40
52
  # Create the menu bar
41
- #create_menu_bar(self)
53
+ create_menu_bar(self)
54
+
42
55
  # Create a canvas to hold the selected app and other elements
43
- self.canvas = tk.Canvas(self, bg="black", highlightthickness=0, width=4000, height=4000)
56
+ self.canvas = tk.Canvas(self, bg="black", highlightthickness=0)
44
57
  self.canvas.grid(row=0, column=0, sticky="nsew")
45
58
  self.grid_rowconfigure(0, weight=1)
46
59
  self.grid_columnconfigure(0, weight=1)
60
+
47
61
  # Create a frame inside the canvas to hold the main content
48
62
  self.content_frame = tk.Frame(self.canvas, bg="black")
49
63
  self.content_frame.pack(fill=tk.BOTH, expand=True)
64
+
50
65
  # Create startup screen with buttons for each GUI app
51
66
  self.create_startup_screen()
52
67
 
@@ -59,10 +74,10 @@ class MainApp(tk.Tk):
59
74
 
60
75
  # Load the logo image
61
76
  if not self.load_logo(logo_frame):
62
- tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Helvetica', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
77
+ tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Helvetica', 24)).pack(padx=10, pady=10)
63
78
 
64
79
  # Add SpaCr text below the logo with padding for sharper text
65
- tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Helvetica', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
80
+ tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Helvetica', 24)).pack(padx=10, pady=10)
66
81
 
67
82
  # Create a frame for the buttons and descriptions
68
83
  buttons_frame = tk.Frame(self.content_frame, bg="black")
@@ -72,10 +87,11 @@ class MainApp(tk.Tk):
72
87
  app_func, app_desc = app_data
73
88
 
74
89
  # Create custom button with text
75
- button = CustomButton(buttons_frame, text=app_name, command=lambda app_name=app_name: self.load_app(app_name), font=('Helvetica', 12))
90
+ button = spacrButton(buttons_frame, text=app_name, command=lambda app_name=app_name, app_func=app_func: self.load_app(app_name, app_func), font=('Helvetica', 12))
91
+ #button = ttk.Button(buttons_frame, text=app_name, command=lambda app_name=app_name, app_func=app_func: self.load_app(app_name, app_func), style='Custom.TButton')
76
92
  button.grid(row=i, column=0, pady=10, padx=10, sticky="w")
77
93
 
78
- description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Helvetica', 10, tkFont.NORMAL))
94
+ description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Helvetica', 12))
79
95
  description_label.grid(row=i, column=1, pady=10, padx=10, sticky="w")
80
96
 
81
97
  # Ensure buttons have a fixed width
@@ -98,7 +114,6 @@ class MainApp(tk.Tk):
98
114
 
99
115
  try:
100
116
  img_path = os.path.join(os.path.dirname(__file__), 'logo_spacr.png')
101
- print(f"Trying to load logo from {img_path}")
102
117
  logo_image = Image.open(img_path)
103
118
  except (FileNotFoundError, Image.UnidentifiedImageError):
104
119
  print(f"File {img_path} not found or is not a valid image. Attempting to download from GitHub.")
@@ -115,7 +130,9 @@ class MainApp(tk.Tk):
115
130
  print(f"An error occurred while loading the logo: {e}")
116
131
  return False
117
132
  try:
118
- logo_image = logo_image.resize((800, 800), Image.Resampling.LANCZOS)
133
+ screen_height = frame.winfo_screenheight()
134
+ new_height = int(screen_height // 4)
135
+ logo_image = logo_image.resize((new_height, new_height), Image.Resampling.LANCZOS)
119
136
  logo_photo = ImageTk.PhotoImage(logo_image)
120
137
  logo_label = tk.Label(frame, image=logo_photo, bg="black")
121
138
  logo_label.image = logo_photo # Keep a reference to avoid garbage collection
@@ -125,13 +142,14 @@ class MainApp(tk.Tk):
125
142
  print(f"An error occurred while processing the logo image: {e}")
126
143
  return False
127
144
 
128
- def load_app(self, app_name):
129
- selected_app_func, _ = self.gui_apps[app_name]
145
+ def load_app(self, app_name, app_func):
146
+ # Clear the current content frame
130
147
  self.clear_frame(self.content_frame)
131
148
 
149
+ # Initialize the selected app
132
150
  app_frame = tk.Frame(self.content_frame, bg="black")
133
151
  app_frame.pack(fill=tk.BOTH, expand=True)
134
- selected_app_func(app_frame)#, self.winfo_width(), self.winfo_height())
152
+ app_func(app_frame)
135
153
 
136
154
  def clear_frame(self, frame):
137
155
  for widget in frame.winfo_children():
@@ -142,4 +160,5 @@ def gui_app():
142
160
  app.mainloop()
143
161
 
144
162
  if __name__ == "__main__":
163
+ set_start_method('spawn', force=True)
145
164
  gui_app()