spacr 0.0.36__py3-none-any.whl → 0.0.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py CHANGED
@@ -19,7 +19,6 @@ from io import BytesIO
19
19
  from IPython.display import display, clear_output
20
20
  from multiprocessing import Pool, cpu_count
21
21
  from torch.utils.data import Dataset
22
- import seaborn as sns
23
22
  import matplotlib.pyplot as plt
24
23
  from torchvision.transforms import ToTensor
25
24
 
@@ -88,15 +87,13 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
88
87
  print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {labels[0].shape}, image type: labels[0].shape')
89
88
  return images, labels, image_names, label_names
90
89
 
91
- def _load_normalized_images_and_labels(image_files, label_files, signal_thresholds=[1000], channels=None, percentiles=None, circular=False, invert=False, visualize=False):
90
+ def _load_normalized_images_and_labels(image_files, label_files, channels=None, percentiles=None, circular=False, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10):
92
91
 
93
92
  from .plot import normalize_and_visualize
94
93
  from .utils import invert_image, apply_mask
95
-
96
- if isinstance(signal_thresholds, int):
97
- signal_thresholds = [signal_thresholds] * (len(channels) if channels is not None else 1)
98
- elif not isinstance(signal_thresholds, list):
99
- signal_thresholds = [signal_thresholds]
94
+
95
+ signal_thresholds = background*Signal_to_noise
96
+ lower_percentile = 2
100
97
 
101
98
  images = []
102
99
  labels = []
@@ -113,16 +110,18 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
113
110
 
114
111
  # Load images and check percentiles
115
112
  for i,img_file in enumerate(image_files):
116
- #print(img_file)
117
113
  image = cellpose.io.imread(img_file)
118
114
  if invert:
119
115
  image = invert_image(image)
120
116
  if circular:
121
117
  image = apply_mask(image, output_value=0)
122
- #print(image.shape)
118
+
123
119
  # If specific channels are specified, select them
124
120
  if channels is not None and image.ndim == 3:
125
121
  image = image[..., channels]
122
+
123
+ if remove_background:
124
+ image[image < background] = 0
126
125
 
127
126
  if image.ndim < 3:
128
127
  image = np.expand_dims(image, axis=-1)
@@ -130,11 +129,11 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
130
129
  images.append(image)
131
130
  if percentiles is None:
132
131
  for c in range(image.shape[-1]):
133
- p1 = np.percentile(image[..., c], 1)
132
+ p1 = np.percentile(image[..., c], lower_percentile)
134
133
  percentiles_1[c].append(p1)
135
- for percentile in [99, 99.9, 99.99, 99.999]:
134
+ for percentile in [98, 99, 99.9, 99.99, 99.999]:
136
135
  p = np.percentile(image[..., c], percentile)
137
- if p > signal_thresholds[min(c, len(signal_thresholds)-1)]:
136
+ if p > signal_thresholds:
138
137
  percentiles_99[c].append(p)
139
138
  break
140
139
 
@@ -143,8 +142,8 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
143
142
  for image in images:
144
143
  normalized_image = np.zeros_like(image, dtype=np.float32)
145
144
  for c in range(image.shape[-1]):
146
- high_p = np.percentile(image[..., c], percentiles[1])
147
145
  low_p = np.percentile(image[..., c], percentiles[0])
146
+ high_p = np.percentile(image[..., c], percentiles[1])
148
147
  normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(low_p, high_p), out_range=(0, 1))
149
148
  normalized_images.append(normalized_image)
150
149
  if visualize:
@@ -155,17 +154,20 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
155
154
  avg_p1 = [np.mean(p) for p in percentiles_1]
156
155
  avg_p99 = [np.mean(p) if len(p) > 0 else np.mean(percentiles_1[i]) for i, p in enumerate(percentiles_99)]
157
156
 
157
+ print(f'Average 1st percentiles: {avg_p1}, Average 99th percentiles: {avg_p99}')
158
+
158
159
  normalized_images = []
159
160
  for image in images:
160
161
  normalized_image = np.zeros_like(image, dtype=np.float32)
161
- for c in range(image.shape[-1]):
162
- normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
163
- normalized_images.append(normalized_image)
164
- if visualize:
165
- normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
162
+ for c in range(image.shape[-1]):
163
+ normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
164
+ normalized_images.append(normalized_image)
165
+ if visualize:
166
+ normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
166
167
 
167
168
  if not image_files is None:
168
169
  image_dir = os.path.dirname(image_files[0])
170
+
169
171
  else:
170
172
  image_dir = None
171
173
 
