spacr 0.0.81__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py CHANGED
@@ -21,6 +21,7 @@ from multiprocessing import Pool, cpu_count
21
21
  from torch.utils.data import Dataset
22
22
  import matplotlib.pyplot as plt
23
23
  from torchvision.transforms import ToTensor
24
+ import seaborn as sns
24
25
 
25
26
 
26
27
  from .logger import log_function_call
@@ -87,7 +88,7 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
87
88
  print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {labels[0].shape}, image type: labels[0].shape')
88
89
  return images, labels, image_names, label_names
89
90
 
90
- def _load_normalized_images_and_labels(image_files, label_files, channels=None, percentiles=None, circular=False, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10):
91
+ def _load_normalized_images_and_labels_v1(image_files, label_files, channels=None, percentiles=None, circular=False, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10):
91
92
 
92
93
  from .plot import normalize_and_visualize
93
94
  from .utils import invert_image, apply_mask
@@ -182,6 +183,115 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
182
183
 
183
184
  return normalized_images, labels, image_names, label_names
184
185
 
186
+ def _load_normalized_images_and_labels(image_files, label_files, channels=None, percentiles=None, circular=False, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10, target_height=None, target_width=None):
187
+
188
+ from .plot import normalize_and_visualize, plot_resize
189
+ from .utils import invert_image, apply_mask
190
+ from skimage.transform import resize as resizescikit
191
+
192
+ signal_thresholds = background * Signal_to_noise
193
+ lower_percentile = 2
194
+
195
+ images = []
196
+ labels = []
197
+ orig_dims = []
198
+
199
+ num_channels = 4
200
+ percentiles_1 = [[] for _ in range(num_channels)]
201
+ percentiles_99 = [[] for _ in range(num_channels)]
202
+
203
+ image_names = [os.path.basename(f) for f in image_files]
204
+ image_dir = os.path.dirname(image_files[0])
205
+
206
+ if label_files is not None:
207
+ label_names = [os.path.basename(f) for f in label_files]
208
+ label_dir = os.path.dirname(label_files[0])
209
+
210
+ # Load, normalize, and resize images
211
+ for i, img_file in enumerate(image_files):
212
+ image = cellpose.io.imread(img_file)
213
+ orig_dims.append((image.shape[0], image.shape[1]))
214
+ if invert:
215
+ image = invert_image(image)
216
+ if circular:
217
+ image = apply_mask(image, output_value=0)
218
+
219
+ # If specific channels are specified, select them
220
+ if channels is not None and image.ndim == 3:
221
+ image = image[..., channels]
222
+
223
+ if remove_background:
224
+ image[image < background] = 0
225
+
226
+ if image.ndim < 3:
227
+ image = np.expand_dims(image, axis=-1)
228
+
229
+ if percentiles is None:
230
+ for c in range(image.shape[-1]):
231
+ p1 = np.percentile(image[..., c], lower_percentile)
232
+ percentiles_1[c].append(p1)
233
+ for percentile in [98, 99, 99.9, 99.99, 99.999]:
234
+ p = np.percentile(image[..., c], percentile)
235
+ if p > signal_thresholds:
236
+ percentiles_99[c].append(p)
237
+ break
238
+
239
+ # Resize image
240
+ if target_height is not None and target_width is not None:
241
+ if image.ndim == 2:
242
+ image_shape = (target_height, target_width)
243
+ elif image.ndim == 3:
244
+ image_shape = (target_height, target_width, image.shape[-1])
245
+
246
+ image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
247
+
248
+ images.append(image)
249
+
250
+ if percentiles is None:
251
+ # Calculate average percentiles for normalization
252
+ avg_p1 = [np.mean(p) for p in percentiles_1]
253
+ avg_p99 = [np.mean(p) if len(p) > 0 else np.mean(percentiles_1[i]) for i, p in enumerate(percentiles_99)]
254
+
255
+ print(f'Average 1st percentiles: {avg_p1}, Average 99th percentiles: {avg_p99}')
256
+
257
+ normalized_images = []
258
+ for image in images:
259
+ normalized_image = np.zeros_like(image, dtype=np.float32)
260
+ for c in range(image.shape[-1]):
261
+ normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
262
+ normalized_images.append(normalized_image)
263
+ if visualize:
264
+ normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
265
+ else:
266
+ normalized_images = []
267
+ for image in images:
268
+ normalized_image = np.zeros_like(image, dtype=np.float32)
269
+ for c in range(image.shape[-1]):
270
+ low_p = np.percentile(image[..., c], percentiles[0])
271
+ high_p = np.percentile(image[..., c], percentiles[1])
272
+ normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(low_p, high_p), out_range=(0, 1))
273
+ normalized_images.append(normalized_image)
274
+ if visualize:
275
+ normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
276
+
277
+ if label_files is not None:
278
+ for lbl_file in label_files:
279
+ label = cellpose.io.imread(lbl_file)
280
+ # Resize label
281
+ if target_height is not None and target_width is not None:
282
+ label = resizescikit(label, (target_height, target_width), order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
283
+ labels.append(label)
284
+ else:
285
+ label_names = []
286
+ label_dir = None
287
+
288
+ print(f'Loaded and normalized {len(normalized_images)} images and {len(labels)} labels from {image_dir} and {label_dir}')
289
+
290
+ if visualize and images and labels:
291
+ plot_resize(images, normalized_images, labels, labels)
292
+
293
+ return normalized_images, labels, image_names, label_names, orig_dims
294
+
185
295
  class CombineLoaders:
186
296
 
187
297
  """
@@ -203,14 +313,14 @@ class CombineLoaders:
203
313
 
204
314
  """
205
315
 
206
- def _init__(self, train_loaders):
316
+ def __init__(self, train_loaders):
207
317
  self.train_loaders = train_loaders
208
318
  self.loader_iters = [iter(loader) for loader in train_loaders]
209
319
 
210
- def _iter__(self):
320
+ def __iter__(self):
211
321
  return self
212
322
 
213
- def _next__(self):
323
+ def __next__(self):
214
324
  while self.loader_iters:
215
325
  random.shuffle(self.loader_iters) # Shuffle the loader_iters list
216
326
  for i, loader_iter in enumerate(self.loader_iters):
@@ -233,7 +343,7 @@ class CombinedDataset(Dataset):
233
343
  shuffle (bool, optional): Whether to shuffle the combined dataset. Defaults to True.
234
344
  """
235
345
 
236
- def _init__(self, datasets, shuffle=True):
346
+ def __init__(self, datasets, shuffle=True):
237
347
  self.datasets = datasets
238
348
  self.lengths = [len(dataset) for dataset in datasets]
239
349
  self.total_length = sum(self.lengths)
@@ -243,14 +353,14 @@ class CombinedDataset(Dataset):
243
353
  random.shuffle(self.indices)
244
354
  else:
245
355
  self.indices = None
246
- def _getitem__(self, index):
356
+ def __getitem__(self, index):
247
357
  if self.shuffle:
248
358
  index = self.indices[index]
249
359
  for dataset, length in zip(self.datasets, self.lengths):
250
360
  if index < length:
251
361
  return dataset[index]
252
362
  index -= length
253
- def _len__(self):
363
+ def __len__(self):
254
364
  return self.total_length
255
365
 
256
366
  class NoClassDataset(Dataset):
@@ -434,7 +544,7 @@ class NoClassDataset(Dataset):
434
544
 
435
545
 
436
546
  class TarImageDataset(Dataset):
437
- def _init__(self, tar_path, transform=None):
547
+ def __init__(self, tar_path, transform=None):
438
548
  self.tar_path = tar_path
439
549
  self.transform = transform
440
550
 
@@ -442,10 +552,10 @@ class TarImageDataset(Dataset):
442
552
  with tarfile.open(self.tar_path, 'r') as f:
443
553
  self.members = [m for m in f.getmembers() if m.isfile()]
444
554
 
445
- def _len__(self):
555
+ def __len__(self):
446
556
  return len(self.members)
447
557
 
448
- def _getitem__(self, idx):
558
+ def __getitem__(self, idx):
449
559
  with tarfile.open(self.tar_path, 'r') as f:
450
560
  m = self.members[idx]
451
561
  img_file = f.extractfile(m)
@@ -890,7 +1000,75 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
890
1000
  print(f'All files concatenated and saved to:{channel_stack_loc}')
891
1001
  return channel_stack_loc
892
1002
 
893
- def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, batch_size=100, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
1003
+ def _normalize_img_batch(stack, channels, save_dtype, settings):
1004
+
1005
+ """
1006
+ Normalize the stack of images.
1007
+
1008
+ Args:
1009
+ stack (numpy.ndarray): The stack of images to normalize.
1010
+ lower_percentile (int): Lower percentile value for normalization.
1011
+ save_dtype (numpy.dtype): Data type for saving the normalized stack.
1012
+ settings (dict): keword arguments
1013
+
1014
+ Returns:
1015
+ numpy.ndarray: The normalized stack.
1016
+ """
1017
+
1018
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1019
+
1020
+ #for channel in range(stack.shape[-1]):
1021
+ for channel in channels:
1022
+ if channel == settings['nucleus_channel']:
1023
+ background = settings['nucleus_background']
1024
+ signal_threshold = settings['nucleus_Signal_to_noise']*settings['nucleus_background']
1025
+ remove_background = settings['remove_background_nucleus']
1026
+
1027
+ if channel == settings['cell_channel']:
1028
+ background = settings['cell_background']
1029
+ signal_threshold = settings['cell_Signal_to_noise']*settings['cell_background']
1030
+ remove_background = settings['remove_background_cell']
1031
+
1032
+ if channel == settings['pathogen_channel']:
1033
+ background = settings['pathogen_background']
1034
+ signal_threshold = settings['pathogen_Signal_to_noise']*settings['pathogen_background']
1035
+ remove_background = settings['remove_background_pathogen']
1036
+
1037
+ single_channel = stack[:, :, :, channel]
1038
+
1039
+ print(f'Processing channel {channel}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
1040
+
1041
+ # Step 3: Remove background if required
1042
+ if remove_background:
1043
+ single_channel[single_channel < background] = 0
1044
+
1045
+ # Step 4: Calculate global lower percentile for the channel
1046
+ non_zero_single_channel = single_channel[single_channel != 0]
1047
+ global_lower = np.percentile(non_zero_single_channel, settings['lower_percentile'])
1048
+
1049
+ # Step 5: Calculate global upper percentile for the channel
1050
+ global_upper = None
1051
+ for upper_p in np.linspace(98, 99.5, num=16):
1052
+ upper_value = np.percentile(non_zero_single_channel, upper_p)
1053
+ if upper_value >= signal_threshold:
1054
+ global_upper = upper_value
1055
+ break
1056
+
1057
+ if global_upper is None:
1058
+ global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
1059
+
1060
+ print(f'Channel {channel}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
1061
+
1062
+ # Step 6: Normalize each array from global_lower to global_upper between 0 and 1
1063
+ for array_index in range(single_channel.shape[0]):
1064
+ arr_2d = single_channel[array_index, :, :]
1065
+ arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
1066
+ normalized_stack[array_index, :, :, channel] = arr_2d_normalized
1067
+
1068
+ return normalized_stack.astype(save_dtype)
1069
+
1070
+ def concatenate_and_normalize(src, channels, save_dtype=np.float32, settings={}):
1071
+
894
1072
  """
895
1073
  Concatenates and normalizes channel data from multiple files and saves the normalized data.
896
1074
 
@@ -910,12 +1088,14 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
910
1088
  Returns:
911
1089
  str: The directory path where the concatenated and normalized channel data is saved.
912
1090
  """
1091
+ # n c p
913
1092
  channels = [item for item in channels if item is not None]
1093
+
914
1094
  paths = []
915
1095
  output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
916
1096
  os.makedirs(output_fldr, exist_ok=True)
917
1097
 
918
- if timelapse:
1098
+ if settings['timelapse']:
919
1099
  try:
920
1100
  time_stack_path_lists = _generate_time_lists(os.listdir(src))
921
1101
  for i, time_stack_list in enumerate(time_stack_path_lists):
@@ -927,12 +1107,19 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
927
1107
  parts = file.split('_')
928
1108
  name = parts[0] + '_' + parts[1] + '_' + parts[2]
929
1109
  array = np.load(path)
930
- array = np.take(array, channels, axis=2)
1110
+ #array = np.take(array, channels, axis=2)
931
1111
  stack_region.append(array)
932
1112
  filenames_region.append(os.path.basename(path))
933
1113
  print(f'Region {i + 1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
934
1114
  stack = np.stack(stack_region)
935
- normalized_stack = _normalize_stack(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
1115
+
1116
+ normalized_stack = _normalize_img_batch(stack=stack,
1117
+ channels=channels,
1118
+ save_dtype=save_dtype,
1119
+ settings=settings)
1120
+
1121
+ normalized_stack = normalized_stack[..., channels]
1122
+
936
1123
  save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
937
1124
  np.savez(save_loc, data=normalized_stack, filenames=filenames_region)
938
1125
  print(save_loc)
@@ -945,7 +1132,7 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
945
1132
  if file.endswith('.npy'):
946
1133
  path = os.path.join(src, file)
947
1134
  paths.append(path)
948
- if randomize:
1135
+ if settings['randomize']:
949
1136
  random.shuffle(paths)
950
1137
  nr_files = len(paths)
951
1138
  batch_index = 0
@@ -954,12 +1141,12 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
954
1141
 
955
1142
  for i, path in enumerate(paths):
956
1143
  array = np.load(path)
957
- array = np.take(array, channels, axis=2)
1144
+ #array = np.take(array, channels, axis=2)
958
1145
  stack_ls.append(array)
959
1146
  filenames_batch.append(os.path.basename(path))
960
1147
  print(f'Concatenated: {i + 1}/{nr_files} files')
961
1148
 
962
- if (i + 1) % batch_size == 0 or i + 1 == nr_files:
1149
+ if (i + 1) % settings['batch_size'] == 0 or i + 1 == nr_files:
963
1150
  unique_shapes = {arr.shape[:-1] for arr in stack_ls}
964
1151
  if len(unique_shapes) > 1:
965
1152
  max_dims = np.max(np.array(list(unique_shapes)), axis=0)
@@ -973,8 +1160,13 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
973
1160
  stack = np.stack(padded_stack_ls)
974
1161
  else:
975
1162
  stack = np.stack(stack_ls)
976
-
977
- normalized_stack = _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
1163
+
1164
+ normalized_stack = _normalize_img_batch(stack=stack,
1165
+ channels=channels,
1166
+ save_dtype=save_dtype,
1167
+ settings=settings)
1168
+
1169
+ normalized_stack = normalized_stack[..., channels]
978
1170
 
979
1171
  save_loc = os.path.join(output_fldr, f'stack_{batch_index}_norm.npz')
980
1172
  np.savez(save_loc, data=normalized_stack, filenames=filenames_batch)
@@ -983,64 +1175,10 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
983
1175
  stack_ls = []
984
1176
  filenames_batch = []
985
1177
  padded_stack_ls = []
1178
+
986
1179
  print(f'All files concatenated and normalized. Saved to: {output_fldr}')
987
1180
  return output_fldr
988
1181
 
989
- def _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
990
- """
991
- Normalize the stack of images.
992
-
993
- Args:
994
- stack (numpy.ndarray): The stack of images to normalize.
995
- backgrounds (list): Background values for each channel.
996
- remove_backgrounds (list): Whether to remove background values for each channel.
997
- lower_percentile (int): Lower percentile value for normalization.
998
- save_dtype (numpy.dtype): Data type for saving the normalized stack.
999
- signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
1000
- signal_thresholds (list): Signal thresholds for each channel.
1001
-
1002
- Returns:
1003
- numpy.ndarray: The normalized stack.
1004
- """
1005
- normalized_stack = np.zeros_like(stack, dtype=np.float32)
1006
-
1007
- for chan_index, channel in enumerate(range(stack.shape[-1])):
1008
- single_channel = stack[:, :, :, channel]
1009
- background = backgrounds[chan_index]
1010
- signal_threshold = signal_thresholds[chan_index]
1011
- remove_background = remove_backgrounds[chan_index]
1012
-
1013
- print(f'Processing channel {chan_index}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
1014
-
1015
- # Step 3: Remove background if required
1016
- if remove_background:
1017
- single_channel[single_channel < background] = 0
1018
-
1019
- # Step 4: Calculate global lower percentile for the channel
1020
- non_zero_single_channel = single_channel[single_channel != 0]
1021
- global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1022
-
1023
- # Step 5: Calculate global upper percentile for the channel
1024
- global_upper = None
1025
- for upper_p in np.linspace(98, 99.5, num=16):
1026
- upper_value = np.percentile(non_zero_single_channel, upper_p)
1027
- if upper_value >= signal_threshold:
1028
- global_upper = upper_value
1029
- break
1030
-
1031
- if global_upper is None:
1032
- global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
1033
-
1034
- print(f'Channel {chan_index}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
1035
-
1036
- # Step 6: Normalize each array from global_lower to global_upper between 0 and 1
1037
- for array_index in range(single_channel.shape[0]):
1038
- arr_2d = single_channel[array_index, :, :]
1039
- arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
1040
- normalized_stack[array_index, :, :, channel] = arr_2d_normalized
1041
-
1042
- return normalized_stack.astype(save_dtype)
1043
-
1044
1182
  def _get_lists_for_normalization(settings):
1045
1183
  """
1046
1184
  Get lists for normalization based on the provided settings.
@@ -1059,22 +1197,25 @@ def _get_lists_for_normalization(settings):
1059
1197
  remove_background = []
1060
1198
 
1061
1199
  # Iterate through the channels and append the corresponding values if the channel is not None
1062
- for ch in settings['channels']:
1063
- if ch == settings['nucleus_channel']:
1064
- backgrounds.append(settings['nucleus_background'])
1065
- signal_to_noise.append(settings['nucleus_Signal_to_noise'])
1066
- signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
1067
- remove_background.append(settings['remove_background_nucleus'])
1068
- elif ch == settings['cell_channel']:
1069
- backgrounds.append(settings['cell_background'])
1070
- signal_to_noise.append(settings['cell_Signal_to_noise'])
1071
- signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
1072
- remove_background.append(settings['remove_background_cell'])
1073
- elif ch == settings['pathogen_channel']:
1074
- backgrounds.append(settings['pathogen_background'])
1075
- signal_to_noise.append(settings['pathogen_Signal_to_noise'])
1076
- signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
1077
- remove_background.append(settings['remove_background_pathogen'])
1200
+ # for ch in settings['channels']:
1201
+ for ch in [settings['nucleus_channel'], settings['cell_channel'], settings['pathogen_channel']]:
1202
+ if not ch is None:
1203
+ if ch == settings['nucleus_channel']:
1204
+ backgrounds.append(settings['nucleus_background'])
1205
+ signal_to_noise.append(settings['nucleus_Signal_to_noise'])
1206
+ signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
1207
+ remove_background.append(settings['remove_background_nucleus'])
1208
+ elif ch == settings['cell_channel']:
1209
+ backgrounds.append(settings['cell_background'])
1210
+ signal_to_noise.append(settings['cell_Signal_to_noise'])
1211
+ signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
1212
+ remove_background.append(settings['remove_background_cell'])
1213
+ elif ch == settings['pathogen_channel']:
1214
+ backgrounds.append(settings['pathogen_background'])
1215
+ signal_to_noise.append(settings['pathogen_Signal_to_noise'])
1216
+ signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
1217
+ remove_background.append(settings['remove_background_pathogen'])
1218
+
1078
1219
  return backgrounds, signal_to_noise, signal_thresholds, remove_background
1079
1220
 
1080
1221
  def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
@@ -1283,7 +1424,8 @@ def delete_empty_subdirectories(folder_path):
1283
1424
  def preprocess_img_data(settings):
1284
1425
 
1285
1426
  from .plot import plot_arrays, _plot_4D_arrays
1286
- from .utils import _run_test_mode, _get_regex, set_default_settings_preprocess_img_data
1427
+ from .utils import _run_test_mode, _get_regex
1428
+ from .settings import set_default_settings_preprocess_img_data
1287
1429
 
1288
1430
  """
1289
1431
  Preprocesses image data by converting z-stack images to maximum intensity projection (MIP) images.
@@ -1400,19 +1542,10 @@ def preprocess_img_data(settings):
1400
1542
  except Exception as e:
1401
1543
  print(f"Error: {e}")
1402
1544
 
1403
- backgrounds, signal_to_noise, signal_thresholds, remove_backgrounds = _get_lists_for_normalization(settings=settings)
1404
-
1405
- concatenate_and_normalize(src+'/stack',
1406
- mask_channels,
1407
- randomize,
1408
- timelapse,
1409
- batch_size,
1410
- backgrounds,
1411
- remove_backgrounds,
1412
- lower_percentile,
1413
- np.float32,
1414
- signal_to_noise,
1415
- signal_thresholds)
1545
+ concatenate_and_normalize(src=src+'/stack',
1546
+ channels=mask_channels,
1547
+ save_dtype=np.float32,
1548
+ settings=settings)
1416
1549
 
1417
1550
  if plot:
1418
1551
  _plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
@@ -1494,13 +1627,13 @@ def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
1494
1627
  del fig
1495
1628
  gc.collect()
1496
1629
 
1497
- def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'parasite', 'png_list']):
1630
+ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list']):
1498
1631
  """
1499
1632
  Reads and joins tables from a SQLite database.
1500
1633
 
1501
1634
  Args:
1502
1635
  db_path (str): The path to the SQLite database file.
1503
- table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', 'parasite', 'png_list'].
1636
+ table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list'].
1504
1637
 
1505
1638
  Returns:
1506
1639
  pandas.DataFrame: The joined DataFrame containing the data from the specified tables, or None if an error occurs.
@@ -1522,9 +1655,9 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
1522
1655
  join_cols = ['object_label', 'plate', 'row', 'col']
1523
1656
  dataframes['cell'] = pd.merge(dataframes['cell'], png_list_df, on=join_cols, how='left')
1524
1657
  else:
1525
- print("Cell table not found. Cannot join with png_list.")
1526
- return None
1527
- for entity in ['nucleus', 'pathogen', 'parasite']:
1658
+ print("Cell table not found in database tables.")
1659
+ return png_list_df
1660
+ for entity in ['nucleus', 'pathogen']:
1528
1661
  if entity in dataframes:
1529
1662
  numeric_cols = dataframes[entity].select_dtypes(include=[np.number]).columns.tolist()
1530
1663
  non_numeric_cols = dataframes[entity].select_dtypes(exclude=[np.number]).columns.tolist()
@@ -1537,14 +1670,11 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
1537
1670
  joined_df = None
1538
1671
  if 'cell' in dataframes:
1539
1672
  joined_df = dataframes['cell']
1540
- if 'cytoplasm' in dataframes:
1541
- joined_df = pd.merge(joined_df, dataframes['cytoplasm'], on=['object_label', 'prcf'], how='left', suffixes=('', '_cytoplasm'))
1542
- for entity in ['nucleus', 'pathogen']:
1543
- if entity in dataframes:
1544
- joined_df = pd.merge(joined_df, dataframes[entity], left_on=['object_label', 'prcf'], right_index=True, how='left', suffixes=('', f'_{entity}'))
1545
- else:
1546
- print("Cell table not found. Cannot proceed with joining.")
1547
- return None
1673
+ if 'cytoplasm' in dataframes:
1674
+ joined_df = pd.merge(joined_df, dataframes['cytoplasm'], on=['object_label', 'prcf'], how='left', suffixes=('', '_cytoplasm'))
1675
+ for entity in ['nucleus', 'pathogen']:
1676
+ if entity in dataframes:
1677
+ joined_df = pd.merge(joined_df, dataframes[entity], left_on=['object_label', 'prcf'], right_index=True, how='left', suffixes=('', f'_{entity}'))
1548
1678
  return joined_df
1549
1679
 
1550
1680
  def _save_settings_to_db(settings):
@@ -1993,8 +2123,75 @@ def _results_to_csv(src, df, df_well):
1993
2123
  ###################################################
1994
2124
  # Classify
1995
2125
  ###################################################
2126
+
2127
+ def read_plot_model_stats(file_path ,save=False):
1996
2128
 
1997
- def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94]):
2129
+ def _plot_and_save(train_df, val_df, column='accuracy', save=False, path=None, dpi=600):
2130
+
2131
+ pdf_path = os.path.join(path, f'{column}.pdf')
2132
+
2133
+ # Create subplots
2134
+ fig, axes = plt.subplots(1, 2, figsize=(20, 10), sharey=True)
2135
+
2136
+ # Plotting
2137
+ sns.lineplot(ax=axes[0], x='epoch', y=column, data=train_df, marker='o', color='red')
2138
+ sns.lineplot(ax=axes[1], x='epoch', y=column, data=val_df, marker='o', color='blue')
2139
+
2140
+ # Set titles and labels
2141
+ axes[0].set_title(f'Train {column} vs. Epoch', fontsize=20)
2142
+ axes[0].set_xlabel('Epoch', fontsize=16)
2143
+ axes[0].set_ylabel(column, fontsize=16)
2144
+ axes[0].tick_params(axis='both', which='major', labelsize=12)
2145
+
2146
+ axes[1].set_title(f'Validation {column} vs. Epoch', fontsize=20)
2147
+ axes[1].set_xlabel('Epoch', fontsize=16)
2148
+ axes[1].tick_params(axis='both', which='major', labelsize=12)
2149
+
2150
+ plt.tight_layout()
2151
+
2152
+ if save:
2153
+ plt.savefig(pdf_path, format='pdf', dpi=dpi)
2154
+ else:
2155
+ plt.show()
2156
+ # Read the CSV into a dataframe
2157
+ df = pd.read_csv(file_path, index_col=0)
2158
+
2159
+ # Split the dataframe into train and validation based on the index
2160
+ train_df = df.filter(like='_train', axis=0).copy()
2161
+ val_df = df.filter(like='_val', axis=0).copy()
2162
+
2163
+ fldr_1 = os.path.dirname(file_path)
2164
+
2165
+ train_csv_path = os.path.join(fldr_1, 'train.csv')
2166
+ val_csv_path = os.path.join(fldr_1, 'validation.csv')
2167
+
2168
+ fldr_2 = os.path.dirname(fldr_1)
2169
+ fldr_3 = os.path.dirname(fldr_2)
2170
+ bn_1 = os.path.basename(fldr_1)
2171
+ bn_2 = os.path.basename(fldr_2)
2172
+ bn_3 = os.path.basename(fldr_3)
2173
+ model_name = str(f'{bn_1}_{bn_2}_{bn_3}')
2174
+
2175
+ # Extract epochs from index
2176
+ train_df['epoch'] = [int(idx.split('_')[0]) for idx in train_df.index]
2177
+ val_df['epoch'] = [int(idx.split('_')[0]) for idx in val_df.index]
2178
+
2179
+ # Save dataframes to a CSV file
2180
+ train_df.to_csv(train_csv_path)
2181
+ val_df.to_csv(val_csv_path)
2182
+
2183
+ if save:
2184
+ # Setting the style
2185
+ sns.set(style="whitegrid")
2186
+
2187
+ _plot_and_save(train_df, val_df, column='accuracy', save=save, path=fldr_1)
2188
+ _plot_and_save(train_df, val_df, column='neg_accuracy', save=save, path=fldr_1)
2189
+ _plot_and_save(train_df, val_df, column='pos_accuracy', save=save, path=fldr_1)
2190
+ _plot_and_save(train_df, val_df, column='loss', save=save, path=fldr_1)
2191
+ _plot_and_save(train_df, val_df, column='prauc', save=save, path=fldr_1)
2192
+ _plot_and_save(train_df, val_df, column='optimal_threshold', save=save, path=fldr_1)
2193
+
2194
+ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=['r','g','b']):
1998
2195
  """
1999
2196
  Save the model based on certain conditions during training.
2000
2197
 
@@ -2007,35 +2204,25 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2007
2204
  epochs (int): The total number of epochs.
2008
2205
  intermedeate_save (list, optional): List of accuracy thresholds to trigger intermediate model saves.
2009
2206
  Defaults to [0.99, 0.98, 0.95, 0.94].
2207
+ channels (list, optional): List of channels used. Defaults to ['r', 'g', 'b'].
2010
2208
  """
