spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py CHANGED
@@ -1,6 +1,7 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
2
2
  import numpy as np
3
3
  import pandas as pd
4
+ import tifffile
4
5
  from PIL import Image
5
6
  from collections import defaultdict, Counter
6
7
  from pathlib import Path
@@ -18,13 +19,12 @@ from io import BytesIO
18
19
  from IPython.display import display, clear_output
19
20
  from multiprocessing import Pool, cpu_count
20
21
  from torch.utils.data import Dataset
21
- import seaborn as sns
22
22
  import matplotlib.pyplot as plt
23
23
  from torchvision.transforms import ToTensor
24
24
 
25
+
25
26
  from .logger import log_function_call
26
27
 
27
- @log_function_call
28
28
  def _load_images_and_labels(image_files, label_files, circular=False, invert=False, image_extension="*.tif", label_extension="*.tif"):
29
29
 
30
30
  from .utils import invert_image, apply_mask
@@ -44,19 +44,19 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
44
44
 
45
45
  if not image_files is None and not label_files is None:
46
46
  for img_file, lbl_file in zip(image_files, label_files):
47
- image = cellpose.imread(img_file)
47
+ image = cellpose.io.imread(img_file)
48
48
  if invert:
49
49
  image = invert_image(image)
50
50
  if circular:
51
51
  image = apply_mask(image, output_value=0)
52
- label = cellpose.imread(lbl_file)
52
+ label = cellpose.io.imread(lbl_file)
53
53
  if image.max() > 1:
54
54
  image = image / image.max()
55
55
  images.append(image)
56
56
  labels.append(label)
57
57
  elif not image_files is None:
58
58
  for img_file in image_files:
59
- image = cellpose.imread(img_file)
59
+ image = cellpose.io.imread(img_file)
60
60
  if invert:
61
61
  image = invert_image(image)
62
62
  if circular:
@@ -66,7 +66,7 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
66
66
  images.append(image)
67
67
  elif not image_files is None:
68
68
  for lbl_file in label_files:
69
- label = cellpose.imread(lbl_file)
69
+ label = cellpose.io.imread(lbl_file)
70
70
  if circular:
71
71
  label = apply_mask(label, output_value=0)
72
72
  labels.append(label)
@@ -87,16 +87,13 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
87
87
  print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {labels[0].shape}, image type: labels[0].shape')
88
88
  return images, labels, image_names, label_names
89
89
 
90
- @log_function_call
91
- def _load_normalized_images_and_labels(image_files, label_files, signal_thresholds=[1000], channels=None, percentiles=None, circular=False, invert=False, visualize=False):
90
+ def _load_normalized_images_and_labels(image_files, label_files, channels=None, percentiles=None, circular=False, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10):
92
91
 
93
92
  from .plot import normalize_and_visualize
94
93
  from .utils import invert_image, apply_mask
95
-
96
- if isinstance(signal_thresholds, int):
97
- signal_thresholds = [signal_thresholds] * (len(channels) if channels is not None else 1)
98
- elif not isinstance(signal_thresholds, list):
99
- signal_thresholds = [signal_thresholds]
94
+
95
+ signal_thresholds = background*Signal_to_noise
96
+ lower_percentile = 2
100
97
 
101
98
  images = []
102
99
  labels = []
@@ -109,18 +106,22 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
109
106
 
110
107
  if label_files is not None:
111
108
  label_names = [os.path.basename(f) for f in label_files]
109
+ label_dir = os.path.dirname(label_files[0])
112
110
 
113
111
  # Load images and check percentiles
114
112
  for i,img_file in enumerate(image_files):
115
- image = cellpose.imread(img_file)
113
+ image = cellpose.io.imread(img_file)
116
114
  if invert:
117
115
  image = invert_image(image)
118
116
  if circular:
119
117
  image = apply_mask(image, output_value=0)
120
-
118
+
121
119
  # If specific channels are specified, select them
122
120
  if channels is not None and image.ndim == 3:
123
121
  image = image[..., channels]
122
+
123
+ if remove_background:
124
+ image[image < background] = 0
124
125
 
125
126
  if image.ndim < 3:
126
127
  image = np.expand_dims(image, axis=-1)
@@ -128,11 +129,11 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
128
129
  images.append(image)
129
130
  if percentiles is None:
130
131
  for c in range(image.shape[-1]):
131
- p1 = np.percentile(image[..., c], 1)
132
+ p1 = np.percentile(image[..., c], lower_percentile)
132
133
  percentiles_1[c].append(p1)
133
- for percentile in [99, 99.9, 99.99, 99.999]:
134
+ for percentile in [98, 99, 99.9, 99.99, 99.999]:
134
135
  p = np.percentile(image[..., c], percentile)
135
- if p > signal_thresholds[min(c, len(signal_thresholds)-1)]:
136
+ if p > signal_thresholds:
136
137
  percentiles_99[c].append(p)
137
138
  break
138
139
 
@@ -141,8 +142,8 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
141
142
  for image in images:
142
143
  normalized_image = np.zeros_like(image, dtype=np.float32)
143
144
  for c in range(image.shape[-1]):
144
- high_p = np.percentile(image[..., c], percentiles[1])
145
145
  low_p = np.percentile(image[..., c], percentiles[0])
146
+ high_p = np.percentile(image[..., c], percentiles[1])
146
147
  normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(low_p, high_p), out_range=(0, 1))
147
148
  normalized_images.append(normalized_image)
148
149
  if visualize:
@@ -153,23 +154,26 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
153
154
  avg_p1 = [np.mean(p) for p in percentiles_1]
154
155
  avg_p99 = [np.mean(p) if len(p) > 0 else np.mean(percentiles_1[i]) for i, p in enumerate(percentiles_99)]
155
156
 
157
+ print(f'Average 1st percentiles: {avg_p1}, Average 99th percentiles: {avg_p99}')
158
+
156
159
  normalized_images = []
157
160
  for image in images:
158
161
  normalized_image = np.zeros_like(image, dtype=np.float32)
159
- for c in range(image.shape[-1]):
160
- normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
161
- normalized_images.append(normalized_image)
162
- if visualize:
163
- normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
162
+ for c in range(image.shape[-1]):
163
+ normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
164
+ normalized_images.append(normalized_image)
165
+ if visualize:
166
+ normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
164
167
 
165
168
  if not image_files is None:
166
169
  image_dir = os.path.dirname(image_files[0])
170
+
167
171
  else:
168
172
  image_dir = None
169
173
 
170
174
  if label_files is not None:
171
175
  for lbl_file in label_files:
172
- labels.append(cellpose.imread(lbl_file))
176
+ labels.append(cellpose.io.imread(lbl_file))
173
177
  else:
174
178
  label_names = []
175
179
  label_dir = None
@@ -178,86 +182,8 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
178
182
 
179
183
  return normalized_images, labels, image_names, label_names
180
184
 
181
- class MyDataset(Dataset):
182
- """
183
- Custom dataset class for loading and processing image data.
184
-
185
- Args:
186
- data_dir (str): The directory path where the data is stored.
187
- loader_classes (list): List of class names.
188
- transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default is None.
189
- shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
190
- load_to_memory (bool, optional): Whether to load images into memory. Default is False.
191
-
192
- Attributes:
193
- data_dir (str): The directory path where the data is stored.
194
- classes (list): List of class names.
195
- transform (callable): A function/transform that takes in an PIL image and returns a transformed version.
196
- shuffle (bool): Whether to shuffle the dataset.
197
- load_to_memory (bool): Whether to load images into memory.
198
- filenames (list): List of file paths.
199
- labels (list): List of labels corresponding to each file.
200
- images (list): List of loaded images.
201
- image_cache (Cache): Cache object for storing loaded images.
202
-
203
- Methods:
204
- load_image: Load an image from file.
205
- __len__: Get the length of the dataset.
206
- shuffle_dataset: Shuffle the dataset.
207
- __getitem__: Get an item from the dataset.
208
-
209
- """
210
-
211
- def _init__(self, data_dir, loader_classes, transform=None, shuffle=True, load_to_memory=False):
212
- from .utils import Cache
213
- self.data_dir = data_dir
214
- self.classes = loader_classes
215
- self.transform = transform
216
- self.shuffle = shuffle
217
- self.load_to_memory = load_to_memory
218
- self.filenames = []
219
- self.labels = []
220
- self.images = []
221
- self.image_cache = Cache(50)
222
- for class_name in self.classes:
223
- class_path = os.path.join(data_dir, class_name)
224
- class_files = [os.path.join(class_path, f) for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]
225
- self.filenames.extend(class_files)
226
- self.labels.extend([self.classes.index(class_name)] * len(class_files))
227
- if self.shuffle:
228
- self.shuffle_dataset()
229
- if self.load_to_memory:
230
- self.images = [self.load_image(f) for f in self.filenames]
231
-
232
- def load_image(self, img_path):
233
- img = self.image_cache.get(img_path)
234
- if img is None:
235
- img = Image.open(img_path).convert('RGB')
236
- self.image_cache.put(img_path, img)
237
- return img
238
-
239
- def _len__(self):
240
- return len(self.filenames)
241
-
242
- def shuffle_dataset(self):
243
- combined = list(zip(self.filenames, self.labels))
244
- random.shuffle(combined)
245
- self.filenames, self.labels = zip(*combined)
246
-
247
- def _getitem__(self, index):
248
- label = self.labels[index]
249
- filename = self.filenames[index]
250
- if self.load_to_memory:
251
- img = self.images[index]
252
- else:
253
- img = self.load_image(filename)
254
- if self.transform is not None:
255
- img = self.transform(img)
256
- else:
257
- img = ToTensor()(img)
258
- return img, label, filename
259
-
260
185
  class CombineLoaders:
186
+
261
187
  """
262
188
  A class that combines multiple data loaders into a single iterator.
263
189
 
@@ -398,7 +324,7 @@ class MyDataset(Dataset):
398
324
  specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
399
325
  """
400
326
 
401
- def _init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
327
+ def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
402
328
  self.data_dir = data_dir
403
329
  self.classes = loader_classes
404
330
  self.transform = transform
@@ -427,7 +353,7 @@ class MyDataset(Dataset):
427
353
  img = Image.open(img_path).convert('RGB')
428
354
  return img
429
355
 
430
- def _len__(self):
356
+ def __len__(self):
431
357
  return len(self.filenames)
432
358
 
433
359
  def shuffle_dataset(self):
@@ -439,7 +365,7 @@ class MyDataset(Dataset):
439
365
  filename = os.path.basename(filepath) # Get just the filename from the full path
440
366
  return filename.split('_')[0]
441
367
 
442
- def _getitem__(self, index):
368
+ def __getitem__(self, index):
443
369
  label = self.labels[index]
444
370
  filename = self.filenames[index]
445
371
  img = self.load_image(filename)
@@ -527,6 +453,7 @@ class TarImageDataset(Dataset):
527
453
 
528
454
  return img, m.name
529
455
 
456
+ #@log_function_call
530
457
  def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=False, skip_mode='01', metadata_type='', img_format='.tif'):
531
458
  """
532
459
  Convert z-stack images to maximum intensity projection (MIP) images.
@@ -599,40 +526,47 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
599
526
  shutil.move(os.path.join(src, filename), move)
600
527
  return
601
528
 
602
- def _merge_file(chan_dirs, stack_dir, file):
529
+ def _merge_file(chan_dirs, stack_dir, file_name):
603
530
  """
604
- Merge multiple channels into a single stack and save it as a numpy array.
605
-
531
+ Merge multiple channels into a single stack and save it as a numpy array, using os module for path handling.
532
+
606
533
  Args:
607
534
  chan_dirs (list): List of directories containing channel images.
608
535
  stack_dir (str): Directory to save the merged stack.
609
- file (str): File name of the channel image.
536
+ file_name (str): File name of the channel image.
610
537
 
611
538
  Returns:
612
539
  None
613
540
  """
614
- chan1 = cv2.imread(str(file), -1)
615
- chan1 = np.expand_dims(chan1, axis=2)
616
- new_file = stack_dir / (file.stem + '.npy')
617
- if not new_file.exists():
618
- stack_dir.mkdir(exist_ok=True)
619
- channels = [chan1]
620
- for chan_dir in chan_dirs[1:]:
621
- img = cv2.imread(str(chan_dir / file.name), -1)
541
+ # Construct new file path
542
+ file_root, file_ext = os.path.splitext(file_name)
543
+ new_file = os.path.join(stack_dir, file_root + '.npy')
544
+
545
+ # Check if the new file exists and create the stack directory if it doesn't
546
+ if not os.path.exists(new_file):
547
+ os.makedirs(stack_dir, exist_ok=True)
548
+ channels = []
549
+ for i, chan_dir in enumerate(chan_dirs):
550
+ img_path = os.path.join(chan_dir, file_name)
551
+ img = cv2.imread(img_path, -1)
552
+ if img is None:
553
+ print(f"Warning: Failed to read image {img_path}")
554
+ continue
622
555
  chan = np.expand_dims(img, axis=2)
623
556
  channels.append(chan)
624
- stack = np.concatenate(channels, axis=2)
625
- np.save(new_file, stack)
557
+ del img # Explicitly delete the reference to the image to free up memory
558
+ if i % 10 == 0: # Periodically suggest garbage collection
559
+ gc.collect()
560
+
561
+ if channels:
562
+ stack = np.concatenate(channels, axis=2)
563
+ np.save(new_file, stack)
564
+ else:
565
+ print(f"No valid channels to merge for file {file_name}")
626
566
 
627
567
  def _is_dir_empty(dir_path):
628
568
  """
629
- Check if a directory is empty.
630
-
631
- Args:
632
- dir_path (str): The path to the directory.
633
-
634
- Returns:
635
- bool: True if the directory is empty, False otherwise.
569
+ Check if a directory is empty using os module.
636
570
  """
637
571
  return len(os.listdir(dir_path)) == 0
638
572
 
@@ -706,7 +640,7 @@ def _move_to_chan_folder(src, regex, timelapse=False, metadata_type=''):
706
640
  if metadata_type =='cq1':
707
641
  orig_wellID = wellID
708
642
  wellID = _convert_cq1_well_id(wellID)
709
- print(f'Converted Well ID: {orig_wellID} to {wellID}')
643
+ print(f'Converted Well ID: {orig_wellID} to {wellID}')#, end='\r', flush=True)
710
644
 
711
645
  newname = f"{plateID}_{wellID}_{fieldID}_{timeID if timelapse else ''}{ext}"
712
646
  newpath = src / chanID
@@ -732,7 +666,7 @@ def _move_to_chan_folder(src, regex, timelapse=False, metadata_type=''):
732
666
  shutil.move(os.path.join(src, filename), move)
733
667
  return
734
668
 
735
- def _merge_channels(src, plot=False):
669
+ def _merge_channels_v2(src, plot=False):
736
670
  from .plot import plot_arrays
737
671
  """
738
672
  Merge the channels in the given source directory and save the merged files in a 'stack' directory.
@@ -757,9 +691,11 @@ def _merge_channels(src, plot=False):
757
691
 
758
692
  # Create the 'stack' directory if it doesn't exist
759
693
  stack_dir.mkdir(exist_ok=True)
694
+ print(f'generated folder with merged arrays: {stack_dir}')
760
695
 
761
696
  if _is_dir_empty(stack_dir):
