spacr 0.2.5__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +1 -11
- spacr/core.py +226 -287
- spacr/deep_spacr.py +248 -269
- spacr/gui.py +41 -19
- spacr/gui_core.py +404 -151
- spacr/gui_elements.py +778 -179
- spacr/gui_utils.py +163 -106
- spacr/io.py +116 -45
- spacr/measure.py +1 -0
- spacr/plot.py +51 -5
- spacr/sequencing.py +477 -587
- spacr/settings.py +211 -66
- spacr/utils.py +34 -14
- {spacr-0.2.5.dist-info → spacr-0.2.8.dist-info}/METADATA +46 -39
- {spacr-0.2.5.dist-info → spacr-0.2.8.dist-info}/RECORD +19 -19
- {spacr-0.2.5.dist-info → spacr-0.2.8.dist-info}/WHEEL +1 -1
- {spacr-0.2.5.dist-info → spacr-0.2.8.dist-info}/LICENSE +0 -0
- {spacr-0.2.5.dist-info → spacr-0.2.8.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.5.dist-info → spacr-0.2.8.dist-info}/top_level.txt +0 -0
spacr/io.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
|
1
|
+
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
4
|
import tifffile
|
5
|
-
from PIL import Image
|
6
|
-
from collections import defaultdict, Counter
|
5
|
+
from PIL import Image, ImageOps
|
6
|
+
from collections import defaultdict, Counter, deque
|
7
7
|
from pathlib import Path
|
8
8
|
from functools import partial
|
9
9
|
from matplotlib.animation import FuncAnimation
|
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
|
|
17
17
|
import matplotlib.pyplot as plt
|
18
18
|
from io import BytesIO
|
19
19
|
from IPython.display import display, clear_output
|
20
|
-
from multiprocessing import Pool, cpu_count
|
21
|
-
from torch.utils.data import Dataset
|
20
|
+
from multiprocessing import Pool, cpu_count, Process, Queue
|
21
|
+
from torch.utils.data import Dataset, DataLoader
|
22
22
|
import matplotlib.pyplot as plt
|
23
23
|
from torchvision.transforms import ToTensor
|
24
24
|
import seaborn as sns
|
25
|
-
|
25
|
+
import atexit
|
26
26
|
|
27
27
|
from .logger import log_function_call
|
28
28
|
|
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
|
|
444
444
|
# Return both the image and its filename
|
445
445
|
return img, self.filenames[index]
|
446
446
|
|
447
|
-
class
|
448
|
-
"""
|
449
|
-
A custom dataset class for loading and processing image data.
|
450
|
-
|
451
|
-
Args:
|
452
|
-
data_dir (str): The directory path where the image data is stored.
|
453
|
-
loader_classes (list): A list of class names for the dataset.
|
454
|
-
transform (callable, optional): A function/transform to apply to the image data. Default is None.
|
455
|
-
shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
|
456
|
-
pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
|
457
|
-
specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
|
458
|
-
specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
|
459
|
-
"""
|
460
|
-
|
447
|
+
class spacrDataset(Dataset):
|
461
448
|
def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
|
462
449
|
self.data_dir = data_dir
|
463
450
|
self.classes = loader_classes
|
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
|
|
466
453
|
self.pin_memory = pin_memory
|
467
454
|
self.filenames = []
|
468
455
|
self.labels = []
|
469
|
-
|
456
|
+
|
470
457
|
if specific_files and specific_labels:
|
471
458
|
self.filenames = specific_files
|
472
459
|
self.labels = specific_labels
|
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
|
|
479
466
|
|
480
467
|
if self.shuffle:
|
481
468
|
self.shuffle_dataset()
|
482
|
-
|
469
|
+
|
483
470
|
if self.pin_memory:
|
484
|
-
|
485
|
-
|
471
|
+
# Use multiprocessing to load images in parallel
|
472
|
+
with Pool(processes=cpu_count()) as pool:
|
473
|
+
self.images = pool.map(self.load_image, self.filenames)
|
474
|
+
else:
|
475
|
+
self.images = None
|
476
|
+
|
486
477
|
def load_image(self, img_path):
|
487
478
|
img = Image.open(img_path).convert('RGB')
|
479
|
+
img = ImageOps.exif_transpose(img) # Handle image orientation
|
488
480
|
return img
|
489
|
-
|
481
|
+
|
490
482
|
def __len__(self):
|
491
483
|
return len(self.filenames)
|
492
|
-
|
484
|
+
|
493
485
|
def shuffle_dataset(self):
|
494
486
|
combined = list(zip(self.filenames, self.labels))
|
495
487
|
random.shuffle(combined)
|
496
488
|
self.filenames, self.labels = zip(*combined)
|
497
|
-
|
489
|
+
|
498
490
|
def get_plate(self, filepath):
|
499
|
-
filename = os.path.basename(filepath)
|
491
|
+
filename = os.path.basename(filepath)
|
500
492
|
return filename.split('_')[0]
|
501
|
-
|
493
|
+
|
502
494
|
def __getitem__(self, index):
|
495
|
+
if self.pin_memory:
|
496
|
+
img = self.images[index]
|
497
|
+
else:
|
498
|
+
img = self.load_image(self.filenames[index])
|
503
499
|
label = self.labels[index]
|
504
500
|
filename = self.filenames[index]
|
505
|
-
img = self.load_image(filename)
|
506
501
|
if self.transform:
|
507
502
|
img = self.transform(img)
|
508
503
|
return img, label, filename
|
504
|
+
|
505
|
+
class spacrDataLoader(DataLoader):
|
506
|
+
def __init__(self, *args, preload_batches=1, **kwargs):
|
507
|
+
super().__init__(*args, **kwargs)
|
508
|
+
self.preload_batches = preload_batches
|
509
|
+
self.batch_queue = Queue(maxsize=preload_batches)
|
510
|
+
self.process = None
|
511
|
+
self.current_batch_index = 0
|
512
|
+
self._stop_event = False
|
513
|
+
self.pin_memory = kwargs.get('pin_memory', False)
|
514
|
+
atexit.register(self.cleanup)
|
515
|
+
|
516
|
+
def _preload_next_batches(self):
|
517
|
+
try:
|
518
|
+
for _ in range(self.preload_batches):
|
519
|
+
if self._stop_event:
|
520
|
+
break
|
521
|
+
batch = next(self._iterator)
|
522
|
+
if self.pin_memory:
|
523
|
+
batch = self._pin_memory_batch(batch)
|
524
|
+
self.batch_queue.put(batch)
|
525
|
+
except StopIteration:
|
526
|
+
pass
|
527
|
+
|
528
|
+
def _start_preloading(self):
|
529
|
+
if self.process is None or not self.process.is_alive():
|
530
|
+
self._iterator = iter(super().__iter__())
|
531
|
+
if not self.pin_memory:
|
532
|
+
self.process = Process(target=self._preload_next_batches)
|
533
|
+
self.process.start()
|
534
|
+
else:
|
535
|
+
self._preload_next_batches() # Directly load if pin_memory is True
|
536
|
+
|
537
|
+
def _pin_memory_batch(self, batch):
|
538
|
+
if isinstance(batch, (list, tuple)):
|
539
|
+
return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
|
540
|
+
elif isinstance(batch, torch.Tensor):
|
541
|
+
return batch.pin_memory()
|
542
|
+
else:
|
543
|
+
return batch
|
544
|
+
|
545
|
+
def __iter__(self):
|
546
|
+
self._start_preloading()
|
547
|
+
return self
|
548
|
+
|
549
|
+
def __next__(self):
|
550
|
+
if self.process and not self.process.is_alive() and self.batch_queue.empty():
|
551
|
+
raise StopIteration
|
552
|
+
|
553
|
+
try:
|
554
|
+
if self.pin_memory:
|
555
|
+
next_batch = self.batch_queue.get(timeout=60)
|
556
|
+
else:
|
557
|
+
next_batch = self.batch_queue.get(timeout=60)
|
558
|
+
self.current_batch_index += 1
|
559
|
+
|
560
|
+
# Start preloading the next batches
|
561
|
+
if self.batch_queue.qsize() < self.preload_batches:
|
562
|
+
self._start_preloading()
|
563
|
+
|
564
|
+
return next_batch
|
565
|
+
except queue.Empty:
|
566
|
+
raise StopIteration
|
567
|
+
|
568
|
+
def cleanup(self):
|
569
|
+
self._stop_event = True
|
570
|
+
if self.process and self.process.is_alive():
|
571
|
+
self.process.terminate()
|
572
|
+
self.process.join()
|
573
|
+
|
574
|
+
def __del__(self):
|
575
|
+
self.cleanup()
|
509
576
|
|
510
577
|
class NoClassDataset(Dataset):
|
511
578
|
def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
|
@@ -2292,18 +2359,27 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2292
2359
|
|
2293
2360
|
def save_model_at_threshold(threshold, epoch, suffix=""):
|
2294
2361
|
percentile = str(threshold * 100)
|
2295
|
-
print(f'
|
2296
|
-
|
2362
|
+
print(f'Found: {percentile}% accurate model')
|
2363
|
+
model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
|
2364
|
+
torch.save(model, model_path)
|
2365
|
+
return model_path
|
2297
2366
|
|
2298
2367
|
if epoch % 100 == 0 or epoch == epochs:
|
2299
|
-
|
2368
|
+
model_path = f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth'
|
2369
|
+
torch.save(model, model_path)
|
2370
|
+
return model_path
|
2300
2371
|
|
2301
2372
|
for threshold in intermedeate_save:
|
2302
|
-
if results_df['neg_accuracy']
|
2303
|
-
|
2304
|
-
|
2373
|
+
if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
|
2374
|
+
print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
|
2375
|
+
model_path = save_model_at_threshold(threshold, epoch)
|
2376
|
+
break
|
2377
|
+
else:
|
2378
|
+
model_path = None
|
2379
|
+
|
2380
|
+
return model_path
|
2305
2381
|
|
2306
|
-
def _save_progress(dst, results_df,
|
2382
|
+
def _save_progress(dst, results_df, result_type='train'):
|
2307
2383
|
"""
|
2308
2384
|
Save the progress of the classification model.
|
2309
2385
|
|
@@ -2317,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
|
|
2317
2393
|
"""
|
2318
2394
|
# Save accuracy, loss, PRAUC
|
2319
2395
|
os.makedirs(dst, exist_ok=True)
|
2320
|
-
results_path = os.path.join(dst, '
|
2396
|
+
results_path = os.path.join(dst, f'{result_type}.csv')
|
2321
2397
|
if not os.path.exists(results_path):
|
2322
2398
|
results_df.to_csv(results_path, index=True, header=True, mode='w')
|
2323
2399
|
else:
|
2324
2400
|
results_df.to_csv(results_path, index=True, header=False, mode='a')
|
2325
|
-
|
2326
|
-
|
2327
|
-
if not os.path.exists(training_metrics_path):
|
2328
|
-
train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
|
2329
|
-
else:
|
2330
|
-
train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
|
2331
|
-
if epoch == epochs:
|
2401
|
+
|
2402
|
+
if result_type == 'train':
|
2332
2403
|
read_plot_model_stats(results_path, save=True)
|
2333
2404
|
return
|
2334
2405
|
|
spacr/measure.py
CHANGED
@@ -1060,6 +1060,7 @@ def measure_crop(settings):
|
|
1060
1060
|
files = [f for f in os.listdir(settings['src']) if f.endswith('.npy')]
|
1061
1061
|
n_jobs = settings['n_jobs']
|
1062
1062
|
print(f'using {n_jobs} cpu cores')
|
1063
|
+
print_progress(files_processed=0, files_to_process=len(files), n_jobs=n_jobs, time_ls=[], operation_type='Measure and Crop')
|
1063
1064
|
|
1064
1065
|
def job_callback(result):
|
1065
1066
|
completed_jobs.add(result[0])
|
spacr/plot.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os,re, random, cv2, glob, time, math
|
1
|
+
import os,re, random, cv2, glob, time, math, torch
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
@@ -125,7 +125,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
125
125
|
|
126
126
|
return
|
127
127
|
|
128
|
-
def plot_masks(batch, masks, flows, cmap='inferno', figuresize=
|
128
|
+
def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
|
129
129
|
"""
|
130
130
|
Plot the masks and flows for a given batch of images.
|
131
131
|
|
@@ -476,7 +476,7 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
|
|
476
476
|
|
477
477
|
return stack
|
478
478
|
|
479
|
-
def plot_arrays(src, figuresize=
|
479
|
+
def plot_arrays(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
|
480
480
|
"""
|
481
481
|
Plot randomly selected arrays from a given directory.
|
482
482
|
|
@@ -870,7 +870,7 @@ def _save_scimg_plot(src, nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1,
|
|
870
870
|
|
871
871
|
return
|
872
872
|
|
873
|
-
def _plot_cropped_arrays(stack, filename, figuresize=
|
873
|
+
def _plot_cropped_arrays(stack, filename, figuresize=10, cmap='inferno', threshold=500):
|
874
874
|
"""
|
875
875
|
Plot cropped arrays.
|
876
876
|
|
@@ -997,7 +997,7 @@ def _display_gif(path):
|
|
997
997
|
with open(path, 'rb') as file:
|
998
998
|
display(ipyimage(file.read()))
|
999
999
|
|
1000
|
-
def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=
|
1000
|
+
def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=10):
|
1001
1001
|
"""
|
1002
1002
|
Plot recruitment data for different conditions and pathogens.
|
1003
1003
|
|
@@ -1186,6 +1186,52 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
|
1186
1186
|
y = row * img_height + 15
|
1187
1187
|
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1188
1188
|
return fig
|
1189
|
+
|
1190
|
+
def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
|
1191
|
+
"""
|
1192
|
+
Display multiple images in a grid with corresponding labels.
|
1193
|
+
|
1194
|
+
Args:
|
1195
|
+
img (torch.Tensor): A batch of images as a tensor.
|
1196
|
+
labels (list): List of labels corresponding to each image.
|
1197
|
+
nrow (int, optional): Number of images per row in the grid. Defaults to 20.
|
1198
|
+
color (str, optional): Color of the label text. Defaults to 'white'.
|
1199
|
+
fontsize (int, optional): Font size of the label text. Defaults to 12.
|
1200
|
+
"""
|
1201
|
+
if img.is_cuda:
|
1202
|
+
img = img.cpu() # Move to CPU if the tensor is on GPU
|
1203
|
+
|
1204
|
+
n_images = len(labels)
|
1205
|
+
n_col = nrow
|
1206
|
+
n_row = int(np.ceil(n_images / n_col))
|
1207
|
+
|
1208
|
+
img_height = img.shape[2] # Height of the image
|
1209
|
+
img_width = img.shape[3] # Width of the image
|
1210
|
+
|
1211
|
+
# Prepare the canvas on CPU
|
1212
|
+
canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
|
1213
|
+
|
1214
|
+
for i in range(n_row):
|
1215
|
+
for j in range(n_col):
|
1216
|
+
idx = i * n_col + j
|
1217
|
+
if idx < n_images:
|
1218
|
+
# Place the image on the canvas
|
1219
|
+
canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
|
1220
|
+
|
1221
|
+
canvas = canvas.numpy() # Convert to NumPy for plotting
|
1222
|
+
|
1223
|
+
fig = plt.figure(figsize=(50, 50))
|
1224
|
+
plt.imshow(canvas)
|
1225
|
+
plt.axis("off")
|
1226
|
+
|
1227
|
+
for i, label in enumerate(labels):
|
1228
|
+
row = i // n_col
|
1229
|
+
col = i % n_col
|
1230
|
+
x = col * img_width + 2
|
1231
|
+
y = row * img_height + 15
|
1232
|
+
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1233
|
+
|
1234
|
+
return fig
|
1189
1235
|
|
1190
1236
|
def _plot_histograms_and_stats(df):
|
1191
1237
|
conditions = df['condition'].unique()
|