spacr 0.2.56__py3-none-any.whl → 0.2.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/gui_elements.py CHANGED
@@ -516,6 +516,7 @@ class spacrDropdownMenu(tk.Frame):
516
516
  self.inactive_color = color_settings['inactive_color']
517
517
  self.active_color = color_settings['active_color']
518
518
  self.fg_color = color_settings['fg_color']
519
+ self.bg_color = style_out['bg_color']
519
520
 
520
521
  # Create the button with rounded edges
521
522
  self.button_bg = self.create_rounded_rectangle(2, 2, self.button_width + 2, self.size + 2, radius=20, fill=self.inactive_color, outline=self.inactive_color)
@@ -536,8 +537,8 @@ class spacrDropdownMenu(tk.Frame):
536
537
  self.canvas.bind("<Leave>", self.on_leave)
537
538
  self.canvas.bind("<Button-1>", self.on_click)
538
539
 
539
- # Create a popup menu
540
- self.menu = tk.Menu(self, tearoff=0)
540
+ # Create a popup menu with the desired background color
541
+ self.menu = tk.Menu(self, tearoff=0, bg=self.bg_color, fg=self.fg_color)
541
542
  for option in self.options:
542
543
  self.menu.add_command(label=option, command=lambda opt=option: self.on_select(opt))
543
544
 
@@ -591,7 +592,6 @@ class spacrDropdownMenu(tk.Frame):
591
592
  else:
592
593
  self.menu.entryconfig(idx, background=style_out['bg_color'], foreground=style_out['fg_color'])
593
594
 
594
-
595
595
  class spacrCheckbutton(ttk.Checkbutton):
596
596
  def __init__(self, parent, text="", variable=None, command=None, *args, **kwargs):
597
597
  super().__init__(parent, *args, **kwargs)
@@ -613,17 +613,26 @@ class spacrProgressBar(ttk.Progressbar):
613
613
  self.bg_color = style_out['bg_color']
614
614
  self.active_color = style_out['active_color']
615
615
  self.inactive_color = style_out['inactive_color']
616
+ self.font_size = style_out['font_size']
617
+ self.font_loader = style_out['font_loader']
616
618
 
617
619
  # Configure the style for the progress bar
618
620
  self.style = ttk.Style()
621
+
622
+ # Remove any borders and ensure the active color fills the entire space
619
623
  self.style.configure(
620
624
  "spacr.Horizontal.TProgressbar",
621
- troughcolor=self.bg_color,
622
- background=self.active_color,
623
- thickness=20,
624
- troughrelief='flat',
625
- borderwidth=0
625
+ troughcolor=self.inactive_color, # Set the trough to bg color
626
+ background=self.active_color, # Active part is the active color
627
+ borderwidth=0, # Remove border width
628
+ pbarrelief="flat", # Flat relief for the progress bar
629
+ troughrelief="flat", # Flat relief for the trough
630
+ thickness=20, # Set the thickness of the progress bar
631
+ darkcolor=self.active_color, # Ensure darkcolor matches the active color
632
+ lightcolor=self.active_color, # Ensure lightcolor matches the active color
633
+ bordercolor=self.bg_color # Set the border color to the background color to hide it
626
634
  )
635
+
627
636
  self.configure(style="spacr.Horizontal.TProgressbar")
628
637
 
629
638
  # Set initial value to 0
@@ -641,15 +650,14 @@ class spacrProgressBar(ttk.Progressbar):
641
650
  justify='left',
642
651
  bg=self.inactive_color,
643
652
  fg=self.fg_color,
644
- wraplength=300 # Adjust the wraplength as needed
653
+ wraplength=300,
654
+ font=self.font_loader.get_font(size=self.font_size)
645
655
  )
646
- self.progress_label.grid_forget() # Temporarily hide it
656
+ self.progress_label.grid_forget()
647
657
 
648
658
  # Initialize attributes for time and operation
649
659
  self.operation_type = None
650
- self.time_image = None
651
- self.time_batch = None
652
- self.time_left = None
660
+ self.additional_info = None
653
661
 
654
662
  def set_label_position(self):
655
663
  if self.label and self.progress_label:
@@ -664,74 +672,19 @@ class spacrProgressBar(ttk.Progressbar):
664
672
  label_text = f"Processing: {self['value']}/{self['maximum']}"
