spacr 0.2.5__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py CHANGED
@@ -1,9 +1,9 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
  import tifffile
5
- from PIL import Image
6
- from collections import defaultdict, Counter
5
+ from PIL import Image, ImageOps
6
+ from collections import defaultdict, Counter, deque
7
7
  from pathlib import Path
8
8
  from functools import partial
9
9
  from matplotlib.animation import FuncAnimation
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
17
17
  import matplotlib.pyplot as plt
18
18
  from io import BytesIO
19
19
  from IPython.display import display, clear_output
20
- from multiprocessing import Pool, cpu_count
21
- from torch.utils.data import Dataset
20
+ from multiprocessing import Pool, cpu_count, Process, Queue
21
+ from torch.utils.data import Dataset, DataLoader
22
22
  import matplotlib.pyplot as plt
23
23
  from torchvision.transforms import ToTensor
24
24
  import seaborn as sns
25
-
25
+ import atexit
26
26
 
27
27
  from .logger import log_function_call
28
28
 
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
444
444
  # Return both the image and its filename
445
445
  return img, self.filenames[index]
446
446
 
447
- class MyDataset(Dataset):
448
- """
449
- A custom dataset class for loading and processing image data.
450
-
451
- Args:
452
- data_dir (str): The directory path where the image data is stored.
453
- loader_classes (list): A list of class names for the dataset.
454
- transform (callable, optional): A function/transform to apply to the image data. Default is None.
455
- shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
456
- pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
457
- specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
458
- specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
459
- """
460
-
447
+ class spacrDataset(Dataset):
461
448
  def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
462
449
  self.data_dir = data_dir
463
450
  self.classes = loader_classes
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
466
453
  self.pin_memory = pin_memory
467
454
  self.filenames = []
468
455
  self.labels = []
469
-
456
+
470
457
  if specific_files and specific_labels:
471
458
  self.filenames = specific_files
472
459
  self.labels = specific_labels
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
479
466
 
480
467
  if self.shuffle:
481
468
  self.shuffle_dataset()
482
-
469
+
483
470
  if self.pin_memory:
484
- self.images = [self.load_image(f) for f in self.filenames]
485
-
471
+ # Use multiprocessing to load images in parallel
472
+ with Pool(processes=cpu_count()) as pool:
473
+ self.images = pool.map(self.load_image, self.filenames)
474
+ else:
475
+ self.images = None
476
+
486
477
  def load_image(self, img_path):
487
478
  img = Image.open(img_path).convert('RGB')
479
+ img = ImageOps.exif_transpose(img) # Handle image orientation
488
480
  return img
489
-
481
+
490
482
  def __len__(self):
491
483
  return len(self.filenames)
492
-
484
+
493
485
  def shuffle_dataset(self):
494
486
  combined = list(zip(self.filenames, self.labels))
495
487
  random.shuffle(combined)
496
488
  self.filenames, self.labels = zip(*combined)
497
-
489
+
498
490
  def get_plate(self, filepath):
499
- filename = os.path.basename(filepath) # Get just the filename from the full path
491
+ filename = os.path.basename(filepath)
500
492
  return filename.split('_')[0]
501
-
493
+
502
494
  def __getitem__(self, index):
495
+ if self.pin_memory:
496
+ img = self.images[index]
497
+ else:
498
+ img = self.load_image(self.filenames[index])
503
499
  label = self.labels[index]
504
500
  filename = self.filenames[index]
505
- img = self.load_image(filename)
506
501
  if self.transform:
507
502
  img = self.transform(img)
508
503
  return img, label, filename
