spacr 0.2.56__py3-none-any.whl → 0.2.65__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/gui_utils.py CHANGED
@@ -1,10 +1,11 @@
1
- import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback
1
+ import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback, torch
2
2
  import tkinter as tk
3
3
  from tkinter import ttk
4
4
  import matplotlib
5
5
  import matplotlib.pyplot as plt
6
6
  matplotlib.use('Agg')
7
7
  from huggingface_hub import list_repo_files
8
+ import psutil
8
9
 
9
10
  from .gui_elements import AnnotateApp, spacrEntry, spacrCheck, spacrCombo
10
11
 
@@ -12,6 +13,36 @@ try:
12
13
  ctypes.windll.shcore.SetProcessDpiAwareness(True)
13
14
  except AttributeError:
14
15
  pass
16
+
17
+ def initialize_cuda():
18
+ """
19
+ Initializes CUDA in the main process by performing a simple GPU operation.
20
+ """
21
+ if torch.cuda.is_available():
22
+ # Allocate a small tensor on the GPU
23
+ _ = torch.tensor([0.0], device='cuda')
24
+ print("CUDA initialized in the main process.")
25
+ else:
26
+ print("CUDA is not available.")
27
+
28
+ def set_high_priority(process):
29
+ try:
30
+ p = psutil.Process(process.pid)
31
+ if os.name == 'nt': # Windows
32
+ p.nice(psutil.HIGH_PRIORITY_CLASS)
33
+ else: # Unix-like systems
34
+ p.nice(-10) # Adjusted priority level
35
+ print(f"Successfully set high priority for process: {process.pid}")
36
+ except psutil.AccessDenied as e:
37
+ print(f"Access denied when trying to set high priority for process {process.pid}: {e}")
38
+ except psutil.NoSuchProcess as e:
39
+ print(f"No such process {process.pid}: {e}")
40
+ except Exception as e:
41
+ print(f"Failed to set high priority for process {process.pid}: {e}")
42
+
43
+ def set_cpu_affinity(process):
44
+ p = psutil.Process(process.pid)
45
+ p.cpu_affinity(list(range(os.cpu_count())))
15
46
 
16
47
  def proceed_with_app(root, app_name, app_func):
17
48
  # Clear the current content frame
@@ -44,17 +75,19 @@ def load_app(root, app_name, app_func):
44
75
  else:
45
76
  proceed_with_app(root, app_name, app_func)
46
77
 
47
- def parse_list_v1(value):
48
- try:
49
- parsed_value = ast.literal_eval(value)
50
- if isinstance(parsed_value, list):
51
- return parsed_value
52
- else:
53
- raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
54
- except (ValueError, SyntaxError) as e:
55
- raise ValueError(f"Invalid format for list: {value}. Error: {e}")
56
-
57
78
  def parse_list(value):
79
+ """
80
+ Parses a string representation of a list and returns the parsed list.
81
+
82
+ Args:
83
+ value (str): The string representation of the list.
84
+
85
+ Returns:
86
+ list: The parsed list.
87
+
88
+ Raises:
89
+ ValueError: If the input value is not a valid list format or contains mixed types or unsupported types.
90
+ """
58
91
  try:
59
92
  parsed_value = ast.literal_eval(value)
60
93
  if isinstance(parsed_value, list):
@@ -72,7 +105,26 @@ def parse_list(value):
72
105
 
73
106
  # Usage example in your create_input_field function
74
107
  def create_input_field(frame, label_text, row, var_type='entry', options=None, default_value=None):
108
+ """
109
+ Create an input field in the specified frame.
110
+
111
+ Args:
112
+ frame (tk.Frame): The frame in which the input field will be created.
113
+ label_text (str): The text to be displayed as the label for the input field.
114
+ row (int): The row in which the input field will be placed.
115
+ var_type (str, optional): The type of input field to create. Defaults to 'entry'.
116
+ options (list, optional): The list of options for a combo box input field. Defaults to None.
117
+ default_value (str, optional): The default value for the input field. Defaults to None.
118
+
119
+ Returns:
120
+ tuple: A tuple containing the label, input widget, variable, and custom frame.
121
+
122
+ Raises:
123
+ Exception: If an error occurs while creating the input field.
124
+
125
+ """
75
126
  from .gui_elements import set_dark_style, set_element_size
127
+
76
128
  label_column = 0
77
129
  widget_column = 0 # Both label and widget will be in the same column
78
130
 
@@ -83,6 +135,7 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
83
135
  size_dict['settings_width'] = size_dict['settings_width'] - int(size_dict['settings_width']*0.1)
84
136
 
