spacr 0.2.53__py3-none-any.whl → 0.2.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/gui_elements.py CHANGED
@@ -516,6 +516,7 @@ class spacrDropdownMenu(tk.Frame):
516
516
  self.inactive_color = color_settings['inactive_color']
517
517
  self.active_color = color_settings['active_color']
518
518
  self.fg_color = color_settings['fg_color']
519
+ self.bg_color = style_out['bg_color']
519
520
 
520
521
  # Create the button with rounded edges
521
522
  self.button_bg = self.create_rounded_rectangle(2, 2, self.button_width + 2, self.size + 2, radius=20, fill=self.inactive_color, outline=self.inactive_color)
@@ -536,8 +537,8 @@ class spacrDropdownMenu(tk.Frame):
536
537
  self.canvas.bind("<Leave>", self.on_leave)
537
538
  self.canvas.bind("<Button-1>", self.on_click)
538
539
 
539
- # Create a popup menu
540
- self.menu = tk.Menu(self, tearoff=0)
540
+ # Create a popup menu with the desired background color
541
+ self.menu = tk.Menu(self, tearoff=0, bg=self.bg_color, fg=self.fg_color)
541
542
  for option in self.options:
542
543
  self.menu.add_command(label=option, command=lambda opt=option: self.on_select(opt))
543
544
 
@@ -591,7 +592,6 @@ class spacrDropdownMenu(tk.Frame):
591
592
  else:
592
593
  self.menu.entryconfig(idx, background=style_out['bg_color'], foreground=style_out['fg_color'])
593
594
 
594
-
595
595
  class spacrCheckbutton(ttk.Checkbutton):
596
596
  def __init__(self, parent, text="", variable=None, command=None, *args, **kwargs):
597
597
  super().__init__(parent, *args, **kwargs)
@@ -613,17 +613,26 @@ class spacrProgressBar(ttk.Progressbar):
613
613
  self.bg_color = style_out['bg_color']
614
614
  self.active_color = style_out['active_color']
615
615
  self.inactive_color = style_out['inactive_color']
616
+ self.font_size = style_out['font_size']
617
+ self.font_loader = style_out['font_loader']
616
618
 
617
619
  # Configure the style for the progress bar
618
620
  self.style = ttk.Style()
621
+
622
+ # Remove any borders and ensure the active color fills the entire space
619
623
  self.style.configure(
620
624
  "spacr.Horizontal.TProgressbar",
621
- troughcolor=self.bg_color,
622
- background=self.active_color,
623
- thickness=20,
624
- troughrelief='flat',
625
- borderwidth=0
625
+ troughcolor=self.inactive_color, # Set the trough to bg color
626
+ background=self.active_color, # Active part is the active color
627
+ borderwidth=0, # Remove border width
628
+ pbarrelief="flat", # Flat relief for the progress bar
629
+ troughrelief="flat", # Flat relief for the trough
630
+ thickness=20, # Set the thickness of the progress bar
631
+ darkcolor=self.active_color, # Ensure darkcolor matches the active color
632
+ lightcolor=self.active_color, # Ensure lightcolor matches the active color
633
+ bordercolor=self.bg_color # Set the border color to the background color to hide it
626
634
  )
635
+
627
636
  self.configure(style="spacr.Horizontal.TProgressbar")
628
637
 
629
638
  # Set initial value to 0
@@ -632,16 +641,23 @@ class spacrProgressBar(ttk.Progressbar):
632
641
  # Track whether to show the progress label
633
642
  self.label = label
634
643
 
635
- # Create the progress label (defer placement)
644
+ # Create the progress label with text wrapping
636
645
  if self.label:
637
- self.progress_label = tk.Label(parent, text="Processing: 0/0", anchor='w', justify='left', bg=self.inactive_color, fg=self.fg_color)
638
- self.progress_label.grid_forget() # Temporarily hide it
646
+ self.progress_label = tk.Label(
647
+ parent,
648
+ text="Processing: 0/0",
649
+ anchor='w',
650
+ justify='left',
651
+ bg=self.inactive_color,
652
+ fg=self.fg_color,
653
+ wraplength=300,
654
+ font=self.font_loader.get_font(size=self.font_size)
655
+ )
656
+ self.progress_label.grid_forget()
639
657
 
640
658
  # Initialize attributes for time and operation
641
659
  self.operation_type = None
642
- self.time_image = None
643
- self.time_batch = None
644
- self.time_left = None
660
+ self.additional_info = None
645
661
 