@@ -181,6 +183,7 @@ def _load_normalized_images_and_labels(image_files, label_files, signal_threshol
181
183
  return normalized_images, labels, image_names, label_names
182
184
 
183
185
  class CombineLoaders:
186
+
184
187
  """
185
188
  A class that combines multiple data loaders into a single iterator.
186
189
 
@@ -306,85 +309,6 @@ class NoClassDataset(Dataset):
306
309
  img = ToTensor()(img)
307
310
  # Return both the image and its filename
308
311
  return img, self.filenames[index]
309
-
310
- class MyDataset_v1(Dataset):
311
- """
312
- Custom dataset class for loading and processing image data.
313
-
314
- Args:
315
- data_dir (str): The directory path where the data is stored.
316
- loader_classes (list): List of class names.
317
- transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default is None.
318
- shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
319
- load_to_memory (bool, optional): Whether to load images into memory. Default is False.
320
-
321
- Attributes:
322
- data_dir (str): The directory path where the data is stored.
323
- classes (list): List of class names.
324
- transform (callable): A function/transform that takes in an PIL image and returns a transformed version.
325
- shuffle (bool): Whether to shuffle the dataset.
326
- load_to_memory (bool): Whether to load images into memory.
327
- filenames (list): List of file paths.
328
- labels (list): List of labels corresponding to each file.
329
- images (list): List of loaded images.
330
- image_cache (Cache): Cache object for storing loaded images.
331
-
332
- Methods:
333
- load_image: Load an image from file.
334
- __len__: Get the length of the dataset.
335
- shuffle_dataset: Shuffle the dataset.
336
- __getitem__: Get an item from the dataset.
337
-
338
- """
339
-
340
- def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, load_to_memory=False):
341
- from .utils import Cache
342
- self.data_dir = data_dir
343
- self.classes = loader_classes
344
- self.transform = transform
345
- self.shuffle = shuffle
346
- self.load_to_memory = load_to_memory
347
- self.filenames = []
348
- self.labels = []
349
- self.images = []
350
- self.image_cache = Cache(50)
351
- for class_name in self.classes:
352
- class_path = os.path.join(data_dir, class_name)
353
- class_files = [os.path.join(class_path, f) for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]
354
- self.filenames.extend(class_files)
355
- self.labels.extend([self.classes.index(class_name)] * len(class_files))
356
- if self.shuffle:
357
- self.shuffle_dataset()
358
- if self.load_to_memory:
359
- self.images = [self.load_image(f) for f in self.filenames]
360
-
361
- def load_image(self, img_path):
362
- img = self.image_cache.get(img_path)
363
- if img is None:
364
- img = Image.open(img_path).convert('RGB')
365
- self.image_cache.put(img_path, img)
366
- return img
367
-
368
- def _len__(self):
369
- return len(self.filenames)
370
-
371
- def shuffle_dataset(self):
372
- combined = list(zip(self.filenames, self.labels))
373
- random.shuffle(combined)
374
- self.filenames, self.labels = zip(*combined)
375
-
376
- def _getitem__(self, index):
377
- label = self.labels[index]
378
- filename = self.filenames[index]
379
- if self.load_to_memory:
380
- img = self.images[index]
381
- else:
382
- img = self.load_image(filename)
383
- if self.transform is not None:
384
- img = self.transform(img)
385
- else:
386
- img = ToTensor()(img)
387
- return img, label, filename
388
312
 
389
313
  class MyDataset(Dataset):
390
314
  """
@@ -602,64 +526,6 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
602
526
  shutil.move(os.path.join(src, filename), move)
603
527
  return
604
528
 
605
- def _merge_file_v1(chan_dirs, stack_dir, file):
606
- """
607
- Merge multiple channels into a single stack and save it as a numpy array.
608
-
609
- Args:
610
- chan_dirs (list): List of directories containing channel images.
611
- stack_dir (str): Directory to save the merged stack.
612
- file (str): File name of the channel image.
613
-
614
- Returns:
615
- None
616
- """
617
- chan1 = cv2.imread(str(file), -1)
618
- chan1 = np.expand_dims(chan1, axis=2)
619
- new_file = stack_dir / (file.stem + '.npy')
620
- if not new_file.exists():
621
- stack_dir.mkdir(exist_ok=True)
622
- channels = [chan1]
623
- for chan_dir in chan_dirs[1:]:
624
- img = cv2.imread(str(chan_dir / file.name), -1)
625
- chan = np.expand_dims(img, axis=2)
626
- channels.append(chan)
627
- stack = np.concatenate(channels, axis=2)
628
- np.save(new_file, stack)
629
-
630
- def _merge_file_v1(chan_dirs, stack_dir, file):
631
- """
632
- Merge multiple channels into a single stack and save it as a numpy array.
633
- Args:
634
- chan_dirs (list): List of directories containing channel images.
635
- stack_dir (str): Directory to save the merged stack.
636
- file (str): File name of the channel image.
637
-
638
- Returns:
639
- None
640
- """
641
- new_file = stack_dir / (file.stem + '.npy')
642
- if not new_file.exists():
643
- stack_dir.mkdir(exist_ok=True)
644
- channels = []
645
- for i, chan_dir in enumerate(chan_dirs):
646
- img_path = str(chan_dir / file.name)
647
- img = cv2.imread(img_path, -1)
648
- if img is None:
649
- print(f"Warning: Failed to read image {img_path}")
650
- continue
651
- chan = np.expand_dims(img, axis=2)
652
- channels.append(chan)
653
- del img # Explicitly delete the reference to the image to free up memory
654
- if i % 10 == 0: # Periodically suggest garbage collection
655
- gc.collect()
656
-
657
- if channels:
658
- stack = np.concatenate(channels, axis=2)
659
- np.save(new_file, stack)
660
- else:
661
- print(f"No valid channels to merge for file {file.name}")
662
-
663
529
  def _merge_file(chan_dirs, stack_dir, file_name):
664
530
  """
665
531
  Merge multiple channels into a single stack and save it as a numpy array, using os module for path handling.
@@ -1021,6 +887,223 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
1021
887
  print(f'All files concatenated and saved to:{channel_stack_loc}')
1022
888
  return channel_stack_loc
1023
889
 
890
+ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, batch_size=100, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
891
+ """
892
+ Concatenates and normalizes channel data from multiple files and saves the normalized data.
893
+
894
+ Args:
895
+ src (str): The source directory containing the channel data files.
896
+ channels (list): The list of channel indices to be concatenated and normalized.
897
+ randomize (bool, optional): Whether to randomize the order of the files. Defaults to True.
898
+ timelapse (bool, optional): Whether the channel data is from a timelapse experiment. Defaults to False.
899
+ batch_size (int, optional): The number of files to be processed in each batch. Defaults to 100.
900
+ backgrounds (list, optional): Background values for each channel. Defaults to [100, 100, 100].
901
+ remove_backgrounds (list, optional): Whether to remove background values for each channel. Defaults to [False, False, False].
902
+ lower_percentile (int, optional): Lower percentile value for normalization. Defaults to 2.
903
+ save_dtype (numpy.dtype, optional): Data type for saving the normalized stack. Defaults to np.float32.
904
+ signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5, 5, 5].
905
+ signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000, 1000, 1000].
906
+
907
+ Returns:
908
+ str: The directory path where the concatenated and normalized channel data is saved.
909
+ """
910
+ channels = [item for item in channels if item is not None]
911
+ paths = []
912
+ output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
913
+ os.makedirs(output_fldr, exist_ok=True)
914
+
915
+ if timelapse:
916
+ try:
917
+ time_stack_path_lists = _generate_time_lists(os.listdir(src))
918
+ for i, time_stack_list in enumerate(time_stack_path_lists):
919
+ stack_region = []
920
+ filenames_region = []
921
+ for idx, file in enumerate(time_stack_list):
922
+ path = os.path.join(src, file)
923
+ if idx == 0:
924
+ parts = file.split('_')
925
+ name = parts[0] + '_' + parts[1] + '_' + parts[2]
926
+ array = np.load(path)
927
+ array = np.take(array, channels, axis=2)
928
+ stack_region.append(array)
929
+ filenames_region.append(os.path.basename(path))
930
+ print(f'Region {i + 1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
931
+ stack = np.stack(stack_region)
932
+ normalized_stack = _normalize_stack(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
933
+ save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
934
+ np.savez(save_loc, data=normalized_stack, filenames=filenames_region)
935
+ print(save_loc)
936
+ del stack, normalized_stack
937
+ except Exception as e:
938
+ print(f"Error processing files, make sure filenames metadata is structured plate_well_field_time.npy")
939
+ print(f"Error: {e}")
940
+ else:
941
+ for file in os.listdir(src):
942
+ if file.endswith('.npy'):
943
+ path = os.path.join(src, file)
944
+ paths.append(path)
945
+ if randomize:
946
+ random.shuffle(paths)
947
+ nr_files = len(paths)
948
+ batch_index = 0
949
+ stack_ls = []
950
+ filenames_batch = []
951
+
952
+ for i, path in enumerate(paths):
953
+ array = np.load(path)
954
+ array = np.take(array, channels, axis=2)
955
+ stack_ls.append(array)
956
+ filenames_batch.append(os.path.basename(path))
957
+ print(f'Concatenated: {i + 1}/{nr_files} files')
958
+
959
+ if (i + 1) % batch_size == 0 or i + 1 == nr_files:
960
+ unique_shapes = {arr.shape[:-1] for arr in stack_ls}
961
+ if len(unique_shapes) > 1:
962
+ max_dims = np.max(np.array(list(unique_shapes)), axis=0)
963
+ print(f'Warning: arrays with multiple shapes found in batch {i + 1}. Padding arrays to max X,Y dimensions {max_dims}')
964
+ padded_stack_ls = []
965
+ for arr in stack_ls:
966
+ pad_width = [(0, max_dim - dim) for max_dim, dim in zip(max_dims, arr.shape[:-1])]
967
+ pad_width.append((0, 0))
968
+ padded_arr = np.pad(arr, pad_width)
969
+ padded_stack_ls.append(padded_arr)
970
+ stack = np.stack(padded_stack_ls)
971
+ else:
972
+ stack = np.stack(stack_ls)
973
+
974
+ normalized_stack = _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
975
+
976
+ save_loc = os.path.join(output_fldr, f'stack_{batch_index}_norm.npz')
977
+ np.savez(save_loc, data=normalized_stack, filenames=filenames_batch)
978
+ batch_index += 1
979
+ del stack, normalized_stack
980
+ stack_ls = []
981
+ filenames_batch = []
982
+ padded_stack_ls = []
983
+ print(f'All files concatenated and normalized. Saved to: {output_fldr}')
984
+ return output_fldr
985
+
986
+ def _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
987
+ """
988
+ Normalize the stack of images.
989
+
990
+ Args:
991
+ stack (numpy.ndarray): The stack of images to normalize.
992
+ backgrounds (list): Background values for each channel.
993
+ remove_backgrounds (list): Whether to remove background values for each channel.
994
+ lower_percentile (int): Lower percentile value for normalization.
995
+ save_dtype (numpy.dtype): Data type for saving the normalized stack.
996
+ signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
997
+ signal_thresholds (list): Signal thresholds for each channel.
998
+
999
+ Returns:
1000
+ numpy.ndarray: The normalized stack.
1001
+ """
1002
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1003
+
1004
+ for chan_index, channel in enumerate(range(stack.shape[-1])):
1005
+ single_channel = stack[:, :, :, channel]
1006
+ background = backgrounds[chan_index]
1007
+ signal_threshold = signal_thresholds[chan_index]
1008
+ remove_background = remove_backgrounds[chan_index]
1009
+
1010
+ print(f'Processing channel {chan_index}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
1011
+
1012
+ # Step 3: Remove background if required
1013
+ if remove_background:
1014
+ single_channel[single_channel < background] = 0
1015
+
1016
+ # Step 4: Calculate global lower percentile for the channel
1017
+ non_zero_single_channel = single_channel[single_channel != 0]
1018
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1019
+
1020
+ # Step 5: Calculate global upper percentile for the channel
1021
+ global_upper = None
1022
+ for upper_p in np.linspace(98, 99.5, num=16):
1023
+ upper_value = np.percentile(non_zero_single_channel, upper_p)
1024
+ if upper_value >= signal_threshold:
1025
+ global_upper = upper_value
1026
+ break
1027
+
1028
+ if global_upper is None:
1029
+ global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
1030
+
1031
+ print(f'Channel {chan_index}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
1032
+
1033
+ # Step 6: Normalize each array from global_lower to global_upper between 0 and 1
1034
+ for array_index in range(single_channel.shape[0]):
1035
+ arr_2d = single_channel[array_index, :, :]
1036
+ arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
1037
+ normalized_stack[array_index, :, :, channel] = arr_2d_normalized
1038
+
1039
+ return normalized_stack.astype(save_dtype)
1040
+
1041
+ def _normalize_img_batch_v1(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
1042
+ """
1043
+ Normalize the stack of images.
1044
+
1045
+ Args:
1046
+ stack (numpy.ndarray): The stack of images to normalize.
1047
+ backgrounds (list): Background values for each channel.
1048
+ remove_backgrounds (list): Whether to remove background values for each channel.
1049
+ lower_percentile (int): Lower percentile value for normalization.
1050
+ save_dtype (numpy.dtype): Data type for saving the normalized stack.
1051
+ signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
1052
+ signal_thresholds (list): Signal thresholds for each channel.
1053
+
1054
+ Returns:
1055
+ numpy.ndarray: The normalized stack.
1056
+ """
1057
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1058
+ time_ls = []
1059
+
1060
+ for chan_index, channel in enumerate(range(stack.shape[-1])):
1061
+ single_channel = stack[:, :, :, channel]
1062
+ background = backgrounds[chan_index]
1063
+ signal_threshold = signal_thresholds[chan_index]
1064
+ remove_background = remove_backgrounds[chan_index]
1065
+ signal_2_noise = signal_to_noise[chan_index]
1066
+ print(f'chan_index:{chan_index} background:{background} signal_threshold:{signal_threshold} remove_background:{remove_background} signal_2_noise:{signal_2_noise}')
1067
+
1068
+ if remove_background:
1069
+ single_channel[single_channel < background] = 0
1070
+
1071
+ non_zero_single_channel = single_channel[single_channel != 0]
1072
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1073
+ for upper_p in np.linspace(98, 99.5, num=20).tolist():
1074
+ global_upper = np.percentile(non_zero_single_channel, upper_p)
1075
+ if global_upper >= signal_threshold:
1076
+ break
1077
+
1078
+ arr_2d_normalized = np.zeros_like(single_channel, dtype=single_channel.dtype)
1079
+ signal_to_noise_ratio_ls = []
1080
+ for array_index in range(single_channel.shape[0]):
1081
+ start = time.time()
1082
+ arr_2d = single_channel[array_index, :, :]
1083
+ non_zero_arr_2d = arr_2d[arr_2d != 0]
1084
+ if non_zero_arr_2d.size > 0:
1085
+ lower, upper = np.percentile(non_zero_arr_2d, (lower_percentile, upper_p))
1086
+ signal_to_noise_ratio = upper / lower
1087
+ else:
1088
+ signal_to_noise_ratio = 0
1089
+ signal_to_noise_ratio_ls.append(signal_to_noise_ratio)
1090
+ average_stnr = np.mean(signal_to_noise_ratio_ls) if len(signal_to_noise_ratio_ls) > 0 else 0
1091
+
1092
+ if signal_to_noise_ratio > signal_2_noise:
1093
+ arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(0, 1))
1094
+ arr_2d_normalized[array_index, :, :] = arr_2d_rescaled
1095
+ else:
1096
+ arr_2d_normalized[array_index, :, :] = arr_2d
1097
+ stop = time.time()
1098
+ duration = (stop - start) * single_channel.shape[0]
1099
+ time_ls.append(duration)
1100
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1101
+ print(f'Progress: channels:{chan_index}/{stack.shape[-1] - 1}, arrays:{array_index + 1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1102
+
1103
+ normalized_stack[:, :, :, channel] = arr_2d_normalized
1104
+
1105
+ return normalized_stack.astype(save_dtype)
1106
+
1024
1107
  def _get_lists_for_normalization(settings):
1025
1108
  """
