spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
  import tifffile
@@ -19,7 +19,6 @@ from io import BytesIO
19
19
  from IPython.display import display, clear_output
20
20
  from multiprocessing import Pool, cpu_count
21
21
  from torch.utils.data import Dataset
22
- import seaborn as sns
23
22
  import matplotlib.pyplot as plt
24
23
  from torchvision.transforms import ToTensor
25
24
 
@@ -45,19 +44,19 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
45
44
 
46
45
  if not image_files is None and not label_files is None:
47
46
  for img_file, lbl_file in zip(image_files, label_files):
48
- image = cellpose.imread(img_file)
47
+ image = cellpose.io.imread(img_file)
49
48
  if invert:
50
49
  image = invert_image(image)
51
50
  if circular:
52
51
  image = apply_mask(image, output_value=0)
53
- label = cellpose.imread(lbl_file)
52
+ label = cellpose.io.imread(lbl_file)
54
53
  if image.max() > 1:
55
54
  image = image / image.max()
56
55
  images.append(image)
57
56
  labels.append(label)
58
57
  elif not image_files is None:
59
58
  for img_file in image_files:
60
- image = cellpose.imread(img_file)
59
+ image = cellpose.io.imread(img_file)
61
60
  if invert:
62
61
  image = invert_image(image)
63
62
  if circular:
@@ -67,7 +66,7 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
67
66
  images.append(image)
68
67
  elif not image_files is None:
69
68
  for lbl_file in label_files:
70
- label = cellpose.imread(lbl_file)
69
+ label = cellpose.io.imread(lbl_file)
71
70
  if circular:
72
71
  label = apply_mask(label, output_value=0)
73
72
  labels.append(label)
@@ -88,15 +87,13 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
88
87
  print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {labels[0].shape}, image type: labels[0].shape')
89
88
  return images, labels, image_names, label_names
90
89
 
91
- def _load_normalized_images_and_labels(image_files, label_files, signal_thresholds=[1000], channels=None, percentiles=None, circular=False, invert=False, visualize=False):
90
+ def _load_normalized_images_and_labels(image_files, label_files, channels=None, percentiles=None, circular=False, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10):
92
91
 
93
92
  from .plot import normalize_and_visualize
94
93
  from .utils import invert_image, apply_mask
95
-
96
- if isinstance(signal_thresholds, int):
97
- signal_thresholds = [signal_thresholds] * (len(channels) if channels is not None else 1)
98
- elif not isinstance(signal_thresholds, list):
99
- signal_thresholds = [signal_thresholds]
94
+
95
+ signal_thresholds = background*Signal_to_noise
96
+ lower_percentile = 2
100
97
 
101
98
  images = []
102
99
  labels = []
@@ -109,18 +106,22 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
109
106
 
110
107
  if label_files is not None:
111
108
  label_names = [os.path.basename(f) for f in label_files]
109
+ label_dir = os.path.dirname(label_files[0])
112
110
 
113
111
  # Load images and check percentiles
114
112
  for i,img_file in enumerate(image_files):
115
- image = cellpose.imread(img_file)
113
+ image = cellpose.io.imread(img_file)
116
114
  if invert:
117
115
  image = invert_image(image)
118
116
  if circular:
119
117
  image = apply_mask(image, output_value=0)
120
-
118
+
121
119
  # If specific channels are specified, select them
122
120
  if channels is not None and image.ndim == 3:
123
121
  image = image[..., channels]
122
+
123
+ if remove_background:
124
+ image[image < background] = 0
124
125
 
125
126
  if image.ndim < 3:
126
127
  image = np.expand_dims(image, axis=-1)
@@ -128,11 +129,11 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
128
129
  images.append(image)
129
130
  if percentiles is None:
130
131
  for c in range(image.shape[-1]):
131
- p1 = np.percentile(image[..., c], 1)
132
+ p1 = np.percentile(image[..., c], lower_percentile)
132
133
  percentiles_1[c].append(p1)
133
- for percentile in [99, 99.9, 99.99, 99.999]:
134
+ for percentile in [98, 99, 99.9, 99.99, 99.999]:
134
135
  p = np.percentile(image[..., c], percentile)
135
- if p > signal_thresholds[min(c, len(signal_thresholds)-1)]:
136
+ if p > signal_thresholds:
136
137
  percentiles_99[c].append(p)
137
138
  break
138
139
 
@@ -141,8 +142,8 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
141
142
  for image in images:
142
143
  normalized_image = np.zeros_like(image, dtype=np.float32)
143
144
  for c in range(image.shape[-1]):
144
- high_p = np.percentile(image[..., c], percentiles[1])
145
145
  low_p = np.percentile(image[..., c], percentiles[0])
146
+ high_p = np.percentile(image[..., c], percentiles[1])
146
147
  normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(low_p, high_p), out_range=(0, 1))
147
148
  normalized_images.append(normalized_image)
148
149
  if visualize:
@@ -153,23 +154,26 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
153
154
  avg_p1 = [np.mean(p) for p in percentiles_1]
154
155
  avg_p99 = [np.mean(p) if len(p) > 0 else np.mean(percentiles_1[i]) for i, p in enumerate(percentiles_99)]
155
156
 
157
+ print(f'Average 1st percentiles: {avg_p1}, Average 99th percentiles: {avg_p99}')
158
+
156
159
  normalized_images = []
157
160
  for image in images:
158
161
  normalized_image = np.zeros_like(image, dtype=np.float32)
159
- for c in range(image.shape[-1]):
160
- normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
161
- normalized_images.append(normalized_image)
162
- if visualize:
163
- normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
162
+ for c in range(image.shape[-1]):
163
+ normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
164
+ normalized_images.append(normalized_image)
165
+ if visualize:
166
+ normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
164
167
 
165
168
  if not image_files is None:
166
169
  image_dir = os.path.dirname(image_files[0])
170
+
167
171
  else:
168
172
  image_dir = None
169
173
 
170
174
  if label_files is not None:
171
175
  for lbl_file in label_files:
172
- labels.append(cellpose.imread(lbl_file))
176
+ labels.append(cellpose.io.imread(lbl_file))
173
177
  else:
174
178
  label_names = []
175
179
  label_dir = None
@@ -178,86 +182,8 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
178
182
 
179
183
  return normalized_images, labels, image_names, label_names
180
184
 