504
+
505
+ class spacrDataLoader(DataLoader):
506
+ def __init__(self, *args, preload_batches=1, **kwargs):
507
+ super().__init__(*args, **kwargs)
508
+ self.preload_batches = preload_batches
509
+ self.batch_queue = Queue(maxsize=preload_batches)
510
+ self.process = None
511
+ self.current_batch_index = 0
512
+ self._stop_event = False
513
+ self.pin_memory = kwargs.get('pin_memory', False)
514
+ atexit.register(self.cleanup)
515
+
516
+ def _preload_next_batches(self):
517
+ try:
518
+ for _ in range(self.preload_batches):
519
+ if self._stop_event:
520
+ break
521
+ batch = next(self._iterator)
522
+ if self.pin_memory:
523
+ batch = self._pin_memory_batch(batch)
524
+ self.batch_queue.put(batch)
525
+ except StopIteration:
526
+ pass
527
+
528
+ def _start_preloading(self):
529
+ if self.process is None or not self.process.is_alive():
530
+ self._iterator = iter(super().__iter__())
531
+ if not self.pin_memory:
532
+ self.process = Process(target=self._preload_next_batches)
533
+ self.process.start()
534
+ else:
535
+ self._preload_next_batches() # Directly load if pin_memory is True
536
+
537
+ def _pin_memory_batch(self, batch):
538
+ if isinstance(batch, (list, tuple)):
539
+ return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
540
+ elif isinstance(batch, torch.Tensor):
541
+ return batch.pin_memory()
542
+ else:
543
+ return batch
544
+
545
+ def __iter__(self):
546
+ self._start_preloading()
547
+ return self
548
+
549
+ def __next__(self):
550
+ if self.process and not self.process.is_alive() and self.batch_queue.empty():
551
+ raise StopIteration
552
+
553
+ try:
554
+ if self.pin_memory:
555
+ next_batch = self.batch_queue.get(timeout=60)
556
+ else:
557
+ next_batch = self.batch_queue.get(timeout=60)
558
+ self.current_batch_index += 1
559
+
560
+ # Start preloading the next batches
561
+ if self.batch_queue.qsize() < self.preload_batches:
562
+ self._start_preloading()
563
+
564
+ return next_batch
565
+ except queue.Empty:
566
+ raise StopIteration
567
+
568
+ def cleanup(self):
569
+ self._stop_event = True
570
+ if self.process and self.process.is_alive():
571
+ self.process.terminate()
572
+ self.process.join()
573
+
574
+ def __del__(self):
575
+ self.cleanup()
509
576
 
510
577
  class NoClassDataset(Dataset):
511
578
  def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
@@ -2292,18 +2359,27 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2292
2359
 
2293
2360
  def save_model_at_threshold(threshold, epoch, suffix=""):
2294
2361
  percentile = str(threshold * 100)
2295
- print(f'\rfound: {percentile}% accurate model')#, end='\r', flush=True)
2296
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth')
2362
+ print(f'Found: {percentile}% accurate model')
2363
+ model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
2364
+ torch.save(model, model_path)
2365
+ return model_path
2297
2366
 
2298
2367
  if epoch % 100 == 0 or epoch == epochs:
2299
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth')
2368
+ model_path = f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth'
2369
+ torch.save(model, model_path)
2370
+ return model_path
2300
2371
 
2301
2372
  for threshold in intermedeate_save:
2302
- if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
2303
- save_model_at_threshold(threshold, epoch)
2304
- break # Ensure we only save for the highest matching threshold
2373
+ if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
2374
+ print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
2375
+ model_path = save_model_at_threshold(threshold, epoch)
2376
+ break
2377
+ else:
2378
+ model_path = None
2379
+
2380
+ return model_path
2305
2381
 
2306
- def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2382
+ def _save_progress(dst, results_df, result_type='train'):
2307
2383
  """
2308
2384
  Save the progress of the classification model.
2309
2385
 
@@ -2317,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2317
2393
  """
2318
2394
  # Save accuracy, loss, PRAUC
2319
2395
  os.makedirs(dst, exist_ok=True)
2320
- results_path = os.path.join(dst, 'acc_loss_prauc.csv')
2396
+ results_path = os.path.join(dst, f'{result_type}.csv')
2321
2397
  if not os.path.exists(results_path):
2322
2398
  results_df.to_csv(results_path, index=True, header=True, mode='w')
2323
2399
  else:
2324
2400
  results_df.to_csv(results_path, index=True, header=False, mode='a')
2325
-
2326
- training_metrics_path = os.path.join(dst, 'training_metrics.csv')
2327
- if not os.path.exists(training_metrics_path):
2328
- train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
2329
- else:
2330
- train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
2331
- if epoch == epochs:
2401
+
2402
+ if result_type == 'train':
2332
2403
  read_plot_model_stats(results_path, save=True)