1026
1109
  Get lists for normalization based on the provided settings.
@@ -1035,7 +1118,8 @@ def _get_lists_for_normalization(settings):
1035
1118
  # Initialize the lists
1036
1119
  backgrounds = []
1037
1120
  signal_to_noise = []
1038
- signal_thresholds = []
1121
+ signal_thresholds = []
1122
+ remove_background = []
1039
1123
 
1040
1124
  # Iterate through the channels and append the corresponding values if the channel is not None
1041
1125
  for ch in settings['channels']:
@@ -1043,29 +1127,31 @@ def _get_lists_for_normalization(settings):
1043
1127
  backgrounds.append(settings['nucleus_background'])
1044
1128
  signal_to_noise.append(settings['nucleus_Signal_to_noise'])
1045
1129
  signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
1130
+ remove_background.append(settings['remove_background_nucleus'])
1046
1131
  elif ch == settings['cell_channel']:
1047
1132
  backgrounds.append(settings['cell_background'])
1048
1133
  signal_to_noise.append(settings['cell_Signal_to_noise'])
1049
1134
  signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
1135
+ remove_background.append(settings['remove_background_cell'])
1050
1136
  elif ch == settings['pathogen_channel']:
1051
1137
  backgrounds.append(settings['pathogen_background'])