646
662
  def set_label_position(self):
647
663
  if self.label and self.progress_label:
@@ -656,12 +672,19 @@ class spacrProgressBar(ttk.Progressbar):
656
672
  label_text = f"Processing: {self['value']}/{self['maximum']}"
657
673
  if self.operation_type:
658
674
  label_text += f", {self.operation_type}"
659
- if self.time_image:
660
- label_text += f", Time/image: {self.time_image:.3f} sec"
661
- if self.time_batch:
662
- label_text += f", Time/batch: {self.time_batch:.3f} sec"
663
- if self.time_left:
664
- label_text += f", Time_left: {self.time_left:.3f} min"
675
+ if hasattr(self, 'additional_info') and self.additional_info:
676
+ # Add a space between progress information and additional information
677
+ label_text += "\n\n"
678
+ # Split the additional_info into a list of items
679
+ items = self.additional_info.split(", ")
680
+ formatted_additional_info = ""
681
+ # Group the items in pairs, adding them to formatted_additional_info
682
+ for i in range(0, len(items), 2):
683
+ if i + 1 < len(items):
684
+ formatted_additional_info += f"{items[i]}, {items[i + 1]}\n\n"
685
+ else:
686
+ formatted_additional_info += f"{items[i]}\n\n" # If there's an odd item out, add it alone
687
+ label_text += formatted_additional_info.strip()
665
688
  self.progress_label.config(text=label_text)
666
689
 
667
690
  def spacrScrollbarStyle(style, inactive_color, active_color):
spacr/gui_utils.py CHANGED
@@ -1,10 +1,11 @@
1
- import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback
1
+ import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback, torch
2
2
  import tkinter as tk
3
3
  from tkinter import ttk
4
4
  import matplotlib
5
5
  import matplotlib.pyplot as plt
6
6
  matplotlib.use('Agg')
7
7
  from huggingface_hub import list_repo_files
8
+ import psutil
8
9
 
9
10
  from .gui_elements import AnnotateApp, spacrEntry, spacrCheck, spacrCombo
10
11
 
@@ -12,6 +13,36 @@ try:
12
13
  ctypes.windll.shcore.SetProcessDpiAwareness(True)
13
14
  except AttributeError:
14
15
  pass
16
+
17
+ def initialize_cuda():
18
+ """
19
+ Initializes CUDA in the main process by performing a simple GPU operation.
20
+ """
21
+ if torch.cuda.is_available():
22
+ # Allocate a small tensor on the GPU
23
+ _ = torch.tensor([0.0], device='cuda')
24
+ print("CUDA initialized in the main process.")
25
+ else:
26
+ print("CUDA is not available.")
27
+
28
+ def set_high_priority(process):
29
+ try:
30
+ p = psutil.Process(process.pid)
31
+ if os.name == 'nt': # Windows
32
+ p.nice(psutil.HIGH_PRIORITY_CLASS)
33
+ else: # Unix-like systems
34
+ p.nice(-10) # Adjusted priority level
35
+ print(f"Successfully set high priority for process: {process.pid}")
36
+ except psutil.AccessDenied as e:
37
+ print(f"Access denied when trying to set high priority for process {process.pid}: {e}")
38
+ except psutil.NoSuchProcess as e:
39
+ print(f"No such process {process.pid}: {e}")
40
+ except Exception as e:
41
+ print(f"Failed to set high priority for process {process.pid}: {e}")
42
+
43
+ def set_cpu_affinity(process):
44
+ p = psutil.Process(process.pid)
45
+ p.cpu_affinity(list(range(os.cpu_count())))
15
46
 
16
47
  def proceed_with_app(root, app_name, app_func):
17
48
  # Clear the current content frame
@@ -48,12 +79,18 @@ def parse_list(value):
48
79
  try:
49
80
  parsed_value = ast.literal_eval(value)
50
81
  if isinstance(parsed_value, list):
51
- return parsed_value
82
+ # Check if the list elements are homogeneous (all int or all str)
83
+ if all(isinstance(item, int) for item in parsed_value):
84
+ return parsed_value
85
+ elif all(isinstance(item, str) for item in parsed_value):
86
+ return parsed_value
87
+ else:
88
+ raise ValueError("List contains mixed types or unsupported types")
52
89
  else:
53
90
  raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
54
91
  except (ValueError, SyntaxError) as e:
55
92
  raise ValueError(f"Invalid format for list: {value}. Error: {e}")