665
673
  if self.operation_type:
666
674
  label_text += f", {self.operation_type}"
667
- if self.time_image:
668
- label_text += f", Time/image: {self.time_image:.3f} sec"
669
- if self.time_batch:
670
- label_text += f", Time/batch: {self.time_batch:.3f} sec"
671
- if self.time_left:
672
- label_text += f", Time_left: {self.time_left:.3f} min"
673
- self.progress_label.config(text=label_text)
674
-
675
- class spacrProgressBar_v1(ttk.Progressbar):
676
- def __init__(self, parent, label=True, *args, **kwargs):
677
- super().__init__(parent, *args, **kwargs)
678
-
679
- # Get the style colors
680
- style_out = set_dark_style(ttk.Style())
681
-
682
- self.fg_color = style_out['fg_color']
683
- self.bg_color = style_out['bg_color']
684
- self.active_color = style_out['active_color']
685
- self.inactive_color = style_out['inactive_color']
686
-
687
- # Configure the style for the progress bar
688
- self.style = ttk.Style()
689
- self.style.configure(
690
- "spacr.Horizontal.TProgressbar",
691
- troughcolor=self.bg_color,
692
- background=self.active_color,
693
- thickness=20,
694
- troughrelief='flat',
695
- borderwidth=0
696
- )
697
- self.configure(style="spacr.Horizontal.TProgressbar")
698
-
699
- # Set initial value to 0
700
- self['value'] = 0
701
-
702
- # Track whether to show the progress label
703
- self.label = label
704
-
705
- # Create the progress label (defer placement)
706
- if self.label:
707
- self.progress_label = tk.Label(parent, text="Processing: 0/0", anchor='w', justify='left', bg=self.inactive_color, fg=self.fg_color)
708
- self.progress_label.grid_forget() # Temporarily hide it
709
-
710
- # Initialize attributes for time and operation
711
- self.operation_type = None
712
- self.time_image = None
713
- self.time_batch = None
714
- self.time_left = None
715
-
716
- def set_label_position(self):
717
- if self.label and self.progress_label:
718
- row_info = self.grid_info().get('row', 0)
719
- col_info = self.grid_info().get('column', 0)
720
- col_span = self.grid_info().get('columnspan', 1)
721
- self.progress_label.grid(row=row_info + 1, column=col_info, columnspan=col_span, pady=5, padx=5, sticky='ew')
722
-
723
- def update_label(self):
724
- if self.label and self.progress_label:
725
- # Update the progress label with current progress and additional info
726
- label_text = f"Processing: {self['value']}/{self['maximum']}"
727
- if self.operation_type:
728
- label_text += f", {self.operation_type}"
729
- if self.time_image:
730
- label_text += f", Time/image: {self.time_image:.3f} sec"
731
- if self.time_batch:
732
- label_text += f", Time/batch: {self.time_batch:.3f} sec"
733
- if self.time_left:
734
- label_text += f", Time_left: {self.time_left:.3f} min"
675
+ if hasattr(self, 'additional_info') and self.additional_info:
676
+ # Add a space between progress information and additional information
677
+ label_text += "\n\n"
678
+ # Split the additional_info into a list of items
679
+ items = self.additional_info.split(", ")
680
+ formatted_additional_info = ""
681
+ # Group the items in pairs, adding them to formatted_additional_info
682
+ for i in range(0, len(items), 2):
683
+ if i + 1 < len(items):
684
+ formatted_additional_info += f"{items[i]}, {items[i + 1]}\n\n"
685
+ else:
686
+ formatted_additional_info += f"{items[i]}\n\n" # If there's an odd item out, add it alone
687
+ label_text += formatted_additional_info.strip()
735
688
  self.progress_label.config(text=label_text)
736
689
 
737
690
  def spacrScrollbarStyle(style, inactive_color, active_color):
spacr/gui_utils.py CHANGED
@@ -1,10 +1,11 @@
1
- import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback
1
+ import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback, torch
2
2
  import tkinter as tk
3
3
  from tkinter import ttk
4
4
  import matplotlib
5
5
  import matplotlib.pyplot as plt
6
6
  matplotlib.use('Agg')
7
7
  from huggingface_hub import list_repo_files
8
+ import psutil
8
9
 
9
10
  from .gui_elements import AnnotateApp, spacrEntry, spacrCheck, spacrCombo