1052
1138
  signal_to_noise.append(settings['pathogen_Signal_to_noise'])
1053
1139
  signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
1054
- return backgrounds, signal_to_noise, signal_thresholds
1140
+ remove_background.append(settings['remove_background_pathogen'])
1141
+ return backgrounds, signal_to_noise, signal_thresholds, remove_background
1055
1142
 
1056
- def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lower_quantile=0.01, save_dtype=np.float32, signal_to_noise=[5,5,5], signal_thresholds=[1000,1000,1000], correct_illumination=False):
1143
+ def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
1057
1144
  """
1058
1145
  Normalize the stack of images.
1059
1146
 
1060
1147
  Args:
1061
1148
  src (str): The source directory containing the stack of images.
1062
- backgrounds (list, optional): Background values for each channel. Defaults to [100,100,100].
1063
- remove_background (bool, optional): Whether to remove background values. Defaults to False.
1064
- lower_quantile (float, optional): Lower quantile value for normalization. Defaults to 0.01.
1149
+ backgrounds (list, optional): Background values for each channel. Defaults to [100, 100, 100].
1150
+ remove_background (list, optional): Whether to remove background values for each channel. Defaults to [False, False, False].
1151
+ lower_percentile (int, optional): Lower percentile value for normalization. Defaults to 2.
1065
1152
  save_dtype (numpy.dtype, optional): Data type for saving the normalized stack. Defaults to np.float32.
1066
- signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5,5,5].
1067
- signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000,1000,1000].
1068
- correct_illumination (bool, optional): Whether to correct illumination. Defaults to False.
1153
+ signal_to_noise (list, optional): Signal-to-noise ratio thresholds for each channel. Defaults to [5, 5, 5].
1154
+ signal_thresholds (list, optional): Signal thresholds for each channel. Defaults to [1000, 1000, 1000].
1069
1155
 
1070
1156
  Returns:
1071
1157
  None
@@ -1074,11 +1160,13 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
1074
1160
  output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
1075
1161
  os.makedirs(output_fldr, exist_ok=True)
1076
1162
  time_ls = []
1163
+
1077
1164
  for file_index, path in enumerate(paths):
1078
1165
  with np.load(path) as data:
1079
1166
  stack = data['data']
1080
1167
  filenames = data['filenames']
1081
- normalized_stack = np.zeros_like(stack, dtype=stack.dtype)
1168
+
1169
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1082
1170
  file = os.path.basename(path)
1083
1171
  name, _ = os.path.splitext(file)
1084
1172
 
@@ -1086,24 +1174,22 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
1086
1174
  single_channel = stack[:, :, :, channel]
1087
1175
  background = backgrounds[chan_index]
1088
1176
  signal_threshold = signal_thresholds[chan_index]
1089
- #print(f'signal_threshold:{signal_threshold} in {signal_thresholds} for {chan_index}')
1090
-
1177
+ remove_background = remove_backgrounds[chan_index]
1091
1178
  signal_2_noise = signal_to_noise[chan_index]
1179
+ print(f'chan_index:{chan_index} background:{background} signal_threshold:{signal_threshold} remove_background:{remove_background} signal_2_noise:{signal_2_noise}')
1180
+
1092
1181
  if remove_background:
1093
1182
  single_channel[single_channel < background] = 0
1094
- if correct_illumination:
1095
- bg = filters.gaussian(single_channel, sigma=50)
1096
- single_channel = single_channel - bg
1097
1183
 
1098
- #Calculate the global lower and upper quantiles for non-zero pixels
1184
+ # Calculate the global lower and upper percentiles for non-zero pixels
1099
1185
  non_zero_single_channel = single_channel[single_channel != 0]
1100
- global_lower = np.quantile(non_zero_single_channel, lower_quantile)
1101
- for upper_p in np.linspace(0.98, 1.0, num=100).tolist():
1102
- global_upper = np.quantile(non_zero_single_channel, upper_p)
1186
+ global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1187
+ for upper_p in np.linspace(98, 100, num=100).tolist():
1188
+ global_upper = np.percentile(non_zero_single_channel, upper_p)
1103
1189
  if global_upper >= signal_threshold:
1104
1190
  break
1105
1191
 
1106
- #Normalize the pixels in each image to the global quantiles and then dtype.
1192
+ # Normalize the pixels in each image to the global percentiles and then dtype.
1107
1193
  arr_2d_normalized = np.zeros_like(single_channel, dtype=single_channel.dtype)
1108
1194
  signal_to_noise_ratio_ls = []
1109
1195
  for array_index in range(single_channel.shape[0]):
@@ -1111,41 +1197,40 @@ def _normalize_stack(src, backgrounds=[100,100,100], remove_background=False, lo
1111
1197
  arr_2d = single_channel[array_index, :, :]
1112
1198
  non_zero_arr_2d = arr_2d[arr_2d != 0]
1113
1199
  if non_zero_arr_2d.size > 0:
1114
- lower, upper = np.quantile(non_zero_arr_2d, (lower_quantile, upper_p))
1115
- signal_to_noise_ratio = upper/lower
1200
+ lower, upper = np.percentile(non_zero_arr_2d, (lower_percentile, upper_p))
1201
+ signal_to_noise_ratio = upper / lower
1116
1202
  else:
1117
1203
  signal_to_noise_ratio = 0
1118
1204
  signal_to_noise_ratio_ls.append(signal_to_noise_ratio)
1119
1205
  average_stnr = np.mean(signal_to_noise_ratio_ls) if len(signal_to_noise_ratio_ls) > 0 else 0
1120
1206
 
1121
1207
  if signal_to_noise_ratio > signal_2_noise:
1122
- arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(global_lower, global_upper))
1208
+ arr_2d_rescaled = exposure.rescale_intensity(arr_2d, in_range=(lower, upper), out_range=(0, 1))
1123
1209
  arr_2d_normalized[array_index, :, :] = arr_2d_rescaled
1124
1210
  else:
1125
1211
  arr_2d_normalized[array_index, :, :] = arr_2d
1126
1212
  stop = time.time()
1127
- duration = (stop - start)*single_channel.shape[0]
1213
+ duration = (stop - start) * single_channel.shape[0]
1128
1214
  time_ls.append(duration)
1129
1215
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1130
- #clear_output(wait=True)
1131
- print(f'Progress: files {file_index+1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1]-1}, arrays:{array_index+1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1132
- #print(f'Progress: files {file_index+1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1]-1}, arrays:{array_index+1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec', end='\r', flush=True)
1133
- normalized_single_channel = exposure.rescale_intensity(arr_2d_normalized, out_range='dtype')
1134
- normalized_stack[:, :, :, channel] = normalized_single_channel
1135
- save_loc = output_fldr+'/'+name+'_norm_stack.npz'
1136
- normalized_stack = normalized_stack.astype(save_dtype)
1137
- np.savez(save_loc, data=normalized_stack, filenames=filenames)
1138
- del normalized_stack, single_channel, normalized_single_channel, stack, filenames
1216
+ print(f'Progress: files {file_index + 1}/{len(paths)}, channels:{chan_index}/{stack.shape[-1] - 1}, arrays:{array_index + 1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1217
+
1218
+ normalized_stack[:, :, :, channel] = arr_2d_normalized
1219
+
1220
+ save_loc = os.path.join(output_fldr, f'{name}_norm_stack.npz')
1221
+ np.savez(save_loc, data=normalized_stack.astype(save_dtype), filenames=filenames)
1222
+ del normalized_stack, single_channel, arr_2d_normalized, stack, filenames
1139
1223
  gc.collect()
1140
- return print(f'Saved stacks:{output_fldr}')
1224
+
1225
+ return print(f'Saved stacks: {output_fldr}')
1141
1226
 
1142
- def _normalize_timelapse(src, lower_quantile=0.01, save_dtype=np.float32):
1227
+ def _normalize_timelapse(src, lower_percentile=2, save_dtype=np.float32):
1143
1228
  """