2333
2404
  return
2334
2405
 
spacr/measure.py CHANGED
@@ -1060,6 +1060,7 @@ def measure_crop(settings):
1060
1060
  files = [f for f in os.listdir(settings['src']) if f.endswith('.npy')]
1061
1061
  n_jobs = settings['n_jobs']
1062
1062
  print(f'using {n_jobs} cpu cores')
1063
+ print_progress(files_processed=0, files_to_process=len(files), n_jobs=n_jobs, time_ls=[], operation_type='Measure and Crop')
1063
1064
 
1064
1065
  def job_callback(result):
1065
1066
  completed_jobs.add(result[0])
spacr/plot.py CHANGED
@@ -1,4 +1,4 @@
1
- import os,re, random, cv2, glob, time, math
1
+ import os,re, random, cv2, glob, time, math, torch
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -125,7 +125,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
125
125
 
126
126
  return
127
127
 
128
- def plot_masks(batch, masks, flows, cmap='inferno', figuresize=20, nr=1, file_type='.npz', print_object_number=True):
128
+ def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
129
129
  """
130
130
  Plot the masks and flows for a given batch of images.
131
131
 
@@ -476,7 +476,7 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
476
476
 
477
477
  return stack
478
478
 
479
- def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
479
+ def plot_arrays(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
480
480
  """
481
481
  Plot randomly selected arrays from a given directory.
482
482
 
@@ -870,7 +870,7 @@ def _save_scimg_plot(src, nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1,
870
870
 
871
871
  return
872
872
 
873
- def _plot_cropped_arrays(stack, filename, figuresize=20, cmap='inferno', threshold=500):
873
+ def _plot_cropped_arrays(stack, filename, figuresize=10, cmap='inferno', threshold=500):
874
874
  """
875
875
  Plot cropped arrays.
876
876
 
@@ -997,7 +997,7 @@ def _display_gif(path):
997
997
  with open(path, 'rb') as file:
998
998
  display(ipyimage(file.read()))
999
999
 
1000
- def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=50):
1000
+ def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=10):
1001
1001
  """
1002
1002
  Plot recruitment data for different conditions and pathogens.
1003
1003
 
@@ -1186,6 +1186,52 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1186
1186
  y = row * img_height + 15
1187
1187
  plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1188
1188
  return fig
1189
+
1190
+ def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
1191
+ """
1192
+ Display multiple images in a grid with corresponding labels.
1193
+
1194
+ Args:
1195
+ img (torch.Tensor): A batch of images as a tensor.
1196
+ labels (list): List of labels corresponding to each image.
1197
+ nrow (int, optional): Number of images per row in the grid. Defaults to 20.
1198
+ color (str, optional): Color of the label text. Defaults to 'white'.
1199
+ fontsize (int, optional): Font size of the label text. Defaults to 12.
1200
+ """
1201
+ if img.is_cuda:
1202
+ img = img.cpu() # Move to CPU if the tensor is on GPU
1203
+
1204
+ n_images = len(labels)
1205
+ n_col = nrow
1206
+ n_row = int(np.ceil(n_images / n_col))
1207
+
1208
+ img_height = img.shape[2] # Height of the image
1209
+ img_width = img.shape[3] # Width of the image
1210
+
1211
+ # Prepare the canvas on CPU
1212
+ canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
1213
+
1214
+ for i in range(n_row):
1215
+ for j in range(n_col):
1216
+ idx = i * n_col + j
1217
+ if idx < n_images:
1218
+ # Place the image on the canvas
1219
+ canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
1220
+
1221
+ canvas = canvas.numpy() # Convert to NumPy for plotting
1222
+
1223
+ fig = plt.figure(figsize=(50, 50))
1224
+ plt.imshow(canvas)
1225
+ plt.axis("off")
1226
+
1227
+ for i, label in enumerate(labels):
1228
+ row = i // n_col
1229
+ col = i % n_col
1230
+ x = col * img_width + 2
1231
+ y = row * img_height + 15
1232
+ plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1233
+
1234
+ return fig
1189
1235
 
1190
1236
  def _plot_histograms_and_stats(df):
1191
1237
  conditions = df['condition'].unique()