10
11
 
@@ -12,6 +13,36 @@ try:
12
13
  ctypes.windll.shcore.SetProcessDpiAwareness(True)
13
14
  except AttributeError:
14
15
  pass
16
+
17
+ def initialize_cuda():
18
+ """
19
+ Initializes CUDA in the main process by performing a simple GPU operation.
20
+ """
21
+ if torch.cuda.is_available():
22
+ # Allocate a small tensor on the GPU
23
+ _ = torch.tensor([0.0], device='cuda')
24
+ print("CUDA initialized in the main process.")
25
+ else:
26
+ print("CUDA is not available.")
27
+
28
+ def set_high_priority(process):
29
+ try:
30
+ p = psutil.Process(process.pid)
31
+ if os.name == 'nt': # Windows
32
+ p.nice(psutil.HIGH_PRIORITY_CLASS)
33
+ else: # Unix-like systems
34
+ p.nice(-10) # Adjusted priority level
35
+ print(f"Successfully set high priority for process: {process.pid}")
36
+ except psutil.AccessDenied as e:
37
+ print(f"Access denied when trying to set high priority for process {process.pid}: {e}")
38
+ except psutil.NoSuchProcess as e:
39
+ print(f"No such process {process.pid}: {e}")
40
+ except Exception as e:
41
+ print(f"Failed to set high priority for process {process.pid}: {e}")
42
+
43
+ def set_cpu_affinity(process):
44
+ p = psutil.Process(process.pid)
45
+ p.cpu_affinity(list(range(os.cpu_count())))
15
46
 
16
47
  def proceed_with_app(root, app_name, app_func):
17
48
  # Clear the current content frame
@@ -44,16 +75,6 @@ def load_app(root, app_name, app_func):
44
75
  else:
45
76
  proceed_with_app(root, app_name, app_func)
46
77
 
47
- def parse_list_v1(value):
48
- try:
49
- parsed_value = ast.literal_eval(value)
50
- if isinstance(parsed_value, list):
51
- return parsed_value
52
- else:
53
- raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
54
- except (ValueError, SyntaxError) as e:
55
- raise ValueError(f"Invalid format for list: {value}. Error: {e}")
56
-
57
78
  def parse_list(value):
58
79
  try:
59
80
  parsed_value = ast.literal_eval(value)
@@ -83,6 +104,7 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
83
104
  size_dict['settings_width'] = size_dict['settings_width'] - int(size_dict['settings_width']*0.1)
84
105
 
85
106
  # Replace underscores with spaces and capitalize the first letter
107
+
86
108
  label_text = label_text.replace('_', ' ').capitalize()
87
109
 
88
110
  # Configure the column widths
@@ -97,32 +119,35 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
97
119
  custom_frame.config(highlightbackground=style_out['bg_color'], highlightthickness=1, bd=2)
98
120
 
99
121
  # Create and configure the label