762
- with Pool(cpu_count()) as pool:
697
+ with Pool(max(cpu_count() // 2, 1)) as pool:
698
+ #with Pool(cpu_count()) as pool:
763
699
  merge_func = partial(_merge_file, chan_dirs, stack_dir)
764
700
  pool.map(merge_func, dir_files)
765
701
 
@@ -771,6 +707,47 @@ def _merge_channels(src, plot=False):
771
707
 
772
708
  return
773
709
 
710
+ def _merge_channels(src, plot=False):
711
+ """
712
+ Merge the channels in the given source directory and save the merged files in a 'stack' directory without using multiprocessing.
713
+ """
714
+
715
+ from .plot import plot_arrays
716
+
717
+ stack_dir = os.path.join(src, 'stack')
718
+ allowed_names = ['01', '02', '03', '04', '00', '1', '2', '3', '4', '0']
719
+
720
+ # List directories that match the allowed names
721
+ chan_dirs = [d for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d in allowed_names]
722
+ chan_dirs.sort()
723
+
724
+ print(f'List of folders in src: {chan_dirs}. Single channel folders.')
725
+ start_time = time.time()
726
+
727
+ # Assuming chan_dirs[0] is not empty and exists, adjust according to your logic
728
+ first_dir_path = os.path.join(src, chan_dirs[0])
729
+ dir_files = os.listdir(first_dir_path)
730
+
731
+ # Create the 'stack' directory if it doesn't exist
732
+ if not os.path.exists(stack_dir):
733
+ os.makedirs(stack_dir, exist_ok=True)
734
+ print(f'Generated folder with merged arrays: {stack_dir}')
735
+
736
+ if _is_dir_empty(stack_dir):
737
+ for file_name in dir_files:
738
+ full_file_path = os.path.join(first_dir_path, file_name)
739
+ if os.path.isfile(full_file_path):
740
+ _merge_file([os.path.join(src, d) for d in chan_dirs], stack_dir, file_name)
741
+
742
+ elapsed_time = time.time() - start_time
743
+ avg_time = elapsed_time / len(dir_files) if dir_files else 0
744
+ print(f'Average Time: {avg_time:.3f} sec, Total Elapsed Time: {elapsed_time:.3f} sec')
745
+
746
+ if plot:
747
+ plot_arrays(os.path.join(src, 'stack'))
748
+
749
+ return
750
+
774
751
  def _mip_all(src, include_first_chan=True):
775
752
 
776
753
  """
@@ -819,6 +796,7 @@ def _mip_all(src, include_first_chan=True):
819
796
  np.save(os.path.join(src, filename), concatenated)
820
797
  return
821
798
 
799
+ #@log_function_call
822
800
  def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_size=100):
823
801
  """
824
802
  Concatenates channel data from multiple files and saves the concatenated data as numpy arrays.
@@ -853,8 +831,8 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
853
831
  array = np.take(array, channels, axis=2)
854
832
  stack_region.append(array)
855
833
  filenames_region.append(os.path.basename(path))
856
- clear_output(wait=True)
857
- print(f'\033[KRegion {i+1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
834
+ #clear_output(wait=True)
835
+ print(f'Region {i+1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
858
836
  stack = np.stack(stack_region)
859
837
  save_loc = os.path.join(channel_stack_loc, f'{name}.npz')
860
838
  np.savez(save_loc, data=stack, filenames=filenames_region)
@@ -879,15 +857,17 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
879
857
  array = np.take(array, channels, axis=2)
880
858
  stack_ls.append(array)
881
859
  filenames_batch.append(os.path.basename(path)) # store the filename
882
- clear_output(wait=True)
883
- print(f'\033[KConcatenated: {i+1}/{nr_files} files', end='\r', flush=True)
860
+ #clear_output(wait=True)
861
+ print(f'Concatenated: {i+1}/{nr_files} files')
862
+ #print(f'Concatenated: {i+1}/{nr_files} files', end='\r', flush=True)
884
863
 
885
864
  if (i+1) % batch_size == 0 or i+1 == nr_files:
886
865
  unique_shapes = {arr.shape[:-1] for arr in stack_ls}
887
866
  if len(unique_shapes) > 1:
888
867
  max_dims = np.max(np.array(list(unique_shapes)), axis=0)
889
- clear_output(wait=True)
890
- print(f'\033[KWarning: arrays with multiple shapes found in batch {i+1}. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
868
+ #clear_output(wait=True)
869
+ print(f'Warning: arrays with multiple shapes found in batch {i+1}. Padding arrays to max X,Y dimentions {max_dims}')
870
+ #print(f'Warning: arrays with multiple shapes found in batch {i+1}. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
891
871
  padded_stack_ls = []
892
872
  for arr in stack_ls:
893
873
  pad_width = [(0, max_dim - dim) for max_dim, dim in zip(max_dims, arr.shape[:-1])]
@@ -904,9 +884,226 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
904
884
  stack_ls = [] # empty the list for the next batch
905
885
  filenames_batch = [] # empty the filenames list for the next batch
906
886
  padded_stack_ls = []
907
- #print(f'\nAll files concatenated and saved to:{channel_stack_loc}')
887
+ print(f'All files concatenated and saved to:{channel_stack_loc}')
908
888
  return channel_stack_loc
909
889
 
890
+ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, batch_size=100, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
891
+ """
892
+ Concatenates and normalizes channel data from multiple files and saves the normalized data.
893
+
894
+ Args:
895
+ src (str): The source directory containing the channel data files.
896
+ channels (list): The list of channel indices to be concatenated and normalized.
897
+ randomize (bool, optional): Whether to randomize the order of the files. Defaults to True.
898
+ timelapse (bool, optional): Whether the channel data is from a timelapse experiment. Defaults to False.
899
+ batch_size (int, optional): The number of files to be processed in each batch. Defaults to 100.
900
+ backgrounds (list, optional): Background values for each channel. Defaults to [100, 100, 100].
901
+ remove_backgrounds (list, optional): Whether to remove background values for each channel. Defaults to [False, False, False].
902
+ lower_percentile (int, optional): Lower percentile value for normalization. Defaults to 2.
903
+ save_dtype (numpy.dtype, optional): Data type for saving the normalized stack. Defaults to np.float32.
904
+ signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5, 5, 5].
905
+ signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000, 1000, 1000].
906
+
907
+ Returns:
908
+ str: The directory path where the concatenated and normalized channel data is saved.
909
+ """
910
+ channels = [item for item in channels if item is not None]
911
+ paths = []
912
+ output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
913
+ os.makedirs(output_fldr, exist_ok=True)
914
+
915
+ if timelapse:
916
+ try:
917
+ time_stack_path_lists = _generate_time_lists(os.listdir(src))
918
+ for i, time_stack_list in enumerate(time_stack_path_lists):
919
+ stack_region = []
920
+ filenames_region = []
921
+ for idx, file in enumerate(time_stack_list):
922
+ path = os.path.join(src, file)
923
+ if idx == 0:
924
+ parts = file.split('_')
925
+ name = parts[0] + '_' + parts[1] + '_' + parts[2]
926
+ array = np.load(path)
927
+ array = np.take(array, channels, axis=2)
928
+ stack_region.append(array)
929
+ filenames_region.append(os.path.basename(path))
930
+ print(f'Region {i + 1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
931
+ stack = np.stack(stack_region)
932
+ normalized_stack = _normalize_stack(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
933
+ save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
934
+ np.savez(save_loc, data=normalized_stack, filenames=filenames_region)
935
+ print(save_loc)
936
+ del stack, normalized_stack
937
+ except Exception as e:
938
+ print(f"Error processing files, make sure filenames metadata is structured plate_well_field_time.npy")
939
+ print(f"Error: {e}")
940
+ else:
941
+ for file in os.listdir(src):
942
+ if file.endswith('.npy'):
943
+ path = os.path.join(src, file)
944
+ paths.append(path)
945
+ if randomize:
946
+ random.shuffle(paths)
947
+ nr_files = len(paths)
948
+ batch_index = 0
949
+ stack_ls = []
950
+ filenames_batch = []
951
+
952
+ for i, path in enumerate(paths):
953
+ array = np.load(path)
954
+ array = np.take(array, channels, axis=2)
955
+ stack_ls.append(array)
956
+ filenames_batch.append(os.path.basename(path))
957
+ print(f'Concatenated: {i + 1}/{nr_files} files')
958
+
959
+ if (i + 1) % batch_size == 0 or i + 1 == nr_files:
960
+ unique_shapes = {arr.shape[:-1] for arr in stack_ls}
961
+ if len(unique_shapes) > 1:
962
+ max_dims = np.max(np.array(list(unique_shapes)), axis=0)
963
+ print(f'Warning: arrays with multiple shapes found in batch {i + 1}. Padding arrays to max X,Y dimensions {max_dims}')
964
+ padded_stack_ls = []
965
+ for arr in stack_ls:
966
+ pad_width = [(0, max_dim - dim) for max_dim, dim in zip(max_dims, arr.shape[:-1])]
967
+ pad_width.append((0, 0))
968
+ padded_arr = np.pad(arr, pad_width)
969
+ padded_stack_ls.append(padded_arr)
970
+ stack = np.stack(padded_stack_ls)
971
+ else:
972
+ stack = np.stack(stack_ls)
973
+
974
+ normalized_stack = _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
975
+
976
+ save_loc = os.path.join(output_fldr, f'stack_{batch_index}_norm.npz')
977
+ np.savez(save_loc, data=normalized_stack, filenames=filenames_batch)
978
+ batch_index += 1
979
+ del stack, normalized_stack
980
+ stack_ls = []
981
+ filenames_batch = []
982
+ padded_stack_ls = []
983
+ print(f'All files concatenated and normalized. Saved to: {output_fldr}')
984
+ return output_fldr
985
+
986
+ def _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
987
+ """
988
+ Normalize the stack of images.
989
+
990
+ Args:
991
+ stack (numpy.ndarray): The stack of images to normalize.
992
+ backgrounds (list): Background values for each channel.
993
+ remove_backgrounds (list): Whether to remove background values for each channel.
994
+ lower_percentile (int): Lower percentile value for normalization.
995
+ save_dtype (numpy.dtype): Data type for saving the normalized stack.
996
+ signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
997
+ signal_thresholds (list): Signal thresholds for each channel.
998
+
999
+ Returns:
1000
+ numpy.ndarray: The normalized stack.
1001
+ """
1002
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1003
+
1004
+ for chan_index, channel in enumerate(range(stack.shape[-1])):
1005
+ single_channel = stack[:, :, :, channel]
1006
+ background = backgrounds[chan_index]
1007
+ signal_threshold = signal_thresholds[chan_index]
1008
+ remove_background = remove_backgrounds[chan_index]
1009
+
1010
+ print(f'Processing channel {chan_index}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
1011
+
1012
+ # Step 3: Remove background if required
1013
+ if remove_background:
1014
+ single_channel[single_channel < background] = 0
1015
+
1016
+ # Step 4: Calculate global lower percentile for the channel
1017
+ non_zero_single_channel = single_channel[single_channel != 0]
1018
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1019
+
1020
+ # Step 5: Calculate global upper percentile for the channel
1021
+ global_upper = None
1022
+ for upper_p in np.linspace(98, 99.5, num=16):
1023
+ upper_value = np.percentile(non_zero_single_channel, upper_p)
1024
+ if upper_value >= signal_threshold:
1025
+ global_upper = upper_value
1026
+ break
1027
+
1028
+ if global_upper is None:
1029
+ global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
1030
+
1031
+ print(f'Channel {chan_index}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
1032
+
1033
+ # Step 6: Normalize each array from global_lower to global_upper between 0 and 1
1034
+ for array_index in range(single_channel.shape[0]):
1035
+ arr_2d = single_channel[array_index, :, :]
1036
+ arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
1037
+ normalized_stack[array_index, :, :, channel] = arr_2d_normalized
1038
+
1039
+ return normalized_stack.astype(save_dtype)
1040
+
1041
+ def _normalize_img_batch_v1(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
1042
+ """
1043
+ Normalize the stack of images.
1044
+
1045
+ Args:
1046
+ stack (numpy.ndarray): The stack of images to normalize.
1047
+ backgrounds (list): Background values for each channel.
1048
+ remove_backgrounds (list): Whether to remove background values for each channel.
1049
+ lower_percentile (int): Lower percentile value for normalization.
1050
+ save_dtype (numpy.dtype): Data type for saving the normalized stack.
1051
+ signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
1052
+ signal_thresholds (list): Signal thresholds for each channel.
1053
+
1054
+ Returns:
1055
+ numpy.ndarray: The normalized stack.
1056
+ """
1057
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1058
+ time_ls = []
1059
+
1060
+ for chan_index, channel in enumerate(range(stack.shape[-1])):
1061
+ single_channel = stack[:, :, :, channel]
1062
+ background = backgrounds[chan_index]
1063
+ signal_threshold = signal_thresholds[chan_index]
1064
+ remove_background = remove_backgrounds[chan_index]
1065
+ signal_2_noise = signal_to_noise[chan_index]
1066
+ print(f'chan_index:{chan_index} background:{background} signal_threshold:{signal_threshold} remove_background:{remove_background} signal_2_noise:{signal_2_noise}')
1067
+
1068
+ if remove_background:
1069
+ single_channel[single_channel < background] = 0
1070
+
1071
+ non_zero_single_channel = single_channel[single_channel != 0]
1072
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1073
+ for upper_p in np.linspace(98, 99.5, num=20).tolist():
1074
+ global_upper = np.percentile(non_zero_single_channel, upper_p)
1075
+ if global_upper >= signal_threshold:
1076
+ break
1077
+
1078
+ arr_2d_normalized = np.zeros_like(single_channel, dtype=single_channel.dtype)
1079
+ signal_to_noise_ratio_ls = []
1080
+ for array_index in range(single_channel.shape[0]):
1081
+ start = time.time()
1082
+ arr_2d = single_channel[array_index, :, :]
1083
+ non_zero_arr_2d = arr_2d[arr_2d != 0]
1084
+ if non_zero_arr_2d.size > 0:
1085
+ lower, upper = np.percentile(non_zero_arr_2d, (lower_percentile, upper_p))
1086
+ signal_to_noise_ratio = upper / lower
1087
+ else:
1088
+ signal_to_noise_ratio = 0
1089
+ signal_to_noise_ratio_ls.append(signal_to_noise_ratio)
1090
+ average_stnr = np.mean(signal_to_noise_ratio_ls) if len(signal_to_noise_ratio_ls) > 0 else 0
1091
+
1092
+ if signal_to_noise_ratio > signal_2_noise:
1093
+ arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(0, 1))
1094
+ arr_2d_normalized[array_index, :, :] = arr_2d_rescaled
1095
+ else:
1096
+ arr_2d_normalized[array_index, :, :] = arr_2d
1097
+ stop = time.time()
1098
+ duration = (stop - start) * single_channel.shape[0]
1099
+ time_ls.append(duration)
1100
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1101
+ print(f'Progress: channels:{chan_index}/{stack.shape[-1] - 1}, arrays:{array_index + 1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1102
+
1103
+ normalized_stack[:, :, :, channel] = arr_2d_normalized
1104
+
1105
+ return normalized_stack.astype(save_dtype)
1106
+
910
1107
  def _get_lists_for_normalization(settings):
911
1108
  """
912
1109
  Get lists for normalization based on the provided settings.
@@ -921,7 +1118,8 @@ def _get_lists_for_normalization(settings):
921
1118
  # Initialize the lists
922
1119
  backgrounds = []
923
1120
  signal_to_noise = []
924
- signal_thresholds = []
1121
+ signal_thresholds = []
1122
+ remove_background = []
925
1123
 
926
1124
  # Iterate through the channels and append the corresponding values if the channel is not None
927
1125
  for ch in settings['channels']:
@@ -929,29 +1127,31 @@ def _get_lists_for_normalization(settings):
929
1127
  backgrounds.append(settings['nucleus_background'])
930
1128
  signal_to_noise.append(settings['nucleus_Signal_to_noise'])
931
1129
  signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
1130
+ remove_background.append(settings['remove_background_nucleus'])
932
1131
  elif ch == settings['cell_channel']:
933
1132
  backgrounds.append(settings['cell_background'])
934
1133
  signal_to_noise.append(settings['cell_Signal_to_noise'])
935
1134
  signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
1135
+ remove_background.append(settings['remove_background_cell'])
936
1136
  elif ch == settings['pathogen_channel']:
937
1137
  backgrounds.append(settings['pathogen_background'])
938
1138
  signal_to_noise.append(settings['pathogen_Signal_to_noise'])
939
1139
  signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
940
- return backgrounds, signal_to_noise, signal_thresholds
1140
+ remove_background.append(settings['remove_background_pathogen'])
1141
+ return backgrounds, signal_to_noise, signal_thresholds, remove_background
941
1142
 
942
- def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lower_quantile=0.01, save_dtype=np.float32, signal_to_noise=[5,5,5], signal_thresholds=[1000,1000,1000], correct_illumination=False):
1143
+ def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
943
1144
  """
944
1145
  Normalize the stack of images.
945
1146
 
946
1147
  Args:
947
1148
  src (str): The source directory containing the stack of images.
948
- backgrounds (list, optional): Background values for each channel. Defaults to [100,100,100].
949
- remove_background (bool, optional): Whether to remove background values. Defaults to False.
950
- lower_quantile (float, optional): Lower quantile value for normalization. Defaults to 0.01.
1149
+ backgrounds (list, optional): Background values for each channel. Defaults to [100, 100, 100].
1150
+ remove_background (list, optional): Whether to remove background values for each channel. Defaults to [False, False, False].
1151
+ lower_percentile (int, optional): Lower percentile value for normalization. Defaults to 2.
951
1152
  save_dtype (numpy.dtype, optional): Data type for saving the normalized stack. Defaults to np.float32.
952
- signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5,5,5].
953
- signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000,1000,1000].
954
- correct_illumination (bool, optional): Whether to correct illumination. Defaults to False.
1153
+ signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5, 5, 5].
1154
+ signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000, 1000, 1000].
955
1155
 
956
1156
  Returns:
957
1157
  None
@@ -960,11 +1160,13 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
960
1160
  output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
961
1161
  os.makedirs(output_fldr, exist_ok=True)
962
1162
  time_ls = []
1163
+
963
1164
  for file_index, path in enumerate(paths):
964
1165
  with np.load(path) as data:
965
1166
  stack = data['data']
966
1167
  filenames = data['filenames']
967
- normalized_stack = np.zeros_like(stack, dtype=stack.dtype)
1168
+
1169
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
968
1170
  file = os.path.basename(path)
969
1171
  name, _ = os.path.splitext(file)
970
1172
 
@@ -972,24 +1174,22 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
972
1174
  single_channel = stack[:, :, :, channel]
973
1175
  background = backgrounds[chan_index]
974
1176
  signal_threshold = signal_thresholds[chan_index]
975
- #print(f'signal_threshold:{signal_threshold} in {signal_thresholds} for {chan_index}')
976
-
1177
+ remove_background = remove_backgrounds[chan_index]
977
1178
  signal_2_noise = signal_to_noise[chan_index]
1179
+ print(f'chan_index:{chan_index} background:{background} signal_threshold:{signal_threshold} remove_background:{remove_background} signal_2_noise:{signal_2_noise}')
1180
+
978
1181
  if remove_background:
979
1182
  single_channel[single_channel < background] = 0
980
- if correct_illumination:
981
- bg = filters.gaussian(single_channel, sigma=50)
982
- single_channel = single_channel - bg
983
1183
 
984
- #Calculate the global lower and upper quantiles for non-zero pixels
1184
+ # Calculate the global lower and upper percentiles for non-zero pixels
985
1185
  non_zero_single_channel = single_channel[single_channel != 0]
986
- global_lower = np.quantile(non_zero_single_channel, lower_quantile)
987
- for upper_p in np.linspace(0.98, 1.0, num=100).tolist():
988
- global_upper = np.quantile(non_zero_single_channel, upper_p)
1186
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1187
+ for upper_p in np.linspace(98, 100, num=100).tolist():
1188
+ global_upper = np.percentile(non_zero_single_channel, upper_p)
989
1189
  if global_upper >= signal_threshold:
990
1190
  break
991
1191
 
992
- #Normalize the pixels in each image to the global quantiles and then dtype.
1192
+ # Normalize the pixels in each image to the global percentiles and then dtype.
993
1193
  arr_2d_normalized = np.zeros_like(single_channel, dtype=single_channel.dtype)
994
1194
  signal_to_noise_ratio_ls = []
995
1195
  for array_index in range(single_channel.shape[0]):
@@ -997,40 +1197,40 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
997
1197
  arr_2d = single_channel[array_index, :, :]
998
1198
  non_zero_arr_2d = arr_2d[arr_2d != 0]
999
1199
  if non_zero_arr_2d.size > 0:
1000
- lower, upper = np.quantile(non_zero_arr_2d, (lower_quantile, upper_p))
1001
- signal_to_noise_ratio = upper/lower
1200
+ lower, upper = np.percentile(non_zero_arr_2d, (lower_percentile, upper_p))
1201
+ signal_to_noise_ratio = upper / lower
1002
1202
  else:
1003
1203
  signal_to_noise_ratio = 0
1004
1204
  signal_to_noise_ratio_ls.append(signal_to_noise_ratio)
1005
1205
  average_stnr = np.mean(signal_to_noise_ratio_ls) if len(signal_to_noise_ratio_ls) > 0 else 0
1006
1206
 
1007
1207
  if signal_to_noise_ratio > signal_2_noise:
1008
- arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(global_lower, global_upper))
1208
+ arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(0, 1))
1009
1209
  arr_2d_normalized[array_index, :, :] = arr_2d_rescaled
1010
1210
  else:
1011
1211
  arr_2d_normalized[array_index, :, :] = arr_2d
1012
1212
  stop = time.time()
1013
- duration = (stop - start)*single_channel.shape[0]
1213
+ duration = (stop - start) * single_channel.shape[0]
1014
1214
  time_ls.append(duration)
1015
1215
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1016
- clear_output(wait=True)
1017
- print(f'\033[KProgress: files {file_index+1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1]-1}, arrays:{array_index+1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec', end='\r', flush=True)
1018
- normalized_single_channel = exposure.rescale_intensity(arr_2d_normalized, out_range='dtype')
1019
- normalized_stack[:, :, :, channel] = normalized_single_channel
1020
- save_loc = output_fldr+'/'+name+'_norm_stack.npz'
1021
- normalized_stack = normalized_stack.astype(save_dtype)
1022
- np.savez(save_loc, data=normalized_stack, filenames=filenames)
1023
- del normalized_stack, single_channel, normalized_single_channel, stack, filenames
1216
+ print(f'Progress: files {file_index + 1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1] - 1}, arrays:{array_index + 1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1217
+
1218
+ normalized_stack[:, :, :, channel] = arr_2d_normalized
1219
+
1220
+ save_loc = os.path.join(output_fldr, f'{name}_norm_stack.npz')
1221
+ np.savez(save_loc, data=normalized_stack.astype(save_dtype), filenames=filenames)
1222
+ del normalized_stack, single_channel, arr_2d_normalized, stack, filenames
1024
1223
  gc.collect()
1025
- return print(f'Saved stacks:{output_fldr}')
1224
+
1225
+ return print(f'Saved stacks: {output_fldr}')
1026
1226
 
1027
- def _normalize_timelapse(src, lower_quantile=0.01, save_dtype=np.float32):
1227
+ def _normalize_timelapse(src, lower_percentile=2, save_dtype=np.float32):
1028
1228
  """
1029
1229
  Normalize the timelapse data by rescaling the intensity values based on percentiles.
1030
1230
 
1031
1231
  Args:
1032
1232
  src (str): The source directory containing the timelapse data files.
1033
- lower_quantile (float, optional): The lower quantile used to calculate the intensity range. Defaults to 0.01.
1233
+ lower_percentile (int, optional): The lower percentile used to calculate the intensity range. Defaults to 1.
1034
1234
  save_dtype (numpy.dtype, optional): The data type to save the normalized stack. Defaults to np.float32.
1035
1235
  """
1036
1236
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
@@ -1052,7 +1252,7 @@ def _normalize_timelapse(src, lower_quantile=0.01, save_dtype=np.float32):
1052
1252
  for array_index in range(single_channel.shape[0]):
1053
1253
  arr_2d = single_channel[array_index]
1054
1254
  # Calculate the 1% and 98% percentiles for this specific image
1055
- q_low = np.percentile(arr_2d[arr_2d != 0], 2)
1255
+ q_low = np.percentile(arr_2d[arr_2d != 0], lower_percentile)
1056
1256
  q_high = np.percentile(arr_2d[arr_2d != 0], 98)
1057
1257
 
1058
1258
  # Rescale intensity based on the calculated percentiles to fill the dtype range
@@ -1069,8 +1269,6 @@ def _normalize_timelapse(src, lower_quantile=0.01, save_dtype=np.float32):
1069
1269
 
1070
1270
  print(f'\nSaved normalized stacks: {output_fldr}')
1071
1271
 
1072
-
1073
-
1074
1272
  def _create_movies_from_npy_per_channel(src, fps=10):
1075
1273
  """
1076
1274
  Create movies from numpy files per channel.
@@ -1122,9 +1320,33 @@ def _create_movies_from_npy_per_channel(src, fps=10):
1122
1320
  channel_save_path = os.path.join(save_path, f'{plate}_{well}_{field}_channel_{channel}.mp4')
1123
1321
  _npz_to_movie(normalized_channel_arrays_3d, filenames, channel_save_path, fps)
1124
1322
 
1323
+ def delete_empty_subdirectories(folder_path):
1324
+ """
1325
+ Deletes all empty subdirectories in the specified folder.
1326
+
1327
+ Args:
1328
+ - folder_path (str): The path to the folder in which to look for empty subdirectories.
1329
+ """
1330
+ # Check each item in the specified folder
1331
+ for dirpath, dirnames, filenames in os.walk(folder_path, topdown=False):
1332
+ # os.walk is used with topdown=False to start from the innermost directories and work upwards.
1333
+ for dirname in dirnames:
1334
+ # Construct the full path to the subdirectory
1335
+ full_dir_path = os.path.join(dirpath, dirname)
1336
+ # Try to remove the directory and catch any error (like if the directory is not empty)
1337
+ try:
1338
+ os.rmdir(full_dir_path)
1339
+ print(f"Deleted empty directory: {full_dir_path}")
1340
+ except OSError as e:
1341
+ continue
1342
+ # An error occurred, likely because the directory is not empty
1343
+ #print(f"Skipping non-empty directory: {full_dir_path}")
1344
+
1345
+ #@log_function_call
1125
1346
  def preprocess_img_data(settings):
1126
1347
 
1127
1348
  from .plot import plot_arrays, _plot_4D_arrays
1349
+ from .utils import _run_test_mode, _get_regex, set_default_settings_preprocess_img_data
1128
1350
 
1129
1351
  """
1130
1352
  Preprocesses image data by converting z-stack images to maximum intensity projection (MIP) images.
@@ -1143,9 +1365,8 @@ def preprocess_img_data(settings):
1143
1365
  timelapse (bool, optional): Whether the images are from a timelapse experiment. Defaults to False.
1144
1366
  remove_background (bool, optional): Whether to remove the background from the images. Defaults to False.
1145
1367
  backgrounds (int, optional): The number of background images to use for background removal. Defaults to 100.
1146
- lower_quantile (float, optional): The lower quantile used for background removal. Defaults to 0.01.
1368
+ lower_percentile (float, optional): The lower percentile used for background removal. Defaults to 1.
1147
1369
  save_dtype (type, optional): The data type used for saving the preprocessed images. Defaults to np.float32.
1148
- correct_illumination (bool, optional): Whether to correct the illumination of the images. Defaults to False.
1149
1370
  randomize (bool, optional): Whether to randomize the order of the images. Defaults to True.
1150
1371
  all_to_mip (bool, optional): Whether to convert all images to MIP. Defaults to False.
1151
1372
  pick_slice (bool, optional): Whether to pick a specific slice based on the provided skip mode. Defaults to False.
@@ -1155,12 +1376,16 @@ def preprocess_img_data(settings):
1155
1376
  Returns:
1156
1377
  None
1157
1378
  """
1379
+
1158
1380
  src = settings['src']
1159
1381
  valid_ext = ['tif', 'tiff', 'png', 'jpeg']
1160
1382
  files = os.listdir(src)
1161
1383
  extensions = [file.split('.')[-1] for file in files]
1162
1384
  extension_counts = Counter(extensions)
1163
1385
  most_common_extension = extension_counts.most_common(1)[0][0]
1386
+ img_format = None
1387
+
1388
+ delete_empty_subdirectories(src)
1164
1389
 
1165
1390
  # Check if the most common extension is one of the specified image formats
1166
1391
  if most_common_extension in valid_ext:
@@ -1168,109 +1393,94 @@ def preprocess_img_data(settings):
1168
1393
  print(f'Found {extension_counts[most_common_extension]} {most_common_extension} files')
1169
1394
  else:
1170
1395
  print(f'Could not find any {valid_ext} files in {src} only found {extension_counts[0]}')
1171
- return
1172
-
1173
- cmap = 'inferno'
1174
- figuresize = 20
1175
- normalize = True
1176
- save_dtype = 'uint16'
1177
- correct_illumination = False
1178
-
1179
- mask_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
1180
- backgrounds = [settings['nucleus_background'], settings['pathogen_background'], settings['cell_background']]
1181
-
1182
- metadata_type = settings['metadata_type']
1183
- custom_regex = settings['custom_regex']
1184
- nr = settings['examples_to_plot']
1185
- plot = settings['plot']
1186
- batch_size = settings['batch_size']
1187
- timelapse = settings['timelapse']
1188
- remove_background = settings['remove_background']
1189
- lower_quantile = settings['lower_quantile']
1190
- randomize = settings['randomize']
1191
- all_to_mip = settings['all_to_mip']
1192
- pick_slice = settings['pick_slice']
1193
- skip_mode = settings['skip_mode']
1194
-
1195
- if metadata_type == 'cellvoyager':
1196
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1197
- elif metadata_type == 'cq1':
1198
- regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1199
- elif metadata_type == 'nikon':
1200
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1201
- elif metadata_type == 'zeis':
1202
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1203
- elif metadata_type == 'leica':
1204
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1205
- elif metadata_type == 'custom':
1206
- regex = f'({custom_regex}){img_format}'
1396
+ if os.path.exists(src+'/stack'):
1397
+ print('Found existing stack folder.')
1398
+ if os.path.exists(src+'/channel_stack'):
1399
+ print('Found existing channel_stack folder.')
1400
+ if os.path.exists(src+'/norm_channel_stack'):
1401
+ print('Found existing norm_channel_stack folder. Skipping preprocessing')
1402
+ return settings, src
1207
1403
 
1208
- print(f'regex mode:{metadata_type} regex:{regex}')
1404
+ mask_channels = [settings['nucleus_channel'], settings['cell_channel'], settings['pathogen_channel']]
1405
+ backgrounds = [settings['nucleus_background'], settings['cell_background'], settings['pathogen_background']]
1406
+
1407
+ settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test = set_default_settings_preprocess_img_data(settings)
1408
+
1409
+ regex = _get_regex(metadata_type, img_format, custom_regex)
1410
+
1411
+ if test_mode:
1412
+
1413
+ print(f'Running spacr in test mode')
1414
+ settings['plot'] = True
1415
+ try:
1416
+ os.rmdir(os.path.join(src, 'test'))
1417
+ print(f"Deleted test directory: {os.path.join(src, 'test')}")
1418
+ except OSError as e:
1419
+ pass
1420
+
1421
+ src = _run_test_mode(settings['src'], regex, timelapse, test_images, random_test)
1422
+ settings['src'] = src
1423
+
1424
+ if img_format == None:
1425
+ if not os.path.exists(src+'/stack'):
1426
+ _merge_channels(src, plot=False)
1209
1427
 
1210
1428
  if not os.path.exists(src+'/stack'):
1211
- if timelapse:
1212
- _move_to_chan_folder(src, regex, timelapse, metadata_type)
1213
- else:
1214
- #_z_to_mip(src, regex, batch_size, pick_slice, skip_mode, metadata_type, img_format)
1215
- _rename_and_organize_image_files(src, regex, batch_size, pick_slice, skip_mode, metadata_type, img_format)
1216
-
1217
- #Make sure no batches will be of only one image
1218
- all_imgs = len(src+'/stack')
1219
- full_batches = all_imgs // batch_size
1220
- last_batch_size = all_imgs % batch_size
1221
-
1222
- # Check if the last batch is of size 1
1223
- if last_batch_size == 1:
1224
- # If there's only one batch and its size is 1, it's also an issue
1225
- if full_batches == 0:
1226
- raise ValueError("Only one batch of size 1 detected. Adjust the batch size.")
1227
- # If the last batch is of size 1, merge it with the second last batch
1228
- elif full_batches > 0:
1229
- raise ValueError("Last batch of size 1 detected. Adjust the batch size.")
1230
-
1231
- _merge_channels(src, plot=False)
1232
- if timelapse:
1233
- _create_movies_from_npy_per_channel(src+'/stack', fps=2)
1234
-
1235
- if plot:
1236
- print(f'plotting {nr} images from {src}/stack')
1237
- plot_arrays(src+'/stack', figuresize, cmap, nr=nr, normalize=normalize)
1238
- if all_to_mip:
1239
- _mip_all(src+'/stack')
1240
- if plot:
1241
- print(f'plotting {nr} images from {src}/stack')
1242
- plot_arrays(src+'/stack', figuresize, cmap, nr=nr, normalize=normalize)
1243
- #nr_of_stacks = len(src+'/channel_stack')
1244
-
1245
- _concatenate_channel(src+'/stack',
1246
- channels=mask_channels,
1247
- randomize=randomize,
1248
- timelapse=timelapse,
1249
- batch_size=batch_size)
1250
-
1251
- if plot:
1252
- print(f'plotting {nr} images from {src}/channel_stack')
1253
- _plot_4D_arrays(src+'/channel_stack', figuresize, cmap, nr_npz=1, nr=nr)
1254
- nr_of_chan_stacks = len(src+'/channel_stack')
1255
-
1256
- backgrounds, signal_to_noise, signal_thresholds = _get_lists_for_normalization(settings=settings)
1257
-
1258
- if not timelapse:
1259
- _normalize_stack(src+'/channel_stack',
1260
- backgrounds=backgrounds,
1261
- lower_quantile=lower_quantile,
1262
- save_dtype=save_dtype,
1263
- signal_thresholds=signal_thresholds,
1264
- correct_illumination=correct_illumination,
1265
- signal_to_noise=signal_to_noise,
1266
- remove_background=remove_background)
1267
- else:
1268
- _normalize_timelapse(src+'/channel_stack', lower_quantile=lower_quantile, save_dtype=np.float32)
1429
+ try:
1430
+ if not img_format == None:
1431
+ if timelapse:
1432
+ _move_to_chan_folder(src, regex, timelapse, metadata_type)
1433
+ else:
1434
+ _rename_and_organize_image_files(src, regex, batch_size, pick_slice, skip_mode, metadata_type, img_format)
1435
+
1436
+ #Make sure no batches will be of only one image
1437
+ all_imgs = len(src+'/stack')
1438
+ full_batches = all_imgs // batch_size
1439
+ last_batch_size = all_imgs % batch_size
1440
+
1441
+ # Check if the last batch is of size 1
1442
+ if last_batch_size == 1:
1443
+ # If there's only one batch and its size is 1, it's also an issue
1444
+ if full_batches == 0:
1445
+ raise ValueError("Only one batch of size 1 detected. Adjust the batch size.")
1446
+ # If the last batch is of size 1, merge it with the second last batch
1447
+ elif full_batches > 0:
1448
+ raise ValueError("Last batch of size 1 detected. Adjust the batch size.")
1269
1449
 
1450
+ _merge_channels(src, plot=False)
1451
+
1452
+ if timelapse:
1453
+ _create_movies_from_npy_per_channel(src+'/stack', fps=2)
1454
+
1455
+ if plot:
1456
+ print(f'plotting {nr} images from {src}/stack')
1457
+ plot_arrays(src+'/stack', figuresize, cmap, nr=nr, normalize=normalize)
1458
+ if all_to_mip:
1459
+ _mip_all(src+'/stack')
1460
+ if plot:
1461
+ print(f'plotting {nr} images from {src}/stack')
1462
+ plot_arrays(src+'/stack', figuresize, cmap, nr=nr, normalize=normalize)
1463
+ except Exception as e:
1464
+ print(f"Error: {e}")
1465
+
1466
+ backgrounds, signal_to_noise, signal_thresholds, remove_backgrounds = _get_lists_for_normalization(settings=settings)
1467
+
1468
+ concatenate_and_normalize(src+'/stack',
1469
+ mask_channels,
1470
+ randomize,
1471
+ timelapse,
1472
+ batch_size,
1473
+ backgrounds,
1474
+ remove_backgrounds,
1475
+ lower_percentile,
1476
+ np.float32,
1477
+ signal_to_noise,
1478
+ signal_thresholds)
1479
+
1270
1480
  if plot:
1271
1481
  _plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
1272
1482
 
1273
- return
1483
+ return settings, src
1274
1484
 
1275
1485
  def _check_masks(batch, batch_filenames, output_folder):
1276
1486
  """
@@ -1292,8 +1502,7 @@ def _check_masks(batch, batch_filenames, output_folder):
1292
1502
  filtered_filenames = [f for f, exists in zip(batch_filenames, existing_files_mask) if exists]
1293
1503
 
1294
1504
  return np.array(filtered_batch), filtered_filenames
1295
-
1296
-
1505
+
1297
1506
  def _get_avg_object_size(masks):
1298
1507
  """
1299
1508
  Calculate the average size of objects in a list of masks.
@@ -1321,27 +1530,6 @@ def _get_avg_object_size(masks):
1321
1530
  return sum(object_areas) / len(object_areas)
1322
1531
  else:
1323
1532
  return 0 # Return 0 if no objects are found
1324
-
1325
- def _save_figure_v1(fig, src, text, dpi=300, ):
1326
- """
1327
- Save a figure to a specified location.
1328
-
1329
- Parameters:
1330
- fig (matplotlib.figure.Figure): The figure to be saved.
1331
- src (str): The source file path.
1332
- text (str): The text to be included in the figure name.
1333
- dpi (int, optional): The resolution of the saved figure. Defaults to 300.
1334
- """
1335
- save_folder = os.path.dirname(src)
1336
- obj_type = os.path.basename(src)
1337
- name = os.path.basename(save_folder)
1338
- save_folder = os.path.join(save_folder, 'figure')
1339
- os.makedirs(save_folder, exist_ok=True)
1340
- fig_name = f'{obj_type}_{name}_{text}.pdf'
1341
- save_location = os.path.join(save_folder, fig_name)
1342
- fig.savefig(save_location, bbox_inches='tight', dpi=dpi)
1343
- print(f'Saved single cell figure: {save_location}')
1344
- plt.close()
1345
1533
 
1346
1534
  def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
1347
1535
  """
@@ -1362,7 +1550,8 @@ def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
1362
1550
  save_location = os.path.join(save_folder, fig_name)
1363
1551
  fig.savefig(save_location, bbox_inches='tight', dpi=dpi)
1364
1552
  clear_output(wait=True)
1365
- print(f'\033[KProgress: {i}/{all_folders}, Saved single cell figure: {os.path.basename(save_location)}', end='\r', flush=True)
1553
+ print(f'Progress: {i}/{all_folders}, Saved single cell figure: {os.path.basename(save_location)}')
1554
+ #print(f'Progress: {i}/{all_folders}, Saved single cell figure: {os.path.basename(save_location)}', end='\r', flush=True)
1366
1555
  # Close and delete the figure to free up memory
1367
1556
  plt.close(fig)
1368
1557
  del fig
@@ -1500,9 +1689,10 @@ def _save_mask_timelapse_as_gif(masks, tracks_df, path, cmap, norm, filenames):
1500
1689
  ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
1501
1690
 
1502
1691
  # Overlay tracks
1503
- for track in tracks_df['track_id'].unique():
1504
- _track = tracks_df[tracks_df['track_id'] == track]
1505
- ax.plot(_track['x'], _track['y'], '-w', linewidth=1)
1692
+ if tracks_df is not None:
1693
+ for track in tracks_df['track_id'].unique():
1694
+ _track = tracks_df[tracks_df['track_id'] == track]
1695
+ ax.plot(_track['x'], _track['y'], '-w', linewidth=1)
1506
1696
 
1507
1697
  anim = FuncAnimation(fig, _update, frames=len(masks), blit=False)
1508
1698
  anim.save(path, writer='pillow', fps=2, dpi=80) # Adjust DPI for size/quality
@@ -1616,56 +1806,65 @@ def _load_and_concatenate_arrays(src, channels, cell_chann_dim, nucleus_chann_di
1616
1806
 
1617
1807
  # Iterate through each file in the reference folder
1618
1808
  for filename in os.listdir(reference_folder):
1619
-
1620
1809
  stack_ls = []
1621
- array_path = []
1622
-
1623
1810
  if filename.endswith('.npy'):
1624
- count+=1
1625
- # Initialize the concatenated array with the array from the reference folder
1626
- concatenated_array = np.load(os.path.join(reference_folder, filename))
1627
- if channels is not None:
1628
- concatenated_array = np.take(concatenated_array, channels, axis=2)
1811
+ count += 1
1812
+
1813
+ # Check if this file exists in all the other specified folders
1814
+ exists_in_all_folders = all(os.path.isfile(os.path.join(folder, filename)) for folder in folder_paths)
1815
+
1816
+ if exists_in_all_folders:
1817
+ # Load and potentially modify the array from the reference folder
1818
+ ref_array_path = os.path.join(reference_folder, filename)
1819
+ concatenated_array = np.load(ref_array_path)
1820
+
1821
+ if channels is not None:
1822
+ concatenated_array = np.take(concatenated_array, channels, axis=2)
1823
+
1824
+ # Add the array from the reference folder to 'stack_ls'
1629
1825
  stack_ls.append(concatenated_array)
1630
- # For each of the other folders, load the array and concatenate it
1631
- for folder in folder_paths[1:]:
1632
- array_path = os.path.join(folder, filename)
1633
- if os.path.isfile(array_path):
1826
+
1827
+ # For each of the other folders, load the array and add it to 'stack_ls'
1828
+ for folder in folder_paths[1:]:
1829
+ array_path = os.path.join(folder, filename)
1634
1830
  array = np.load(array_path)
1635
1831
  if array.ndim == 2:
1636
- array = np.expand_dims(array, axis=-1) # add an extra dimension if the array is 2D
1832
+ array = np.expand_dims(array, axis=-1) # Add an extra dimension if the array is 2D
1637
1833
  stack_ls.append(array)
1638
1834
 
1639
- stack_ls = [np.expand_dims(arr, axis=-1) if arr.ndim == 2 else arr for arr in stack_ls]
1640
- unique_shapes = {arr.shape[:-1] for arr in stack_ls}
1641
- if len(unique_shapes) > 1:
1642
- #max_dims = np.max(np.array(list(unique_shapes)), axis=0)
1643
- # Determine the maximum length of tuples in unique_shapes
1644
- max_tuple_length = max(len(shape) for shape in unique_shapes)
1645
- # Pad shorter tuples with zeros to make them all the same length
1646
- padded_shapes = [shape + (0,) * (max_tuple_length - len(shape)) for shape in unique_shapes]
1647
- # Now create a NumPy array and find the maximum dimensions
1648
- max_dims = np.max(np.array(padded_shapes), axis=0)
1649
- clear_output(wait=True)
1650
- print(f'\033[KWarning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
1651
- padded_stack_ls = []
1652
- for arr in stack_ls:
1653
- pad_width = [(0, max_dim - dim) for max_dim, dim in zip(max_dims, arr.shape[:-1])]
1654
- pad_width.append((0, 0))
1655
- padded_arr = np.pad(arr, pad_width)
1656
- padded_stack_ls.append(padded_arr)
1657
- # Concatenate the padded arrays along the channel dimension (last dimension)
1658
- stack = np.concatenate(padded_stack_ls, axis=-1)
1835
+ if len(stack_ls) > 0:
1836
+ stack_ls = [np.expand_dims(arr, axis=-1) if arr.ndim == 2 else arr for arr in stack_ls]
1837
+ unique_shapes = {arr.shape[:-1] for arr in stack_ls}
1838
+ if len(unique_shapes) > 1:
1839
+ #max_dims = np.max(np.array(list(unique_shapes)), axis=0)
1840
+ # Determine the maximum length of tuples in unique_shapes
1841
+ max_tuple_length = max(len(shape) for shape in unique_shapes)
1842
+ # Pad shorter tuples with zeros to make them all the same length
1843
+ padded_shapes = [shape + (0,) * (max_tuple_length - len(shape)) for shape in unique_shapes]
1844
+ # Now create a NumPy array and find the maximum dimensions
1845
+ max_dims = np.max(np.array(padded_shapes), axis=0)
1846
+ #clear_output(wait=True)
1847
+ print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}')
1848
+ #print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
1849
+ padded_stack_ls = []
1850
+ for arr in stack_ls:
1851
+ pad_width = [(0, max_dim - dim) for max_dim, dim in zip(max_dims, arr.shape[:-1])]
1852
+ pad_width.append((0, 0))
1853
+ padded_arr = np.pad(arr, pad_width)
1854
+ padded_stack_ls.append(padded_arr)
1855
+ # Concatenate the padded arrays along the channel dimension (last dimension)
1856
+ stack = np.concatenate(padded_stack_ls, axis=-1)
1659
1857
 
1660
- else:
1661
- stack = np.concatenate(stack_ls, axis=-1)
1858
+ else:
1859
+ stack = np.concatenate(stack_ls, axis=-1)
1662
1860
 
1663
- if stack.shape[-1] > concatenated_array.shape[-1]:
1664
- output_path = os.path.join(output_folder, filename)
1665
- np.save(output_path, stack)
1861
+ if stack.shape[-1] > concatenated_array.shape[-1]:
1862
+ output_path = os.path.join(output_folder, filename)
1863
+ np.save(output_path, stack)
1666
1864
 
1667
- clear_output(wait=True)
1668
- #print(f'\033[KFiles merged: {count}/{all_imgs}', end='\r', flush=True)
1865
+ #clear_output(wait=True)
1866
+ print(f'Files merged: {count}/{all_imgs}')
1867
+ #print(f'Files merged: {count}/{all_imgs}', end='\r', flush=True)
1669
1868
  return
1670
1869
 
1671
1870
  def _read_db(db_loc, tables):
@@ -2139,133 +2338,85 @@ def _read_mask(mask_path):
2139
2338
  if mask.dtype != np.uint16:
2140
2339
  mask = img_as_uint(mask)
2141
2340
  return mask
2341
+
2342
+
2343
+ def convert_numpy_to_tiff(folder_path, limit=None):
2344
+ """
2345
+ Converts all numpy files in a folder to TIFF format and saves them in a subdirectory 'tiff'.
2142
2346
 
2347
+ Args:
2348
+ folder_path (str): The path to the folder containing numpy files.
2349
+ """
2350
+ # Create the subdirectory 'tiff' within the specified folder if it doesn't already exist
2351
+ tiff_subdir = os.path.join(folder_path, 'tiff')
2352
+ os.makedirs(tiff_subdir, exist_ok=True)
2353
+
2354
+ files = os.listdir(folder_path)
2355
+
2356
+ npy_files = [f for f in files if f.endswith('.npy')]
2143
2357
 
2358
+ # Iterate over all files in the folder
2359
+ for i, filename in enumerate(files):
2360
+ if limit is not None and i >= limit:
2361
+ break
2362
+ if not filename.endswith('.npy'):
2363
+ continue
2364
+
2365
+ # Construct the full file path
2366
+ file_path = os.path.join(folder_path, filename)
2367
+ # Load the numpy file
2368
+ numpy_array = np.load(file_path)
2369
+
2370
+ # Construct the output TIFF file path
2371
+ tiff_filename = os.path.splitext(filename)[0] + '.tif'
2372
+ tiff_file_path = os.path.join(tiff_subdir, tiff_filename)
2373
+
2374
+ # Save the numpy array as a TIFF file
2375
+ tifffile.imwrite(tiff_file_path, numpy_array)
2376
+
2377
+ print(f"Converted {filename} to {tiff_filename} and saved in 'tiff' subdirectory.")
2378
+ return
2144
2379
 
2145
-
2146
-
2147
-
2148
-
2149
-
2150
-
2151
-
2152
-
2153
-
2154
-
2155
-
2156
-
2157
-
2158
-
2159
-
2160
-
2161
-
2162
-
2163
-
2164
-
2165
-
2166
-
2167
-
2168
-
2169
-
2170
-
2171
-
2172
-
2173
-
2174
-
2175
-
2176
-
2177
-
2178
-
2179
-
2180
-
2181
-
2182
-
2183
-
2184
-
2185
-
2186
-
2187
-
2188
-
2189
-
2190
-
2191
-
2192
-
2193
-
2194
-
2195
-
2196
-
2197
-
2198
-
2199
-
2200
-
2201
-
2202
-
2203
-
2204
-
2205
-
2206
-
2207
-
2208
-
2209
-
2210
-
2211
-
2212
-
2213
-
2214
-
2215
-
2216
-
2217
-
2218
-
2219
-
2220
-
2221
-
2222
-
2223
-
2224
-
2225
-
2226
-
2227
-
2228
-
2229
-
2230
-
2231
-
2232
-
2233
-
2234
-
2235
-
2236
-
2237
-
2238
-
2239
-
2240
-
2241
-
2242
-
2243
-
2244
-
2245
-
2246
-
2247
-
2248
-
2249
-
2250
-
2251
-
2252
-
2253
-
2254
-
2255
-
2256
-
2257
-
2258
-
2259
-
2260
-
2261
-
2262
-
2263
-
2264
-
2265
-
2266
-
2267
-
2268
-
2269
-
2270
-
2271
-
2380
+ def generate_cellpose_train_test(src, test_split=0.1):
2381
+ mask_src = os.path.join(src, 'masks')
2382
+ img_paths = glob.glob(os.path.join(src, '*.tif'))
2383
+ img_filenames = [os.path.basename(file) for file in img_paths]
2384
+ img_filenames = [file for file in img_filenames if os.path.exists(os.path.join(mask_src, file))]
2385
+ print(f'Found {len(img_filenames)} images with masks')
2386
+
2387
+ random.shuffle(img_filenames)
2388
+ split_index = int(len(img_filenames) * test_split)
2389
+ train_files = img_filenames[split_index:]
2390
+ test_files = img_filenames[:split_index]
2391
+ list_of_lists = [test_files, train_files]
2392
+ print(f'Split dataset into Train {len(train_files)} and Test {len(test_files)} files')
2393
+
2394
+ train_dir = os.path.join(os.path.dirname(src), 'train')
2395
+ train_dir_masks = os.path.join(train_dir, 'masks')
2396
+ test_dir = os.path.join(os.path.dirname(src), 'test')
2397
+ test_dir_masks = os.path.join(test_dir, 'masks')
2398
+
2399
+ os.makedirs(train_dir, exist_ok=True)
2400
+ os.makedirs(train_dir_masks, exist_ok=True)
2401
+ os.makedirs(test_dir, exist_ok=True)
2402
+ os.makedirs(test_dir_masks, exist_ok=True)
2403
+
2404
+ for i, ls in enumerate(list_of_lists):
2405
+ if i == 0:
2406
+ dst = test_dir
2407
+ dst_mask = test_dir_masks
2408
+ _type = 'Test'
2409
+ else:
2410
+ dst = train_dir
2411
+ dst_mask = train_dir_masks
2412
+ _type = 'Train'
2413
+
2414
+ for idx, filename in enumerate(ls):
2415
+ img_path = os.path.join(src, filename)
2416
+ mask_path = os.path.join(mask_src, filename)
2417
+ new_img_path = os.path.join(dst, filename)
2418
+ new_mask_path = os.path.join(dst_mask, filename)
2419
+ shutil.copy(img_path, new_img_path)
2420
+ shutil.copy(mask_path, new_mask_path)
2421
+ print(f'Copied {idx+1}/{len(ls)} images to {_type} set', end='\r', flush=True)
2422
+