spacr 0.2.56__py3-none-any.whl → 0.2.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +135 -472
- spacr/deep_spacr.py +189 -270
- spacr/gui_core.py +296 -87
- spacr/gui_elements.py +34 -81
- spacr/gui_utils.py +61 -47
- spacr/io.py +104 -41
- spacr/plot.py +47 -1
- spacr/settings.py +27 -31
- spacr/utils.py +14 -13
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/METADATA +1 -1
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/RECORD +15 -15
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/LICENSE +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/WHEEL +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/top_level.txt +0 -0
spacr/gui_elements.py
CHANGED
@@ -516,6 +516,7 @@ class spacrDropdownMenu(tk.Frame):
|
|
516
516
|
self.inactive_color = color_settings['inactive_color']
|
517
517
|
self.active_color = color_settings['active_color']
|
518
518
|
self.fg_color = color_settings['fg_color']
|
519
|
+
self.bg_color = style_out['bg_color']
|
519
520
|
|
520
521
|
# Create the button with rounded edges
|
521
522
|
self.button_bg = self.create_rounded_rectangle(2, 2, self.button_width + 2, self.size + 2, radius=20, fill=self.inactive_color, outline=self.inactive_color)
|
@@ -536,8 +537,8 @@ class spacrDropdownMenu(tk.Frame):
|
|
536
537
|
self.canvas.bind("<Leave>", self.on_leave)
|
537
538
|
self.canvas.bind("<Button-1>", self.on_click)
|
538
539
|
|
539
|
-
# Create a popup menu
|
540
|
-
self.menu = tk.Menu(self, tearoff=0)
|
540
|
+
# Create a popup menu with the desired background color
|
541
|
+
self.menu = tk.Menu(self, tearoff=0, bg=self.bg_color, fg=self.fg_color)
|
541
542
|
for option in self.options:
|
542
543
|
self.menu.add_command(label=option, command=lambda opt=option: self.on_select(opt))
|
543
544
|
|
@@ -591,7 +592,6 @@ class spacrDropdownMenu(tk.Frame):
|
|
591
592
|
else:
|
592
593
|
self.menu.entryconfig(idx, background=style_out['bg_color'], foreground=style_out['fg_color'])
|
593
594
|
|
594
|
-
|
595
595
|
class spacrCheckbutton(ttk.Checkbutton):
|
596
596
|
def __init__(self, parent, text="", variable=None, command=None, *args, **kwargs):
|
597
597
|
super().__init__(parent, *args, **kwargs)
|
@@ -613,17 +613,26 @@ class spacrProgressBar(ttk.Progressbar):
|
|
613
613
|
self.bg_color = style_out['bg_color']
|
614
614
|
self.active_color = style_out['active_color']
|
615
615
|
self.inactive_color = style_out['inactive_color']
|
616
|
+
self.font_size = style_out['font_size']
|
617
|
+
self.font_loader = style_out['font_loader']
|
616
618
|
|
617
619
|
# Configure the style for the progress bar
|
618
620
|
self.style = ttk.Style()
|
621
|
+
|
622
|
+
# Remove any borders and ensure the active color fills the entire space
|
619
623
|
self.style.configure(
|
620
624
|
"spacr.Horizontal.TProgressbar",
|
621
|
-
troughcolor=self.
|
622
|
-
background=self.active_color,
|
623
|
-
|
624
|
-
|
625
|
-
|
625
|
+
troughcolor=self.inactive_color, # Set the trough to bg color
|
626
|
+
background=self.active_color, # Active part is the active color
|
627
|
+
borderwidth=0, # Remove border width
|
628
|
+
pbarrelief="flat", # Flat relief for the progress bar
|
629
|
+
troughrelief="flat", # Flat relief for the trough
|
630
|
+
thickness=20, # Set the thickness of the progress bar
|
631
|
+
darkcolor=self.active_color, # Ensure darkcolor matches the active color
|
632
|
+
lightcolor=self.active_color, # Ensure lightcolor matches the active color
|
633
|
+
bordercolor=self.bg_color # Set the border color to the background color to hide it
|
626
634
|
)
|
635
|
+
|
627
636
|
self.configure(style="spacr.Horizontal.TProgressbar")
|
628
637
|
|
629
638
|
# Set initial value to 0
|
@@ -641,15 +650,14 @@ class spacrProgressBar(ttk.Progressbar):
|
|
641
650
|
justify='left',
|
642
651
|
bg=self.inactive_color,
|
643
652
|
fg=self.fg_color,
|
644
|
-
wraplength=300
|
653
|
+
wraplength=300,
|
654
|
+
font=self.font_loader.get_font(size=self.font_size)
|
645
655
|
)
|
646
|
-
self.progress_label.grid_forget()
|
656
|
+
self.progress_label.grid_forget()
|
647
657
|
|
648
658
|
# Initialize attributes for time and operation
|
649
659
|
self.operation_type = None
|
650
|
-
self.
|
651
|
-
self.time_batch = None
|
652
|
-
self.time_left = None
|
660
|
+
self.additional_info = None
|
653
661
|
|
654
662
|
def set_label_position(self):
|
655
663
|
if self.label and self.progress_label:
|
@@ -664,74 +672,19 @@ class spacrProgressBar(ttk.Progressbar):
|
|
664
672
|
label_text = f"Processing: {self['value']}/{self['maximum']}"
|
665
673
|
if self.operation_type:
|
666
674
|
label_text += f", {self.operation_type}"
|
667
|
-
if self.
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
style_out = set_dark_style(ttk.Style())
|
681
|
-
|
682
|
-
self.fg_color = style_out['fg_color']
|
683
|
-
self.bg_color = style_out['bg_color']
|
684
|
-
self.active_color = style_out['active_color']
|
685
|
-
self.inactive_color = style_out['inactive_color']
|
686
|
-
|
687
|
-
# Configure the style for the progress bar
|
688
|
-
self.style = ttk.Style()
|
689
|
-
self.style.configure(
|
690
|
-
"spacr.Horizontal.TProgressbar",
|
691
|
-
troughcolor=self.bg_color,
|
692
|
-
background=self.active_color,
|
693
|
-
thickness=20,
|
694
|
-
troughrelief='flat',
|
695
|
-
borderwidth=0
|
696
|
-
)
|
697
|
-
self.configure(style="spacr.Horizontal.TProgressbar")
|
698
|
-
|
699
|
-
# Set initial value to 0
|
700
|
-
self['value'] = 0
|
701
|
-
|
702
|
-
# Track whether to show the progress label
|
703
|
-
self.label = label
|
704
|
-
|
705
|
-
# Create the progress label (defer placement)
|
706
|
-
if self.label:
|
707
|
-
self.progress_label = tk.Label(parent, text="Processing: 0/0", anchor='w', justify='left', bg=self.inactive_color, fg=self.fg_color)
|
708
|
-
self.progress_label.grid_forget() # Temporarily hide it
|
709
|
-
|
710
|
-
# Initialize attributes for time and operation
|
711
|
-
self.operation_type = None
|
712
|
-
self.time_image = None
|
713
|
-
self.time_batch = None
|
714
|
-
self.time_left = None
|
715
|
-
|
716
|
-
def set_label_position(self):
|
717
|
-
if self.label and self.progress_label:
|
718
|
-
row_info = self.grid_info().get('row', 0)
|
719
|
-
col_info = self.grid_info().get('column', 0)
|
720
|
-
col_span = self.grid_info().get('columnspan', 1)
|
721
|
-
self.progress_label.grid(row=row_info + 1, column=col_info, columnspan=col_span, pady=5, padx=5, sticky='ew')
|
722
|
-
|
723
|
-
def update_label(self):
|
724
|
-
if self.label and self.progress_label:
|
725
|
-
# Update the progress label with current progress and additional info
|
726
|
-
label_text = f"Processing: {self['value']}/{self['maximum']}"
|
727
|
-
if self.operation_type:
|
728
|
-
label_text += f", {self.operation_type}"
|
729
|
-
if self.time_image:
|
730
|
-
label_text += f", Time/image: {self.time_image:.3f} sec"
|
731
|
-
if self.time_batch:
|
732
|
-
label_text += f", Time/batch: {self.time_batch:.3f} sec"
|
733
|
-
if self.time_left:
|
734
|
-
label_text += f", Time_left: {self.time_left:.3f} min"
|
675
|
+
if hasattr(self, 'additional_info') and self.additional_info:
|
676
|
+
# Add a space between progress information and additional information
|
677
|
+
label_text += "\n\n"
|
678
|
+
# Split the additional_info into a list of items
|
679
|
+
items = self.additional_info.split(", ")
|
680
|
+
formatted_additional_info = ""
|
681
|
+
# Group the items in pairs, adding them to formatted_additional_info
|
682
|
+
for i in range(0, len(items), 2):
|
683
|
+
if i + 1 < len(items):
|
684
|
+
formatted_additional_info += f"{items[i]}, {items[i + 1]}\n\n"
|
685
|
+
else:
|
686
|
+
formatted_additional_info += f"{items[i]}\n\n" # If there's an odd item out, add it alone
|
687
|
+
label_text += formatted_additional_info.strip()
|
735
688
|
self.progress_label.config(text=label_text)
|
736
689
|
|
737
690
|
def spacrScrollbarStyle(style, inactive_color, active_color):
|
spacr/gui_utils.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
-
import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback
|
1
|
+
import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback, torch
|
2
2
|
import tkinter as tk
|
3
3
|
from tkinter import ttk
|
4
4
|
import matplotlib
|
5
5
|
import matplotlib.pyplot as plt
|
6
6
|
matplotlib.use('Agg')
|
7
7
|
from huggingface_hub import list_repo_files
|
8
|
+
import psutil
|
8
9
|
|
9
10
|
from .gui_elements import AnnotateApp, spacrEntry, spacrCheck, spacrCombo
|
10
11
|
|
@@ -12,6 +13,36 @@ try:
|
|
12
13
|
ctypes.windll.shcore.SetProcessDpiAwareness(True)
|
13
14
|
except AttributeError:
|
14
15
|
pass
|
16
|
+
|
17
|
+
def initialize_cuda():
|
18
|
+
"""
|
19
|
+
Initializes CUDA in the main process by performing a simple GPU operation.
|
20
|
+
"""
|
21
|
+
if torch.cuda.is_available():
|
22
|
+
# Allocate a small tensor on the GPU
|
23
|
+
_ = torch.tensor([0.0], device='cuda')
|
24
|
+
print("CUDA initialized in the main process.")
|
25
|
+
else:
|
26
|
+
print("CUDA is not available.")
|
27
|
+
|
28
|
+
def set_high_priority(process):
|
29
|
+
try:
|
30
|
+
p = psutil.Process(process.pid)
|
31
|
+
if os.name == 'nt': # Windows
|
32
|
+
p.nice(psutil.HIGH_PRIORITY_CLASS)
|
33
|
+
else: # Unix-like systems
|
34
|
+
p.nice(-10) # Adjusted priority level
|
35
|
+
print(f"Successfully set high priority for process: {process.pid}")
|
36
|
+
except psutil.AccessDenied as e:
|
37
|
+
print(f"Access denied when trying to set high priority for process {process.pid}: {e}")
|
38
|
+
except psutil.NoSuchProcess as e:
|
39
|
+
print(f"No such process {process.pid}: {e}")
|
40
|
+
except Exception as e:
|
41
|
+
print(f"Failed to set high priority for process {process.pid}: {e}")
|
42
|
+
|
43
|
+
def set_cpu_affinity(process):
|
44
|
+
p = psutil.Process(process.pid)
|
45
|
+
p.cpu_affinity(list(range(os.cpu_count())))
|
15
46
|
|
16
47
|
def proceed_with_app(root, app_name, app_func):
|
17
48
|
# Clear the current content frame
|
@@ -44,16 +75,6 @@ def load_app(root, app_name, app_func):
|
|
44
75
|
else:
|
45
76
|
proceed_with_app(root, app_name, app_func)
|
46
77
|
|
47
|
-
def parse_list_v1(value):
|
48
|
-
try:
|
49
|
-
parsed_value = ast.literal_eval(value)
|
50
|
-
if isinstance(parsed_value, list):
|
51
|
-
return parsed_value
|
52
|
-
else:
|
53
|
-
raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
|
54
|
-
except (ValueError, SyntaxError) as e:
|
55
|
-
raise ValueError(f"Invalid format for list: {value}. Error: {e}")
|
56
|
-
|
57
78
|
def parse_list(value):
|
58
79
|
try:
|
59
80
|
parsed_value = ast.literal_eval(value)
|
@@ -83,6 +104,7 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
|
|
83
104
|
size_dict['settings_width'] = size_dict['settings_width'] - int(size_dict['settings_width']*0.1)
|
84
105
|
|
85
106
|
# Replace underscores with spaces and capitalize the first letter
|
107
|
+
|
86
108
|
label_text = label_text.replace('_', ' ').capitalize()
|
87
109
|
|
88
110
|
# Configure the column widths
|
@@ -97,32 +119,35 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
|
|
97
119
|
custom_frame.config(highlightbackground=style_out['bg_color'], highlightthickness=1, bd=2)
|
98
120
|
|
99
121
|
# Create and configure the label
|
100
|
-
|
101
|
-
label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
|
102
|
-
label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=(style_out['font_family'], style_out['font_size']), anchor='e', justify='right')
|
122
|
+
label = tk.Label(custom_frame, text=label_text, bg=style_out['bg_color'], fg=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
|
103
123
|
label.grid(column=label_column, row=0, sticky=tk.W, padx=(5, 2), pady=5) # Place the label in the first row
|
104
124
|
|
105
125
|
# Create and configure the input widget based on var_type
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
+
try:
|
127
|
+
if var_type == 'entry':
|
128
|
+
var = tk.StringVar(value=default_value)
|
129
|
+
entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
|
130
|
+
entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
|
131
|
+
return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
|
132
|
+
elif var_type == 'check':
|
133
|
+
var = tk.BooleanVar(value=default_value) # Set default value (True/False)
|
134
|
+
check = spacrCheck(custom_frame, text="", variable=var)
|
135
|
+
check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
|
136
|
+
return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
|
137
|
+
elif var_type == 'combo':
|
138
|
+
var = tk.StringVar(value=default_value) # Set default value
|
139
|
+
combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
|
140
|
+
combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
|
141
|
+
if default_value:
|
142
|
+
combo.set(default_value)
|
143
|
+
return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
|
144
|
+
else:
|
145
|
+
var = None # Placeholder in case of an undefined var_type
|
146
|
+
return (label, None, var, custom_frame)
|
147
|
+
except Exception as e:
|
148
|
+
traceback.print_exc()
|
149
|
+
print(f"Error creating input field: {e}")
|
150
|
+
print(f"Wrong type for {label_text} Expected {var_type}")
|
126
151
|
|
127
152
|
def process_stdout_stderr(q):
|
128
153
|
"""
|
@@ -150,16 +175,6 @@ def cancel_after_tasks(frame):
|
|
150
175
|
frame.after_cancel(task)
|
151
176
|
frame.after_tasks.clear()
|
152
177
|
|
153
|
-
def main_thread_update_function(root, q, fig_queue, canvas_widget):
|
154
|
-
try:
|
155
|
-
#ansi_escape_pattern = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]')
|
156
|
-
while not q.empty():
|
157
|
-
message = q.get_nowait()
|
158
|
-
except Exception as e:
|
159
|
-
print(f"Error updating GUI canvas: {e}")
|
160
|
-
finally:
|
161
|
-
root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
|
162
|
-
|
163
178
|
def annotate(settings):
|
164
179
|
from .settings import set_annotate_default_settings
|
165
180
|
settings = set_annotate_default_settings(settings)
|
@@ -341,7 +356,7 @@ def convert_settings_dict_for_gui(settings):
|
|
341
356
|
'channels': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
|
342
357
|
'train_channels': ('combo', ["['r','g','b']", "['r','g']", "['r','b']", "['g','b']", "['r']", "['g']", "['b']"], "['r','g','b']"),
|
343
358
|
'channel_dims': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
|
344
|
-
'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], '
|
359
|
+
'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'metadata'),
|
345
360
|
'cell_mask_dim': ('combo', chans, None),
|
346
361
|
'cell_chann_dim': ('combo', chans, None),
|
347
362
|
'nucleus_mask_dim': ('combo', chans, None),
|
@@ -476,7 +491,7 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
476
491
|
imports = 2
|
477
492
|
elif settings_type == 'recruitment':
|
478
493
|
function = analyze_recruitment
|
479
|
-
imports =
|
494
|
+
imports = 1
|
480
495
|
elif settings_type == 'umap':
|
481
496
|
function = generate_image_umap
|
482
497
|
imports = 1
|
@@ -490,7 +505,6 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
490
505
|
finally:
|
491
506
|
stop_requested.value = 1
|
492
507
|
|
493
|
-
|
494
508
|
def hide_all_settings(vars_dict, categories):
|
495
509
|
"""
|
496
510
|
Function to initially hide all settings in the GUI.
|
spacr/io.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
|
1
|
+
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
4
|
import tifffile
|
5
|
-
from PIL import Image
|
6
|
-
from collections import defaultdict, Counter
|
5
|
+
from PIL import Image, ImageOps
|
6
|
+
from collections import defaultdict, Counter, deque
|
7
7
|
from pathlib import Path
|
8
8
|
from functools import partial
|
9
9
|
from matplotlib.animation import FuncAnimation
|
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
|
|
17
17
|
import matplotlib.pyplot as plt
|
18
18
|
from io import BytesIO
|
19
19
|
from IPython.display import display, clear_output
|
20
|
-
from multiprocessing import Pool, cpu_count
|
21
|
-
from torch.utils.data import Dataset
|
20
|
+
from multiprocessing import Pool, cpu_count, Process, Queue
|
21
|
+
from torch.utils.data import Dataset, DataLoader
|
22
22
|
import matplotlib.pyplot as plt
|
23
23
|
from torchvision.transforms import ToTensor
|
24
24
|
import seaborn as sns
|
25
|
-
|
25
|
+
import atexit
|
26
26
|
|
27
27
|
from .logger import log_function_call
|
28
28
|
|
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
|
|
444
444
|
# Return both the image and its filename
|
445
445
|
return img, self.filenames[index]
|
446
446
|
|
447
|
-
class
|
448
|
-
"""
|
449
|
-
A custom dataset class for loading and processing image data.
|
450
|
-
|
451
|
-
Args:
|
452
|
-
data_dir (str): The directory path where the image data is stored.
|
453
|
-
loader_classes (list): A list of class names for the dataset.
|
454
|
-
transform (callable, optional): A function/transform to apply to the image data. Default is None.
|
455
|
-
shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
|
456
|
-
pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
|
457
|
-
specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
|
458
|
-
specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
|
459
|
-
"""
|
460
|
-
|
447
|
+
class spacrDataset(Dataset):
|
461
448
|
def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
|
462
449
|
self.data_dir = data_dir
|
463
450
|
self.classes = loader_classes
|
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
|
|
466
453
|
self.pin_memory = pin_memory
|
467
454
|
self.filenames = []
|
468
455
|
self.labels = []
|
469
|
-
|
456
|
+
|
470
457
|
if specific_files and specific_labels:
|
471
458
|
self.filenames = specific_files
|
472
459
|
self.labels = specific_labels
|
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
|
|
479
466
|
|
480
467
|
if self.shuffle:
|
481
468
|
self.shuffle_dataset()
|
482
|
-
|
469
|
+
|
483
470
|
if self.pin_memory:
|
484
|
-
|
485
|
-
|
471
|
+
# Use multiprocessing to load images in parallel
|
472
|
+
with Pool(processes=cpu_count()) as pool:
|
473
|
+
self.images = pool.map(self.load_image, self.filenames)
|
474
|
+
else:
|
475
|
+
self.images = None
|
476
|
+
|
486
477
|
def load_image(self, img_path):
|
487
478
|
img = Image.open(img_path).convert('RGB')
|
479
|
+
img = ImageOps.exif_transpose(img) # Handle image orientation
|
488
480
|
return img
|
489
|
-
|
481
|
+
|
490
482
|
def __len__(self):
|
491
483
|
return len(self.filenames)
|
492
|
-
|
484
|
+
|
493
485
|
def shuffle_dataset(self):
|
494
486
|
combined = list(zip(self.filenames, self.labels))
|
495
487
|
random.shuffle(combined)
|
496
488
|
self.filenames, self.labels = zip(*combined)
|
497
|
-
|
489
|
+
|
498
490
|
def get_plate(self, filepath):
|
499
|
-
filename = os.path.basename(filepath)
|
491
|
+
filename = os.path.basename(filepath)
|
500
492
|
return filename.split('_')[0]
|
501
|
-
|
493
|
+
|
502
494
|
def __getitem__(self, index):
|
495
|
+
if self.pin_memory:
|
496
|
+
img = self.images[index]
|
497
|
+
else:
|
498
|
+
img = self.load_image(self.filenames[index])
|
503
499
|
label = self.labels[index]
|
504
500
|
filename = self.filenames[index]
|
505
|
-
img = self.load_image(filename)
|
506
501
|
if self.transform:
|
507
502
|
img = self.transform(img)
|
508
503
|
return img, label, filename
|
504
|
+
|
505
|
+
class spacrDataLoader(DataLoader):
|
506
|
+
def __init__(self, *args, preload_batches=1, **kwargs):
|
507
|
+
super().__init__(*args, **kwargs)
|
508
|
+
self.preload_batches = preload_batches
|
509
|
+
self.batch_queue = Queue(maxsize=preload_batches)
|
510
|
+
self.process = None
|
511
|
+
self.current_batch_index = 0
|
512
|
+
self._stop_event = False
|
513
|
+
self.pin_memory = kwargs.get('pin_memory', False)
|
514
|
+
atexit.register(self.cleanup)
|
515
|
+
|
516
|
+
def _preload_next_batches(self):
|
517
|
+
try:
|
518
|
+
for _ in range(self.preload_batches):
|
519
|
+
if self._stop_event:
|
520
|
+
break
|
521
|
+
batch = next(self._iterator)
|
522
|
+
if self.pin_memory:
|
523
|
+
batch = self._pin_memory_batch(batch)
|
524
|
+
self.batch_queue.put(batch)
|
525
|
+
except StopIteration:
|
526
|
+
pass
|
527
|
+
|
528
|
+
def _start_preloading(self):
|
529
|
+
if self.process is None or not self.process.is_alive():
|
530
|
+
self._iterator = iter(super().__iter__())
|
531
|
+
if not self.pin_memory:
|
532
|
+
self.process = Process(target=self._preload_next_batches)
|
533
|
+
self.process.start()
|
534
|
+
else:
|
535
|
+
self._preload_next_batches() # Directly load if pin_memory is True
|
536
|
+
|
537
|
+
def _pin_memory_batch(self, batch):
|
538
|
+
if isinstance(batch, (list, tuple)):
|
539
|
+
return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
|
540
|
+
elif isinstance(batch, torch.Tensor):
|
541
|
+
return batch.pin_memory()
|
542
|
+
else:
|
543
|
+
return batch
|
544
|
+
|
545
|
+
def __iter__(self):
|
546
|
+
self._start_preloading()
|
547
|
+
return self
|
548
|
+
|
549
|
+
def __next__(self):
|
550
|
+
if self.process and not self.process.is_alive() and self.batch_queue.empty():
|
551
|
+
raise StopIteration
|
552
|
+
|
553
|
+
try:
|
554
|
+
if self.pin_memory:
|
555
|
+
next_batch = self.batch_queue.get(timeout=60)
|
556
|
+
else:
|
557
|
+
next_batch = self.batch_queue.get(timeout=60)
|
558
|
+
self.current_batch_index += 1
|
559
|
+
|
560
|
+
# Start preloading the next batches
|
561
|
+
if self.batch_queue.qsize() < self.preload_batches:
|
562
|
+
self._start_preloading()
|
563
|
+
|
564
|
+
return next_batch
|
565
|
+
except queue.Empty:
|
566
|
+
raise StopIteration
|
567
|
+
|
568
|
+
def cleanup(self):
|
569
|
+
self._stop_event = True
|
570
|
+
if self.process and self.process.is_alive():
|
571
|
+
self.process.terminate()
|
572
|
+
self.process.join()
|
573
|
+
|
574
|
+
def __del__(self):
|
575
|
+
self.cleanup()
|
509
576
|
|
510
577
|
class NoClassDataset(Dataset):
|
511
578
|
def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
|
@@ -2292,7 +2359,7 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2292
2359
|
|
2293
2360
|
def save_model_at_threshold(threshold, epoch, suffix=""):
|
2294
2361
|
percentile = str(threshold * 100)
|
2295
|
-
print(f'
|
2362
|
+
print(f'Found: {percentile}% accurate model')
|
2296
2363
|
model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
|
2297
2364
|
torch.save(model, model_path)
|
2298
2365
|
return model_path
|
@@ -2303,7 +2370,8 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2303
2370
|
return model_path
|
2304
2371
|
|
2305
2372
|
for threshold in intermedeate_save:
|
2306
|
-
if results_df['neg_accuracy']
|
2373
|
+
if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
|
2374
|
+
print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
|
2307
2375
|
model_path = save_model_at_threshold(threshold, epoch)
|
2308
2376
|
break
|
2309
2377
|
else:
|
@@ -2311,7 +2379,7 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2311
2379
|
|
2312
2380
|
return model_path
|
2313
2381
|
|
2314
|
-
def _save_progress(dst, results_df,
|
2382
|
+
def _save_progress(dst, results_df, result_type='train'):
|
2315
2383
|
"""
|
2316
2384
|
Save the progress of the classification model.
|
2317
2385
|
|
@@ -2325,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
|
|
2325
2393
|
"""
|
2326
2394
|
# Save accuracy, loss, PRAUC
|
2327
2395
|
os.makedirs(dst, exist_ok=True)
|
2328
|
-
results_path = os.path.join(dst, '
|
2396
|
+
results_path = os.path.join(dst, f'{result_type}.csv')
|
2329
2397
|
if not os.path.exists(results_path):
|
2330
2398
|
results_df.to_csv(results_path, index=True, header=True, mode='w')
|
2331
2399
|
else:
|
2332
2400
|
results_df.to_csv(results_path, index=True, header=False, mode='a')
|
2333
|
-
|
2334
|
-
|
2335
|
-
if not os.path.exists(training_metrics_path):
|
2336
|
-
train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
|
2337
|
-
else:
|
2338
|
-
train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
|
2339
|
-
if epoch == epochs:
|
2401
|
+
|
2402
|
+
if result_type == 'train':
|
2340
2403
|
read_plot_model_stats(results_path, save=True)
|
2341
2404
|
return
|
2342
2405
|
|
spacr/plot.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os,re, random, cv2, glob, time, math
|
1
|
+
import os,re, random, cv2, glob, time, math, torch
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
@@ -1186,6 +1186,52 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
|
1186
1186
|
y = row * img_height + 15
|
1187
1187
|
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1188
1188
|
return fig
|
1189
|
+
|
1190
|
+
def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
|
1191
|
+
"""
|
1192
|
+
Display multiple images in a grid with corresponding labels.
|
1193
|
+
|
1194
|
+
Args:
|
1195
|
+
img (torch.Tensor): A batch of images as a tensor.
|
1196
|
+
labels (list): List of labels corresponding to each image.
|
1197
|
+
nrow (int, optional): Number of images per row in the grid. Defaults to 20.
|
1198
|
+
color (str, optional): Color of the label text. Defaults to 'white'.
|
1199
|
+
fontsize (int, optional): Font size of the label text. Defaults to 12.
|
1200
|
+
"""
|
1201
|
+
if img.is_cuda:
|
1202
|
+
img = img.cpu() # Move to CPU if the tensor is on GPU
|
1203
|
+
|
1204
|
+
n_images = len(labels)
|
1205
|
+
n_col = nrow
|
1206
|
+
n_row = int(np.ceil(n_images / n_col))
|
1207
|
+
|
1208
|
+
img_height = img.shape[2] # Height of the image
|
1209
|
+
img_width = img.shape[3] # Width of the image
|
1210
|
+
|
1211
|
+
# Prepare the canvas on CPU
|
1212
|
+
canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
|
1213
|
+
|
1214
|
+
for i in range(n_row):
|
1215
|
+
for j in range(n_col):
|
1216
|
+
idx = i * n_col + j
|
1217
|
+
if idx < n_images:
|
1218
|
+
# Place the image on the canvas
|
1219
|
+
canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
|
1220
|
+
|
1221
|
+
canvas = canvas.numpy() # Convert to NumPy for plotting
|
1222
|
+
|
1223
|
+
fig = plt.figure(figsize=(50, 50))
|
1224
|
+
plt.imshow(canvas)
|
1225
|
+
plt.axis("off")
|
1226
|
+
|
1227
|
+
for i, label in enumerate(labels):
|
1228
|
+
row = i // n_col
|
1229
|
+
col = i % n_col
|
1230
|
+
x = col * img_width + 2
|
1231
|
+
y = row * img_height + 15
|
1232
|
+
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1233
|
+
|
1234
|
+
return fig
|
1189
1235
|
|
1190
1236
|
def _plot_histograms_and_stats(df):
|
1191
1237
|
conditions = df['condition'].unique()
|