181
- class MyDataset(Dataset):
182
- """
183
- Custom dataset class for loading and processing image data.
184
-
185
- Args:
186
- data_dir (str): The directory path where the data is stored.
187
- loader_classes (list): List of class names.
188
- transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default is None.
189
- shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
190
- load_to_memory (bool, optional): Whether to load images into memory. Default is False.
191
-
192
- Attributes:
193
- data_dir (str): The directory path where the data is stored.
194
- classes (list): List of class names.
195
- transform (callable): A function/transform that takes in an PIL image and returns a transformed version.
196
- shuffle (bool): Whether to shuffle the dataset.
197
- load_to_memory (bool): Whether to load images into memory.
198
- filenames (list): List of file paths.
199
- labels (list): List of labels corresponding to each file.
200
- images (list): List of loaded images.
201
- image_cache (Cache): Cache object for storing loaded images.
202
-
203
- Methods:
204
- load_image: Load an image from file.
205
- __len__: Get the length of the dataset.
206
- shuffle_dataset: Shuffle the dataset.
207
- __getitem__: Get an item from the dataset.
208
-
209
- """
210
-
211
- def _init__(self, data_dir, loader_classes, transform=None, shuffle=True, load_to_memory=False):
212
- from .utils import Cache
213
- self.data_dir = data_dir
214
- self.classes = loader_classes
215
- self.transform = transform
216
- self.shuffle = shuffle
217
- self.load_to_memory = load_to_memory
218
- self.filenames = []
219
- self.labels = []
220
- self.images = []
221
- self.image_cache = Cache(50)
222
- for class_name in self.classes:
223
- class_path = os.path.join(data_dir, class_name)
224
- class_files = [os.path.join(class_path, f) for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]
225
- self.filenames.extend(class_files)
226
- self.labels.extend([self.classes.index(class_name)] * len(class_files))
227
- if self.shuffle:
228
- self.shuffle_dataset()
229
- if self.load_to_memory:
230
- self.images = [self.load_image(f) for f in self.filenames]
231
-
232
- def load_image(self, img_path):
233
- img = self.image_cache.get(img_path)
234
- if img is None:
235
- img = Image.open(img_path).convert('RGB')
236
- self.image_cache.put(img_path, img)
237
- return img
238
-
239
- def _len__(self):
240
- return len(self.filenames)
241
-
242
- def shuffle_dataset(self):
243
- combined = list(zip(self.filenames, self.labels))
244
- random.shuffle(combined)
245
- self.filenames, self.labels = zip(*combined)
246
-
247
- def _getitem__(self, index):
248
- label = self.labels[index]
249
- filename = self.filenames[index]
250
- if self.load_to_memory:
251
- img = self.images[index]
252
- else:
253
- img = self.load_image(filename)
254
- if self.transform is not None:
255
- img = self.transform(img)
256
- else:
257
- img = ToTensor()(img)
258
- return img, label, filename
259
-
260
185
  class CombineLoaders:
186
+
261
187
  """
262
188
  A class that combines multiple data loaders into a single iterator.
263
189
 
@@ -398,7 +324,7 @@ class MyDataset(Dataset):
398
324
  specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
399
325
  """
400
326
 
401
- def _init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
327
+ def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
402
328
  self.data_dir = data_dir
403
329
  self.classes = loader_classes
404
330
  self.transform = transform
@@ -427,7 +353,7 @@ class MyDataset(Dataset):
427
353
  img = Image.open(img_path).convert('RGB')
428
354
  return img
429
355
 
430
- def _len__(self):
356
+ def __len__(self):
431
357
  return len(self.filenames)
432
358
 
433
359
  def shuffle_dataset(self):
@@ -439,7 +365,7 @@ class MyDataset(Dataset):
439
365
  filename = os.path.basename(filepath) # Get just the filename from the full path
440
366
  return filename.split('_')[0]
441
367
 
442
- def _getitem__(self, index):
368
+ def __getitem__(self, index):
443
369
  label = self.labels[index]
444
370
  filename = self.filenames[index]
445
371
  img = self.load_image(filename)
@@ -527,7 +453,7 @@ class TarImageDataset(Dataset):
527
453
 
528
454
  return img, m.name
529
455
 
530
- @log_function_call
456
+ #@log_function_call
531
457
  def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=False, skip_mode='01', metadata_type='', img_format='.tif'):
532
458
  """
533
459
  Convert z-stack images to maximum intensity projection (MIP) images.
@@ -600,40 +526,47 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
600
526
  shutil.move(os.path.join(src, filename), move)
601
527
  return
602
528
 
603
- def _merge_file(chan_dirs, stack_dir, file):
529
+ def _merge_file(chan_dirs, stack_dir, file_name):
604
530
  """
605
- Merge multiple channels into a single stack and save it as a numpy array.
606
-
531
+ Merge multiple channels into a single stack and save it as a numpy array, using os module for path handling.
532
+
607
533
  Args:
608
534
  chan_dirs (list): List of directories containing channel images.
609
535
  stack_dir (str): Directory to save the merged stack.
610
- file (str): File name of the channel image.
536
+ file_name (str): File name of the channel image.
611
537
 
612
538
  Returns:
613
539
  None
614
540
  """
615
- chan1 = cv2.imread(str(file), -1)
616
- chan1 = np.expand_dims(chan1, axis=2)
617
- new_file = stack_dir / (file.stem + '.npy')
618
- if not new_file.exists():
619
- stack_dir.mkdir(exist_ok=True)
620
- channels = [chan1]
621
- for chan_dir in chan_dirs[1:]:
622
- img = cv2.imread(str(chan_dir / file.name), -1)
541
+ # Construct new file path
542
+ file_root, file_ext = os.path.splitext(file_name)
543
+ new_file = os.path.join(stack_dir, file_root + '.npy')
544
+
545
+ # Check if the new file exists and create the stack directory if it doesn't
546
+ if not os.path.exists(new_file):
547
+ os.makedirs(stack_dir, exist_ok=True)
548
+ channels = []
549
+ for i, chan_dir in enumerate(chan_dirs):
550
+ img_path = os.path.join(chan_dir, file_name)
551
+ img = cv2.imread(img_path, -1)
552
+ if img is None:
553
+ print(f"Warning: Failed to read image {img_path}")
554
+ continue
623
555
  chan = np.expand_dims(img, axis=2)
624
556
  channels.append(chan)
625
- stack = np.concatenate(channels, axis=2)
626
- np.save(new_file, stack)
557
+ del img # Explicitly delete the reference to the image to free up memory
558
+ if i % 10 == 0: # Periodically suggest garbage collection
559
+ gc.collect()
560
+
561
+ if channels:
562
+ stack = np.concatenate(channels, axis=2)
563
+ np.save(new_file, stack)
564
+ else:
565
+ print(f"No valid channels to merge for file {file_name}")
627
566
 
628
567
  def _is_dir_empty(dir_path):
629
568
  """
