spacr 0.0.82__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py CHANGED
@@ -21,6 +21,7 @@ from multiprocessing import Pool, cpu_count
21
21
  from torch.utils.data import Dataset
22
22
  import matplotlib.pyplot as plt
23
23
  from torchvision.transforms import ToTensor
24
+ import seaborn as sns
24
25
 
25
26
 
26
27
  from .logger import log_function_call
@@ -193,7 +194,8 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
193
194
 
194
195
  images = []
195
196
  labels = []
196
-
197
+ orig_dims = []
198
+
197
199
  num_channels = 4
198
200
  percentiles_1 = [[] for _ in range(num_channels)]
199
201
  percentiles_99 = [[] for _ in range(num_channels)]
@@ -204,10 +206,11 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
204
206
  if label_files is not None:
205
207
  label_names = [os.path.basename(f) for f in label_files]
206
208
  label_dir = os.path.dirname(label_files[0])
207
-
209
+
208
210
  # Load, normalize, and resize images
209
211
  for i, img_file in enumerate(image_files):
210
212
  image = cellpose.io.imread(img_file)
213
+ orig_dims.append((image.shape[0], image.shape[1]))
211
214
  if invert:
212
215
  image = invert_image(image)
213
216
  if circular:
@@ -287,7 +290,7 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
287
290
  if visualize and images and labels:
288
291
  plot_resize(images, normalized_images, labels, labels)
289
292
 
290
- return normalized_images, labels, image_names, label_names
293
+ return normalized_images, labels, image_names, label_names, orig_dims
291
294
 
292
295
  class CombineLoaders:
293
296
 
@@ -310,14 +313,14 @@ class CombineLoaders:
310
313
 
311
314
  """
312
315
 
313
- def _init__(self, train_loaders):
316
+ def __init__(self, train_loaders):
314
317
  self.train_loaders = train_loaders
315
318
  self.loader_iters = [iter(loader) for loader in train_loaders]
316
319
 
317
- def _iter__(self):
320
+ def __iter__(self):
318
321
  return self
319
322
 
320
- def _next__(self):
323
+ def __next__(self):
321
324
  while self.loader_iters:
322
325
  random.shuffle(self.loader_iters) # Shuffle the loader_iters list
323
326
  for i, loader_iter in enumerate(self.loader_iters):
@@ -340,7 +343,7 @@ class CombinedDataset(Dataset):
340
343
  shuffle (bool, optional): Whether to shuffle the combined dataset. Defaults to True.
341
344
  """
342
345
 
343
- def _init__(self, datasets, shuffle=True):
346
+ def __init__(self, datasets, shuffle=True):
344
347
  self.datasets = datasets
345
348
  self.lengths = [len(dataset) for dataset in datasets]
346
349
  self.total_length = sum(self.lengths)
@@ -350,14 +353,14 @@ class CombinedDataset(Dataset):
350
353
  random.shuffle(self.indices)
351
354
  else:
352
355
  self.indices = None
353
- def _getitem__(self, index):
356
+ def __getitem__(self, index):
354
357
  if self.shuffle:
355
358
  index = self.indices[index]
356
359
  for dataset, length in zip(self.datasets, self.lengths):
357
360
  if index < length:
358
361
  return dataset[index]
359
362
  index -= length
360
- def _len__(self):
363
+ def __len__(self):
361
364
  return self.total_length
362
365
 
363
366
  class NoClassDataset(Dataset):
@@ -541,7 +544,7 @@ class NoClassDataset(Dataset):
541
544
 
542
545
 
543
546
  class TarImageDataset(Dataset):
544
- def _init__(self, tar_path, transform=None):
547
+ def __init__(self, tar_path, transform=None):
545
548
  self.tar_path = tar_path
546
549
  self.transform = transform
547
550
 
@@ -549,10 +552,10 @@ class TarImageDataset(Dataset):
549
552
  with tarfile.open(self.tar_path, 'r') as f:
550
553
  self.members = [m for m in f.getmembers() if m.isfile()]
551
554
 
552
- def _len__(self):
555
+ def __len__(self):
553
556
  return len(self.members)
554
557
 