2011
-
2012
- if epoch % 100 == 0:
2013
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
2014
-
2015
- if epoch == epochs:
2016
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
2017
-
2018
- if results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[0] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[0]:
2019
- percentile = str(intermedeate_save[0]*100)
2020
- print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
2021
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
2022
2209
 
2023
- elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[1] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[1]:
2024
- percentile = str(intermedeate_save[1]*100)
2025
- print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
2026
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
2210
+ channels_str = ''.join(channels)
2027
2211
 
2028
- elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[2] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[2]:
2029
- percentile = str(intermedeate_save[2]*100)
2212
+ def save_model_at_threshold(threshold, epoch, suffix=""):
2213
+ percentile = str(threshold * 100)
2030
2214
  print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
2031
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
2032
-
2033
- elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[3] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[3]:
2034
- percentile = str(intermedeate_save[3]*100)
2035
- print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
2036
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
2215
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth')
2216
+
2217
+ if epoch % 100 == 0 or epoch == epochs:
2218
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth')
2037
2219
 
2038
- def _save_progress(dst, results_df, train_metrics_df):
2220
+ for threshold in intermedeate_save:
2221
+ if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
2222
+ save_model_at_threshold(threshold, epoch)
2223
+ break # Ensure we only save for the highest matching threshold
2224
+
2225
+ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2039
2226
  """
2040
2227
  Save the progress of the classification model.
2041
2228
 
@@ -2054,11 +2241,14 @@ def _save_progress(dst, results_df, train_metrics_df):
2054
2241
  results_df.to_csv(results_path, index=True, header=True, mode='w')
2055
2242
  else:
2056
2243
  results_df.to_csv(results_path, index=True, header=False, mode='a')
2244
+
2057
2245
  training_metrics_path = os.path.join(dst, 'training_metrics.csv')
2058
2246
  if not os.path.exists(training_metrics_path):
2059
2247
  train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
2060
2248
  else:
2061
2249
  train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
2250
+ if epoch == epochs:
2251
+ read_plot_model_stats(results_path, save=True)
2062
2252
  return
2063
2253
 
2064
2254
  def _save_settings(settings, src):