630
- Check if a directory is empty.
631
-
632
- Args:
633
- dir_path (str): The path to the directory.
634
-
635
- Returns:
636
- bool: True if the directory is empty, False otherwise.
569
+ Check if a directory is empty using os module.
637
570
  """
638
571
  return len(os.listdir(dir_path)) == 0
639
572
 
@@ -733,7 +666,7 @@ def _move_to_chan_folder(src, regex, timelapse=False, metadata_type=''):
733
666
  shutil.move(os.path.join(src, filename), move)
734
667
  return
735
668
 
736
- def _merge_channels(src, plot=False):
669
+ def _merge_channels_v2(src, plot=False):
737
670
  from .plot import plot_arrays
738
671
  """
739
672
  Merge the channels in the given source directory and save the merged files in a 'stack' directory.
@@ -761,7 +694,8 @@ def _merge_channels(src, plot=False):
761
694
  print(f'generated folder with merged arrays: {stack_dir}')
762
695
 
763
696
  if _is_dir_empty(stack_dir):
764
- with Pool(cpu_count()) as pool:
697
+ with Pool(max(cpu_count() // 2, 1)) as pool:
698
+ #with Pool(cpu_count()) as pool:
765
699
  merge_func = partial(_merge_file, chan_dirs, stack_dir)
766
700
  pool.map(merge_func, dir_files)
767
701
 
@@ -773,6 +707,47 @@ def _merge_channels(src, plot=False):
773
707
 
774
708
  return
775
709
 
710
+ def _merge_channels(src, plot=False):
711
+ """
712
+ Merge the channels in the given source directory and save the merged files in a 'stack' directory without using multiprocessing.
713
+ """
714
+
715
+ from .plot import plot_arrays
716
+
717
+ stack_dir = os.path.join(src, 'stack')
718
+ allowed_names = ['01', '02', '03', '04', '00', '1', '2', '3', '4', '0']
719
+
720
+ # List directories that match the allowed names
721
+ chan_dirs = [d for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d in allowed_names]
722
+ chan_dirs.sort()
723
+
724
+ print(f'List of folders in src: {chan_dirs}. Single channel folders.')
725
+ start_time = time.time()
726
+
727
+ # Assuming chan_dirs[0] is not empty and exists, adjust according to your logic
728
+ first_dir_path = os.path.join(src, chan_dirs[0])
729
+ dir_files = os.listdir(first_dir_path)
730
+
731
+ # Create the 'stack' directory if it doesn't exist
732
+ if not os.path.exists(stack_dir):
733
+ os.makedirs(stack_dir, exist_ok=True)
734
+ print(f'Generated folder with merged arrays: {stack_dir}')
735
+
736
+ if _is_dir_empty(stack_dir):
737
+ for file_name in dir_files:
738
+ full_file_path = os.path.join(first_dir_path, file_name)
739
+ if os.path.isfile(full_file_path):
740
+ _merge_file([os.path.join(src, d) for d in chan_dirs], stack_dir, file_name)
741
+
742
+ elapsed_time = time.time() - start_time
743
+ avg_time = elapsed_time / len(dir_files) if dir_files else 0
744
+ print(f'Average Time: {avg_time:.3f} sec, Total Elapsed Time: {elapsed_time:.3f} sec')
745
+
746
+ if plot:
747
+ plot_arrays(os.path.join(src, 'stack'))
748
+
749
+ return
750
+
776
751
  def _mip_all(src, include_first_chan=True):
777
752
 
778
753
  """
@@ -821,7 +796,7 @@ def _mip_all(src, include_first_chan=True):
821
796
  np.save(os.path.join(src, filename), concatenated)
822
797
  return
823
798
 
824
- @log_function_call
799
+ #@log_function_call
825
800
  def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_size=100):
826
801
  """
827
802
  Concatenates channel data from multiple files and saves the concatenated data as numpy arrays.
@@ -912,6 +887,223 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
912
887
  print(f'All files concatenated and saved to:{channel_stack_loc}')
913
888
  return channel_stack_loc
914
889
 
890
+ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, batch_size=100, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
891
+ """
892
+ Concatenates and normalizes channel data from multiple files and saves the normalized data.
893
+
894
+ Args:
895
+ src (str): The source directory containing the channel data files.
896
+ channels (list): The list of channel indices to be concatenated and normalized.
897
+ randomize (bool, optional): Whether to randomize the order of the files. Defaults to True.
898
+ timelapse (bool, optional): Whether the channel data is from a timelapse experiment. Defaults to False.
899
+ batch_size (int, optional): The number of files to be processed in each batch. Defaults to 100.
900
+ backgrounds (list, optional): Background values for each channel. Defaults to [100, 100, 100].
901
+ remove_backgrounds (list, optional): Whether to remove background values for each channel. Defaults to [False, False, False].
902
+ lower_percentile (int, optional): Lower percentile value for normalization. Defaults to 2.
903
+ save_dtype (numpy.dtype, optional): Data type for saving the normalized stack. Defaults to np.float32.
904
+ signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5, 5, 5].
905
+ signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000, 1000, 1000].
906
+
907
+ Returns:
908
+ str: The directory path where the concatenated and normalized channel data is saved.
909
+ """
910
+ channels = [item for item in channels if item is not None]
911
+ paths = []
912
+ output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
913
+ os.makedirs(output_fldr, exist_ok=True)
914
+
915
+ if timelapse:
916
+ try:
917
+ time_stack_path_lists = _generate_time_lists(os.listdir(src))
918
+ for i, time_stack_list in enumerate(time_stack_path_lists):
919
+ stack_region = []
920
+ filenames_region = []
921
+ for idx, file in enumerate(time_stack_list):
922
+ path = os.path.join(src, file)
923
+ if idx == 0:
924
+ parts = file.split('_')
925
+ name = parts[0] + '_' + parts[1] + '_' + parts[2]
926
+ array = np.load(path)
927
+ array = np.take(array, channels, axis=2)
928
+ stack_region.append(array)
929
+ filenames_region.append(os.path.basename(path))
930
+ print(f'Region {i + 1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
931
+ stack = np.stack(stack_region)
932
+ normalized_stack = _normalize_stack(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
933
+ save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
934
+ np.savez(save_loc, data=normalized_stack, filenames=filenames_region)
935
+ print(save_loc)
936
+ del stack, normalized_stack
937
+ except Exception as e:
938
+ print(f"Error processing files, make sure filenames metadata is structured plate_well_field_time.npy")
939
+ print(f"Error: {e}")
940
+ else:
941
+ for file in os.listdir(src):
942
+ if file.endswith('.npy'):
943
+ path = os.path.join(src, file)
944
+ paths.append(path)
945
+ if randomize:
946
+ random.shuffle(paths)
947
+ nr_files = len(paths)
948
+ batch_index = 0
949
+ stack_ls = []
950
+ filenames_batch = []
951
+
952
+ for i, path in enumerate(paths):
953
+ array = np.load(path)
954
+ array = np.take(array, channels, axis=2)
955
+ stack_ls.append(array)
956
+ filenames_batch.append(os.path.basename(path))
957
+ print(f'Concatenated: {i + 1}/{nr_files} files')
958
+
959
+ if (i + 1) % batch_size == 0 or i + 1 == nr_files:
960
+ unique_shapes = {arr.shape[:-1] for arr in stack_ls}
961
+ if len(unique_shapes) > 1:
962
+ max_dims = np.max(np.array(list(unique_shapes)), axis=0)
963
+ print(f'Warning: arrays with multiple shapes found in batch {i + 1}. Padding arrays to max X,Y dimensions {max_dims}')
964
+ padded_stack_ls = []
965
+ for arr in stack_ls:
966
+ pad_width = [(0, max_dim - dim) for max_dim, dim in zip(max_dims, arr.shape[:-1])]
967
+ pad_width.append((0, 0))
968
+ padded_arr = np.pad(arr, pad_width)
969
+ padded_stack_ls.append(padded_arr)
970
+ stack = np.stack(padded_stack_ls)
971
+ else:
972
+ stack = np.stack(stack_ls)
973
+
974
+ normalized_stack = _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
975
+
976
+ save_loc = os.path.join(output_fldr, f'stack_{batch_index}_norm.npz')
977
+ np.savez(save_loc, data=normalized_stack, filenames=filenames_batch)
978
+ batch_index += 1
979
+ del stack, normalized_stack
980
+ stack_ls = []
981
+ filenames_batch = []
982
+ padded_stack_ls = []
983
+ print(f'All files concatenated and normalized. Saved to: {output_fldr}')
984
+ return output_fldr
985
+
986
+ def _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
987
+ """
988
+ Normalize the stack of images.
989
+
990
+ Args:
991
+ stack (numpy.ndarray): The stack of images to normalize.
992
+ backgrounds (list): Background values for each channel.
993
+ remove_backgrounds (list): Whether to remove background values for each channel.
994
+ lower_percentile (int): Lower percentile value for normalization.
995
+ save_dtype (numpy.dtype): Data type for saving the normalized stack.
996
+ signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
997
+ signal_thresholds (list): Signal thresholds for each channel.
998
+
999
+ Returns:
1000
+ numpy.ndarray: The normalized stack.
1001
+ """
1002
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1003
+
1004
+ for chan_index, channel in enumerate(range(stack.shape[-1])):
1005
+ single_channel = stack[:, :, :, channel]
1006
+ background = backgrounds[chan_index]
1007
+ signal_threshold = signal_thresholds[chan_index]
1008
+ remove_background = remove_backgrounds[chan_index]
1009
+
1010
+ print(f'Processing channel {chan_index}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
1011
+
1012
+ # Step 3: Remove background if required
1013
+ if remove_background:
1014
+ single_channel[single_channel < background] = 0
1015
+
1016
+ # Step 4: Calculate global lower percentile for the channel
1017
+ non_zero_single_channel = single_channel[single_channel != 0]
1018
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1019
+
1020
+ # Step 5: Calculate global upper percentile for the channel
1021
+ global_upper = None
1022
+ for upper_p in np.linspace(98, 99.5, num=16):
1023
+ upper_value = np.percentile(non_zero_single_channel, upper_p)
1024
+ if upper_value >= signal_threshold:
1025
+ global_upper = upper_value
1026
+ break
1027
+
1028
+ if global_upper is None:
1029
+ global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
1030
+
1031
+ print(f'Channel {chan_index}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
1032
+
1033
+ # Step 6: Normalize each array from global_lower to global_upper between 0 and 1
1034
+ for array_index in range(single_channel.shape[0]):
1035
+ arr_2d = single_channel[array_index, :, :]
1036
+ arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
1037
+ normalized_stack[array_index, :, :, channel] = arr_2d_normalized
1038
+
1039
+ return normalized_stack.astype(save_dtype)
1040
+
1041
+ def _normalize_img_batch_v1(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
1042
+ """
1043
+ Normalize the stack of images.
1044
+
1045
+ Args:
1046
+ stack (numpy.ndarray): The stack of images to normalize.
1047
+ backgrounds (list): Background values for each channel.
1048
+ remove_backgrounds (list): Whether to remove background values for each channel.
1049
+ lower_percentile (int): Lower percentile value for normalization.
1050
+ save_dtype (numpy.dtype): Data type for saving the normalized stack.
1051
+ signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
1052
+ signal_thresholds (list): Signal thresholds for each channel.
1053
+
1054
+ Returns:
1055
+ numpy.ndarray: The normalized stack.
1056
+ """
1057
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1058
+ time_ls = []
1059
+
1060
+ for chan_index, channel in enumerate(range(stack.shape[-1])):
1061
+ single_channel = stack[:, :, :, channel]
1062
+ background = backgrounds[chan_index]
1063
+ signal_threshold = signal_thresholds[chan_index]
1064
+ remove_background = remove_backgrounds[chan_index]
1065
+ signal_2_noise = signal_to_noise[chan_index]
1066
+ print(f'chan_index:{chan_index} background:{background} signal_threshold:{signal_threshold} remove_background:{remove_background} signal_2_noise:{signal_2_noise}')
1067
+
1068
+ if remove_background:
1069
+ single_channel[single_channel < background] = 0
1070
+
1071
+ non_zero_single_channel = single_channel[single_channel != 0]
1072
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1073
+ for upper_p in np.linspace(98, 99.5, num=20).tolist():
1074
+ global_upper = np.percentile(non_zero_single_channel, upper_p)
1075
+ if global_upper >= signal_threshold:
1076
+ break
1077
+
1078
+ arr_2d_normalized = np.zeros_like(single_channel, dtype=single_channel.dtype)
1079
+ signal_to_noise_ratio_ls = []
1080
+ for array_index in range(single_channel.shape[0]):
1081
+ start = time.time()
1082
+ arr_2d = single_channel[array_index, :, :]
1083
+ non_zero_arr_2d = arr_2d[arr_2d != 0]
1084
+ if non_zero_arr_2d.size > 0:
1085
+ lower, upper = np.percentile(non_zero_arr_2d, (lower_percentile, upper_p))
1086
+ signal_to_noise_ratio = upper / lower
1087
+ else:
1088
+ signal_to_noise_ratio = 0
1089
+ signal_to_noise_ratio_ls.append(signal_to_noise_ratio)
1090
+ average_stnr = np.mean(signal_to_noise_ratio_ls) if len(signal_to_noise_ratio_ls) > 0 else 0
1091
+
1092
+ if signal_to_noise_ratio > signal_2_noise:
1093
+ arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(0, 1))
1094
+ arr_2d_normalized[array_index, :, :] = arr_2d_rescaled
1095
+ else:
1096
+ arr_2d_normalized[array_index, :, :] = arr_2d
1097
+ stop = time.time()
1098
+ duration = (stop - start) * single_channel.shape[0]
1099
+ time_ls.append(duration)
1100
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1101
+ print(f'Progress: channels:{chan_index}/{stack.shape[-1] - 1}, arrays:{array_index + 1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1102
+
1103
+ normalized_stack[:, :, :, channel] = arr_2d_normalized
1104
+
1105
+ return normalized_stack.astype(save_dtype)
1106
+
915
1107
  def _get_lists_for_normalization(settings):
916
1108
  """
917
1109
  Get lists for normalization based on the provided settings.
@@ -926,7 +1118,8 @@ def _get_lists_for_normalization(settings):
926
1118
  # Initialize the lists
927
1119
  backgrounds = []
928
1120
  signal_to_noise = []
929
- signal_thresholds = []
1121
+ signal_thresholds = []
1122
+ remove_background = []
930
1123
 
931
1124
  # Iterate through the channels and append the corresponding values if the channel is not None
932
1125
  for ch in settings['channels']:
@@ -934,29 +1127,31 @@ def _get_lists_for_normalization(settings):
934
1127
  backgrounds.append(settings['nucleus_background'])
935
1128
  signal_to_noise.append(settings['nucleus_Signal_to_noise'])
936
1129
  signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
1130
+ remove_background.append(settings['remove_background_nucleus'])
937
1131
  elif ch == settings['cell_channel']:
938
1132
  backgrounds.append(settings['cell_background'])
939
1133
  signal_to_noise.append(settings['cell_Signal_to_noise'])
940
1134
  signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
1135
+ remove_background.append(settings['remove_background_cell'])
941
1136
  elif ch == settings['pathogen_channel']:
942
1137
  backgrounds.append(settings['pathogen_background'])
943
1138
  signal_to_noise.append(settings['pathogen_Signal_to_noise'])
944
1139
  signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
945
- return backgrounds, signal_to_noise, signal_thresholds
1140
+ remove_background.append(settings['remove_background_pathogen'])
1141
+ return backgrounds, signal_to_noise, signal_thresholds, remove_background
946
1142
 
947
- def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lower_quantile=0.01, save_dtype=np.float32, signal_to_noise=[5,5,5], signal_thresholds=[1000,1000,1000], correct_illumination=False):
1143
+ def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
948
1144
  """
949
1145
  Normalize the stack of images.
950
1146
 
951
1147
  Args:
952
1148
  src (str): The source directory containing the stack of images.
953
- backgrounds (list, optional): Background values for each channel. Defaults to [100,100,100].
954
- remove_background (bool, optional): Whether to remove background values. Defaults to False.
955
- lower_quantile (float, optional): Lower quantile value for normalization. Defaults to 0.01.
1149
+ backgrounds (list, optional): Background values for each channel. Defaults to [100, 100, 100].
1150
+ remove_background (list, optional): Whether to remove background values for each channel. Defaults to [False, False, False].
1151
+ lower_percentile (int, optional): Lower percentile value for normalization. Defaults to 2.
956
1152
  save_dtype (numpy.dtype, optional): Data type for saving the normalized stack. Defaults to np.float32.
957
- signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5,5,5].
958
- signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000,1000,1000].
959
- correct_illumination (bool, optional): Whether to correct illumination. Defaults to False.
1153
+ signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5, 5, 5].
1154
+ signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000, 1000, 1000].
960
1155
 
961
1156
  Returns:
962
1157
  None
@@ -965,11 +1160,13 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
965
1160
  output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
966
1161
  os.makedirs(output_fldr, exist_ok=True)
967
1162
  time_ls = []
1163
+
968
1164
  for file_index, path in enumerate(paths):
969
1165
  with np.load(path) as data:
970
1166
  stack = data['data']
971
1167
  filenames = data['filenames']
972
- normalized_stack = np.zeros_like(stack, dtype=stack.dtype)
1168
+
1169
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
973
1170
  file = os.path.basename(path)
974
1171
  name, _ = os.path.splitext(file)
975
1172
 
@@ -977,24 +1174,22 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
977
1174
  single_channel = stack[:, :, :, channel]
978
1175
  background = backgrounds[chan_index]
979
1176
  signal_threshold = signal_thresholds[chan_index]
980
- #print(f'signal_threshold:{signal_threshold} in {signal_thresholds} for {chan_index}')
981
-
1177
+ remove_background = remove_backgrounds[chan_index]
982
1178
  signal_2_noise = signal_to_noise[chan_index]
1179
+ print(f'chan_index:{chan_index} background:{background} signal_threshold:{signal_threshold} remove_background:{remove_background} signal_2_noise:{signal_2_noise}')
1180
+
983
1181
  if remove_background:
984
1182
  single_channel[single_channel < background] = 0
985
- if correct_illumination:
986
- bg = filters.gaussian(single_channel, sigma=50)
987
- single_channel = single_channel - bg
988
1183
 
989
- #Calculate the global lower and upper quantiles for non-zero pixels
1184
+ # Calculate the global lower and upper percentiles for non-zero pixels
990
1185
  non_zero_single_channel = single_channel[single_channel != 0]
991
- global_lower = np.quantile(non_zero_single_channel, lower_quantile)
992
- for upper_p in np.linspace(0.98, 1.0, num=100).tolist():
993
- global_upper = np.quantile(non_zero_single_channel, upper_p)
1186
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1187
+ for upper_p in np.linspace(98, 100, num=100).tolist():
1188
+ global_upper = np.percentile(non_zero_single_channel, upper_p)
994
1189
  if global_upper >= signal_threshold:
995
1190
  break
996
1191
 
997
- #Normalize the pixels in each image to the global quantiles and then dtype.
1192
+ # Normalize the pixels in each image to the global percentiles and then dtype.
998
1193
  arr_2d_normalized = np.zeros_like(single_channel, dtype=single_channel.dtype)
999
1194
  signal_to_noise_ratio_ls = []
1000
1195
  for array_index in range(single_channel.shape[0]):
@@ -1002,41 +1197,40 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
1002
1197
  arr_2d = single_channel[array_index, :, :]
1003
1198
  non_zero_arr_2d = arr_2d[arr_2d != 0]
1004
1199
  if non_zero_arr_2d.size > 0:
1005
- lower, upper = np.quantile(non_zero_arr_2d, (lower_quantile, upper_p))
1006
- signal_to_noise_ratio = upper/lower
1200
+ lower, upper = np.percentile(non_zero_arr_2d, (lower_percentile, upper_p))
1201
+ signal_to_noise_ratio = upper / lower
1007
1202
  else:
1008
1203
  signal_to_noise_ratio = 0
1009
1204
  signal_to_noise_ratio_ls.append(signal_to_noise_ratio)
1010
1205
  average_stnr = np.mean(signal_to_noise_ratio_ls) if len(signal_to_noise_ratio_ls) > 0 else 0
1011
1206
 
1012
1207
  if signal_to_noise_ratio > signal_2_noise:
1013
- arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(global_lower, global_upper))
1208
+ arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(0, 1))
1014
1209
  arr_2d_normalized[array_index, :, :] = arr_2d_rescaled
1015
1210
  else:
1016
1211
  arr_2d_normalized[array_index, :, :] = arr_2d
1017
1212
  stop = time.time()
1018
- duration = (stop - start)*single_channel.shape[0]
1213
+ duration = (stop - start) * single_channel.shape[0]
1019
1214
  time_ls.append(duration)
1020
1215
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1021
- #clear_output(wait=True)
1022
- print(f'Progress: files {file_index+1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1]-1}, arrays:{array_index+1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1023
- #print(f'Progress: files {file_index+1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1]-1}, arrays:{array_index+1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec', end='\r', flush=True)
1024
- normalized_single_channel = exposure.rescale_intensity(arr_2d_normalized, out_range='dtype')
1025
- normalized_stack[:, :, :, channel] = normalized_single_channel
1026
- save_loc = output_fldr+'/'+name+'_norm_stack.npz'
1027
- normalized_stack = normalized_stack.astype(save_dtype)
1028
- np.savez(save_loc, data=normalized_stack, filenames=filenames)
1029
- del normalized_stack, single_channel, normalized_single_channel, stack, filenames
1216
+ print(f'Progress: files {file_index + 1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1] - 1}, arrays:{array_index + 1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1217
+
1218
+ normalized_stack[:, :, :, channel] = arr_2d_normalized
1219
+
1220
+ save_loc = os.path.join(output_fldr, f'{name}_norm_stack.npz')
1221
+ np.savez(save_loc, data=normalized_stack.astype(save_dtype), filenames=filenames)
1222
+ del normalized_stack, single_channel, arr_2d_normalized, stack, filenames
1030
1223
  gc.collect()
1031
- return print(f'Saved stacks:{output_fldr}')
1224
+
1225
+ return print(f'Saved stacks: {output_fldr}')
1032
1226
 
1033
- def _normalize_timelapse(src, lower_quantile=0.01, save_dtype=np.float32):
1227
+ def _normalize_timelapse(src, lower_percentile=2, save_dtype=np.float32):
1034
1228
  """
1035
1229
  Normalize the timelapse data by rescaling the intensity values based on percentiles.
1036
1230
 
1037
1231
  Args:
1038
1232
  src (str): The source directory containing the timelapse data files.
1039
- lower_quantile (float, optional): The lower quantile used to calculate the intensity range. Defaults to 0.01.
1233
+ lower_percentile (int, optional): The lower percentile used to calculate the intensity range. Defaults to 1.
1040
1234
  save_dtype (numpy.dtype, optional): The data type to save the normalized stack. Defaults to np.float32.
1041
1235
  """
1042
1236
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
@@ -1058,7 +1252,7 @@ def _normalize_timelapse(src, lower_quantile=0.01, save_dtype=np.float32):
1058
1252
  for array_index in range(single_channel.shape[0]):
1059
1253
  arr_2d = single_channel[array_index]
1060
1254
  # Calculate the 1% and 98% percentiles for this specific image
1061
- q_low = np.percentile(arr_2d[arr_2d != 0], 2)
1255
+ q_low = np.percentile(arr_2d[arr_2d != 0], lower_percentile)
1062
1256
  q_high = np.percentile(arr_2d[arr_2d != 0], 98)
1063
1257
 
1064
1258
  # Rescale intensity based on the calculated percentiles to fill the dtype range
@@ -1148,11 +1342,11 @@ def delete_empty_subdirectories(folder_path):
1148
1342
  # An error occurred, likely because the directory is not empty
1149
1343
  #print(f"Skipping non-empty directory: {full_dir_path}")
1150
1344
 
1151
- @log_function_call
1345
+ #@log_function_call
1152
1346
  def preprocess_img_data(settings):
1153
1347
 
1154
1348
  from .plot import plot_arrays, _plot_4D_arrays
1155
- from .utils import _run_test_mode
1349
+ from .utils import _run_test_mode, _get_regex, set_default_settings_preprocess_img_data
1156
1350
 
1157
1351
  """
1158
1352
  Preprocesses image data by converting z-stack images to maximum intensity projection (MIP) images.
@@ -1171,9 +1365,8 @@ def preprocess_img_data(settings):
1171
1365
  timelapse (bool, optional): Whether the images are from a timelapse experiment. Defaults to False.
1172
1366
  remove_background (bool, optional): Whether to remove the background from the images. Defaults to False.
1173
1367
  backgrounds (int, optional): The number of background images to use for background removal. Defaults to 100.
1174
- lower_quantile (float, optional): The lower quantile used for background removal. Defaults to 0.01.
1368
+ lower_percentile (float, optional): The lower percentile used for background removal. Defaults to 1.
1175
1369
  save_dtype (type, optional): The data type used for saving the preprocessed images. Defaults to np.float32.
1176
- correct_illumination (bool, optional): Whether to correct the illumination of the images. Defaults to False.
1177
1370
  randomize (bool, optional): Whether to randomize the order of the images. Defaults to True.
1178
1371
  all_to_mip (bool, optional): Whether to convert all images to MIP. Defaults to False.
1179
1372
  pick_slice (bool, optional): Whether to pick a specific slice based on the provided skip mode. Defaults to False.
@@ -1191,7 +1384,7 @@ def preprocess_img_data(settings):
1191
1384
  extension_counts = Counter(extensions)
1192
1385
  most_common_extension = extension_counts.most_common(1)[0][0]
1193
1386
  img_format = None
1194
-
1387
+
1195
1388
  delete_empty_subdirectories(src)
1196
1389
 
1197
1390
  # Check if the most common extension is one of the specified image formats
@@ -1206,56 +1399,31 @@ def preprocess_img_data(settings):
1206
1399
  print('Found existing channel_stack folder.')
1207
1400
  if os.path.exists(src+'/norm_channel_stack'):
1208
1401
  print('Found existing norm_channel_stack folder. Skipping preprocessing')
1209
- return
1210
-
1211
- cmap = 'inferno'
1212
- figuresize = 20
1213
- normalize = True
1214
- save_dtype = 'uint16'
1215
- correct_illumination = False
1216
-
1217
- mask_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
1218
- backgrounds = [settings['nucleus_background'], settings['pathogen_background'], settings['cell_background']]
1219
-
1220
- metadata_type = settings['metadata_type']
1221
- custom_regex = settings['custom_regex']
1222
- nr = settings['examples_to_plot']
1223
- plot = settings['plot']
1224
- batch_size = settings['batch_size']
1225
- timelapse = settings['timelapse']
1226
- remove_background = settings['remove_background']
1227
- lower_quantile = settings['lower_quantile']
1228
- randomize = settings['randomize']
1229
- all_to_mip = settings['all_to_mip']
1230
- pick_slice = settings['pick_slice']
1231
- skip_mode = settings['skip_mode']
1232
-
1233
-
1234
- if not img_format == None:
1235
- if metadata_type == 'cellvoyager':
1236
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1237
- elif metadata_type == 'cq1':
1238
- regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1239
- elif metadata_type == 'nikon':
1240
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1241
- elif metadata_type == 'zeis':
1242
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1243
- elif metadata_type == 'leica':
1244
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1245
- elif metadata_type == 'custom':
1246
- regex = f'({custom_regex}){img_format}'
1402
+ return settings, src
1247
1403
 
1248
- print(f'regex mode:{metadata_type} regex:{regex}')
1404
+ mask_channels = [settings['nucleus_channel'], settings['cell_channel'], settings['pathogen_channel']]
1405
+ backgrounds = [settings['nucleus_background'], settings['cell_background'], settings['pathogen_background']]
1406
+
1407
+ settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test = set_default_settings_preprocess_img_data(settings)
1408
+
1409
+ regex = _get_regex(metadata_type, img_format, custom_regex)
1410
+
1411
+ if test_mode:
1249
1412
 
1250
- if settings.get('test_mode', False):
1413
+ print(f'Running spacr in test mode')
1414
+ settings['plot'] = True
1251
1415
  try:
1252
1416
  os.rmdir(os.path.join(src, 'test'))
1253
1417
  print(f"Deleted test directory: {os.path.join(src, 'test')}")
1254
1418
  except OSError as e:
1255
1419
  pass
1256
1420
 
1257
- src = _run_test_mode(settings['src'], regex, timelapse=timelapse)
1421
+ src = _run_test_mode(settings['src'], regex, timelapse, test_images, random_test)
1258
1422
  settings['src'] = src
1423
+
1424
+ if img_format == None:
1425
+ if not os.path.exists(src+'/stack'):
1426
+ _merge_channels(src, plot=False)
1259
1427
 
1260
1428
  if not os.path.exists(src+'/stack'):
1261
1429
  try:
@@ -1295,31 +1463,20 @@ def preprocess_img_data(settings):
1295
1463
  except Exception as e:
1296
1464
  print(f"Error: {e}")
1297
1465
 
1298
- print('concatinating cahnnels')
1299
- _concatenate_channel(src+'/stack',
1300
- channels=mask_channels,
1301
- randomize=randomize,
1302
- timelapse=timelapse,
1303
- batch_size=batch_size)
1304
-
1305
- if plot:
1306
- print(f'plotting {nr} images from {src}/channel_stack')
1307
- _plot_4D_arrays(src+'/channel_stack', figuresize, cmap, nr_npz=1, nr=nr)
1308
-
1309
- backgrounds, signal_to_noise, signal_thresholds = _get_lists_for_normalization(settings=settings)
1310
-
1311
- if not timelapse:
1312
- _normalize_stack(src+'/channel_stack',
1313
- backgrounds=backgrounds,
1314
- lower_quantile=lower_quantile,
1315
- save_dtype=save_dtype,
1316
- signal_thresholds=signal_thresholds,
1317
- correct_illumination=correct_illumination,
1318
- signal_to_noise=signal_to_noise,
1319
- remove_background=remove_background)
1320
- else:
1321
- _normalize_timelapse(src+'/channel_stack', lower_quantile=lower_quantile, save_dtype=np.float32)
1322
-
1466
+ backgrounds, signal_to_noise, signal_thresholds, remove_backgrounds = _get_lists_for_normalization(settings=settings)
1467
+
1468
+ concatenate_and_normalize(src+'/stack',
1469
+ mask_channels,
1470
+ randomize,
1471
+ timelapse,
1472
+ batch_size,
1473
+ backgrounds,
1474
+ remove_backgrounds,
1475
+ lower_percentile,
1476
+ np.float32,
1477
+ signal_to_noise,
1478
+ signal_thresholds)
1479
+
1323
1480
  if plot:
1324
1481
  _plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
1325
1482
 
@@ -1373,27 +1530,6 @@ def _get_avg_object_size(masks):
1373
1530
  return sum(object_areas) / len(object_areas)
1374
1531
  else:
1375
1532
  return 0 # Return 0 if no objects are found
1376
-
1377
- def _save_figure_v1(fig, src, text, dpi=300, ):
1378
- """
1379
- Save a figure to a specified location.
1380
-
1381
- Parameters:
1382
- fig (matplotlib.figure.Figure): The figure to be saved.
1383
- src (str): The source file path.
1384
- text (str): The text to be included in the figure name.
1385
- dpi (int, optional): The resolution of the saved figure. Defaults to 300.
1386
- """
1387
- save_folder = os.path.dirname(src)
1388
- obj_type = os.path.basename(src)
1389
- name = os.path.basename(save_folder)
1390
- save_folder = os.path.join(save_folder, 'figure')
1391
- os.makedirs(save_folder, exist_ok=True)
1392
- fig_name = f'{obj_type}_{name}_{text}.pdf'
1393
- save_location = os.path.join(save_folder, fig_name)
1394
- fig.savefig(save_location, bbox_inches='tight', dpi=dpi)
1395
- print(f'Saved single cell figure: {save_location}')
1396
- plt.close()
1397
1533
 
1398
1534
  def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
1399
1535
  """
@@ -1499,56 +1635,6 @@ def _save_settings_to_db(settings):
1499
1635
  settings_df.to_sql('settings', conn, if_exists='replace', index=False) # Replace the table if it already exists
1500
1636
  conn.close()
1501
1637
 
1502
- def _save_mask_timelapse_as_gif_v1(masks, path, cmap, norm, filenames):
1503
- """
1504
- Save a timelapse of masks as a GIF.
1505
-
1506
- Parameters:
1507
- masks (list): List of mask frames.
1508
- path (str): Path to save the GIF.
1509
- cmap: Colormap for displaying the masks.
1510
- norm: Normalization for the masks.
1511
- filenames (list): List of filenames corresponding to each mask frame.
1512
-
1513
- Returns:
1514
- None
1515
- """
1516
- def _update(frame):
1517
- """
1518
- Update the plot with the given frame.
1519
-
1520
- Parameters:
1521
- frame (int): The frame number to update the plot with.
1522
-
1523
- Returns:
1524
- None
1525
- """
1526
- nonlocal filename_text_obj
1527
- if filename_text_obj is not None:
1528
- filename_text_obj.remove()
1529
- ax.clear()
1530
- ax.axis('off')
1531
- current_mask = masks[frame]
1532
- ax.imshow(current_mask, cmap=cmap, norm=norm)
1533
- ax.set_title(f'Frame: {frame}', fontsize=24, color='white')
1534
- filename_text = filenames[frame]
1535
- filename_text_obj = fig.text(0.5, 0.01, filename_text, ha='center', va='center', fontsize=20, color='white')
1536
- for label_value in np.unique(current_mask):
1537
- if label_value == 0: continue # Skip background
1538
- y, x = np.mean(np.where(current_mask == label_value), axis=1)
1539
- ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
1540
-
1541
- fig, ax = plt.subplots(figsize=(50, 50), facecolor='black')
1542
- ax.set_facecolor('black')
1543
- ax.axis('off')
1544
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
1545
-
1546
- filename_text_obj = None
1547
- anim = FuncAnimation(fig, _update, frames=len(masks), blit=False)
1548
- anim.save(path, writer='pillow', fps=2, dpi=80) # Adjust DPI for size/quality
1549
- plt.close(fig)
1550
- print(f'Saved timelapse to {path}')
1551
-
1552
1638
  def _save_mask_timelapse_as_gif(masks, tracks_df, path, cmap, norm, filenames):
1553
1639
  """
1554
1640
  Save a timelapse animation of masks as a GIF.
@@ -2273,6 +2359,8 @@ def convert_numpy_to_tiff(folder_path, limit=None):
2273
2359
  for i, filename in enumerate(files):
2274
2360
  if limit is not None and i >= limit:
2275
2361
  break
2362
+ if not filename.endswith('.npy'):
2363
+ continue
2276
2364
 
2277
2365
  # Construct the full file path
2278
2366
  file_path = os.path.join(folder_path, filename)
@@ -2289,131 +2377,46 @@ def convert_numpy_to_tiff(folder_path, limit=None):
2289
2377
  print(f"Converted {filename} to {tiff_filename} and saved in 'tiff' subdirectory.")
2290
2378
  return
2291
2379
 
2292
-
2293
-
2294
-
2295
-
2296
-
2297
-
2298
-
2299
-
2300
-
2301
-
2302
-
2303
-
2304
-
2305
-
2306
-
2307
-
2308
-
2309
-
2310
-
2311
-
2312
-
2313
-
2314
-
2315
-
2316
-
2317
-
2318
-
2319
-
2320
-
2321
-
2322
-
2323
-
2324
-
2325
-
2326
-
2327
-
2328
-
2329
-
2330
-
2331
-
2332
-
2333
-
2334
-
2335
-
2336
-
2337
-
2338
-
2339
-
2340
-
2341
-
2342
-
2343
-
2344
-
2345
-
2346
-
2347
-
2348
-
2349
-
2350
-
2351
-
2352
-
2353
-
2354
-
2355
-
2356
-
2357
-
2358
-
2359
-
2360
-
2361
-
2362
-
2363
-
2364
-
2365
-
2366
-
2367
-
2368
-
2369
-
2370
-
2371
-
2372
-
2373
-
2374
-
2375
-
2376
-
2377
-
2378
-
2379
-
2380
-
2381
-
2382
-
2383
-
2384
-
2385
-
2386
-
2387
-
2388
-
2389
-
2390
-
2391
-
2392
-
2393
-
2394
-
2395
-
2396
-
2397
-
2398
-
2399
-
2400
-
2401
-
2402
-
2403
-
2404
-
2405
-
2406
-
2407
-
2408
-
2409
-
2410
-
2411
-
2412
-
2413
-
2414
-
2415
-
2416
-
2417
-
2418
-
2419
-
2380
+ def generate_cellpose_train_test(src, test_split=0.1):
2381
+ mask_src = os.path.join(src, 'masks')
2382
+ img_paths = glob.glob(os.path.join(src, '*.tif'))
2383
+ img_filenames = [os.path.basename(file) for file in img_paths]
2384
+ img_filenames = [file for file in img_filenames if os.path.exists(os.path.join(mask_src, file))]
2385
+ print(f'Found {len(img_filenames)} images with masks')
2386
+
2387
+ random.shuffle(img_filenames)
2388
+ split_index = int(len(img_filenames) * test_split)
2389
+ train_files = img_filenames[split_index:]
2390
+ test_files = img_filenames[:split_index]
2391
+ list_of_lists = [test_files, train_files]
2392
+ print(f'Split dataset into Train {len(train_files)} and Test {len(test_files)} files')
2393
+
2394
+ train_dir = os.path.join(os.path.dirname(src), 'train')
2395
+ train_dir_masks = os.path.join(train_dir, 'masks')
2396
+ test_dir = os.path.join(os.path.dirname(src), 'test')
2397
+ test_dir_masks = os.path.join(test_dir, 'masks')
2398
+
2399
+ os.makedirs(train_dir, exist_ok=True)
2400
+ os.makedirs(train_dir_masks, exist_ok=True)
2401
+ os.makedirs(test_dir, exist_ok=True)
2402
+ os.makedirs(test_dir_masks, exist_ok=True)
2403
+
2404
+ for i, ls in enumerate(list_of_lists):
2405
+ if i == 0:
2406
+ dst = test_dir
2407
+ dst_mask = test_dir_masks
2408
+ _type = 'Test'
2409
+ else:
2410
+ dst = train_dir
2411
+ dst_mask = train_dir_masks
2412
+ _type = 'Train'
2413
+
2414
+ for idx, filename in enumerate(ls):
2415
+ img_path = os.path.join(src, filename)
2416
+ mask_path = os.path.join(mask_src, filename)
2417
+ new_img_path = os.path.join(dst, filename)
2418
+ new_mask_path = os.path.join(dst_mask, filename)
2419
+ shutil.copy(img_path, new_img_path)
2420
+ shutil.copy(mask_path, new_mask_path)
2421
+ print(f'Copied {idx+1}/{len(ls)} images to {_type} set', end='\r', flush=True)
2422
+