1144
1229
  Normalize the timelapse data by rescaling the intensity values based on percentiles.
1145
1230
 
1146
1231
  Args:
1147
1232
  src (str): The source directory containing the timelapse data files.
1148
- lower_quantile (float, optional): The lower quantile used to calculate the intensity range. Defaults to 0.01.
1233
+ lower_percentile (int, optional): The lower percentile used to calculate the intensity range. Defaults to 1.
1149
1234
  save_dtype (numpy.dtype, optional): The data type to save the normalized stack. Defaults to np.float32.
1150
1235
  """
1151
1236
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
@@ -1167,7 +1252,7 @@ def _normalize_timelapse(src, lower_quantile=0.01, save_dtype=np.float32):
1167
1252
  for array_index in range(single_channel.shape[0]):
1168
1253
  arr_2d = single_channel[array_index]
1169
1254
  # Calculate the 1% and 98% percentiles for this specific image
1170
- q_low = np.percentile(arr_2d[arr_2d != 0], 2)
1255
+ q_low = np.percentile(arr_2d[arr_2d != 0], lower_percentile)
1171
1256
  q_high = np.percentile(arr_2d[arr_2d != 0], 98)
1172
1257
 
1173
1258
  # Rescale intensity based on the calculated percentiles to fill the dtype range
@@ -1261,7 +1346,7 @@ def delete_empty_subdirectories(folder_path):
1261
1346
  def preprocess_img_data(settings):
1262
1347
 
1263
1348
  from .plot import plot_arrays, _plot_4D_arrays
1264
- from .utils import _run_test_mode
1349
+ from .utils import _run_test_mode, _get_regex, set_default_settings_preprocess_img_data
1265
1350
 
1266
1351
  """
