spacr 0.2.56__py3-none-any.whl → 0.2.65__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +1 -3
- spacr/core.py +135 -472
- spacr/deep_spacr.py +189 -270
- spacr/gui_core.py +294 -86
- spacr/gui_elements.py +335 -82
- spacr/gui_utils.py +100 -49
- spacr/io.py +104 -41
- spacr/plot.py +51 -5
- spacr/sequencing.py +4 -8
- spacr/settings.py +27 -31
- spacr/utils.py +15 -14
- {spacr-0.2.56.dist-info → spacr-0.2.65.dist-info}/METADATA +1 -1
- {spacr-0.2.56.dist-info → spacr-0.2.65.dist-info}/RECORD +17 -17
- {spacr-0.2.56.dist-info → spacr-0.2.65.dist-info}/LICENSE +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.65.dist-info}/WHEEL +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.65.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.65.dist-info}/top_level.txt +0 -0
spacr/gui_utils.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
-
import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback
|
1
|
+
import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback, torch
|
2
2
|
import tkinter as tk
|
3
3
|
from tkinter import ttk
|
4
4
|
import matplotlib
|
5
5
|
import matplotlib.pyplot as plt
|
6
6
|
matplotlib.use('Agg')
|
7
7
|
from huggingface_hub import list_repo_files
|
8
|
+
import psutil
|
8
9
|
|
9
10
|
from .gui_elements import AnnotateApp, spacrEntry, spacrCheck, spacrCombo
|
10
11
|
|
@@ -12,6 +13,36 @@ try:
|
|
12
13
|
ctypes.windll.shcore.SetProcessDpiAwareness(True)
|
13
14
|
except AttributeError:
|
14
15
|
pass
|
16
|
+
|
17
|
+
def initialize_cuda():
|
18
|
+
"""
|
19
|
+
Initializes CUDA in the main process by performing a simple GPU operation.
|
20
|
+
"""
|
21
|
+
if torch.cuda.is_available():
|
22
|
+
# Allocate a small tensor on the GPU
|
23
|
+
_ = torch.tensor([0.0], device='cuda')
|
24
|
+
print("CUDA initialized in the main process.")
|
25
|
+
else:
|
26
|
+
print("CUDA is not available.")
|
27
|
+
|
28
|
+
def set_high_priority(process):
|
29
|
+
try:
|
30
|
+
p = psutil.Process(process.pid)
|
31
|
+
if os.name == 'nt': # Windows
|
32
|
+
p.nice(psutil.HIGH_PRIORITY_CLASS)
|
33
|
+
else: # Unix-like systems
|
34
|
+
p.nice(-10) # Adjusted priority level
|
35
|
+
print(f"Successfully set high priority for process: {process.pid}")
|
36
|
+
except psutil.AccessDenied as e:
|
37
|
+
print(f"Access denied when trying to set high priority for process {process.pid}: {e}")
|
38
|
+
except psutil.NoSuchProcess as e:
|
39
|
+
print(f"No such process {process.pid}: {e}")
|
40
|
+
except Exception as e:
|
41
|
+
print(f"Failed to set high priority for process {process.pid}: {e}")
|
42
|
+
|
43
|
+
def set_cpu_affinity(process):
|
44
|
+
p = psutil.Process(process.pid)
|
45
|
+
p.cpu_affinity(list(range(os.cpu_count())))
|
15
46
|
|
16
47
|
def proceed_with_app(root, app_name, app_func):
|
17
48
|
# Clear the current content frame
|
@@ -44,17 +75,19 @@ def load_app(root, app_name, app_func):
|
|
44
75
|
else:
|
45
76
|
proceed_with_app(root, app_name, app_func)
|
46
77
|
|
47
|
-
def parse_list_v1(value):
|
48
|
-
try:
|
49
|
-
parsed_value = ast.literal_eval(value)
|
50
|
-
if isinstance(parsed_value, list):
|
51
|
-
return parsed_value
|
52
|
-
else:
|
53
|
-
raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
|
54
|
-
except (ValueError, SyntaxError) as e:
|
55
|
-
raise ValueError(f"Invalid format for list: {value}. Error: {e}")
|
56
|
-
|
57
78
|
def parse_list(value):
|
79
|
+
"""
|
80
|
+
Parses a string representation of a list and returns the parsed list.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
value (str): The string representation of the list.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
list: The parsed list.
|
87
|
+
|
88
|
+
Raises:
|
89
|
+
ValueError: If the input value is not a valid list format or contains mixed types or unsupported types.
|
90
|
+
"""
|
58
91
|
try:
|
59
92
|
parsed_value = ast.literal_eval(value)
|
60
93
|
if isinstance(parsed_value, list):
|
@@ -72,7 +105,26 @@ def parse_list(value):
|
|
72
105
|
|
73
106
|
# Usage example in your create_input_field function
|
74
107
|
def create_input_field(frame, label_text, row, var_type='entry', options=None, default_value=None):
|
108
|
+
"""
|
109
|
+
Create an input field in the specified frame.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
frame (tk.Frame): The frame in which the input field will be created.
|
113
|
+
label_text (str): The text to be displayed as the label for the input field.
|
114
|
+
row (int): The row in which the input field will be placed.
|
115
|
+
var_type (str, optional): The type of input field to create. Defaults to 'entry'.
|
116
|
+
options (list, optional): The list of options for a combo box input field. Defaults to None.
|
117
|
+
default_value (str, optional): The default value for the input field. Defaults to None.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
tuple: A tuple containing the label, input widget, variable, and custom frame.
|
121
|
+
|
122
|
+
Raises:
|
123
|
+
Exception: If an error occurs while creating the input field.
|
124
|
+
|
125
|
+
"""
|
75
126
|
from .gui_elements import set_dark_style, set_element_size
|
127
|
+
|
76
128
|
label_column = 0
|
77
129
|
widget_column = 0 # Both label and widget will be in the same column
|
78
130
|
|
@@ -83,6 +135,7 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
|
|
83
135
|
size_dict['settings_width'] = size_dict['settings_width'] - int(size_dict['settings_width']*0.1)
|
84
136
|
|
85
137
|
# Replace underscores with spaces and capitalize the first letter
|
138
|
+
|
86
139
|
label_text = label_text.replace('_', ' ').capitalize()
|
87
140
|
|
88
141
|
# Configure the column widths
|
@@ -97,32 +150,35 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
|
|
97
150
|
custom_frame.config(highlightbackground=style_out['bg_color'], highlightthickness=1, bd=2)
|
98
151
|
|
99
152
|
# Create and configure the label
|
100
|
-
|
101
|
-
label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
|
102
|
-
label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=(style_out['font_family'], style_out['font_size']), anchor='e', justify='right')
|
153
|
+
label = tk.Label(custom_frame, text=label_text, bg=style_out['bg_color'], fg=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
|
103
154
|
label.grid(column=label_column, row=0, sticky=tk.W, padx=(5, 2), pady=5) # Place the label in the first row
|
104
155
|
|
105
156
|
# Create and configure the input widget based on var_type
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
157
|
+
try:
|
158
|
+
if var_type == 'entry':
|
159
|
+
var = tk.StringVar(value=default_value)
|
160
|
+
entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
|
161
|
+
entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
|
162
|
+
return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
|
163
|
+
elif var_type == 'check':
|
164
|
+
var = tk.BooleanVar(value=default_value) # Set default value (True/False)
|
165
|
+
check = spacrCheck(custom_frame, text="", variable=var)
|
166
|
+
check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
|
167
|
+
return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
|
168
|
+
elif var_type == 'combo':
|
169
|
+
var = tk.StringVar(value=default_value) # Set default value
|
170
|
+
combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
|
171
|
+
combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
|
172
|
+
if default_value:
|
173
|
+
combo.set(default_value)
|
174
|
+
return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
|
175
|
+
else:
|
176
|
+
var = None # Placeholder in case of an undefined var_type
|
177
|
+
return (label, None, var, custom_frame)
|
178
|
+
except Exception as e:
|
179
|
+
traceback.print_exc()
|
180
|
+
print(f"Error creating input field: {e}")
|
181
|
+
print(f"Wrong type for {label_text} Expected {var_type}")
|
126
182
|
|
127
183
|
def process_stdout_stderr(q):
|
128
184
|
"""
|
@@ -150,16 +206,6 @@ def cancel_after_tasks(frame):
|
|
150
206
|
frame.after_cancel(task)
|
151
207
|
frame.after_tasks.clear()
|
152
208
|
|
153
|
-
def main_thread_update_function(root, q, fig_queue, canvas_widget):
|
154
|
-
try:
|
155
|
-
#ansi_escape_pattern = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]')
|
156
|
-
while not q.empty():
|
157
|
-
message = q.get_nowait()
|
158
|
-
except Exception as e:
|
159
|
-
print(f"Error updating GUI canvas: {e}")
|
160
|
-
finally:
|
161
|
-
root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
|
162
|
-
|
163
209
|
def annotate(settings):
|
164
210
|
from .settings import set_annotate_default_settings
|
165
211
|
settings = set_annotate_default_settings(settings)
|
@@ -190,6 +236,12 @@ def annotate(settings):
|
|
190
236
|
|
191
237
|
def generate_annotate_fields(frame):
|
192
238
|
from .settings import set_annotate_default_settings
|
239
|
+
from .gui_elements import set_dark_style
|
240
|
+
|
241
|
+
style_out = set_dark_style(ttk.Style())
|
242
|
+
font_loader = style_out['font_loader']
|
243
|
+
font_size = style_out['font_size'] - 2
|
244
|
+
|
193
245
|
vars_dict = {}
|
194
246
|
settings = set_annotate_default_settings(settings={})
|
195
247
|
|
@@ -201,8 +253,8 @@ def generate_annotate_fields(frame):
|
|
201
253
|
|
202
254
|
# Arrange input fields and labels
|
203
255
|
for row, (name, data) in enumerate(vars_dict.items()):
|
204
|
-
|
205
|
-
|
256
|
+
tk.Label(frame, text=f"{name.replace('_', ' ').capitalize()}:", bg=style_out['bg_color'], fg=style_out['fg_color'], font=font_loader.get_font(size=font_size)).grid(row=row, column=0)
|
257
|
+
#ttk.Label(frame, text=f"{name.replace('_', ' ').capitalize()}:", background="black", foreground="white").grid(row=row, column=0)
|
206
258
|
if isinstance(data['value'], list):
|
207
259
|
# Convert lists to comma-separated strings
|
208
260
|
data['entry'].insert(0, ','.join(map(str, data['value'])))
|
@@ -341,7 +393,7 @@ def convert_settings_dict_for_gui(settings):
|
|
341
393
|
'channels': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
|
342
394
|
'train_channels': ('combo', ["['r','g','b']", "['r','g']", "['r','b']", "['g','b']", "['r']", "['g']", "['b']"], "['r','g','b']"),
|
343
395
|
'channel_dims': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
|
344
|
-
'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], '
|
396
|
+
'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'metadata'),
|
345
397
|
'cell_mask_dim': ('combo', chans, None),
|
346
398
|
'cell_chann_dim': ('combo', chans, None),
|
347
399
|
'nucleus_mask_dim': ('combo', chans, None),
|
@@ -476,7 +528,7 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
476
528
|
imports = 2
|
477
529
|
elif settings_type == 'recruitment':
|
478
530
|
function = analyze_recruitment
|
479
|
-
imports =
|
531
|
+
imports = 1
|
480
532
|
elif settings_type == 'umap':
|
481
533
|
function = generate_image_umap
|
482
534
|
imports = 1
|
@@ -490,7 +542,6 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
490
542
|
finally:
|
491
543
|
stop_requested.value = 1
|
492
544
|
|
493
|
-
|
494
545
|
def hide_all_settings(vars_dict, categories):
|
495
546
|
"""
|
496
547
|
Function to initially hide all settings in the GUI.
|
spacr/io.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
|
1
|
+
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
4
|
import tifffile
|
5
|
-
from PIL import Image
|
6
|
-
from collections import defaultdict, Counter
|
5
|
+
from PIL import Image, ImageOps
|
6
|
+
from collections import defaultdict, Counter, deque
|
7
7
|
from pathlib import Path
|
8
8
|
from functools import partial
|
9
9
|
from matplotlib.animation import FuncAnimation
|
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
|
|
17
17
|
import matplotlib.pyplot as plt
|
18
18
|
from io import BytesIO
|
19
19
|
from IPython.display import display, clear_output
|
20
|
-
from multiprocessing import Pool, cpu_count
|
21
|
-
from torch.utils.data import Dataset
|
20
|
+
from multiprocessing import Pool, cpu_count, Process, Queue
|
21
|
+
from torch.utils.data import Dataset, DataLoader
|
22
22
|
import matplotlib.pyplot as plt
|
23
23
|
from torchvision.transforms import ToTensor
|
24
24
|
import seaborn as sns
|
25
|
-
|
25
|
+
import atexit
|
26
26
|
|
27
27
|
from .logger import log_function_call
|
28
28
|
|
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
|
|
444
444
|
# Return both the image and its filename
|
445
445
|
return img, self.filenames[index]
|
446
446
|
|
447
|
-
class
|
448
|
-
"""
|
449
|
-
A custom dataset class for loading and processing image data.
|
450
|
-
|
451
|
-
Args:
|
452
|
-
data_dir (str): The directory path where the image data is stored.
|
453
|
-
loader_classes (list): A list of class names for the dataset.
|
454
|
-
transform (callable, optional): A function/transform to apply to the image data. Default is None.
|
455
|
-
shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
|
456
|
-
pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
|
457
|
-
specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
|
458
|
-
specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
|
459
|
-
"""
|
460
|
-
|
447
|
+
class spacrDataset(Dataset):
|
461
448
|
def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
|
462
449
|
self.data_dir = data_dir
|
463
450
|
self.classes = loader_classes
|
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
|
|
466
453
|
self.pin_memory = pin_memory
|
467
454
|
self.filenames = []
|
468
455
|
self.labels = []
|
469
|
-
|
456
|
+
|
470
457
|
if specific_files and specific_labels:
|
471
458
|
self.filenames = specific_files
|
472
459
|
self.labels = specific_labels
|
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
|
|
479
466
|
|
480
467
|
if self.shuffle:
|
481
468
|
self.shuffle_dataset()
|
482
|
-
|
469
|
+
|
483
470
|
if self.pin_memory:
|
484
|
-
|
485
|
-
|
471
|
+
# Use multiprocessing to load images in parallel
|
472
|
+
with Pool(processes=cpu_count()) as pool:
|
473
|
+
self.images = pool.map(self.load_image, self.filenames)
|
474
|
+
else:
|
475
|
+
self.images = None
|
476
|
+
|
486
477
|
def load_image(self, img_path):
|
487
478
|
img = Image.open(img_path).convert('RGB')
|
479
|
+
img = ImageOps.exif_transpose(img) # Handle image orientation
|
488
480
|
return img
|
489
|
-
|
481
|
+
|
490
482
|
def __len__(self):
|
491
483
|
return len(self.filenames)
|
492
|
-
|
484
|
+
|
493
485
|
def shuffle_dataset(self):
|
494
486
|
combined = list(zip(self.filenames, self.labels))
|
495
487
|
random.shuffle(combined)
|
496
488
|
self.filenames, self.labels = zip(*combined)
|
497
|
-
|
489
|
+
|
498
490
|
def get_plate(self, filepath):
|
499
|
-
filename = os.path.basename(filepath)
|
491
|
+
filename = os.path.basename(filepath)
|
500
492
|
return filename.split('_')[0]
|
501
|
-
|
493
|
+
|
502
494
|
def __getitem__(self, index):
|
495
|
+
if self.pin_memory:
|
496
|
+
img = self.images[index]
|
497
|
+
else:
|
498
|
+
img = self.load_image(self.filenames[index])
|
503
499
|
label = self.labels[index]
|
504
500
|
filename = self.filenames[index]
|
505
|
-
img = self.load_image(filename)
|
506
501
|
if self.transform:
|
507
502
|
img = self.transform(img)
|
508
503
|
return img, label, filename
|
504
|
+
|
505
|
+
class spacrDataLoader(DataLoader):
|
506
|
+
def __init__(self, *args, preload_batches=1, **kwargs):
|
507
|
+
super().__init__(*args, **kwargs)
|
508
|
+
self.preload_batches = preload_batches
|
509
|
+
self.batch_queue = Queue(maxsize=preload_batches)
|
510
|
+
self.process = None
|
511
|
+
self.current_batch_index = 0
|
512
|
+
self._stop_event = False
|
513
|
+
self.pin_memory = kwargs.get('pin_memory', False)
|
514
|
+
atexit.register(self.cleanup)
|
515
|
+
|
516
|
+
def _preload_next_batches(self):
|
517
|
+
try:
|
518
|
+
for _ in range(self.preload_batches):
|
519
|
+
if self._stop_event:
|
520
|
+
break
|
521
|
+
batch = next(self._iterator)
|
522
|
+
if self.pin_memory:
|
523
|
+
batch = self._pin_memory_batch(batch)
|
524
|
+
self.batch_queue.put(batch)
|
525
|
+
except StopIteration:
|
526
|
+
pass
|
527
|
+
|
528
|
+
def _start_preloading(self):
|
529
|
+
if self.process is None or not self.process.is_alive():
|
530
|
+
self._iterator = iter(super().__iter__())
|
531
|
+
if not self.pin_memory:
|
532
|
+
self.process = Process(target=self._preload_next_batches)
|
533
|
+
self.process.start()
|
534
|
+
else:
|
535
|
+
self._preload_next_batches() # Directly load if pin_memory is True
|
536
|
+
|
537
|
+
def _pin_memory_batch(self, batch):
|
538
|
+
if isinstance(batch, (list, tuple)):
|
539
|
+
return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
|
540
|
+
elif isinstance(batch, torch.Tensor):
|
541
|
+
return batch.pin_memory()
|
542
|
+
else:
|
543
|
+
return batch
|
544
|
+
|
545
|
+
def __iter__(self):
|
546
|
+
self._start_preloading()
|
547
|
+
return self
|
548
|
+
|
549
|
+
def __next__(self):
|
550
|
+
if self.process and not self.process.is_alive() and self.batch_queue.empty():
|
551
|
+
raise StopIteration
|
552
|
+
|
553
|
+
try:
|
554
|
+
if self.pin_memory:
|
555
|
+
next_batch = self.batch_queue.get(timeout=60)
|
556
|
+
else:
|
557
|
+
next_batch = self.batch_queue.get(timeout=60)
|
558
|
+
self.current_batch_index += 1
|
559
|
+
|
560
|
+
# Start preloading the next batches
|
561
|
+
if self.batch_queue.qsize() < self.preload_batches:
|
562
|
+
self._start_preloading()
|
563
|
+
|
564
|
+
return next_batch
|
565
|
+
except queue.Empty:
|
566
|
+
raise StopIteration
|
567
|
+
|
568
|
+
def cleanup(self):
|
569
|
+
self._stop_event = True
|
570
|
+
if self.process and self.process.is_alive():
|
571
|
+
self.process.terminate()
|
572
|
+
self.process.join()
|
573
|
+
|
574
|
+
def __del__(self):
|
575
|
+
self.cleanup()
|
509
576
|
|
510
577
|
class NoClassDataset(Dataset):
|
511
578
|
def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
|
@@ -2292,7 +2359,7 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2292
2359
|
|
2293
2360
|
def save_model_at_threshold(threshold, epoch, suffix=""):
|
2294
2361
|
percentile = str(threshold * 100)
|
2295
|
-
print(f'
|
2362
|
+
print(f'Found: {percentile}% accurate model')
|
2296
2363
|
model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
|
2297
2364
|
torch.save(model, model_path)
|
2298
2365
|
return model_path
|
@@ -2303,7 +2370,8 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2303
2370
|
return model_path
|
2304
2371
|
|
2305
2372
|
for threshold in intermedeate_save:
|
2306
|
-
if results_df['neg_accuracy']
|
2373
|
+
if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
|
2374
|
+
print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
|
2307
2375
|
model_path = save_model_at_threshold(threshold, epoch)
|
2308
2376
|
break
|
2309
2377
|
else:
|
@@ -2311,7 +2379,7 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2311
2379
|
|
2312
2380
|
return model_path
|
2313
2381
|
|
2314
|
-
def _save_progress(dst, results_df,
|
2382
|
+
def _save_progress(dst, results_df, result_type='train'):
|
2315
2383
|
"""
|
2316
2384
|
Save the progress of the classification model.
|
2317
2385
|
|
@@ -2325,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
|
|
2325
2393
|
"""
|
2326
2394
|
# Save accuracy, loss, PRAUC
|
2327
2395
|
os.makedirs(dst, exist_ok=True)
|
2328
|
-
results_path = os.path.join(dst, '
|
2396
|
+
results_path = os.path.join(dst, f'{result_type}.csv')
|
2329
2397
|
if not os.path.exists(results_path):
|
2330
2398
|
results_df.to_csv(results_path, index=True, header=True, mode='w')
|
2331
2399
|
else:
|
2332
2400
|
results_df.to_csv(results_path, index=True, header=False, mode='a')
|
2333
|
-
|
2334
|
-
|
2335
|
-
if not os.path.exists(training_metrics_path):
|
2336
|
-
train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
|
2337
|
-
else:
|
2338
|
-
train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
|
2339
|
-
if epoch == epochs:
|
2401
|
+
|
2402
|
+
if result_type == 'train':
|
2340
2403
|
read_plot_model_stats(results_path, save=True)
|
2341
2404
|
return
|
2342
2405
|
|
spacr/plot.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os,re, random, cv2, glob, time, math
|
1
|
+
import os,re, random, cv2, glob, time, math, torch
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
@@ -125,7 +125,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
125
125
|
|
126
126
|
return
|
127
127
|
|
128
|
-
def plot_masks(batch, masks, flows, cmap='inferno', figuresize=
|
128
|
+
def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
|
129
129
|
"""
|
130
130
|
Plot the masks and flows for a given batch of images.
|
131
131
|
|
@@ -476,7 +476,7 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
|
|
476
476
|
|
477
477
|
return stack
|
478
478
|
|
479
|
-
def plot_arrays(src, figuresize=
|
479
|
+
def plot_arrays(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
|
480
480
|
"""
|
481
481
|
Plot randomly selected arrays from a given directory.
|
482
482
|
|
@@ -870,7 +870,7 @@ def _save_scimg_plot(src, nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1,
|
|
870
870
|
|
871
871
|
return
|
872
872
|
|
873
|
-
def _plot_cropped_arrays(stack, filename, figuresize=
|
873
|
+
def _plot_cropped_arrays(stack, filename, figuresize=10, cmap='inferno', threshold=500):
|
874
874
|
"""
|
875
875
|
Plot cropped arrays.
|
876
876
|
|
@@ -997,7 +997,7 @@ def _display_gif(path):
|
|
997
997
|
with open(path, 'rb') as file:
|
998
998
|
display(ipyimage(file.read()))
|
999
999
|
|
1000
|
-
def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=
|
1000
|
+
def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=10):
|
1001
1001
|
"""
|
1002
1002
|
Plot recruitment data for different conditions and pathogens.
|
1003
1003
|
|
@@ -1186,6 +1186,52 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
|
1186
1186
|
y = row * img_height + 15
|
1187
1187
|
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1188
1188
|
return fig
|
1189
|
+
|
1190
|
+
def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
|
1191
|
+
"""
|
1192
|
+
Display multiple images in a grid with corresponding labels.
|
1193
|
+
|
1194
|
+
Args:
|
1195
|
+
img (torch.Tensor): A batch of images as a tensor.
|
1196
|
+
labels (list): List of labels corresponding to each image.
|
1197
|
+
nrow (int, optional): Number of images per row in the grid. Defaults to 20.
|
1198
|
+
color (str, optional): Color of the label text. Defaults to 'white'.
|
1199
|
+
fontsize (int, optional): Font size of the label text. Defaults to 12.
|
1200
|
+
"""
|
1201
|
+
if img.is_cuda:
|
1202
|
+
img = img.cpu() # Move to CPU if the tensor is on GPU
|
1203
|
+
|
1204
|
+
n_images = len(labels)
|
1205
|
+
n_col = nrow
|
1206
|
+
n_row = int(np.ceil(n_images / n_col))
|
1207
|
+
|
1208
|
+
img_height = img.shape[2] # Height of the image
|
1209
|
+
img_width = img.shape[3] # Width of the image
|
1210
|
+
|
1211
|
+
# Prepare the canvas on CPU
|
1212
|
+
canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
|
1213
|
+
|
1214
|
+
for i in range(n_row):
|
1215
|
+
for j in range(n_col):
|
1216
|
+
idx = i * n_col + j
|
1217
|
+
if idx < n_images:
|
1218
|
+
# Place the image on the canvas
|
1219
|
+
canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
|
1220
|
+
|
1221
|
+
canvas = canvas.numpy() # Convert to NumPy for plotting
|
1222
|
+
|
1223
|
+
fig = plt.figure(figsize=(50, 50))
|
1224
|
+
plt.imshow(canvas)
|
1225
|
+
plt.axis("off")
|
1226
|
+
|
1227
|
+
for i, label in enumerate(labels):
|
1228
|
+
row = i // n_col
|
1229
|
+
col = i % n_col
|
1230
|
+
x = col * img_width + 2
|
1231
|
+
y = row * img_height + 15
|
1232
|
+
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1233
|
+
|
1234
|
+
return fig
|
1189
1235
|
|
1190
1236
|
def _plot_histograms_and_stats(df):
|
1191
1237
|
conditions = df['condition'].unique()
|
spacr/sequencing.py
CHANGED
@@ -223,14 +223,10 @@ def save_to_hdf(queue, output_file, complevel=9, compression='zlib'):
|
|
223
223
|
Save data from a queue to an HDF file.
|
224
224
|
|
225
225
|
Parameters:
|
226
|
-
- queue: Queue object
|
227
|
-
|
228
|
-
-
|
229
|
-
|
230
|
-
- complevel: int, optional
|
231
|
-
The compression level to use (default is 9).
|
232
|
-
- compression: str, optional
|
233
|
-
The compression algorithm to use (default is 'zlib').
|
226
|
+
- queue: Queue object containing chunks of data to be saved
|
227
|
+
- output_file: Path to the output HDF file
|
228
|
+
- complevel: Compression level (default: 9)
|
229
|
+
- compression: Compression algorithm (default: 'zlib')
|
234
230
|
|
235
231
|
Returns:
|
236
232
|
None
|