85
137
  # Replace underscores with spaces and capitalize the first letter
138
+
86
139
  label_text = label_text.replace('_', ' ').capitalize()
87
140
 
88
141
  # Configure the column widths
@@ -97,32 +150,35 @@ def create_input_field(frame, label_text, row, var_type='entry', options=None, d
97
150
  custom_frame.config(highlightbackground=style_out['bg_color'], highlightthickness=1, bd=2)
98
151
 
99
152
  # Create and configure the label
100
- if font_loader:
101
- label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
102
- label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=(style_out['font_family'], style_out['font_size']), anchor='e', justify='right')
153
+ label = tk.Label(custom_frame, text=label_text, bg=style_out['bg_color'], fg=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
103
154
  label.grid(column=label_column, row=0, sticky=tk.W, padx=(5, 2), pady=5) # Place the label in the first row
104
155
 
105
156
  # Create and configure the input widget based on var_type
106
- if var_type == 'entry':
107
- var = tk.StringVar(value=default_value)
108
- entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
109
- entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
110
- return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
111
- elif var_type == 'check':
112
- var = tk.BooleanVar(value=default_value) # Set default value (True/False)
113
- check = spacrCheck(custom_frame, text="", variable=var)
114
- check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
115
- return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
116
- elif var_type == 'combo':
117
- var = tk.StringVar(value=default_value) # Set default value
118
- combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
119
- combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
120
- if default_value:
121
- combo.set(default_value)
122
- return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
123
- else:
124
- var = None # Placeholder in case of an undefined var_type
125
- return (label, None, var, custom_frame)
157
+ try:
158
+ if var_type == 'entry':
159
+ var = tk.StringVar(value=default_value)
160
+ entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
161
+ entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
162
+ return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
163
+ elif var_type == 'check':
164
+ var = tk.BooleanVar(value=default_value) # Set default value (True/False)
165
+ check = spacrCheck(custom_frame, text="", variable=var)
166
+ check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
167
+ return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
168
+ elif var_type == 'combo':
169
+ var = tk.StringVar(value=default_value) # Set default value
170
+ combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
171
+ combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
172
+ if default_value:
173
+ combo.set(default_value)
174
+ return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
175
+ else:
176
+ var = None # Placeholder in case of an undefined var_type
177
+ return (label, None, var, custom_frame)
178
+ except Exception as e:
179
+ traceback.print_exc()
180
+ print(f"Error creating input field: {e}")
181
+ print(f"Wrong type for {label_text} Expected {var_type}")
126
182
 
127
183
  def process_stdout_stderr(q):
128
184
  """
@@ -150,16 +206,6 @@ def cancel_after_tasks(frame):
150
206
  frame.after_cancel(task)
151
207
  frame.after_tasks.clear()
152
208
 
153
- def main_thread_update_function(root, q, fig_queue, canvas_widget):
154
- try:
155
- #ansi_escape_pattern = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]')
156
- while not q.empty():
157
- message = q.get_nowait()
158
- except Exception as e:
159
- print(f"Error updating GUI canvas: {e}")
160
- finally:
161
- root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
162
-
163
209
  def annotate(settings):
164
210
  from .settings import set_annotate_default_settings
165
211
  settings = set_annotate_default_settings(settings)
@@ -190,6 +236,12 @@ def annotate(settings):
190
236
 
191
237
  def generate_annotate_fields(frame):
192
238
  from .settings import set_annotate_default_settings
239
+ from .gui_elements import set_dark_style
240
+
241
+ style_out = set_dark_style(ttk.Style())
242
+ font_loader = style_out['font_loader']
243
+ font_size = style_out['font_size'] - 2
244
+
193
245
  vars_dict = {}
194
246
  settings = set_annotate_default_settings(settings={})
195
247
 
@@ -201,8 +253,8 @@ def generate_annotate_fields(frame):
201
253
 
202
254
  # Arrange input fields and labels
203
255
  for row, (name, data) in enumerate(vars_dict.items()):
204
- ttk.Label(frame, text=f"{name.replace('_', ' ').capitalize()}:",
205
- background="black", foreground="white").grid(row=row, column=0)
256
+ tk.Label(frame, text=f"{name.replace('_', ' ').capitalize()}:", bg=style_out['bg_color'], fg=style_out['fg_color'], font=font_loader.get_font(size=font_size)).grid(row=row, column=0)
257
+ #ttk.Label(frame, text=f"{name.replace('_', ' ').capitalize()}:", background="black", foreground="white").grid(row=row, column=0)
206
258
  if isinstance(data['value'], list):
207
259
  # Convert lists to comma-separated strings
208
260
  data['entry'].insert(0, ','.join(map(str, data['value'])))
@@ -341,7 +393,7 @@ def convert_settings_dict_for_gui(settings):
341
393
  'channels': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
342
394
  'train_channels': ('combo', ["['r','g','b']", "['r','g']", "['r','b']", "['g','b']", "['r']", "['g']", "['b']"], "['r','g','b']"),
343
395
  'channel_dims': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
344
- 'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'annotate'),
396
+ 'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'metadata'),
345
397
  'cell_mask_dim': ('combo', chans, None),
346
398
  'cell_chann_dim': ('combo', chans, None),
347
399
  'nucleus_mask_dim': ('combo', chans, None),
@@ -476,7 +528,7 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
476
528
  imports = 2
477
529
  elif settings_type == 'recruitment':
478
530
  function = analyze_recruitment
479
- imports = 2
531
+ imports = 1
480
532
  elif settings_type == 'umap':
481
533
  function = generate_image_umap
482
534
  imports = 1
@@ -490,7 +542,6 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
490
542
  finally:
491
543
  stop_requested.value = 1
492
544
 
493
-
494
545
  def hide_all_settings(vars_dict, categories):
495
546
  """
496
547
  Function to initially hide all settings in the GUI.
spacr/io.py CHANGED
@@ -1,9 +1,9 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
  import tifffile
5
- from PIL import Image
6
- from collections import defaultdict, Counter
5
+ from PIL import Image, ImageOps
6
+ from collections import defaultdict, Counter, deque
7
7
  from pathlib import Path
8
8
  from functools import partial
9
9
  from matplotlib.animation import FuncAnimation
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
17
17
  import matplotlib.pyplot as plt
18
18
  from io import BytesIO
19
19
  from IPython.display import display, clear_output
20
- from multiprocessing import Pool, cpu_count
21
- from torch.utils.data import Dataset
20
+ from multiprocessing import Pool, cpu_count, Process, Queue
21
+ from torch.utils.data import Dataset, DataLoader
22
22
  import matplotlib.pyplot as plt
23
23
  from torchvision.transforms import ToTensor
24
24
  import seaborn as sns
25
-
25
+ import atexit
26
26
 
27
27
  from .logger import log_function_call
28
28
 
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
444
444
  # Return both the image and its filename
445
445
  return img, self.filenames[index]
446
446
 
447
- class MyDataset(Dataset):
448
- """
449
- A custom dataset class for loading and processing image data.
450
-
451
- Args:
452
- data_dir (str): The directory path where the image data is stored.
453
- loader_classes (list): A list of class names for the dataset.
454
- transform (callable, optional): A function/transform to apply to the image data. Default is None.
455
- shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
456
- pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
457
- specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
458
- specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
459
- """
460
-
447
+ class spacrDataset(Dataset):
461
448
  def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
462
449
  self.data_dir = data_dir
463
450
  self.classes = loader_classes
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
466
453
  self.pin_memory = pin_memory
467
454
  self.filenames = []
468
455
  self.labels = []
469
-
456
+
470
457
  if specific_files and specific_labels:
471
458
  self.filenames = specific_files
472
459
  self.labels = specific_labels
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
479
466
 
480
467
  if self.shuffle:
481
468
  self.shuffle_dataset()
482
-
469
+
483
470
  if self.pin_memory:
484
- self.images = [self.load_image(f) for f in self.filenames]
485
-
471
+ # Use multiprocessing to load images in parallel
472
+ with Pool(processes=cpu_count()) as pool:
473
+ self.images = pool.map(self.load_image, self.filenames)
474
+ else:
475
+ self.images = None
476
+
486
477
  def load_image(self, img_path):
487
478
  img = Image.open(img_path).convert('RGB')
479
+ img = ImageOps.exif_transpose(img) # Handle image orientation
488
480
  return img
489
-
481
+
490
482
  def __len__(self):
491
483
  return len(self.filenames)
492
-
484
+
493
485
  def shuffle_dataset(self):
494
486
  combined = list(zip(self.filenames, self.labels))
495
487
  random.shuffle(combined)
496
488
  self.filenames, self.labels = zip(*combined)
497
-
489
+
498
490
  def get_plate(self, filepath):
499
- filename = os.path.basename(filepath) # Get just the filename from the full path
491
+ filename = os.path.basename(filepath)
500
492
  return filename.split('_')[0]
501
-
493
+
502
494
  def __getitem__(self, index):
495
+ if self.pin_memory:
496
+ img = self.images[index]
497
+ else:
498
+ img = self.load_image(self.filenames[index])
503
499
  label = self.labels[index]
504
500
  filename = self.filenames[index]
505
- img = self.load_image(filename)
506
501
  if self.transform:
507
502
  img = self.transform(img)
508
503
  return img, label, filename
504
+
505
+ class spacrDataLoader(DataLoader):
506
+ def __init__(self, *args, preload_batches=1, **kwargs):
507
+ super().__init__(*args, **kwargs)
508
+ self.preload_batches = preload_batches
509
+ self.batch_queue = Queue(maxsize=preload_batches)
510
+ self.process = None
511
+ self.current_batch_index = 0
512
+ self._stop_event = False
513
+ self.pin_memory = kwargs.get('pin_memory', False)
514
+ atexit.register(self.cleanup)
515
+
516
+ def _preload_next_batches(self):
517
+ try:
518
+ for _ in range(self.preload_batches):
519
+ if self._stop_event:
520
+ break
521
+ batch = next(self._iterator)
522
+ if self.pin_memory:
523
+ batch = self._pin_memory_batch(batch)
524
+ self.batch_queue.put(batch)
525
+ except StopIteration:
526
+ pass
527
+
528
+ def _start_preloading(self):
529
+ if self.process is None or not self.process.is_alive():
530
+ self._iterator = iter(super().__iter__())
531
+ if not self.pin_memory:
532
+ self.process = Process(target=self._preload_next_batches)
533
+ self.process.start()
534
+ else:
535
+ self._preload_next_batches() # Directly load if pin_memory is True
536
+
537
+ def _pin_memory_batch(self, batch):
538
+ if isinstance(batch, (list, tuple)):
539
+ return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
540
+ elif isinstance(batch, torch.Tensor):
541
+ return batch.pin_memory()
542
+ else:
543
+ return batch
544
+
545
+ def __iter__(self):
546
+ self._start_preloading()
547
+ return self
548
+
549
+ def __next__(self):
550
+ if self.process and not self.process.is_alive() and self.batch_queue.empty():
551
+ raise StopIteration
552
+
553
+ try:
554
+ if self.pin_memory:
555
+ next_batch = self.batch_queue.get(timeout=60)
556
+ else:
557
+ next_batch = self.batch_queue.get(timeout=60)
558
+ self.current_batch_index += 1
559
+
560
+ # Start preloading the next batches
561
+ if self.batch_queue.qsize() < self.preload_batches:
562
+ self._start_preloading()
563
+
564
+ return next_batch
565
+ except queue.Empty:
566
+ raise StopIteration
567
+
568
+ def cleanup(self):
569
+ self._stop_event = True
570
+ if self.process and self.process.is_alive():
571
+ self.process.terminate()
572
+ self.process.join()
573
+
574
+ def __del__(self):
575
+ self.cleanup()
509
576
 
510
577
  class NoClassDataset(Dataset):
511
578
  def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
@@ -2292,7 +2359,7 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2292
2359
 
2293
2360
  def save_model_at_threshold(threshold, epoch, suffix=""):
2294
2361
  percentile = str(threshold * 100)
2295
- print(f'\rfound: {percentile}% accurate model')#, end='\r', flush=True)
2362
+ print(f'Found: {percentile}% accurate model')
2296
2363
  model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
2297
2364
  torch.save(model, model_path)
2298
2365
  return model_path
@@ -2303,7 +2370,8 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2303
2370
  return model_path
2304
2371
 
2305
2372
  for threshold in intermedeate_save:
2306
- if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
2373
+ if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
2374
+ print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
2307
2375
  model_path = save_model_at_threshold(threshold, epoch)
2308
2376
  break
2309
2377
  else:
@@ -2311,7 +2379,7 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2311
2379
 
2312
2380
  return model_path
2313
2381
 
2314
- def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2382
+ def _save_progress(dst, results_df, result_type='train'):
2315
2383
  """
2316
2384
  Save the progress of the classification model.
2317
2385
 
@@ -2325,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2325
2393
  """
2326
2394
  # Save accuracy, loss, PRAUC
2327
2395
  os.makedirs(dst, exist_ok=True)
2328
- results_path = os.path.join(dst, 'acc_loss_prauc.csv')
2396
+ results_path = os.path.join(dst, f'{result_type}.csv')
2329
2397
  if not os.path.exists(results_path):
2330
2398
  results_df.to_csv(results_path, index=True, header=True, mode='w')
2331
2399
  else:
2332
2400
  results_df.to_csv(results_path, index=True, header=False, mode='a')
2333
-
2334
- training_metrics_path = os.path.join(dst, 'training_metrics.csv')
2335
- if not os.path.exists(training_metrics_path):
2336
- train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
2337
- else:
2338
- train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
2339
- if epoch == epochs:
2401
+
2402
+ if result_type == 'train':
2340
2403
  read_plot_model_stats(results_path, save=True)
2341
2404
  return
2342
2405
 
spacr/plot.py CHANGED
@@ -1,4 +1,4 @@
1
- import os,re, random, cv2, glob, time, math
1
+ import os,re, random, cv2, glob, time, math, torch
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -125,7 +125,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
125
125
 
126
126
  return
127
127
 
128
- def plot_masks(batch, masks, flows, cmap='inferno', figuresize=20, nr=1, file_type='.npz', print_object_number=True):
128
+ def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
129
129
  """
130
130
  Plot the masks and flows for a given batch of images.
131
131
 
@@ -476,7 +476,7 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
476
476
 
477
477
  return stack
478
478
 
479
- def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
479
+ def plot_arrays(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
480
480
  """
481
481
  Plot randomly selected arrays from a given directory.
482
482
 
@@ -870,7 +870,7 @@ def _save_scimg_plot(src, nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1,
870
870
 
871
871
  return
872
872
 
873
- def _plot_cropped_arrays(stack, filename, figuresize=20, cmap='inferno', threshold=500):
873
+ def _plot_cropped_arrays(stack, filename, figuresize=10, cmap='inferno', threshold=500):
874
874
  """
875
875
  Plot cropped arrays.
876
876
 
@@ -997,7 +997,7 @@ def _display_gif(path):
997
997
  with open(path, 'rb') as file:
998
998
  display(ipyimage(file.read()))
999
999
 
1000
- def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=50):
1000
+ def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=10):
1001
1001
  """
1002
1002
  Plot recruitment data for different conditions and pathogens.
1003
1003
 
@@ -1186,6 +1186,52 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1186
1186
  y = row * img_height + 15
1187
1187
  plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1188
1188
  return fig
1189
+
1190
+ def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
1191
+ """
1192
+ Display multiple images in a grid with corresponding labels.
1193
+
1194
+ Args:
1195
+ img (torch.Tensor): A batch of images as a tensor.
1196
+ labels (list): List of labels corresponding to each image.
1197
+ nrow (int, optional): Number of images per row in the grid. Defaults to 20.
1198
+ color (str, optional): Color of the label text. Defaults to 'white'.
1199
+ fontsize (int, optional): Font size of the label text. Defaults to 12.
1200
+ """
1201
+ if img.is_cuda:
1202
+ img = img.cpu() # Move to CPU if the tensor is on GPU
1203
+
1204
+ n_images = len(labels)
1205
+ n_col = nrow
1206
+ n_row = int(np.ceil(n_images / n_col))
1207
+
1208
+ img_height = img.shape[2] # Height of the image
1209
+ img_width = img.shape[3] # Width of the image
1210
+
1211
+ # Prepare the canvas on CPU
1212
+ canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
1213
+
1214
+ for i in range(n_row):
1215
+ for j in range(n_col):
1216
+ idx = i * n_col + j
1217
+ if idx < n_images:
1218
+ # Place the image on the canvas
1219
+ canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
1220
+
1221
+ canvas = canvas.numpy() # Convert to NumPy for plotting
1222
+
1223
+ fig = plt.figure(figsize=(50, 50))
1224
+ plt.imshow(canvas)
1225
+ plt.axis("off")
1226
+
1227
+ for i, label in enumerate(labels):
1228
+ row = i // n_col
1229
+ col = i % n_col
1230
+ x = col * img_width + 2
1231
+ y = row * img_height + 15
1232
+ plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1233
+
1234
+ return fig
1189
1235
 
1190
1236
  def _plot_histograms_and_stats(df):
1191
1237
  conditions = df['condition'].unique()
spacr/sequencing.py CHANGED
@@ -223,14 +223,10 @@ def save_to_hdf(queue, output_file, complevel=9, compression='zlib'):
223
223
  Save data from a queue to an HDF file.
224
224
 
225
225
  Parameters:
226
- - queue: Queue object
227
- The queue containing the data to be saved.
228
- - output_file: strs
229
- The path to the output HDF file.
230
- - complevel: int, optional
231
- The compression level to use (default is 9).
232
- - compression: str, optional
233
- The compression algorithm to use (default is 'zlib').
226
+ - queue: Queue object containing chunks of data to be saved
227
+ - output_file: Path to the output HDF file
228
+ - complevel: Compression level (default: 9)
229
+ - compression: Compression algorithm (default: 'zlib')
234
230
 
235
231
  Returns:
236
232
  None