1267
1352
  Preprocesses image data by converting z-stack images to maximum intensity projection (MIP) images.
@@ -1280,9 +1365,8 @@ def preprocess_img_data(settings):
1280
1365
  timelapse (bool, optional): Whether the images are from a timelapse experiment. Defaults to False.
1281
1366
  remove_background (bool, optional): Whether to remove the background from the images. Defaults to False.
1282
1367
  backgrounds (int, optional): The number of background images to use for background removal. Defaults to 100.
1283
- lower_quantile (float, optional): The lower quantile used for background removal. Defaults to 0.01.
1368
+ lower_percentile (float, optional): The lower percentile used for background removal. Defaults to 1.
1284
1369
  save_dtype (type, optional): The data type used for saving the preprocessed images. Defaults to np.float32.
1285
- correct_illumination (bool, optional): Whether to correct the illumination of the images. Defaults to False.
1286
1370
  randomize (bool, optional): Whether to randomize the order of the images. Defaults to True.
1287
1371
  all_to_mip (bool, optional): Whether to convert all images to MIP. Defaults to False.
1288
1372
  pick_slice (bool, optional): Whether to pick a specific slice based on the provided skip mode. Defaults to False.
@@ -1301,7 +1385,6 @@ def preprocess_img_data(settings):
1301
1385
  most_common_extension = extension_counts.most_common(1)[0][0]
1302
1386
  img_format = None
1303
1387
 
1304
-
1305
1388
  delete_empty_subdirectories(src)
1306
1389
 
1307
1390
  # Check if the most common extension is one of the specified image formats
@@ -1318,47 +1401,15 @@ def preprocess_img_data(settings):
1318
1401
  print('Found existing norm_channel_stack folder. Skipping preprocessing')
1319
1402
  return settings, src
1320
1403
 
1321
- cmap = 'inferno'
1322
- figuresize = 20
1323
- normalize = True
1324
- save_dtype = 'uint16'
1325
- correct_illumination = False
1326
-
1327
- #mask_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
1328
- #backgrounds = [settings['nucleus_background'], settings['pathogen_background'], settings['cell_background']]
1329
1404
  mask_channels = [settings['nucleus_channel'], settings['cell_channel'], settings['pathogen_channel']]
1330
1405
  backgrounds = [settings['nucleus_background'], settings['cell_background'], settings['pathogen_background']]