555
- def _getitem__(self, idx):
558
+ def __getitem__(self, idx):
556
559
  with tarfile.open(self.tar_path, 'r') as f:
557
560
  m = self.members[idx]
558
561
  img_file = f.extractfile(m)
@@ -997,7 +1000,75 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
997
1000
  print(f'All files concatenated and saved to:{channel_stack_loc}')
998
1001
  return channel_stack_loc
999
1002
 
1000
- def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, batch_size=100, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
1003
+ def _normalize_img_batch(stack, channels, save_dtype, settings):
1004
+
1005
+ """
1006
+ Normalize the stack of images.
1007
+
1008
+ Args:
1009
+ stack (numpy.ndarray): The stack of images to normalize.
1010
+ lower_percentile (int): Lower percentile value for normalization.
1011
+ save_dtype (numpy.dtype): Data type for saving the normalized stack.
1012
+ settings (dict): keword arguments
1013
+
1014
+ Returns:
1015
+ numpy.ndarray: The normalized stack.
1016
+ """
1017
+
1018
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1019
+
1020
+ #for channel in range(stack.shape[-1]):
1021
+ for channel in channels:
1022
+ if channel == settings['nucleus_channel']:
1023
+ background = settings['nucleus_background']
1024
+ signal_threshold = settings['nucleus_Signal_to_noise']*settings['nucleus_background']
1025
+ remove_background = settings['remove_background_nucleus']
1026
+
1027
+ if channel == settings['cell_channel']:
1028
+ background = settings['cell_background']
1029
+ signal_threshold = settings['cell_Signal_to_noise']*settings['cell_background']
1030
+ remove_background = settings['remove_background_cell']
1031
+
1032
+ if channel == settings['pathogen_channel']:
1033
+ background = settings['pathogen_background']
1034
+ signal_threshold = settings['pathogen_Signal_to_noise']*settings['pathogen_background']
1035
+ remove_background = settings['remove_background_pathogen']
1036
+
1037
+ single_channel = stack[:, :, :, channel]
1038
+
1039
+ print(f'Processing channel {channel}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
1040
+
1041
+ # Step 3: Remove background if required
1042
+ if remove_background:
1043
+ single_channel[single_channel < background] = 0
1044
+
1045
+ # Step 4: Calculate global lower percentile for the channel
1046
+ non_zero_single_channel = single_channel[single_channel != 0]
1047
+ global_lower = np.percentile(non_zero_single_channel, settings['lower_percentile'])
1048
+
1049
+ # Step 5: Calculate global upper percentile for the channel
1050
+ global_upper = None
1051
+ for upper_p in np.linspace(98, 99.5, num=16):
1052
+ upper_value = np.percentile(non_zero_single_channel, upper_p)
1053
+ if upper_value >= signal_threshold:
1054
+ global_upper = upper_value
1055
+ break
1056
+
1057
+ if global_upper is None:
1058
+ global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
1059
+
1060
+ print(f'Channel {channel}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
1061
+
1062
+ # Step 6: Normalize each array from global_lower to global_upper between 0 and 1
1063
+ for array_index in range(single_channel.shape[0]):
1064
+ arr_2d = single_channel[array_index, :, :]
1065
+ arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
1066
+ normalized_stack[array_index, :, :, channel] = arr_2d_normalized
1067
+
1068
+ return normalized_stack.astype(save_dtype)
1069
+
1070
+ def concatenate_and_normalize(src, channels, save_dtype=np.float32, settings={}):
1071
+
1001
1072
  """
1002
1073
  Concatenates and normalizes channel data from multiple files and saves the normalized data.
1003
1074
 
@@ -1017,12 +1088,14 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
1017
1088
  Returns:
1018
1089
  str: The directory path where the concatenated and normalized channel data is saved.
1019
1090
  """
1091
+ # n c p
1020
1092
  channels = [item for item in channels if item is not None]
1093
+
1021
1094
  paths = []
1022
1095
  output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
1023
1096
  os.makedirs(output_fldr, exist_ok=True)
1024
1097
 
1025
- if timelapse:
1098
+ if settings['timelapse']:
1026
1099
  try:
1027
1100
  time_stack_path_lists = _generate_time_lists(os.listdir(src))
1028
1101
  for i, time_stack_list in enumerate(time_stack_path_lists):
@@ -1034,12 +1107,19 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
1034
1107
  parts = file.split('_')
1035
1108
  name = parts[0] + '_' + parts[1] + '_' + parts[2]
1036
1109
  array = np.load(path)
1037
- array = np.take(array, channels, axis=2)
1110
+ #array = np.take(array, channels, axis=2)
1038
1111
  stack_region.append(array)
1039
1112
  filenames_region.append(os.path.basename(path))
1040
1113
  print(f'Region {i + 1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
1041
1114
  stack = np.stack(stack_region)
1042
- normalized_stack = _normalize_stack(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
1115
+
1116
+ normalized_stack = _normalize_img_batch(stack=stack,
1117
+ channels=channels,
1118
+ save_dtype=save_dtype,
1119
+ settings=settings)
1120
+
1121
+ normalized_stack = normalized_stack[..., channels]
1122
+
1043
1123
  save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
1044
1124
  np.savez(save_loc, data=normalized_stack, filenames=filenames_region)
1045
1125
  print(save_loc)
@@ -1052,7 +1132,7 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
1052
1132
  if file.endswith('.npy'):
1053
1133
  path = os.path.join(src, file)
1054
1134
  paths.append(path)
1055
- if randomize:
1135
+ if settings['randomize']:
1056
1136
  random.shuffle(paths)
1057
1137
  nr_files = len(paths)
1058
1138
  batch_index = 0
@@ -1061,12 +1141,12 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
1061
1141
 
1062
1142
  for i, path in enumerate(paths):
1063
1143
  array = np.load(path)
1064
- array = np.take(array, channels, axis=2)
1144
+ #array = np.take(array, channels, axis=2)
1065
1145
  stack_ls.append(array)
1066
1146
  filenames_batch.append(os.path.basename(path))
1067
1147
  print(f'Concatenated: {i + 1}/{nr_files} files')
1068
1148
 
1069
- if (i + 1) % batch_size == 0 or i + 1 == nr_files:
1149
+ if (i + 1) % settings['batch_size'] == 0 or i + 1 == nr_files:
1070
1150
  unique_shapes = {arr.shape[:-1] for arr in stack_ls}
1071
1151
  if len(unique_shapes) > 1:
1072
1152
  max_dims = np.max(np.array(list(unique_shapes)), axis=0)
@@ -1080,8 +1160,13 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
1080
1160
  stack = np.stack(padded_stack_ls)
1081
1161
  else:
1082
1162
  stack = np.stack(stack_ls)
1083
-
1084
- normalized_stack = _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds)
1163
+
1164
+ normalized_stack = _normalize_img_batch(stack=stack,
1165
+ channels=channels,
1166
+ save_dtype=save_dtype,
1167
+ settings=settings)
1168
+
1169
+ normalized_stack = normalized_stack[..., channels]
1085
1170
 
1086
1171
  save_loc = os.path.join(output_fldr, f'stack_{batch_index}_norm.npz')
1087
1172
  np.savez(save_loc, data=normalized_stack, filenames=filenames_batch)
@@ -1090,64 +1175,10 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
1090
1175
  stack_ls = []
1091
1176
  filenames_batch = []
1092
1177
  padded_stack_ls = []
1178
+
1093
1179
  print(f'All files concatenated and normalized. Saved to: {output_fldr}')
1094
1180
  return output_fldr
1095
1181
 
1096
- def _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
1097
- """
1098
- Normalize the stack of images.
1099
-
1100
- Args:
1101
- stack (numpy.ndarray): The stack of images to normalize.
1102
- backgrounds (list): Background values for each channel.
1103
- remove_backgrounds (list): Whether to remove background values for each channel.
1104
- lower_percentile (int): Lower percentile value for normalization.
1105
- save_dtype (numpy.dtype): Data type for saving the normalized stack.
1106
- signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
1107
- signal_thresholds (list): Signal thresholds for each channel.
1108
-
1109
- Returns:
1110
- numpy.ndarray: The normalized stack.
1111
- """
1112
- normalized_stack = np.zeros_like(stack, dtype=np.float32)
1113
-
1114
- for chan_index, channel in enumerate(range(stack.shape[-1])):
1115
- single_channel = stack[:, :, :, channel]
1116
- background = backgrounds[chan_index]
1117
- signal_threshold = signal_thresholds[chan_index]
1118
- remove_background = remove_backgrounds[chan_index]
1119
-
1120
- print(f'Processing channel {chan_index}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
1121
-
1122
- # Step 3: Remove background if required
1123
- if remove_background:
1124
- single_channel[single_channel < background] = 0
1125
-
1126
- # Step 4: Calculate global lower percentile for the channel
1127
- non_zero_single_channel = single_channel[single_channel != 0]
1128
- global_lower = np.percentile(non_zero_single_channel, lower_percentile)
1129
-
1130
- # Step 5: Calculate global upper percentile for the channel
1131
- global_upper = None
1132
- for upper_p in np.linspace(98, 99.5, num=16):
1133
- upper_value = np.percentile(non_zero_single_channel, upper_p)
1134
- if upper_value >= signal_threshold:
1135
- global_upper = upper_value
1136
- break
1137
-
1138
- if global_upper is None:
1139
- global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
1140
-
1141
- print(f'Channel {chan_index}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
1142
-
1143
- # Step 6: Normalize each array from global_lower to global_upper between 0 and 1
1144
- for array_index in range(single_channel.shape[0]):
1145
- arr_2d = single_channel[array_index, :, :]
1146
- arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
1147
- normalized_stack[array_index, :, :, channel] = arr_2d_normalized
1148
-
1149
- return normalized_stack.astype(save_dtype)
1150
-
1151
1182
  def _get_lists_for_normalization(settings):
1152
1183
  """
1153
1184
  Get lists for normalization based on the provided settings.
@@ -1166,22 +1197,25 @@ def _get_lists_for_normalization(settings):
1166
1197
  remove_background = []
1167
1198
 
1168
1199
  # Iterate through the channels and append the corresponding values if the channel is not None
1169
- for ch in settings['channels']:
1170
- if ch == settings['nucleus_channel']:
1171
- backgrounds.append(settings['nucleus_background'])
1172
- signal_to_noise.append(settings['nucleus_Signal_to_noise'])
1173
- signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
1174
- remove_background.append(settings['remove_background_nucleus'])
1175
- elif ch == settings['cell_channel']:
1176
- backgrounds.append(settings['cell_background'])
1177
- signal_to_noise.append(settings['cell_Signal_to_noise'])
1178
- signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
1179
- remove_background.append(settings['remove_background_cell'])
1180
- elif ch == settings['pathogen_channel']:
1181
- backgrounds.append(settings['pathogen_background'])
1182
- signal_to_noise.append(settings['pathogen_Signal_to_noise'])
1183
- signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
1184
- remove_background.append(settings['remove_background_pathogen'])
1200
+ # for ch in settings['channels']:
1201
+ for ch in [settings['nucleus_channel'], settings['cell_channel'], settings['pathogen_channel']]:
1202
+ if not ch is None:
1203
+ if ch == settings['nucleus_channel']:
1204
+ backgrounds.append(settings['nucleus_background'])
1205
+ signal_to_noise.append(settings['nucleus_Signal_to_noise'])
1206
+ signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
1207
+ remove_background.append(settings['remove_background_nucleus'])
1208
+ elif ch == settings['cell_channel']:
1209
+ backgrounds.append(settings['cell_background'])
1210
+ signal_to_noise.append(settings['cell_Signal_to_noise'])
1211
+ signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
1212
+ remove_background.append(settings['remove_background_cell'])
1213
+ elif ch == settings['pathogen_channel']:
1214
+ backgrounds.append(settings['pathogen_background'])
1215
+ signal_to_noise.append(settings['pathogen_Signal_to_noise'])
1216
+ signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
1217
+ remove_background.append(settings['remove_background_pathogen'])
1218
+
1185
1219
  return backgrounds, signal_to_noise, signal_thresholds, remove_background
1186
1220
 
1187
1221
  def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
@@ -1390,7 +1424,8 @@ def delete_empty_subdirectories(folder_path):
1390
1424
  def preprocess_img_data(settings):
1391
1425
 
1392
1426
  from .plot import plot_arrays, _plot_4D_arrays
1393
- from .utils import _run_test_mode, _get_regex, set_default_settings_preprocess_img_data
1427
+ from .utils import _run_test_mode, _get_regex
1428
+ from .settings import set_default_settings_preprocess_img_data
1394
1429
 
1395
1430
  """
1396
1431
  Preprocesses image data by converting z-stack images to maximum intensity projection (MIP) images.
@@ -1507,19 +1542,10 @@ def preprocess_img_data(settings):
1507
1542
  except Exception as e:
1508
1543
  print(f"Error: {e}")
1509
1544
 
1510
- backgrounds, signal_to_noise, signal_thresholds, remove_backgrounds = _get_lists_for_normalization(settings=settings)
1511
-
1512
- concatenate_and_normalize(src+'/stack',
1513
- mask_channels,
1514
- randomize,
1515
- timelapse,
1516
- batch_size,
1517
- backgrounds,
1518
- remove_backgrounds,
1519
- lower_percentile,
1520
- np.float32,
1521
- signal_to_noise,
1522
- signal_thresholds)
1545
+ concatenate_and_normalize(src=src+'/stack',
1546
+ channels=mask_channels,
1547
+ save_dtype=np.float32,
1548
+ settings=settings)
1523
1549
 
1524
1550
  if plot:
1525
1551
  _plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
@@ -1601,13 +1627,13 @@ def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
1601
1627
  del fig
1602
1628
  gc.collect()
1603
1629
 
1604
- def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'parasite', 'png_list']):
1630
+ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list']):
1605
1631
  """
1606
1632
  Reads and joins tables from a SQLite database.
1607
1633
 
1608
1634
  Args:
1609
1635
  db_path (str): The path to the SQLite database file.
1610
- table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', 'parasite', 'png_list'].
1636
+ table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list'].
1611
1637
 
1612
1638
  Returns:
1613
1639
  pandas.DataFrame: The joined DataFrame containing the data from the specified tables, or None if an error occurs.
@@ -1629,9 +1655,9 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
1629
1655
  join_cols = ['object_label', 'plate', 'row', 'col']
1630
1656
  dataframes['cell'] = pd.merge(dataframes['cell'], png_list_df, on=join_cols, how='left')
1631
1657
  else:
1632
- print("Cell table not found. Cannot join with png_list.")
1633
- return None
1634
- for entity in ['nucleus', 'pathogen', 'parasite']:
1658
+ print("Cell table not found in database tables.")
1659
+ return png_list_df
1660
+ for entity in ['nucleus', 'pathogen']:
1635
1661
  if entity in dataframes:
1636
1662
  numeric_cols = dataframes[entity].select_dtypes(include=[np.number]).columns.tolist()
1637
1663
  non_numeric_cols = dataframes[entity].select_dtypes(exclude=[np.number]).columns.tolist()
@@ -1644,14 +1670,11 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
1644
1670
  joined_df = None
1645
1671
  if 'cell' in dataframes:
1646
1672
  joined_df = dataframes['cell']
1647
- if 'cytoplasm' in dataframes:
1648
- joined_df = pd.merge(joined_df, dataframes['cytoplasm'], on=['object_label', 'prcf'], how='left', suffixes=('', '_cytoplasm'))
1649
- for entity in ['nucleus', 'pathogen']:
1650
- if entity in dataframes:
1651
- joined_df = pd.merge(joined_df, dataframes[entity], left_on=['object_label', 'prcf'], right_index=True, how='left', suffixes=('', f'_{entity}'))
1652
- else:
1653
- print("Cell table not found. Cannot proceed with joining.")
1654
- return None
1673
+ if 'cytoplasm' in dataframes:
1674
+ joined_df = pd.merge(joined_df, dataframes['cytoplasm'], on=['object_label', 'prcf'], how='left', suffixes=('', '_cytoplasm'))
1675
+ for entity in ['nucleus', 'pathogen']:
1676
+ if entity in dataframes:
1677
+ joined_df = pd.merge(joined_df, dataframes[entity], left_on=['object_label', 'prcf'], right_index=True, how='left', suffixes=('', f'_{entity}'))
1655
1678
  return joined_df
1656
1679
 
1657
1680
  def _save_settings_to_db(settings):
@@ -2100,8 +2123,75 @@ def _results_to_csv(src, df, df_well):
2100
2123
  ###################################################
2101
2124
  # Classify
2102
2125
  ###################################################
2126
+
2127
+ def read_plot_model_stats(file_path ,save=False):
2128
+
2129
+ def _plot_and_save(train_df, val_df, column='accuracy', save=False, path=None, dpi=600):
2130
+
2131
+ pdf_path = os.path.join(path, f'{column}.pdf')
2132
+
2133
+ # Create subplots
2134
+ fig, axes = plt.subplots(1, 2, figsize=(20, 10), sharey=True)
2135
+
2136
+ # Plotting
2137
+ sns.lineplot(ax=axes[0], x='epoch', y=column, data=train_df, marker='o', color='red')
2138
+ sns.lineplot(ax=axes[1], x='epoch', y=column, data=val_df, marker='o', color='blue')
2139
+
2140
+ # Set titles and labels
2141
+ axes[0].set_title(f'Train {column} vs. Epoch', fontsize=20)
2142
+ axes[0].set_xlabel('Epoch', fontsize=16)
2143
+ axes[0].set_ylabel(column, fontsize=16)
2144
+ axes[0].tick_params(axis='both', which='major', labelsize=12)
2145
+
2146
+ axes[1].set_title(f'Validation {column} vs. Epoch', fontsize=20)
2147
+ axes[1].set_xlabel('Epoch', fontsize=16)
2148
+ axes[1].tick_params(axis='both', which='major', labelsize=12)
2149
+
2150
+ plt.tight_layout()
2151
+
2152
+ if save:
2153
+ plt.savefig(pdf_path, format='pdf', dpi=dpi)
2154
+ else:
2155
+ plt.show()
2156
+ # Read the CSV into a dataframe
2157
+ df = pd.read_csv(file_path, index_col=0)
2158
+
2159
+ # Split the dataframe into train and validation based on the index
2160
+ train_df = df.filter(like='_train', axis=0).copy()
2161
+ val_df = df.filter(like='_val', axis=0).copy()
2162
+
2163
+ fldr_1 = os.path.dirname(file_path)
2103
2164
 
2104
- def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94]):
2165
+ train_csv_path = os.path.join(fldr_1, 'train.csv')
2166
+ val_csv_path = os.path.join(fldr_1, 'validation.csv')
2167
+
2168
+ fldr_2 = os.path.dirname(fldr_1)
2169
+ fldr_3 = os.path.dirname(fldr_2)
2170
+ bn_1 = os.path.basename(fldr_1)
2171
+ bn_2 = os.path.basename(fldr_2)
2172
+ bn_3 = os.path.basename(fldr_3)
2173
+ model_name = str(f'{bn_1}_{bn_2}_{bn_3}')
2174
+
2175
+ # Extract epochs from index
2176
+ train_df['epoch'] = [int(idx.split('_')[0]) for idx in train_df.index]
2177
+ val_df['epoch'] = [int(idx.split('_')[0]) for idx in val_df.index]
2178
+
2179
+ # Save dataframes to a CSV file
2180
+ train_df.to_csv(train_csv_path)
2181
+ val_df.to_csv(val_csv_path)
2182
+
2183
+ if save:
2184
+ # Setting the style
2185
+ sns.set(style="whitegrid")
2186
+
2187
+ _plot_and_save(train_df, val_df, column='accuracy', save=save, path=fldr_1)
2188
+ _plot_and_save(train_df, val_df, column='neg_accuracy', save=save, path=fldr_1)
2189
+ _plot_and_save(train_df, val_df, column='pos_accuracy', save=save, path=fldr_1)
2190
+ _plot_and_save(train_df, val_df, column='loss', save=save, path=fldr_1)
2191
+ _plot_and_save(train_df, val_df, column='prauc', save=save, path=fldr_1)
2192
+ _plot_and_save(train_df, val_df, column='optimal_threshold', save=save, path=fldr_1)
2193
+
2194
+ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=['r','g','b']):
2105
2195
  """
2106
2196
  Save the model based on certain conditions during training.
2107
2197
 
@@ -2114,35 +2204,25 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2114
2204
  epochs (int): The total number of epochs.
2115
2205
  intermedeate_save (list, optional): List of accuracy thresholds to trigger intermediate model saves.
2116
2206
  Defaults to [0.99, 0.98, 0.95, 0.94].
2207
+ channels (list, optional): List of channels used. Defaults to ['r', 'g', 'b'].
2117
2208
  """
2118
-
2119
- if epoch % 100 == 0:
2120
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
2121
-
2122
- if epoch == epochs:
2123
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
2124
-
2125
- if results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[0] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[0]:
2126
- percentile = str(intermedeate_save[0]*100)
2127
- print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
2128
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
2129
2209
 
2130
- elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[1] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[1]:
2131
- percentile = str(intermedeate_save[1]*100)
2132
- print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
2133
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
2134
-
2135
- elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[2] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[2]:
2136
- percentile = str(intermedeate_save[2]*100)
2137
- print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
2138
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
2210
+ channels_str = ''.join(channels)
2139
2211
 
2140
- elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[3] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[3]:
2141
- percentile = str(intermedeate_save[3]*100)
2212
+ def save_model_at_threshold(threshold, epoch, suffix=""):
2213
+ percentile = str(threshold * 100)
2142
2214
  print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
2143
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
2215
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth')
2144
2216
 
2145
- def _save_progress(dst, results_df, train_metrics_df):
2217
+ if epoch % 100 == 0 or epoch == epochs:
2218
+ torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth')
2219
+
2220
+ for threshold in intermedeate_save:
2221
+ if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
2222
+ save_model_at_threshold(threshold, epoch)
2223
+ break # Ensure we only save for the highest matching threshold
2224
+
2225
+ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2146
2226
  """
2147
2227
  Save the progress of the classification model.
2148
2228
 
@@ -2161,11 +2241,14 @@ def _save_progress(dst, results_df, train_metrics_df):
2161
2241
  results_df.to_csv(results_path, index=True, header=True, mode='w')
2162
2242
  else:
2163
2243
  results_df.to_csv(results_path, index=True, header=False, mode='a')
2244
+
2164
2245
  training_metrics_path = os.path.join(dst, 'training_metrics.csv')
2165
2246
  if not os.path.exists(training_metrics_path):
2166
2247
  train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
2167
2248
  else:
2168
2249
  train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
2250
+ if epoch == epochs:
2251
+ read_plot_model_stats(results_path, save=True)
2169
2252
  return
2170
2253
 
2171
2254
  def _save_settings(settings, src):
spacr/measure.py CHANGED
@@ -920,7 +920,8 @@ def measure_crop(settings):
920
920
  from .io import _save_settings_to_db
921
921
  from .timelapse import _timelapse_masks_to_gif, _scmovie
922
922
  from .plot import _save_scimg_plot
923
- from .utils import _list_endpoint_subdirectories, _generate_representative_images, get_measure_crop_settings, measure_test_mode
923
+ from .utils import _list_endpoint_subdirectories, _generate_representative_images, measure_test_mode
924
+ from .settings import get_measure_crop_settings
924
925
 
925
926
  settings = get_measure_crop_settings(settings)
926
927
  settings = measure_test_mode(settings)
@@ -0,0 +1,23 @@
1
+ Key,Value
2
+ img_src,/nas_mnt/carruthers/patrick/Plaque_assay_training/train
3
+ model_name,toxo_plaque
4
+ model_type,cyto
5
+ Signal_to_noise,10
6
+ background,200
7
+ remove_background,False
8
+ learning_rate,0.2
9
+ weight_decay,1e-05
10
+ batch_size,8
11
+ n_epochs,25000
12
+ from_scratch,False
13
+ diameter,30
14
+ resize,True
15
+ width_height,"[1120, 1120]"
16
+ verbose,True
17
+ channels,"[0, 0]"
18
+ normalize,True
19
+ percentiles,
20
+ circular,False
21
+ invert,False
22
+ grayscale,True
23
+ test,False