100
- if font_loader:
101
- label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
102
- label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=(style_out['font_family'], style_out['font_size']), anchor='e', justify='right')
122
+ label = tk.Label(custom_frame, text=label_text, bg=style_out['bg_color'], fg=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
103
123
  label.grid(column=label_column, row=0, sticky=tk.W, padx=(5, 2), pady=5) # Place the label in the first row
104
124
 
105
125
  # Create and configure the input widget based on var_type
106
- if var_type == 'entry':
107
- var = tk.StringVar(value=default_value)
108
- entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
109
- entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
110
- return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
111
- elif var_type == 'check':
112
- var = tk.BooleanVar(value=default_value) # Set default value (True/False)
113
- check = spacrCheck(custom_frame, text="", variable=var)
114
- check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
115
- return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
116
- elif var_type == 'combo':
117
- var = tk.StringVar(value=default_value) # Set default value
118
- combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
119
- combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
120
- if default_value:
121
- combo.set(default_value)
122
- return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
123
- else:
124
- var = None # Placeholder in case of an undefined var_type
125
- return (label, None, var, custom_frame)
126
+ try:
127
+ if var_type == 'entry':
128
+ var = tk.StringVar(value=default_value)
129
+ entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
130
+ entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
131
+ return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
132
+ elif var_type == 'check':
133
+ var = tk.BooleanVar(value=default_value) # Set default value (True/False)
134
+ check = spacrCheck(custom_frame, text="", variable=var)
135
+ check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
136
+ return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
137
+ elif var_type == 'combo':
138
+ var = tk.StringVar(value=default_value) # Set default value
139
+ combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
140
+ combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
141
+ if default_value:
142
+ combo.set(default_value)
143
+ return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
144
+ else:
145
+ var = None # Placeholder in case of an undefined var_type
146
+ return (label, None, var, custom_frame)
147
+ except Exception as e:
148
+ traceback.print_exc()
149
+ print(f"Error creating input field: {e}")
150
+ print(f"Wrong type for {label_text} Expected {var_type}")
126
151
 
127
152
  def process_stdout_stderr(q):
128
153
  """
@@ -150,16 +175,6 @@ def cancel_after_tasks(frame):
150
175
  frame.after_cancel(task)
151
176
  frame.after_tasks.clear()
152
177
 
153
- def main_thread_update_function(root, q, fig_queue, canvas_widget):
154
- try:
155
- #ansi_escape_pattern = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]')
156
- while not q.empty():
157
- message = q.get_nowait()
158
- except Exception as e:
159
- print(f"Error updating GUI canvas: {e}")
160
- finally:
161
- root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
162
-
163
178
  def annotate(settings):
164
179
  from .settings import set_annotate_default_settings
165
180
  settings = set_annotate_default_settings(settings)
@@ -341,7 +356,7 @@ def convert_settings_dict_for_gui(settings):
341
356
  'channels': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
342
357
  'train_channels': ('combo', ["['r','g','b']", "['r','g']", "['r','b']", "['g','b']", "['r']", "['g']", "['b']"], "['r','g','b']"),
343
358
  'channel_dims': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
344
- 'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'annotate'),
359
+ 'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'metadata'),
345
360
  'cell_mask_dim': ('combo', chans, None),
346
361
  'cell_chann_dim': ('combo', chans, None),
347
362
  'nucleus_mask_dim': ('combo', chans, None),
@@ -476,7 +491,7 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
476
491
  imports = 2
477
492
  elif settings_type == 'recruitment':
478
493
  function = analyze_recruitment
479
- imports = 2
494
+ imports = 1
480
495
  elif settings_type == 'umap':
481
496
  function = generate_image_umap
482
497
  imports = 1
@@ -490,7 +505,6 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
490
505
  finally:
491
506
  stop_requested.value = 1
492
507
 
493
-
494
508
  def hide_all_settings(vars_dict, categories):
495
509
  """
496
510
  Function to initially hide all settings in the GUI.
spacr/io.py CHANGED
@@ -1,9 +1,9 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
  import tifffile
5
- from PIL import Image
6
- from collections import defaultdict, Counter
5
+ from PIL import Image, ImageOps
6
+ from collections import defaultdict, Counter, deque
7
7
  from pathlib import Path
8
8
  from functools import partial
9
9
  from matplotlib.animation import FuncAnimation
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
17
17
  import matplotlib.pyplot as plt
18
18
  from io import BytesIO
19
19
  from IPython.display import display, clear_output
20
- from multiprocessing import Pool, cpu_count
21
- from torch.utils.data import Dataset
20
+ from multiprocessing import Pool, cpu_count, Process, Queue
21
+ from torch.utils.data import Dataset, DataLoader
22
22
  import matplotlib.pyplot as plt
23
23
  from torchvision.transforms import ToTensor
24
24
  import seaborn as sns
25
-
25
+ import atexit
26
26
 
27
27
  from .logger import log_function_call
28
28
 
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
444
444
  # Return both the image and its filename
445
445
  return img, self.filenames[index]
446
446
 
447
- class MyDataset(Dataset):
448
- """
449
- A custom dataset class for loading and processing image data.
450
-
451
- Args:
452
- data_dir (str): The directory path where the image data is stored.
453
- loader_classes (list): A list of class names for the dataset.
454
- transform (callable, optional): A function/transform to apply to the image data. Default is None.
455
- shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
456
- pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
457
- specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
458
- specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
459
- """
460
-
447
+ class spacrDataset(Dataset):
461
448
  def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
462
449
  self.data_dir = data_dir
463
450
  self.classes = loader_classes
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
466
453
  self.pin_memory = pin_memory
467
454
  self.filenames = []
468
455
  self.labels = []
469
-
456
+
470
457
  if specific_files and specific_labels:
471
458
  self.filenames = specific_files
472
459
  self.labels = specific_labels
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
479
466
 
480
467
  if self.shuffle:
481
468
  self.shuffle_dataset()
482
-
469
+
483
470
  if self.pin_memory:
484
- self.images = [self.load_image(f) for f in self.filenames]
485
-
471
+ # Use multiprocessing to load images in parallel
472
+ with Pool(processes=cpu_count()) as pool:
473
+ self.images = pool.map(self.load_image, self.filenames)
474
+ else:
475
+ self.images = None
476
+
486
477
  def load_image(self, img_path):
487
478
  img = Image.open(img_path).convert('RGB')
479
+ img = ImageOps.exif_transpose(img) # Handle image orientation
488
480
  return img
489
-
481
+
490
482
  def __len__(self):
491
483
  return len(self.filenames)
492
-
484
+
493
485
  def shuffle_dataset(self):
494
486
  combined = list(zip(self.filenames, self.labels))
495
487
  random.shuffle(combined)
496
488
  self.filenames, self.labels = zip(*combined)
497
-
489
+
498
490
  def get_plate(self, filepath):
499
- filename = os.path.basename(filepath) # Get just the filename from the full path
491
+ filename = os.path.basename(filepath)
500
492
  return filename.split('_')[0]
501
-
493
+
502
494
  def __getitem__(self, index):
495
+ if self.pin_memory:
496
+ img = self.images[index]
497
+ else:
498
+ img = self.load_image(self.filenames[index])
503
499
  label = self.labels[index]
504
500
  filename = self.filenames[index]
505
- img = self.load_image(filename)
506
501
  if self.transform:
507
502
  img = self.transform(img)
508
503
  return img, label, filename
504
+
505
+ class spacrDataLoader(DataLoader):
506
+ def __init__(self, *args, preload_batches=1, **kwargs):
507
+ super().__init__(*args, **kwargs)
508
+ self.preload_batches = preload_batches
509
+ self.batch_queue = Queue(maxsize=preload_batches)
510
+ self.process = None
511
+ self.current_batch_index = 0
512
+ self._stop_event = False
513
+ self.pin_memory = kwargs.get('pin_memory', False)
514
+ atexit.register(self.cleanup)
515
+
516
+ def _preload_next_batches(self):
517
+ try:
518
+ for _ in range(self.preload_batches):
519
+ if self._stop_event:
520
+ break
521
+ batch = next(self._iterator)
522
+ if self.pin_memory:
523
+ batch = self._pin_memory_batch(batch)
524
+ self.batch_queue.put(batch)
525
+ except StopIteration:
526
+ pass
527
+
528
+ def _start_preloading(self):
529
+ if self.process is None or not self.process.is_alive():
530
+ self._iterator = iter(super().__iter__())
531
+ if not self.pin_memory:
532
+ self.process = Process(target=self._preload_next_batches)
533
+ self.process.start()
534
+ else:
535
+ self._preload_next_batches() # Directly load if pin_memory is True
536
+
537
+ def _pin_memory_batch(self, batch):
538
+ if isinstance(batch, (list, tuple)):
539
+ return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
540
+ elif isinstance(batch, torch.Tensor):
541
+ return batch.pin_memory()
542
+ else:
543
+ return batch
544
+
545
+ def __iter__(self):
546
+ self._start_preloading()
547
+ return self
548
+
549
+ def __next__(self):
550
+ if self.process and not self.process.is_alive() and self.batch_queue.empty():
551
+ raise StopIteration
552
+
553
+ try:
554
+ if self.pin_memory:
555
+ next_batch = self.batch_queue.get(timeout=60)
556
+ else:
557
+ next_batch = self.batch_queue.get(timeout=60)
558
+ self.current_batch_index += 1
559
+
560
+ # Start preloading the next batches
561
+ if self.batch_queue.qsize() < self.preload_batches:
562
+ self._start_preloading()
563
+
564
+ return next_batch
565
+ except queue.Empty:
566
+ raise StopIteration
567
+
568
+ def cleanup(self):
569
+ self._stop_event = True
570
+ if self.process and self.process.is_alive():
571
+ self.process.terminate()
572
+ self.process.join()
573
+
574
+ def __del__(self):
575
+ self.cleanup()
509
576
 
510
577
  class NoClassDataset(Dataset):
511
578
  def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
@@ -2292,7 +2359,7 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2292
2359
 
2293
2360
  def save_model_at_threshold(threshold, epoch, suffix=""):
2294
2361
  percentile = str(threshold * 100)
2295
- print(f'\rfound: {percentile}% accurate model')#, end='\r', flush=True)
2362
+ print(f'Found: {percentile}% accurate model')
2296
2363
  model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
2297
2364
  torch.save(model, model_path)
2298
2365
  return model_path
@@ -2303,7 +2370,8 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2303
2370
  return model_path
2304
2371
 
2305
2372
  for threshold in intermedeate_save:
2306
- if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
2373
+ if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
2374
+ print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
2307
2375
  model_path = save_model_at_threshold(threshold, epoch)
2308
2376
  break
2309
2377
  else:
@@ -2311,7 +2379,7 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2311
2379
 
2312
2380
  return model_path
2313
2381
 
2314
- def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2382
+ def _save_progress(dst, results_df, result_type='train'):
2315
2383
  """
2316
2384
  Save the progress of the classification model.
2317
2385
 
@@ -2325,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2325
2393
  """
2326
2394
  # Save accuracy, loss, PRAUC
2327
2395
  os.makedirs(dst, exist_ok=True)
2328
- results_path = os.path.join(dst, 'acc_loss_prauc.csv')
2396
+ results_path = os.path.join(dst, f'{result_type}.csv')
2329
2397
  if not os.path.exists(results_path):
2330
2398
  results_df.to_csv(results_path, index=True, header=True, mode='w')
2331
2399
  else:
2332
2400
  results_df.to_csv(results_path, index=True, header=False, mode='a')
2333
-
2334
- training_metrics_path = os.path.join(dst, 'training_metrics.csv')
2335
- if not os.path.exists(training_metrics_path):
2336
- train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
2337
- else:
2338
- train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
2339
- if epoch == epochs:
2401
+
2402
+ if result_type == 'train':
2340
2403
  read_plot_model_stats(results_path, save=True)
2341
2404
  return
2342
2405
 
spacr/plot.py CHANGED
@@ -1,4 +1,4 @@
1
- import os,re, random, cv2, glob, time, math
1
+ import os,re, random, cv2, glob, time, math, torch
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -1186,6 +1186,52 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1186
1186
  y = row * img_height + 15
1187
1187
  plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1188
1188
  return fig
1189
+
1190
+ def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
1191
+ """
1192
+ Display multiple images in a grid with corresponding labels.
1193
+
1194
+ Args:
1195
+ img (torch.Tensor): A batch of images as a tensor.
1196
+ labels (list): List of labels corresponding to each image.
1197
+ nrow (int, optional): Number of images per row in the grid. Defaults to 20.
1198
+ color (str, optional): Color of the label text. Defaults to 'white'.
1199
+ fontsize (int, optional): Font size of the label text. Defaults to 12.
1200
+ """
1201
+ if img.is_cuda:
1202
+ img = img.cpu() # Move to CPU if the tensor is on GPU
1203
+
1204
+ n_images = len(labels)
1205
+ n_col = nrow
1206
+ n_row = int(np.ceil(n_images / n_col))
1207
+
1208
+ img_height = img.shape[2] # Height of the image
1209
+ img_width = img.shape[3] # Width of the image
1210
+
1211
+ # Prepare the canvas on CPU
1212
+ canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
1213
+
1214
+ for i in range(n_row):
1215
+ for j in range(n_col):
1216
+ idx = i * n_col + j
1217
+ if idx < n_images:
1218
+ # Place the image on the canvas
1219
+ canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
1220
+
1221
+ canvas = canvas.numpy() # Convert to NumPy for plotting
1222
+
1223
+ fig = plt.figure(figsize=(50, 50))
1224
+ plt.imshow(canvas)
1225
+ plt.axis("off")
1226
+
1227
+ for i, label in enumerate(labels):
1228
+ row = i // n_col
1229
+ col = i % n_col
1230
+ x = col * img_width + 2
1231
+ y = row * img_height + 15
1232
+ plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1233
+
1234
+ return fig
1189
1235
 
1190
1236
  def _plot_histograms_and_stats(df):
1191
1237
  conditions = df['condition'].unique()