1331
-
1332
- metadata_type = settings['metadata_type']
1333
- custom_regex = settings['custom_regex']
1334
- nr = settings['examples_to_plot']
1335
- plot = settings['plot']
1336
- batch_size = settings['batch_size']
1337
- timelapse = settings['timelapse']
1338
- remove_background = settings['remove_background']
1339
- lower_quantile = settings['lower_quantile']
1340
- randomize = settings['randomize']
1341
- all_to_mip = settings['all_to_mip']
1342
- pick_slice = settings['pick_slice']
1343
- skip_mode = settings['skip_mode']
1344
-
1345
- if not img_format == None:
1346
- if metadata_type == 'cellvoyager':
1347
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1348
- elif metadata_type == 'cq1':
1349
- regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1350
- elif metadata_type == 'nikon':
1351
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1352
- elif metadata_type == 'zeis':
1353
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1354
- elif metadata_type == 'leica':
1355
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
1356
- elif metadata_type == 'custom':
1357
- regex = f'({custom_regex}){img_format}'
1358
-
1359
- print(f'regex mode:{metadata_type} regex:{regex}')
1360
1406
 
1361
- if settings.get('test_mode', False):
1407
+ settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test = set_default_settings_preprocess_img_data(settings)
1408
+
1409
+ regex = _get_regex(metadata_type, img_format, custom_regex)
1410
+
1411
+ if test_mode:
1412
+
1362
1413
  print(f'Running spacr in test mode')
1363
1414
  settings['plot'] = True
1364
1415
  try:
@@ -1367,7 +1418,7 @@ def preprocess_img_data(settings):
1367
1418
  except OSError as e:
1368
1419
  pass
1369
1420
 
1370
- src = _run_test_mode(settings['src'], regex, timelapse=timelapse)
1421
+ src = _run_test_mode(settings['src'], regex, timelapse, test_images, random_test)
1371
1422
  settings['src'] = src
1372
1423
 
1373
1424
  if img_format == None:
@@ -1412,31 +1463,20 @@ def preprocess_img_data(settings):
1412
1463
  except Exception as e:
1413
1464
  print(f"Error: {e}")
1414
1465
 
1415
- print('concatinating cahnnels')
1416
- _concatenate_channel(src+'/stack',
1417
- channels=mask_channels,
1418
- randomize=randomize,
1419
- timelapse=timelapse,
1420
- batch_size=batch_size)
1421
-
1422
- if plot:
1423
- print(f'plotting {nr} images from {src}/channel_stack')
1424
- _plot_4D_arrays(src+'/channel_stack', figuresize, cmap, nr_npz=1, nr=nr)
1425
-
1426
- backgrounds, signal_to_noise, signal_thresholds = _get_lists_for_normalization(settings=settings)
1427
-
1428
- if not timelapse:
1429
- _normalize_stack(src+'/channel_stack',
1430
- backgrounds=backgrounds,
1431
- lower_quantile=lower_quantile,
1432
- save_dtype=save_dtype,
1433
- signal_thresholds=signal_thresholds,
1434
- correct_illumination=correct_illumination,
1435
- signal_to_noise=signal_to_noise,
1436
- remove_background=remove_background)
1437
- else:
1438
- _normalize_timelapse(src+'/channel_stack', lower_quantile=lower_quantile, save_dtype=np.float32)
1439
-
1466
+ backgrounds, signal_to_noise, signal_thresholds, remove_backgrounds = _get_lists_for_normalization(settings=settings)
1467
+
1468
+ concatenate_and_normalize(src+'/stack',
1469
+ mask_channels,
1470
+ randomize,
1471
+ timelapse,
1472
+ batch_size,
1473
+ backgrounds,
1474
+ remove_backgrounds,
1475
+ lower_percentile,
1476
+ np.float32,
1477
+ signal_to_noise,
1478
+ signal_thresholds)
1479
+
1440
1480
  if plot:
1441
1481
  _plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
1442
1482
 
@@ -1490,27 +1530,6 @@ def _get_avg_object_size(masks):
1490
1530
  return sum(object_areas) / len(object_areas)
1491
1531
  else:
1492
1532
  return 0 # Return 0 if no objects are found
1493
-
1494
- def _save_figure_v1(fig, src, text, dpi=300, ):
1495
- """
1496
- Save a figure to a specified location.
1497
-
1498
- Parameters:
1499
- fig (matplotlib.figure.Figure): The figure to be saved.
1500
- src (str): The source file path.
1501
- text (str): The text to be included in the figure name.
1502
- dpi (int, optional): The resolution of the saved figure. Defaults to 300.
1503
- """
1504
- save_folder = os.path.dirname(src)
1505
- obj_type = os.path.basename(src)
1506
- name = os.path.basename(save_folder)
1507
- save_folder = os.path.join(save_folder, 'figure')
1508
- os.makedirs(save_folder, exist_ok=True)
1509
- fig_name = f'{obj_type}_{name}_{text}.pdf'
1510
- save_location = os.path.join(save_folder, fig_name)
1511
- fig.savefig(save_location, bbox_inches='tight', dpi=dpi)
1512
- print(f'Saved single cell figure: {save_location}')
1513
- plt.close()
1514
1533
 
1515
1534
  def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