56
-
93
+
57
94
  # Usage example in your create_input_field function
58
95
  def create_input_field(frame, label_text, row, var_type='entry', options=None, default_value=None):
59
96
  from .gui_elements import set_dark_style, set_element_size
@@ -67,6 +104,7 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
67
104
  size_dict['settings_width'] = size_dict['settings_width'] - int(size_dict['settings_width']*0.1)
68
105
 
69
106
  # Replace underscores with spaces and capitalize the first letter
107
+
70
108
  label_text = label_text.replace('_', ' ').capitalize()
71
109
 
72
110
  # Configure the column widths
@@ -81,32 +119,35 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
81
119
  custom_frame.config(highlightbackground=style_out['bg_color'], highlightthickness=1, bd=2)
82
120
 
83
121
  # Create and configure the label
84
- if font_loader:
85
- label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
86
- label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=(style_out['font_family'], style_out['font_size']), anchor='e', justify='right')
122
+ label = tk.Label(custom_frame, text=label_text, bg=style_out['bg_color'], fg=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
87
123
  label.grid(column=label_column, row=0, sticky=tk.W, padx=(5, 2), pady=5) # Place the label in the first row
88
124
 
89
125
  # Create and configure the input widget based on var_type
90
- if var_type == 'entry':
91
- var = tk.StringVar(value=default_value)
92
- entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
93
- entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
94
- return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
95
- elif var_type == 'check':
96
- var = tk.BooleanVar(value=default_value) # Set default value (True/False)
97
- check = spacrCheck(custom_frame, text="", variable=var)
98
- check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
99
- return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
100
- elif var_type == 'combo':
101
- var = tk.StringVar(value=default_value) # Set default value
102
- combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
103
- combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
104
- if default_value:
105
- combo.set(default_value)
106
- return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
107
- else:
108
- var = None # Placeholder in case of an undefined var_type
109
- return (label, None, var, custom_frame)
126
+ try:
127
+ if var_type == 'entry':
128
+ var = tk.StringVar(value=default_value)
129
+ entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
130
+ entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
131
+ return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
132
+ elif var_type == 'check':
133
+ var = tk.BooleanVar(value=default_value) # Set default value (True/False)
134
+ check = spacrCheck(custom_frame, text="", variable=var)
135
+ check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
136
+ return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
137
+ elif var_type == 'combo':
138
+ var = tk.StringVar(value=default_value) # Set default value
139
+ combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
140
+ combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
141
+ if default_value:
142
+ combo.set(default_value)
143
+ return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
144
+ else:
145
+ var = None # Placeholder in case of an undefined var_type
146
+ return (label, None, var, custom_frame)
147
+ except Exception as e:
148
+ traceback.print_exc()
149
+ print(f"Error creating input field: {e}")
150
+ print(f"Wrong type for {label_text} Expected {var_type}")
110
151
 
111
152
  def process_stdout_stderr(q):
112
153
  """
@@ -134,16 +175,6 @@ def cancel_after_tasks(frame):
134
175
  frame.after_cancel(task)
135
176
  frame.after_tasks.clear()
136
177
 
137
- def main_thread_update_function(root, q, fig_queue, canvas_widget):
138
- try:
139
- #ansi_escape_pattern = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]')
140
- while not q.empty():
141
- message = q.get_nowait()
142
- except Exception as e:
143
- print(f"Error updating GUI canvas: {e}")
144
- finally:
145
- root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
146
-
147
178
  def annotate(settings):
148
179
  from .settings import set_annotate_default_settings
149
180
  settings = set_annotate_default_settings(settings)
@@ -323,7 +354,9 @@ def convert_settings_dict_for_gui(settings):
323
354
  special_cases = {
324
355
  'metadata_type': ('combo', ['cellvoyager', 'cq1', 'nikon', 'zeis', 'custom'], 'cellvoyager'),
325
356
  'channels': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
357
+ 'train_channels': ('combo', ["['r','g','b']", "['r','g']", "['r','b']", "['g','b']", "['r']", "['g']", "['b']"], "['r','g','b']"),
326
358
  'channel_dims': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
359
+ 'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'metadata'),
327
360
  'cell_mask_dim': ('combo', chans, None),
328
361
  'cell_chann_dim': ('combo', chans, None),
329
362
  'nucleus_mask_dim': ('combo', chans, None),
@@ -369,6 +402,7 @@ def convert_settings_dict_for_gui(settings):
369
402
  variables[key] = ('entry', None, str(value))
370
403
  else:
371
404
  variables[key] = ('entry', None, str(value))
405
+
372
406
  return variables
373
407
 
374
408
 
@@ -413,13 +447,14 @@ def function_gui_wrapper(function=None, settings={}, q=None, fig_queue=None, imp
413
447
  plt.show = original_show
414
448
 
415
449
  def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
450
+
416
451
  from .gui_utils import process_stdout_stderr
417
- from .core import preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, compare_cellpose_masks, analyze_plaques, generate_dataset, apply_model_to_tar
452
+ from .core import generate_image_umap, preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, compare_cellpose_masks, analyze_plaques, generate_dataset, apply_model_to_tar
418
453
  from .io import generate_cellpose_train_test
419
454
  from .measure import measure_crop
420
455
  from .sim import run_multiple_simulations
421
- from .deep_spacr import train_test_model
422
- from .sequencing import analyze_reads, map_barcodes_folder, perform_regression
456
+ from .deep_spacr import deep_spacr
457
+ from .sequencing import generate_barecode_mapping, perform_regression
423
458
  process_stdout_stderr(q)
424
459
 
425
460
  print(f'run_function_gui settings_type: {settings_type}')
@@ -433,12 +468,9 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
433
468
  elif settings_type == 'simulation':
434
469
  function = run_multiple_simulations
435
470
  imports = 1
436
- elif settings_type == 'sequencing':
437
- function = analyze_reads
438
- imports = 1
439
471
  elif settings_type == 'classify':
440
- function = train_test_model
441
- imports = 2
472
+ function = deep_spacr
473
+ imports = 1
442
474
  elif settings_type == 'train_cellpose':
443
475
  function = train_cellpose
444
476
  imports = 1
@@ -452,14 +484,17 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
452
484
  function = check_cellpose_models
453
485
  imports = 1
454
486
  elif settings_type == 'map_barcodes':
455
- function = map_barcodes_folder
487
+ function = generate_barecode_mapping
456
488
  imports = 1
457
489
  elif settings_type == 'regression':
458
490
  function = perform_regression
459
491
  imports = 2
460
492
  elif settings_type == 'recruitment':
461
493
  function = analyze_recruitment
462
- imports = 2
494
+ imports = 1
495
+ elif settings_type == 'umap':
496
+ function = generate_image_umap
497
+ imports = 1
463
498
  else:
464
499
  raise ValueError(f"Invalid settings type: {settings_type}")
465
500
  try:
@@ -470,7 +505,6 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
470
505
  finally:
471
506
  stop_requested.value = 1
472
507
 
473
-
474
508
  def hide_all_settings(vars_dict, categories):
475
509
  """
476
510
  Function to initially hide all settings in the GUI.
spacr/io.py CHANGED
@@ -1,9 +1,9 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
  import tifffile
5
- from PIL import Image
6
- from collections import defaultdict, Counter
5
+ from PIL import Image, ImageOps
6
+ from collections import defaultdict, Counter, deque
7
7
  from pathlib import Path
8
8
  from functools import partial
9
9
  from matplotlib.animation import FuncAnimation
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
17
17
  import matplotlib.pyplot as plt
18
18
  from io import BytesIO
19
19
  from IPython.display import display, clear_output
20
- from multiprocessing import Pool, cpu_count
21
- from torch.utils.data import Dataset
20
+ from multiprocessing import Pool, cpu_count, Process, Queue
21
+ from torch.utils.data import Dataset, DataLoader
22
22
  import matplotlib.pyplot as plt
23
23
  from torchvision.transforms import ToTensor
24
24
  import seaborn as sns
25
-
25
+ import atexit
26
26
 
27
27
  from .logger import log_function_call
28
28
 
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
444
444
  # Return both the image and its filename
445
445
  return img, self.filenames[index]
446
446
 
447
- class MyDataset(Dataset):
448
- """
449
- A custom dataset class for loading and processing image data.
450
-
451
- Args:
452
- data_dir (str): The directory path where the image data is stored.
453
- loader_classes (list): A list of class names for the dataset.
454
- transform (callable, optional): A function/transform to apply to the image data. Default is None.
455
- shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
456
- pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
457
- specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
458
- specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
459
- """
460
-
447
+ class spacrDataset(Dataset):
461
448
  def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
462
449
  self.data_dir = data_dir
463
450
  self.classes = loader_classes
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
466
453
  self.pin_memory = pin_memory
467
454
  self.filenames = []
468
455
  self.labels = []
469
-
456
+
470
457
  if specific_files and specific_labels:
471
458
  self.filenames = specific_files
472
459
  self.labels = specific_labels
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
479
466
 
480
467
  if self.shuffle:
481
468
  self.shuffle_dataset()
482
-
469
+
483
470
  if self.pin_memory:
484
- self.images = [self.load_image(f) for f in self.filenames]
485
-
471
+ # Use multiprocessing to load images in parallel
472
+ with Pool(processes=cpu_count()) as pool:
473
+ self.images = pool.map(self.load_image, self.filenames)
474
+ else:
475
+ self.images = None
476
+
486
477
  def load_image(self, img_path):
487
478
  img = Image.open(img_path).convert('RGB')
479
+ img = ImageOps.exif_transpose(img) # Handle image orientation
488
480
  return img
489
-
481
+
490
482
  def __len__(self):
491
483
  return len(self.filenames)
492
-
484
+
493
485
  def shuffle_dataset(self):
494
486
  combined = list(zip(self.filenames, self.labels))
495
487
  random.shuffle(combined)
496
488
  self.filenames, self.labels = zip(*combined)
497
-
489
+
498
490
  def get_plate(self, filepath):
499
- filename = os.path.basename(filepath) # Get just the filename from the full path
491
+ filename = os.path.basename(filepath)
500
492
  return filename.split('_')[0]
501
-
493
+
502
494
  def __getitem__(self, index):
495
+ if self.pin_memory:
496
+ img = self.images[index]
497
+ else:
498
+ img = self.load_image(self.filenames[index])
503
499
  label = self.labels[index]
504
500
  filename = self.filenames[index]
505
- img = self.load_image(filename)
506
501
  if self.transform:
507
502
  img = self.transform(img)
508
503
  return img, label, filename
504
+
505
+ class spacrDataLoader(DataLoader):
506
+ def __init__(self, *args, preload_batches=1, **kwargs):
507
+ super().__init__(*args, **kwargs)
508
+ self.preload_batches = preload_batches
509
+ self.batch_queue = Queue(maxsize=preload_batches)
510
+ self.process = None
511
+ self.current_batch_index = 0
512
+ self._stop_event = False
513
+ self.pin_memory = kwargs.get('pin_memory', False)
514
+ atexit.register(self.cleanup)
515
+
516
+ def _preload_next_batches(self):
517
+ try:
518
+ for _ in range(self.preload_batches):
519
+ if self._stop_event:
520
+ break
521
+ batch = next(self._iterator)
522
+ if self.pin_memory:
523
+ batch = self._pin_memory_batch(batch)
524
+ self.batch_queue.put(batch)
525
+ except StopIteration:
526
+ pass
527
+
528
+ def _start_preloading(self):
529
+ if self.process is None or not self.process.is_alive():
530
+ self._iterator = iter(super().__iter__())
531
+ if not self.pin_memory:
532
+ self.process = Process(target=self._preload_next_batches)
533
+ self.process.start()
534
+ else:
535
+ self._preload_next_batches() # Directly load if pin_memory is True
536
+
537
+ def _pin_memory_batch(self, batch):
538
+ if isinstance(batch, (list, tuple)):
539
+ return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
540
+ elif isinstance(batch, torch.Tensor):
541
+ return batch.pin_memory()
542
+ else:
543
+ return batch
544
+
545
+ def __iter__(self):
546
+ self._start_preloading()
547
+ return self
548
+
549
+ def __next__(self):
550
+ if self.process and not self.process.is_alive() and self.batch_queue.empty():
551
+ raise StopIteration
552
+
553
+ try:
554
+ if self.pin_memory:
555
+ next_batch = self.batch_queue.get(timeout=60)
556
+ else:
557
+ next_batch = self.batch_queue.get(timeout=60)
558
+ self.current_batch_index += 1
559
+
560
+ # Start preloading the next batches
561
+ if self.batch_queue.qsize() < self.preload_batches:
562
+ self._start_preloading()
563
+
564
+ return next_batch
565
+ except queue.Empty:
566
+ raise StopIteration
567
+
568
+ def cleanup(self):
569
+ self._stop_event = True
570
+ if self.process and self.process.is_alive():
571
+ self.process.terminate()
572
+ self.process.join()
573
+
574
+ def __del__(self):
575
+ self.cleanup()
509
576
 
510
577
  class NoClassDataset(Dataset):
511
578
  def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
@@ -2292,18 +2359,27 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2292
2359
 
2293
2360
  def save_model_at_threshold(threshold, epoch, suffix=""):
2294
2361
  percentile = str(threshold * 100)
2295
- print(f'\rfound: {percentile}% accurate model')#, end='\r', flush=True)
2296
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth')
2362
+ print(f'Found: {percentile}% accurate model')
2363
+ model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
2364
+ torch.save(model, model_path)
2365
+ return model_path
2297
2366
 
2298
2367
  if epoch % 100 == 0 or epoch == epochs:
2299
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth')
2368
+ model_path = f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth'
2369
+ torch.save(model, model_path)
2370
+ return model_path
2300
2371
 
2301
2372
  for threshold in intermedeate_save:
2302
- if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
2303
- save_model_at_threshold(threshold, epoch)
2304
- break # Ensure we only save for the highest matching threshold
2373
+ if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
2374
+ print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
2375
+ model_path = save_model_at_threshold(threshold, epoch)
2376
+ break
2377
+ else:
2378
+ model_path = None
2379
+
2380
+ return model_path
2305
2381
 
2306
- def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2382
+ def _save_progress(dst, results_df, result_type='train'):
2307
2383
  """
2308
2384
  Save the progress of the classification model.
2309
2385
 
@@ -2317,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2317
2393
  """
2318
2394
  # Save accuracy, loss, PRAUC
2319
2395
  os.makedirs(dst, exist_ok=True)
2320
- results_path = os.path.join(dst, 'acc_loss_prauc.csv')
2396
+ results_path = os.path.join(dst, f'{result_type}.csv')
2321
2397
  if not os.path.exists(results_path):
2322
2398
  results_df.to_csv(results_path, index=True, header=True, mode='w')
2323
2399
  else:
2324
2400
  results_df.to_csv(results_path, index=True, header=False, mode='a')
2325
-
2326
- training_metrics_path = os.path.join(dst, 'training_metrics.csv')
2327
- if not os.path.exists(training_metrics_path):
2328
- train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
2329
- else:
2330
- train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
2331
- if epoch == epochs:
2401
+
2402
+ if result_type == 'train':
2332
2403
  read_plot_model_stats(results_path, save=True)
2333
2404
  return
2334
2405
 
spacr/plot.py CHANGED
@@ -1,4 +1,4 @@
1
- import os,re, random, cv2, glob, time, math
1
+ import os,re, random, cv2, glob, time, math, torch
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -1186,6 +1186,52 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1186
1186
  y = row * img_height + 15
1187
1187
  plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1188
1188
  return fig
1189
+
1190
+ def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
1191
+ """
1192
+ Display multiple images in a grid with corresponding labels.
1193
+
1194
+ Args:
1195
+ img (torch.Tensor): A batch of images as a tensor.
1196
+ labels (list): List of labels corresponding to each image.
1197
+ nrow (int, optional): Number of images per row in the grid. Defaults to 20.
1198
+ color (str, optional): Color of the label text. Defaults to 'white'.
1199
+ fontsize (int, optional): Font size of the label text. Defaults to 12.
1200
+ """
1201
+ if img.is_cuda:
1202
+ img = img.cpu() # Move to CPU if the tensor is on GPU
1203
+
1204
+ n_images = len(labels)
1205
+ n_col = nrow
1206
+ n_row = int(np.ceil(n_images / n_col))
1207
+
1208
+ img_height = img.shape[2] # Height of the image
1209
+ img_width = img.shape[3] # Width of the image
1210
+
1211
+ # Prepare the canvas on CPU
1212
+ canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
1213
+
1214
+ for i in range(n_row):
1215
+ for j in range(n_col):
1216
+ idx = i * n_col + j
1217
+ if idx < n_images:
1218
+ # Place the image on the canvas
1219
+ canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
1220
+
1221
+ canvas = canvas.numpy() # Convert to NumPy for plotting
1222
+
1223
+ fig = plt.figure(figsize=(50, 50))
1224
+ plt.imshow(canvas)
1225
+ plt.axis("off")
1226
+
1227
+ for i, label in enumerate(labels):
1228
+ row = i // n_col
1229
+ col = i % n_col
1230
+ x = col * img_width + 2
1231
+ y = row * img_height + 15
1232
+ plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1233
+
1234
+ return fig
1189
1235
 
1190
1236
  def _plot_histograms_and_stats(df):
1191
1237
  conditions = df['condition'].unique()