spacr 0.2.53__py3-none-any.whl → 0.2.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +218 -283
- spacr/deep_spacr.py +248 -269
- spacr/gui.py +1 -1
- spacr/gui_core.py +301 -94
- spacr/gui_elements.py +43 -20
- spacr/gui_utils.py +81 -47
- spacr/io.py +116 -45
- spacr/plot.py +47 -1
- spacr/sequencing.py +443 -643
- spacr/settings.py +192 -64
- spacr/utils.py +22 -13
- {spacr-0.2.53.dist-info → spacr-0.2.61.dist-info}/METADATA +2 -1
- {spacr-0.2.53.dist-info → spacr-0.2.61.dist-info}/RECORD +17 -17
- {spacr-0.2.53.dist-info → spacr-0.2.61.dist-info}/LICENSE +0 -0
- {spacr-0.2.53.dist-info → spacr-0.2.61.dist-info}/WHEEL +0 -0
- {spacr-0.2.53.dist-info → spacr-0.2.61.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.53.dist-info → spacr-0.2.61.dist-info}/top_level.txt +0 -0
spacr/gui_elements.py
CHANGED
@@ -516,6 +516,7 @@ class spacrDropdownMenu(tk.Frame):
|
|
516
516
|
self.inactive_color = color_settings['inactive_color']
|
517
517
|
self.active_color = color_settings['active_color']
|
518
518
|
self.fg_color = color_settings['fg_color']
|
519
|
+
self.bg_color = style_out['bg_color']
|
519
520
|
|
520
521
|
# Create the button with rounded edges
|
521
522
|
self.button_bg = self.create_rounded_rectangle(2, 2, self.button_width + 2, self.size + 2, radius=20, fill=self.inactive_color, outline=self.inactive_color)
|
@@ -536,8 +537,8 @@ class spacrDropdownMenu(tk.Frame):
|
|
536
537
|
self.canvas.bind("<Leave>", self.on_leave)
|
537
538
|
self.canvas.bind("<Button-1>", self.on_click)
|
538
539
|
|
539
|
-
# Create a popup menu
|
540
|
-
self.menu = tk.Menu(self, tearoff=0)
|
540
|
+
# Create a popup menu with the desired background color
|
541
|
+
self.menu = tk.Menu(self, tearoff=0, bg=self.bg_color, fg=self.fg_color)
|
541
542
|
for option in self.options:
|
542
543
|
self.menu.add_command(label=option, command=lambda opt=option: self.on_select(opt))
|
543
544
|
|
@@ -591,7 +592,6 @@ class spacrDropdownMenu(tk.Frame):
|
|
591
592
|
else:
|
592
593
|
self.menu.entryconfig(idx, background=style_out['bg_color'], foreground=style_out['fg_color'])
|
593
594
|
|
594
|
-
|
595
595
|
class spacrCheckbutton(ttk.Checkbutton):
|
596
596
|
def __init__(self, parent, text="", variable=None, command=None, *args, **kwargs):
|
597
597
|
super().__init__(parent, *args, **kwargs)
|
@@ -613,17 +613,26 @@ class spacrProgressBar(ttk.Progressbar):
|
|
613
613
|
self.bg_color = style_out['bg_color']
|
614
614
|
self.active_color = style_out['active_color']
|
615
615
|
self.inactive_color = style_out['inactive_color']
|
616
|
+
self.font_size = style_out['font_size']
|
617
|
+
self.font_loader = style_out['font_loader']
|
616
618
|
|
617
619
|
# Configure the style for the progress bar
|
618
620
|
self.style = ttk.Style()
|
621
|
+
|
622
|
+
# Remove any borders and ensure the active color fills the entire space
|
619
623
|
self.style.configure(
|
620
624
|
"spacr.Horizontal.TProgressbar",
|
621
|
-
troughcolor=self.
|
622
|
-
background=self.active_color,
|
623
|
-
|
624
|
-
|
625
|
-
|
625
|
+
troughcolor=self.inactive_color, # Set the trough to bg color
|
626
|
+
background=self.active_color, # Active part is the active color
|
627
|
+
borderwidth=0, # Remove border width
|
628
|
+
pbarrelief="flat", # Flat relief for the progress bar
|
629
|
+
troughrelief="flat", # Flat relief for the trough
|
630
|
+
thickness=20, # Set the thickness of the progress bar
|
631
|
+
darkcolor=self.active_color, # Ensure darkcolor matches the active color
|
632
|
+
lightcolor=self.active_color, # Ensure lightcolor matches the active color
|
633
|
+
bordercolor=self.bg_color # Set the border color to the background color to hide it
|
626
634
|
)
|
635
|
+
|
627
636
|
self.configure(style="spacr.Horizontal.TProgressbar")
|
628
637
|
|
629
638
|
# Set initial value to 0
|
@@ -632,16 +641,23 @@ class spacrProgressBar(ttk.Progressbar):
|
|
632
641
|
# Track whether to show the progress label
|
633
642
|
self.label = label
|
634
643
|
|
635
|
-
# Create the progress label
|
644
|
+
# Create the progress label with text wrapping
|
636
645
|
if self.label:
|
637
|
-
self.progress_label = tk.Label(
|
638
|
-
|
646
|
+
self.progress_label = tk.Label(
|
647
|
+
parent,
|
648
|
+
text="Processing: 0/0",
|
649
|
+
anchor='w',
|
650
|
+
justify='left',
|
651
|
+
bg=self.inactive_color,
|
652
|
+
fg=self.fg_color,
|
653
|
+
wraplength=300,
|
654
|
+
font=self.font_loader.get_font(size=self.font_size)
|
655
|
+
)
|
656
|
+
self.progress_label.grid_forget()
|
639
657
|
|
640
658
|
# Initialize attributes for time and operation
|
641
659
|
self.operation_type = None
|
642
|
-
self.
|
643
|
-
self.time_batch = None
|
644
|
-
self.time_left = None
|
660
|
+
self.additional_info = None
|
645
661
|
|
646
662
|
def set_label_position(self):
|
647
663
|
if self.label and self.progress_label:
|
@@ -656,12 +672,19 @@ class spacrProgressBar(ttk.Progressbar):
|
|
656
672
|
label_text = f"Processing: {self['value']}/{self['maximum']}"
|
657
673
|
if self.operation_type:
|
658
674
|
label_text += f", {self.operation_type}"
|
659
|
-
if self.
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
675
|
+
if hasattr(self, 'additional_info') and self.additional_info:
|
676
|
+
# Add a space between progress information and additional information
|
677
|
+
label_text += "\n\n"
|
678
|
+
# Split the additional_info into a list of items
|
679
|
+
items = self.additional_info.split(", ")
|
680
|
+
formatted_additional_info = ""
|
681
|
+
# Group the items in pairs, adding them to formatted_additional_info
|
682
|
+
for i in range(0, len(items), 2):
|
683
|
+
if i + 1 < len(items):
|
684
|
+
formatted_additional_info += f"{items[i]}, {items[i + 1]}\n\n"
|
685
|
+
else:
|
686
|
+
formatted_additional_info += f"{items[i]}\n\n" # If there's an odd item out, add it alone
|
687
|
+
label_text += formatted_additional_info.strip()
|
665
688
|
self.progress_label.config(text=label_text)
|
666
689
|
|
667
690
|
def spacrScrollbarStyle(style, inactive_color, active_color):
|
spacr/gui_utils.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
-
import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback
|
1
|
+
import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback, torch
|
2
2
|
import tkinter as tk
|
3
3
|
from tkinter import ttk
|
4
4
|
import matplotlib
|
5
5
|
import matplotlib.pyplot as plt
|
6
6
|
matplotlib.use('Agg')
|
7
7
|
from huggingface_hub import list_repo_files
|
8
|
+
import psutil
|
8
9
|
|
9
10
|
from .gui_elements import AnnotateApp, spacrEntry, spacrCheck, spacrCombo
|
10
11
|
|
@@ -12,6 +13,36 @@ try:
|
|
12
13
|
ctypes.windll.shcore.SetProcessDpiAwareness(True)
|
13
14
|
except AttributeError:
|
14
15
|
pass
|
16
|
+
|
17
|
+
def initialize_cuda():
|
18
|
+
"""
|
19
|
+
Initializes CUDA in the main process by performing a simple GPU operation.
|
20
|
+
"""
|
21
|
+
if torch.cuda.is_available():
|
22
|
+
# Allocate a small tensor on the GPU
|
23
|
+
_ = torch.tensor([0.0], device='cuda')
|
24
|
+
print("CUDA initialized in the main process.")
|
25
|
+
else:
|
26
|
+
print("CUDA is not available.")
|
27
|
+
|
28
|
+
def set_high_priority(process):
|
29
|
+
try:
|
30
|
+
p = psutil.Process(process.pid)
|
31
|
+
if os.name == 'nt': # Windows
|
32
|
+
p.nice(psutil.HIGH_PRIORITY_CLASS)
|
33
|
+
else: # Unix-like systems
|
34
|
+
p.nice(-10) # Adjusted priority level
|
35
|
+
print(f"Successfully set high priority for process: {process.pid}")
|
36
|
+
except psutil.AccessDenied as e:
|
37
|
+
print(f"Access denied when trying to set high priority for process {process.pid}: {e}")
|
38
|
+
except psutil.NoSuchProcess as e:
|
39
|
+
print(f"No such process {process.pid}: {e}")
|
40
|
+
except Exception as e:
|
41
|
+
print(f"Failed to set high priority for process {process.pid}: {e}")
|
42
|
+
|
43
|
+
def set_cpu_affinity(process):
|
44
|
+
p = psutil.Process(process.pid)
|
45
|
+
p.cpu_affinity(list(range(os.cpu_count())))
|
15
46
|
|
16
47
|
def proceed_with_app(root, app_name, app_func):
|
17
48
|
# Clear the current content frame
|
@@ -48,12 +79,18 @@ def parse_list(value):
|
|
48
79
|
try:
|
49
80
|
parsed_value = ast.literal_eval(value)
|
50
81
|
if isinstance(parsed_value, list):
|
51
|
-
|
82
|
+
# Check if the list elements are homogeneous (all int or all str)
|
83
|
+
if all(isinstance(item, int) for item in parsed_value):
|
84
|
+
return parsed_value
|
85
|
+
elif all(isinstance(item, str) for item in parsed_value):
|
86
|
+
return parsed_value
|
87
|
+
else:
|
88
|
+
raise ValueError("List contains mixed types or unsupported types")
|
52
89
|
else:
|
53
90
|
raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
|
54
91
|
except (ValueError, SyntaxError) as e:
|
55
92
|
raise ValueError(f"Invalid format for list: {value}. Error: {e}")
|
56
|
-
|
93
|
+
|
57
94
|
# Usage example in your create_input_field function
|
58
95
|
def create_input_field(frame, label_text, row, var_type='entry', options=None, default_value=None):
|
59
96
|
from .gui_elements import set_dark_style, set_element_size
|
@@ -67,6 +104,7 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
|
|
67
104
|
size_dict['settings_width'] = size_dict['settings_width'] - int(size_dict['settings_width']*0.1)
|
68
105
|
|
69
106
|
# Replace underscores with spaces and capitalize the first letter
|
107
|
+
|
70
108
|
label_text = label_text.replace('_', ' ').capitalize()
|
71
109
|
|
72
110
|
# Configure the column widths
|
@@ -81,32 +119,35 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
|
|
81
119
|
custom_frame.config(highlightbackground=style_out['bg_color'], highlightthickness=1, bd=2)
|
82
120
|
|
83
121
|
# Create and configure the label
|
84
|
-
|
85
|
-
label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
|
86
|
-
label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=(style_out['font_family'], style_out['font_size']), anchor='e', justify='right')
|
122
|
+
label = tk.Label(custom_frame, text=label_text, bg=style_out['bg_color'], fg=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
|
87
123
|
label.grid(column=label_column, row=0, sticky=tk.W, padx=(5, 2), pady=5) # Place the label in the first row
|
88
124
|
|
89
125
|
# Create and configure the input widget based on var_type
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
126
|
+
try:
|
127
|
+
if var_type == 'entry':
|
128
|
+
var = tk.StringVar(value=default_value)
|
129
|
+
entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
|
130
|
+
entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
|
131
|
+
return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
|
132
|
+
elif var_type == 'check':
|
133
|
+
var = tk.BooleanVar(value=default_value) # Set default value (True/False)
|
134
|
+
check = spacrCheck(custom_frame, text="", variable=var)
|
135
|
+
check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
|
136
|
+
return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
|
137
|
+
elif var_type == 'combo':
|
138
|
+
var = tk.StringVar(value=default_value) # Set default value
|
139
|
+
combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
|
140
|
+
combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
|
141
|
+
if default_value:
|
142
|
+
combo.set(default_value)
|
143
|
+
return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
|
144
|
+
else:
|
145
|
+
var = None # Placeholder in case of an undefined var_type
|
146
|
+
return (label, None, var, custom_frame)
|
147
|
+
except Exception as e:
|
148
|
+
traceback.print_exc()
|
149
|
+
print(f"Error creating input field: {e}")
|
150
|
+
print(f"Wrong type for {label_text} Expected {var_type}")
|
110
151
|
|
111
152
|
def process_stdout_stderr(q):
|
112
153
|
"""
|
@@ -134,16 +175,6 @@ def cancel_after_tasks(frame):
|
|
134
175
|
frame.after_cancel(task)
|
135
176
|
frame.after_tasks.clear()
|
136
177
|
|
137
|
-
def main_thread_update_function(root, q, fig_queue, canvas_widget):
|
138
|
-
try:
|
139
|
-
#ansi_escape_pattern = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]')
|
140
|
-
while not q.empty():
|
141
|
-
message = q.get_nowait()
|
142
|
-
except Exception as e:
|
143
|
-
print(f"Error updating GUI canvas: {e}")
|
144
|
-
finally:
|
145
|
-
root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
|
146
|
-
|
147
178
|
def annotate(settings):
|
148
179
|
from .settings import set_annotate_default_settings
|
149
180
|
settings = set_annotate_default_settings(settings)
|
@@ -323,7 +354,9 @@ def convert_settings_dict_for_gui(settings):
|
|
323
354
|
special_cases = {
|
324
355
|
'metadata_type': ('combo', ['cellvoyager', 'cq1', 'nikon', 'zeis', 'custom'], 'cellvoyager'),
|
325
356
|
'channels': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
|
357
|
+
'train_channels': ('combo', ["['r','g','b']", "['r','g']", "['r','b']", "['g','b']", "['r']", "['g']", "['b']"], "['r','g','b']"),
|
326
358
|
'channel_dims': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
|
359
|
+
'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'metadata'),
|
327
360
|
'cell_mask_dim': ('combo', chans, None),
|
328
361
|
'cell_chann_dim': ('combo', chans, None),
|
329
362
|
'nucleus_mask_dim': ('combo', chans, None),
|
@@ -369,6 +402,7 @@ def convert_settings_dict_for_gui(settings):
|
|
369
402
|
variables[key] = ('entry', None, str(value))
|
370
403
|
else:
|
371
404
|
variables[key] = ('entry', None, str(value))
|
405
|
+
|
372
406
|
return variables
|
373
407
|
|
374
408
|
|
@@ -413,13 +447,14 @@ def function_gui_wrapper(function=None, settings={}, q=None, fig_queue=None, imp
|
|
413
447
|
plt.show = original_show
|
414
448
|
|
415
449
|
def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
450
|
+
|
416
451
|
from .gui_utils import process_stdout_stderr
|
417
|
-
from .core import preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, compare_cellpose_masks, analyze_plaques, generate_dataset, apply_model_to_tar
|
452
|
+
from .core import generate_image_umap, preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, compare_cellpose_masks, analyze_plaques, generate_dataset, apply_model_to_tar
|
418
453
|
from .io import generate_cellpose_train_test
|
419
454
|
from .measure import measure_crop
|
420
455
|
from .sim import run_multiple_simulations
|
421
|
-
from .deep_spacr import
|
422
|
-
from .sequencing import
|
456
|
+
from .deep_spacr import deep_spacr
|
457
|
+
from .sequencing import generate_barecode_mapping, perform_regression
|
423
458
|
process_stdout_stderr(q)
|
424
459
|
|
425
460
|
print(f'run_function_gui settings_type: {settings_type}')
|
@@ -433,12 +468,9 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
433
468
|
elif settings_type == 'simulation':
|
434
469
|
function = run_multiple_simulations
|
435
470
|
imports = 1
|
436
|
-
elif settings_type == 'sequencing':
|
437
|
-
function = analyze_reads
|
438
|
-
imports = 1
|
439
471
|
elif settings_type == 'classify':
|
440
|
-
function =
|
441
|
-
imports =
|
472
|
+
function = deep_spacr
|
473
|
+
imports = 1
|
442
474
|
elif settings_type == 'train_cellpose':
|
443
475
|
function = train_cellpose
|
444
476
|
imports = 1
|
@@ -452,14 +484,17 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
452
484
|
function = check_cellpose_models
|
453
485
|
imports = 1
|
454
486
|
elif settings_type == 'map_barcodes':
|
455
|
-
function =
|
487
|
+
function = generate_barecode_mapping
|
456
488
|
imports = 1
|
457
489
|
elif settings_type == 'regression':
|
458
490
|
function = perform_regression
|
459
491
|
imports = 2
|
460
492
|
elif settings_type == 'recruitment':
|
461
493
|
function = analyze_recruitment
|
462
|
-
imports =
|
494
|
+
imports = 1
|
495
|
+
elif settings_type == 'umap':
|
496
|
+
function = generate_image_umap
|
497
|
+
imports = 1
|
463
498
|
else:
|
464
499
|
raise ValueError(f"Invalid settings type: {settings_type}")
|
465
500
|
try:
|
@@ -470,7 +505,6 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
470
505
|
finally:
|
471
506
|
stop_requested.value = 1
|
472
507
|
|
473
|
-
|
474
508
|
def hide_all_settings(vars_dict, categories):
|
475
509
|
"""
|
476
510
|
Function to initially hide all settings in the GUI.
|
spacr/io.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
|
1
|
+
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
4
|
import tifffile
|
5
|
-
from PIL import Image
|
6
|
-
from collections import defaultdict, Counter
|
5
|
+
from PIL import Image, ImageOps
|
6
|
+
from collections import defaultdict, Counter, deque
|
7
7
|
from pathlib import Path
|
8
8
|
from functools import partial
|
9
9
|
from matplotlib.animation import FuncAnimation
|
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
|
|
17
17
|
import matplotlib.pyplot as plt
|
18
18
|
from io import BytesIO
|
19
19
|
from IPython.display import display, clear_output
|
20
|
-
from multiprocessing import Pool, cpu_count
|
21
|
-
from torch.utils.data import Dataset
|
20
|
+
from multiprocessing import Pool, cpu_count, Process, Queue
|
21
|
+
from torch.utils.data import Dataset, DataLoader
|
22
22
|
import matplotlib.pyplot as plt
|
23
23
|
from torchvision.transforms import ToTensor
|
24
24
|
import seaborn as sns
|
25
|
-
|
25
|
+
import atexit
|
26
26
|
|
27
27
|
from .logger import log_function_call
|
28
28
|
|
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
|
|
444
444
|
# Return both the image and its filename
|
445
445
|
return img, self.filenames[index]
|
446
446
|
|
447
|
-
class
|
448
|
-
"""
|
449
|
-
A custom dataset class for loading and processing image data.
|
450
|
-
|
451
|
-
Args:
|
452
|
-
data_dir (str): The directory path where the image data is stored.
|
453
|
-
loader_classes (list): A list of class names for the dataset.
|
454
|
-
transform (callable, optional): A function/transform to apply to the image data. Default is None.
|
455
|
-
shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
|
456
|
-
pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
|
457
|
-
specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
|
458
|
-
specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
|
459
|
-
"""
|
460
|
-
|
447
|
+
class spacrDataset(Dataset):
|
461
448
|
def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
|
462
449
|
self.data_dir = data_dir
|
463
450
|
self.classes = loader_classes
|
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
|
|
466
453
|
self.pin_memory = pin_memory
|
467
454
|
self.filenames = []
|
468
455
|
self.labels = []
|
469
|
-
|
456
|
+
|
470
457
|
if specific_files and specific_labels:
|
471
458
|
self.filenames = specific_files
|
472
459
|
self.labels = specific_labels
|
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
|
|
479
466
|
|
480
467
|
if self.shuffle:
|
481
468
|
self.shuffle_dataset()
|
482
|
-
|
469
|
+
|
483
470
|
if self.pin_memory:
|
484
|
-
|
485
|
-
|
471
|
+
# Use multiprocessing to load images in parallel
|
472
|
+
with Pool(processes=cpu_count()) as pool:
|
473
|
+
self.images = pool.map(self.load_image, self.filenames)
|
474
|
+
else:
|
475
|
+
self.images = None
|
476
|
+
|
486
477
|
def load_image(self, img_path):
|
487
478
|
img = Image.open(img_path).convert('RGB')
|
479
|
+
img = ImageOps.exif_transpose(img) # Handle image orientation
|
488
480
|
return img
|
489
|
-
|
481
|
+
|
490
482
|
def __len__(self):
|
491
483
|
return len(self.filenames)
|
492
|
-
|
484
|
+
|
493
485
|
def shuffle_dataset(self):
|
494
486
|
combined = list(zip(self.filenames, self.labels))
|
495
487
|
random.shuffle(combined)
|
496
488
|
self.filenames, self.labels = zip(*combined)
|
497
|
-
|
489
|
+
|
498
490
|
def get_plate(self, filepath):
|
499
|
-
filename = os.path.basename(filepath)
|
491
|
+
filename = os.path.basename(filepath)
|
500
492
|
return filename.split('_')[0]
|
501
|
-
|
493
|
+
|
502
494
|
def __getitem__(self, index):
|
495
|
+
if self.pin_memory:
|
496
|
+
img = self.images[index]
|
497
|
+
else:
|
498
|
+
img = self.load_image(self.filenames[index])
|
503
499
|
label = self.labels[index]
|
504
500
|
filename = self.filenames[index]
|
505
|
-
img = self.load_image(filename)
|
506
501
|
if self.transform:
|
507
502
|
img = self.transform(img)
|
508
503
|
return img, label, filename
|
504
|
+
|
505
|
+
class spacrDataLoader(DataLoader):
|
506
|
+
def __init__(self, *args, preload_batches=1, **kwargs):
|
507
|
+
super().__init__(*args, **kwargs)
|
508
|
+
self.preload_batches = preload_batches
|
509
|
+
self.batch_queue = Queue(maxsize=preload_batches)
|
510
|
+
self.process = None
|
511
|
+
self.current_batch_index = 0
|
512
|
+
self._stop_event = False
|
513
|
+
self.pin_memory = kwargs.get('pin_memory', False)
|
514
|
+
atexit.register(self.cleanup)
|
515
|
+
|
516
|
+
def _preload_next_batches(self):
|
517
|
+
try:
|
518
|
+
for _ in range(self.preload_batches):
|
519
|
+
if self._stop_event:
|
520
|
+
break
|
521
|
+
batch = next(self._iterator)
|
522
|
+
if self.pin_memory:
|
523
|
+
batch = self._pin_memory_batch(batch)
|
524
|
+
self.batch_queue.put(batch)
|
525
|
+
except StopIteration:
|
526
|
+
pass
|
527
|
+
|
528
|
+
def _start_preloading(self):
|
529
|
+
if self.process is None or not self.process.is_alive():
|
530
|
+
self._iterator = iter(super().__iter__())
|
531
|
+
if not self.pin_memory:
|
532
|
+
self.process = Process(target=self._preload_next_batches)
|
533
|
+
self.process.start()
|
534
|
+
else:
|
535
|
+
self._preload_next_batches() # Directly load if pin_memory is True
|
536
|
+
|
537
|
+
def _pin_memory_batch(self, batch):
|
538
|
+
if isinstance(batch, (list, tuple)):
|
539
|
+
return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
|
540
|
+
elif isinstance(batch, torch.Tensor):
|
541
|
+
return batch.pin_memory()
|
542
|
+
else:
|
543
|
+
return batch
|
544
|
+
|
545
|
+
def __iter__(self):
|
546
|
+
self._start_preloading()
|
547
|
+
return self
|
548
|
+
|
549
|
+
def __next__(self):
|
550
|
+
if self.process and not self.process.is_alive() and self.batch_queue.empty():
|
551
|
+
raise StopIteration
|
552
|
+
|
553
|
+
try:
|
554
|
+
if self.pin_memory:
|
555
|
+
next_batch = self.batch_queue.get(timeout=60)
|
556
|
+
else:
|
557
|
+
next_batch = self.batch_queue.get(timeout=60)
|
558
|
+
self.current_batch_index += 1
|
559
|
+
|
560
|
+
# Start preloading the next batches
|
561
|
+
if self.batch_queue.qsize() < self.preload_batches:
|
562
|
+
self._start_preloading()
|
563
|
+
|
564
|
+
return next_batch
|
565
|
+
except queue.Empty:
|
566
|
+
raise StopIteration
|
567
|
+
|
568
|
+
def cleanup(self):
|
569
|
+
self._stop_event = True
|
570
|
+
if self.process and self.process.is_alive():
|
571
|
+
self.process.terminate()
|
572
|
+
self.process.join()
|
573
|
+
|
574
|
+
def __del__(self):
|
575
|
+
self.cleanup()
|
509
576
|
|
510
577
|
class NoClassDataset(Dataset):
|
511
578
|
def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
|
@@ -2292,18 +2359,27 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2292
2359
|
|
2293
2360
|
def save_model_at_threshold(threshold, epoch, suffix=""):
|
2294
2361
|
percentile = str(threshold * 100)
|
2295
|
-
print(f'
|
2296
|
-
|
2362
|
+
print(f'Found: {percentile}% accurate model')
|
2363
|
+
model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
|
2364
|
+
torch.save(model, model_path)
|
2365
|
+
return model_path
|
2297
2366
|
|
2298
2367
|
if epoch % 100 == 0 or epoch == epochs:
|
2299
|
-
|
2368
|
+
model_path = f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth'
|
2369
|
+
torch.save(model, model_path)
|
2370
|
+
return model_path
|
2300
2371
|
|
2301
2372
|
for threshold in intermedeate_save:
|
2302
|
-
if results_df['neg_accuracy']
|
2303
|
-
|
2304
|
-
|
2373
|
+
if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
|
2374
|
+
print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
|
2375
|
+
model_path = save_model_at_threshold(threshold, epoch)
|
2376
|
+
break
|
2377
|
+
else:
|
2378
|
+
model_path = None
|
2379
|
+
|
2380
|
+
return model_path
|
2305
2381
|
|
2306
|
-
def _save_progress(dst, results_df,
|
2382
|
+
def _save_progress(dst, results_df, result_type='train'):
|
2307
2383
|
"""
|
2308
2384
|
Save the progress of the classification model.
|
2309
2385
|
|
@@ -2317,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
|
|
2317
2393
|
"""
|
2318
2394
|
# Save accuracy, loss, PRAUC
|
2319
2395
|
os.makedirs(dst, exist_ok=True)
|
2320
|
-
results_path = os.path.join(dst, '
|
2396
|
+
results_path = os.path.join(dst, f'{result_type}.csv')
|
2321
2397
|
if not os.path.exists(results_path):
|
2322
2398
|
results_df.to_csv(results_path, index=True, header=True, mode='w')
|
2323
2399
|
else:
|
2324
2400
|
results_df.to_csv(results_path, index=True, header=False, mode='a')
|
2325
|
-
|
2326
|
-
|
2327
|
-
if not os.path.exists(training_metrics_path):
|
2328
|
-
train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
|
2329
|
-
else:
|
2330
|
-
train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
|
2331
|
-
if epoch == epochs:
|
2401
|
+
|
2402
|
+
if result_type == 'train':
|
2332
2403
|
read_plot_model_stats(results_path, save=True)
|
2333
2404
|
return
|
2334
2405
|
|
spacr/plot.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os,re, random, cv2, glob, time, math
|
1
|
+
import os,re, random, cv2, glob, time, math, torch
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
@@ -1186,6 +1186,52 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
|
1186
1186
|
y = row * img_height + 15
|
1187
1187
|
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1188
1188
|
return fig
|
1189
|
+
|
1190
|
+
def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
|
1191
|
+
"""
|
1192
|
+
Display multiple images in a grid with corresponding labels.
|
1193
|
+
|
1194
|
+
Args:
|
1195
|
+
img (torch.Tensor): A batch of images as a tensor.
|
1196
|
+
labels (list): List of labels corresponding to each image.
|
1197
|
+
nrow (int, optional): Number of images per row in the grid. Defaults to 20.
|
1198
|
+
color (str, optional): Color of the label text. Defaults to 'white'.
|
1199
|
+
fontsize (int, optional): Font size of the label text. Defaults to 12.
|
1200
|
+
"""
|
1201
|
+
if img.is_cuda:
|
1202
|
+
img = img.cpu() # Move to CPU if the tensor is on GPU
|
1203
|
+
|
1204
|
+
n_images = len(labels)
|
1205
|
+
n_col = nrow
|
1206
|
+
n_row = int(np.ceil(n_images / n_col))
|
1207
|
+
|
1208
|
+
img_height = img.shape[2] # Height of the image
|
1209
|
+
img_width = img.shape[3] # Width of the image
|
1210
|
+
|
1211
|
+
# Prepare the canvas on CPU
|
1212
|
+
canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
|
1213
|
+
|
1214
|
+
for i in range(n_row):
|
1215
|
+
for j in range(n_col):
|
1216
|
+
idx = i * n_col + j
|
1217
|
+
if idx < n_images:
|
1218
|
+
# Place the image on the canvas
|
1219
|
+
canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
|
1220
|
+
|
1221
|
+
canvas = canvas.numpy() # Convert to NumPy for plotting
|
1222
|
+
|
1223
|
+
fig = plt.figure(figsize=(50, 50))
|
1224
|
+
plt.imshow(canvas)
|
1225
|
+
plt.axis("off")
|
1226
|
+
|
1227
|
+
for i, label in enumerate(labels):
|
1228
|
+
row = i // n_col
|
1229
|
+
col = i % n_col
|
1230
|
+
x = col * img_width + 2
|
1231
|
+
y = row * img_height + 15
|
1232
|
+
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1233
|
+
|
1234
|
+
return fig
|
1189
1235
|
|
1190
1236
|
def _plot_histograms_and_stats(df):
|
1191
1237
|
conditions = df['condition'].unique()
|