1516
1535
  """
@@ -1616,56 +1635,6 @@ def _save_settings_to_db(settings):
1616
1635
  settings_df.to_sql('settings', conn, if_exists='replace', index=False) # Replace the table if it already exists
1617
1636
  conn.close()
1618
1637
 
1619
- def _save_mask_timelapse_as_gif_v1(masks, path, cmap, norm, filenames):
1620
- """
1621
- Save a timelapse of masks as a GIF.
1622
-
1623
- Parameters:
1624
- masks (list): List of mask frames.
1625
- path (str): Path to save the GIF.
1626
- cmap: Colormap for displaying the masks.
1627
- norm: Normalization for the masks.
1628
- filenames (list): List of filenames corresponding to each mask frame.
1629
-
1630
- Returns:
1631
- None
1632
- """
1633
- def _update(frame):
1634
- """
1635
- Update the plot with the given frame.
1636
-
1637
- Parameters:
1638
- frame (int): The frame number to update the plot with.
1639
-
1640
- Returns:
1641
- None
1642
- """
1643
- nonlocal filename_text_obj
1644
- if filename_text_obj is not None:
1645
- filename_text_obj.remove()
1646
- ax.clear()
1647
- ax.axis('off')
1648
- current_mask = masks[frame]
1649
- ax.imshow(current_mask, cmap=cmap, norm=norm)
1650
- ax.set_title(f'Frame: {frame}', fontsize=24, color='white')
1651
- filename_text = filenames[frame]
1652
- filename_text_obj = fig.text(0.5, 0.01, filename_text, ha='center', va='center', fontsize=20, color='white')
1653
- for label_value in np.unique(current_mask):
1654
- if label_value == 0: continue # Skip background
1655
- y, x = np.mean(np.where(current_mask == label_value), axis=1)
1656
- ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
1657
-
1658
- fig, ax = plt.subplots(figsize=(50, 50), facecolor='black')
1659
- ax.set_facecolor('black')
1660
- ax.axis('off')
1661
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
1662
-
1663
- filename_text_obj = None
1664
- anim = FuncAnimation(fig, _update, frames=len(masks), blit=False)
1665
- anim.save(path, writer='pillow', fps=2, dpi=80) # Adjust DPI for size/quality
1666
- plt.close(fig)
1667
- print(f'Saved timelapse to {path}')
1668
-
1669
1638
  def _save_mask_timelapse_as_gif(masks, tracks_df, path, cmap, norm, filenames):
1670
1639
  """
1671
1640
  Save a timelapse animation of masks as a GIF.
@@ -2409,10 +2378,9 @@ def convert_numpy_to_tiff(folder_path, limit=None):
2409
2378
  return
2410
2379
 
2411
2380
  def generate_cellpose_train_test(src, test_split=0.1):
2412
-
2413
2381
  mask_src = os.path.join(src, 'masks')
2414
2382
  img_paths = glob.glob(os.path.join(src, '*.tif'))
2415
- img_filenames = [os.path.basename(file) for file in img_paths + img_paths]
2383
+ img_filenames = [os.path.basename(file) for file in img_paths]
2416
2384
  img_filenames = [file for file in img_filenames if os.path.exists(os.path.join(mask_src, file))]
2417
2385
  print(f'Found {len(img_filenames)} images with masks')
2418
2386
 
@@ -2424,19 +2392,21 @@ def generate_cellpose_train_test(src, test_split=0.1):
2424
2392
  print(f'Split dataset into Train {len(train_files)} and Test {len(test_files)} files')
2425
2393
 
2426
2394
  train_dir = os.path.join(os.path.dirname(src), 'train')
2427
- train_dir_masks = os.path.join(train_dir, 'mask')
2395
+ train_dir_masks = os.path.join(train_dir, 'masks')
2428
2396
  test_dir = os.path.join(os.path.dirname(src), 'test')
2429
- test_dir_masks = os.path.join(test_dir, 'mask')
2397
+ test_dir_masks = os.path.join(test_dir, 'masks')
2430
2398
 
2399
+ os.makedirs(train_dir, exist_ok=True)
2431
2400
  os.makedirs(train_dir_masks, exist_ok=True)
2401
+ os.makedirs(test_dir, exist_ok=True)
2432
2402
  os.makedirs(test_dir_masks, exist_ok=True)
2403
+
2433
2404
  for i, ls in enumerate(list_of_lists):
2434
-
2435
2405
  if i == 0:
2436
2406
  dst = test_dir
2437
2407
  dst_mask = test_dir_masks
2438
2408
  _type = 'Test'
2439
- if i == 1:
2409
+ else:
2440
2410
  dst = train_dir
2441
2411
  dst_mask = train_dir_masks
2442
2412
  _type = 'Train'
@@ -2449,130 +2419,4 @@ def generate_cellpose_train_test(src, test_split=0.1):
2449
2419
  shutil.copy(img_path, new_img_path)
2450
2420
  shutil.copy(mask_path, new_mask_path)
2451
2421
  print(f'Copied {idx+1}/{len(ls)} images to {_type} set', end='\r', flush=True)
2452
-
2453
-
2454
-
2455
-
2456
-
2457
-
2458
-
2459
-
2460
-
2461
-
2462
-
2463
-
2464
-
2465
-
2466
-
2467
-
2468
-
2469
-
2470
-
2471
-
2472
-
2473
-
2474
-
2475
-
2476
-
2477
-
2478
-
2479
-
2480
-
2481
-
2482
-
2483
-
2484
-
2485
-
2486
-
2487
-
2488
-
2489
-
2490
-
2491
-
2492
-
2493
-
2494
-
2495
-
2496
-
2497
-
2498
-
2499
-
2500
-
2501
-
2502
-
2503
-
2504
-
2505
-
2506
-
2507
-
2508
-
2509
-
2510
-
2511
-
2512
-
2513
-
2514
-
2515
-
2516
-
2517
-
2518
-
2519
-
2520
-
2521
-
2522
-
2523
-
2524
-
2525
-
2526
-
2527
-
2528
-
2529
-
2530
-
2531
-
2532
-
2533
-
2534
-
2535
-
2536
-
2537
-
2538
-
2539
-
2540
-
2541
-
2542
-
2543
-
2544
-
2545
-
2546
-
2547
-
2548
-
2549
-
2550
-
2551
-
2552
-
2553
-
2554
-
2555
-
2556
-
2557
-
2558
-
2559
-
2560
-
2561
-
2562
-
2563
-
2564
-
2565
-
2566
-
2567
-
2568
-
2569
-
2570
-
2571
-
2572
-
2573
-
2574
-
2575
-
2576
-
2577
-
2578
-
2422
+