spacr 0.0.82__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +4 -0
- spacr/annotate_app.py +4 -0
- spacr/annotate_app_v2.py +511 -0
- spacr/core.py +254 -172
- spacr/deep_spacr.py +137 -50
- spacr/graph_learning.py +28 -8
- spacr/gui.py +5 -5
- spacr/gui_2.py +106 -36
- spacr/gui_classify_app.py +3 -3
- spacr/gui_mask_app.py +34 -11
- spacr/gui_measure_app.py +32 -17
- spacr/gui_utils.py +96 -29
- spacr/io.py +227 -144
- spacr/measure.py +2 -1
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +23 -0
- spacr/plot.py +102 -6
- spacr/sequencing.py +140 -91
- spacr/settings.py +477 -0
- spacr/timelapse.py +0 -3
- spacr/utils.py +312 -275
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/METADATA +1 -1
- spacr-0.1.1.dist-info/RECORD +40 -0
- spacr-0.0.82.dist-info/RECORD +0 -36
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/LICENSE +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/WHEEL +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/top_level.txt +0 -0
spacr/io.py
CHANGED
@@ -21,6 +21,7 @@ from multiprocessing import Pool, cpu_count
|
|
21
21
|
from torch.utils.data import Dataset
|
22
22
|
import matplotlib.pyplot as plt
|
23
23
|
from torchvision.transforms import ToTensor
|
24
|
+
import seaborn as sns
|
24
25
|
|
25
26
|
|
26
27
|
from .logger import log_function_call
|
@@ -193,7 +194,8 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
|
|
193
194
|
|
194
195
|
images = []
|
195
196
|
labels = []
|
196
|
-
|
197
|
+
orig_dims = []
|
198
|
+
|
197
199
|
num_channels = 4
|
198
200
|
percentiles_1 = [[] for _ in range(num_channels)]
|
199
201
|
percentiles_99 = [[] for _ in range(num_channels)]
|
@@ -204,10 +206,11 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
|
|
204
206
|
if label_files is not None:
|
205
207
|
label_names = [os.path.basename(f) for f in label_files]
|
206
208
|
label_dir = os.path.dirname(label_files[0])
|
207
|
-
|
209
|
+
|
208
210
|
# Load, normalize, and resize images
|
209
211
|
for i, img_file in enumerate(image_files):
|
210
212
|
image = cellpose.io.imread(img_file)
|
213
|
+
orig_dims.append((image.shape[0], image.shape[1]))
|
211
214
|
if invert:
|
212
215
|
image = invert_image(image)
|
213
216
|
if circular:
|
@@ -287,7 +290,7 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
|
|
287
290
|
if visualize and images and labels:
|
288
291
|
plot_resize(images, normalized_images, labels, labels)
|
289
292
|
|
290
|
-
return normalized_images, labels, image_names, label_names
|
293
|
+
return normalized_images, labels, image_names, label_names, orig_dims
|
291
294
|
|
292
295
|
class CombineLoaders:
|
293
296
|
|
@@ -310,14 +313,14 @@ class CombineLoaders:
|
|
310
313
|
|
311
314
|
"""
|
312
315
|
|
313
|
-
def
|
316
|
+
def __init__(self, train_loaders):
|
314
317
|
self.train_loaders = train_loaders
|
315
318
|
self.loader_iters = [iter(loader) for loader in train_loaders]
|
316
319
|
|
317
|
-
def
|
320
|
+
def __iter__(self):
|
318
321
|
return self
|
319
322
|
|
320
|
-
def
|
323
|
+
def __next__(self):
|
321
324
|
while self.loader_iters:
|
322
325
|
random.shuffle(self.loader_iters) # Shuffle the loader_iters list
|
323
326
|
for i, loader_iter in enumerate(self.loader_iters):
|
@@ -340,7 +343,7 @@ class CombinedDataset(Dataset):
|
|
340
343
|
shuffle (bool, optional): Whether to shuffle the combined dataset. Defaults to True.
|
341
344
|
"""
|
342
345
|
|
343
|
-
def
|
346
|
+
def __init__(self, datasets, shuffle=True):
|
344
347
|
self.datasets = datasets
|
345
348
|
self.lengths = [len(dataset) for dataset in datasets]
|
346
349
|
self.total_length = sum(self.lengths)
|
@@ -350,14 +353,14 @@ class CombinedDataset(Dataset):
|
|
350
353
|
random.shuffle(self.indices)
|
351
354
|
else:
|
352
355
|
self.indices = None
|
353
|
-
def
|
356
|
+
def __getitem__(self, index):
|
354
357
|
if self.shuffle:
|
355
358
|
index = self.indices[index]
|
356
359
|
for dataset, length in zip(self.datasets, self.lengths):
|
357
360
|
if index < length:
|
358
361
|
return dataset[index]
|
359
362
|
index -= length
|
360
|
-
def
|
363
|
+
def __len__(self):
|
361
364
|
return self.total_length
|
362
365
|
|
363
366
|
class NoClassDataset(Dataset):
|
@@ -541,7 +544,7 @@ class NoClassDataset(Dataset):
|
|
541
544
|
|
542
545
|
|
543
546
|
class TarImageDataset(Dataset):
|
544
|
-
def
|
547
|
+
def __init__(self, tar_path, transform=None):
|
545
548
|
self.tar_path = tar_path
|
546
549
|
self.transform = transform
|
547
550
|
|
@@ -549,10 +552,10 @@ class TarImageDataset(Dataset):
|
|
549
552
|
with tarfile.open(self.tar_path, 'r') as f:
|
550
553
|
self.members = [m for m in f.getmembers() if m.isfile()]
|
551
554
|
|
552
|
-
def
|
555
|
+
def __len__(self):
|
553
556
|
return len(self.members)
|
554
557
|
|
555
|
-
def
|
558
|
+
def __getitem__(self, idx):
|
556
559
|
with tarfile.open(self.tar_path, 'r') as f:
|
557
560
|
m = self.members[idx]
|
558
561
|
img_file = f.extractfile(m)
|
@@ -997,7 +1000,75 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
|
|
997
1000
|
print(f'All files concatenated and saved to:{channel_stack_loc}')
|
998
1001
|
return channel_stack_loc
|
999
1002
|
|
1000
|
-
def
|
1003
|
+
def _normalize_img_batch(stack, channels, save_dtype, settings):
|
1004
|
+
|
1005
|
+
"""
|
1006
|
+
Normalize the stack of images.
|
1007
|
+
|
1008
|
+
Args:
|
1009
|
+
stack (numpy.ndarray): The stack of images to normalize.
|
1010
|
+
lower_percentile (int): Lower percentile value for normalization.
|
1011
|
+
save_dtype (numpy.dtype): Data type for saving the normalized stack.
|
1012
|
+
settings (dict): keword arguments
|
1013
|
+
|
1014
|
+
Returns:
|
1015
|
+
numpy.ndarray: The normalized stack.
|
1016
|
+
"""
|
1017
|
+
|
1018
|
+
normalized_stack = np.zeros_like(stack, dtype=np.float32)
|
1019
|
+
|
1020
|
+
#for channel in range(stack.shape[-1]):
|
1021
|
+
for channel in channels:
|
1022
|
+
if channel == settings['nucleus_channel']:
|
1023
|
+
background = settings['nucleus_background']
|
1024
|
+
signal_threshold = settings['nucleus_Signal_to_noise']*settings['nucleus_background']
|
1025
|
+
remove_background = settings['remove_background_nucleus']
|
1026
|
+
|
1027
|
+
if channel == settings['cell_channel']:
|
1028
|
+
background = settings['cell_background']
|
1029
|
+
signal_threshold = settings['cell_Signal_to_noise']*settings['cell_background']
|
1030
|
+
remove_background = settings['remove_background_cell']
|
1031
|
+
|
1032
|
+
if channel == settings['pathogen_channel']:
|
1033
|
+
background = settings['pathogen_background']
|
1034
|
+
signal_threshold = settings['pathogen_Signal_to_noise']*settings['pathogen_background']
|
1035
|
+
remove_background = settings['remove_background_pathogen']
|
1036
|
+
|
1037
|
+
single_channel = stack[:, :, :, channel]
|
1038
|
+
|
1039
|
+
print(f'Processing channel {channel}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
|
1040
|
+
|
1041
|
+
# Step 3: Remove background if required
|
1042
|
+
if remove_background:
|
1043
|
+
single_channel[single_channel < background] = 0
|
1044
|
+
|
1045
|
+
# Step 4: Calculate global lower percentile for the channel
|
1046
|
+
non_zero_single_channel = single_channel[single_channel != 0]
|
1047
|
+
global_lower = np.percentile(non_zero_single_channel, settings['lower_percentile'])
|
1048
|
+
|
1049
|
+
# Step 5: Calculate global upper percentile for the channel
|
1050
|
+
global_upper = None
|
1051
|
+
for upper_p in np.linspace(98, 99.5, num=16):
|
1052
|
+
upper_value = np.percentile(non_zero_single_channel, upper_p)
|
1053
|
+
if upper_value >= signal_threshold:
|
1054
|
+
global_upper = upper_value
|
1055
|
+
break
|
1056
|
+
|
1057
|
+
if global_upper is None:
|
1058
|
+
global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
|
1059
|
+
|
1060
|
+
print(f'Channel {channel}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
|
1061
|
+
|
1062
|
+
# Step 6: Normalize each array from global_lower to global_upper between 0 and 1
|
1063
|
+
for array_index in range(single_channel.shape[0]):
|
1064
|
+
arr_2d = single_channel[array_index, :, :]
|
1065
|
+
arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
|
1066
|
+
normalized_stack[array_index, :, :, channel] = arr_2d_normalized
|
1067
|
+
|
1068
|
+
return normalized_stack.astype(save_dtype)
|
1069
|
+
|
1070
|
+
def concatenate_and_normalize(src, channels, save_dtype=np.float32, settings={}):
|
1071
|
+
|
1001
1072
|
"""
|
1002
1073
|
Concatenates and normalizes channel data from multiple files and saves the normalized data.
|
1003
1074
|
|
@@ -1017,12 +1088,14 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
1017
1088
|
Returns:
|
1018
1089
|
str: The directory path where the concatenated and normalized channel data is saved.
|
1019
1090
|
"""
|
1091
|
+
# n c p
|
1020
1092
|
channels = [item for item in channels if item is not None]
|
1093
|
+
|
1021
1094
|
paths = []
|
1022
1095
|
output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
|
1023
1096
|
os.makedirs(output_fldr, exist_ok=True)
|
1024
1097
|
|
1025
|
-
if timelapse:
|
1098
|
+
if settings['timelapse']:
|
1026
1099
|
try:
|
1027
1100
|
time_stack_path_lists = _generate_time_lists(os.listdir(src))
|
1028
1101
|
for i, time_stack_list in enumerate(time_stack_path_lists):
|
@@ -1034,12 +1107,19 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
1034
1107
|
parts = file.split('_')
|
1035
1108
|
name = parts[0] + '_' + parts[1] + '_' + parts[2]
|
1036
1109
|
array = np.load(path)
|
1037
|
-
array = np.take(array, channels, axis=2)
|
1110
|
+
#array = np.take(array, channels, axis=2)
|
1038
1111
|
stack_region.append(array)
|
1039
1112
|
filenames_region.append(os.path.basename(path))
|
1040
1113
|
print(f'Region {i + 1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
|
1041
1114
|
stack = np.stack(stack_region)
|
1042
|
-
|
1115
|
+
|
1116
|
+
normalized_stack = _normalize_img_batch(stack=stack,
|
1117
|
+
channels=channels,
|
1118
|
+
save_dtype=save_dtype,
|
1119
|
+
settings=settings)
|
1120
|
+
|
1121
|
+
normalized_stack = normalized_stack[..., channels]
|
1122
|
+
|
1043
1123
|
save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
|
1044
1124
|
np.savez(save_loc, data=normalized_stack, filenames=filenames_region)
|
1045
1125
|
print(save_loc)
|
@@ -1052,7 +1132,7 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
1052
1132
|
if file.endswith('.npy'):
|
1053
1133
|
path = os.path.join(src, file)
|
1054
1134
|
paths.append(path)
|
1055
|
-
if randomize:
|
1135
|
+
if settings['randomize']:
|
1056
1136
|
random.shuffle(paths)
|
1057
1137
|
nr_files = len(paths)
|
1058
1138
|
batch_index = 0
|
@@ -1061,12 +1141,12 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
1061
1141
|
|
1062
1142
|
for i, path in enumerate(paths):
|
1063
1143
|
array = np.load(path)
|
1064
|
-
array = np.take(array, channels, axis=2)
|
1144
|
+
#array = np.take(array, channels, axis=2)
|
1065
1145
|
stack_ls.append(array)
|
1066
1146
|
filenames_batch.append(os.path.basename(path))
|
1067
1147
|
print(f'Concatenated: {i + 1}/{nr_files} files')
|
1068
1148
|
|
1069
|
-
if (i + 1) % batch_size == 0 or i + 1 == nr_files:
|
1149
|
+
if (i + 1) % settings['batch_size'] == 0 or i + 1 == nr_files:
|
1070
1150
|
unique_shapes = {arr.shape[:-1] for arr in stack_ls}
|
1071
1151
|
if len(unique_shapes) > 1:
|
1072
1152
|
max_dims = np.max(np.array(list(unique_shapes)), axis=0)
|
@@ -1080,8 +1160,13 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
1080
1160
|
stack = np.stack(padded_stack_ls)
|
1081
1161
|
else:
|
1082
1162
|
stack = np.stack(stack_ls)
|
1083
|
-
|
1084
|
-
normalized_stack = _normalize_img_batch(stack,
|
1163
|
+
|
1164
|
+
normalized_stack = _normalize_img_batch(stack=stack,
|
1165
|
+
channels=channels,
|
1166
|
+
save_dtype=save_dtype,
|
1167
|
+
settings=settings)
|
1168
|
+
|
1169
|
+
normalized_stack = normalized_stack[..., channels]
|
1085
1170
|
|
1086
1171
|
save_loc = os.path.join(output_fldr, f'stack_{batch_index}_norm.npz')
|
1087
1172
|
np.savez(save_loc, data=normalized_stack, filenames=filenames_batch)
|
@@ -1090,64 +1175,10 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
1090
1175
|
stack_ls = []
|
1091
1176
|
filenames_batch = []
|
1092
1177
|
padded_stack_ls = []
|
1178
|
+
|
1093
1179
|
print(f'All files concatenated and normalized. Saved to: {output_fldr}')
|
1094
1180
|
return output_fldr
|
1095
1181
|
|
1096
|
-
def _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
|
1097
|
-
"""
|
1098
|
-
Normalize the stack of images.
|
1099
|
-
|
1100
|
-
Args:
|
1101
|
-
stack (numpy.ndarray): The stack of images to normalize.
|
1102
|
-
backgrounds (list): Background values for each channel.
|
1103
|
-
remove_backgrounds (list): Whether to remove background values for each channel.
|
1104
|
-
lower_percentile (int): Lower percentile value for normalization.
|
1105
|
-
save_dtype (numpy.dtype): Data type for saving the normalized stack.
|
1106
|
-
signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
|
1107
|
-
signal_thresholds (list): Signal thresholds for each channel.
|
1108
|
-
|
1109
|
-
Returns:
|
1110
|
-
numpy.ndarray: The normalized stack.
|
1111
|
-
"""
|
1112
|
-
normalized_stack = np.zeros_like(stack, dtype=np.float32)
|
1113
|
-
|
1114
|
-
for chan_index, channel in enumerate(range(stack.shape[-1])):
|
1115
|
-
single_channel = stack[:, :, :, channel]
|
1116
|
-
background = backgrounds[chan_index]
|
1117
|
-
signal_threshold = signal_thresholds[chan_index]
|
1118
|
-
remove_background = remove_backgrounds[chan_index]
|
1119
|
-
|
1120
|
-
print(f'Processing channel {chan_index}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
|
1121
|
-
|
1122
|
-
# Step 3: Remove background if required
|
1123
|
-
if remove_background:
|
1124
|
-
single_channel[single_channel < background] = 0
|
1125
|
-
|
1126
|
-
# Step 4: Calculate global lower percentile for the channel
|
1127
|
-
non_zero_single_channel = single_channel[single_channel != 0]
|
1128
|
-
global_lower = np.percentile(non_zero_single_channel, lower_percentile)
|
1129
|
-
|
1130
|
-
# Step 5: Calculate global upper percentile for the channel
|
1131
|
-
global_upper = None
|
1132
|
-
for upper_p in np.linspace(98, 99.5, num=16):
|
1133
|
-
upper_value = np.percentile(non_zero_single_channel, upper_p)
|
1134
|
-
if upper_value >= signal_threshold:
|
1135
|
-
global_upper = upper_value
|
1136
|
-
break
|
1137
|
-
|
1138
|
-
if global_upper is None:
|
1139
|
-
global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
|
1140
|
-
|
1141
|
-
print(f'Channel {chan_index}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
|
1142
|
-
|
1143
|
-
# Step 6: Normalize each array from global_lower to global_upper between 0 and 1
|
1144
|
-
for array_index in range(single_channel.shape[0]):
|
1145
|
-
arr_2d = single_channel[array_index, :, :]
|
1146
|
-
arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
|
1147
|
-
normalized_stack[array_index, :, :, channel] = arr_2d_normalized
|
1148
|
-
|
1149
|
-
return normalized_stack.astype(save_dtype)
|
1150
|
-
|
1151
1182
|
def _get_lists_for_normalization(settings):
|
1152
1183
|
"""
|
1153
1184
|
Get lists for normalization based on the provided settings.
|
@@ -1166,22 +1197,25 @@ def _get_lists_for_normalization(settings):
|
|
1166
1197
|
remove_background = []
|
1167
1198
|
|
1168
1199
|
# Iterate through the channels and append the corresponding values if the channel is not None
|
1169
|
-
for ch in settings['channels']:
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1200
|
+
# for ch in settings['channels']:
|
1201
|
+
for ch in [settings['nucleus_channel'], settings['cell_channel'], settings['pathogen_channel']]:
|
1202
|
+
if not ch is None:
|
1203
|
+
if ch == settings['nucleus_channel']:
|
1204
|
+
backgrounds.append(settings['nucleus_background'])
|
1205
|
+
signal_to_noise.append(settings['nucleus_Signal_to_noise'])
|
1206
|
+
signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
|
1207
|
+
remove_background.append(settings['remove_background_nucleus'])
|
1208
|
+
elif ch == settings['cell_channel']:
|
1209
|
+
backgrounds.append(settings['cell_background'])
|
1210
|
+
signal_to_noise.append(settings['cell_Signal_to_noise'])
|
1211
|
+
signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
|
1212
|
+
remove_background.append(settings['remove_background_cell'])
|
1213
|
+
elif ch == settings['pathogen_channel']:
|
1214
|
+
backgrounds.append(settings['pathogen_background'])
|
1215
|
+
signal_to_noise.append(settings['pathogen_Signal_to_noise'])
|
1216
|
+
signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
|
1217
|
+
remove_background.append(settings['remove_background_pathogen'])
|
1218
|
+
|
1185
1219
|
return backgrounds, signal_to_noise, signal_thresholds, remove_background
|
1186
1220
|
|
1187
1221
|
def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
|
@@ -1390,7 +1424,8 @@ def delete_empty_subdirectories(folder_path):
|
|
1390
1424
|
def preprocess_img_data(settings):
|
1391
1425
|
|
1392
1426
|
from .plot import plot_arrays, _plot_4D_arrays
|
1393
|
-
from .utils import _run_test_mode, _get_regex
|
1427
|
+
from .utils import _run_test_mode, _get_regex
|
1428
|
+
from .settings import set_default_settings_preprocess_img_data
|
1394
1429
|
|
1395
1430
|
"""
|
1396
1431
|
Preprocesses image data by converting z-stack images to maximum intensity projection (MIP) images.
|
@@ -1507,19 +1542,10 @@ def preprocess_img_data(settings):
|
|
1507
1542
|
except Exception as e:
|
1508
1543
|
print(f"Error: {e}")
|
1509
1544
|
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
randomize,
|
1515
|
-
timelapse,
|
1516
|
-
batch_size,
|
1517
|
-
backgrounds,
|
1518
|
-
remove_backgrounds,
|
1519
|
-
lower_percentile,
|
1520
|
-
np.float32,
|
1521
|
-
signal_to_noise,
|
1522
|
-
signal_thresholds)
|
1545
|
+
concatenate_and_normalize(src=src+'/stack',
|
1546
|
+
channels=mask_channels,
|
1547
|
+
save_dtype=np.float32,
|
1548
|
+
settings=settings)
|
1523
1549
|
|
1524
1550
|
if plot:
|
1525
1551
|
_plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
|
@@ -1601,13 +1627,13 @@ def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
|
|
1601
1627
|
del fig
|
1602
1628
|
gc.collect()
|
1603
1629
|
|
1604
|
-
def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', '
|
1630
|
+
def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list']):
|
1605
1631
|
"""
|
1606
1632
|
Reads and joins tables from a SQLite database.
|
1607
1633
|
|
1608
1634
|
Args:
|
1609
1635
|
db_path (str): The path to the SQLite database file.
|
1610
|
-
table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', '
|
1636
|
+
table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list'].
|
1611
1637
|
|
1612
1638
|
Returns:
|
1613
1639
|
pandas.DataFrame: The joined DataFrame containing the data from the specified tables, or None if an error occurs.
|
@@ -1629,9 +1655,9 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
|
|
1629
1655
|
join_cols = ['object_label', 'plate', 'row', 'col']
|
1630
1656
|
dataframes['cell'] = pd.merge(dataframes['cell'], png_list_df, on=join_cols, how='left')
|
1631
1657
|
else:
|
1632
|
-
print("Cell table not found
|
1633
|
-
return
|
1634
|
-
for entity in ['nucleus', 'pathogen'
|
1658
|
+
print("Cell table not found in database tables.")
|
1659
|
+
return png_list_df
|
1660
|
+
for entity in ['nucleus', 'pathogen']:
|
1635
1661
|
if entity in dataframes:
|
1636
1662
|
numeric_cols = dataframes[entity].select_dtypes(include=[np.number]).columns.tolist()
|
1637
1663
|
non_numeric_cols = dataframes[entity].select_dtypes(exclude=[np.number]).columns.tolist()
|
@@ -1644,14 +1670,11 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
|
|
1644
1670
|
joined_df = None
|
1645
1671
|
if 'cell' in dataframes:
|
1646
1672
|
joined_df = dataframes['cell']
|
1647
|
-
|
1648
|
-
|
1649
|
-
|
1650
|
-
|
1651
|
-
|
1652
|
-
else:
|
1653
|
-
print("Cell table not found. Cannot proceed with joining.")
|
1654
|
-
return None
|
1673
|
+
if 'cytoplasm' in dataframes:
|
1674
|
+
joined_df = pd.merge(joined_df, dataframes['cytoplasm'], on=['object_label', 'prcf'], how='left', suffixes=('', '_cytoplasm'))
|
1675
|
+
for entity in ['nucleus', 'pathogen']:
|
1676
|
+
if entity in dataframes:
|
1677
|
+
joined_df = pd.merge(joined_df, dataframes[entity], left_on=['object_label', 'prcf'], right_index=True, how='left', suffixes=('', f'_{entity}'))
|
1655
1678
|
return joined_df
|
1656
1679
|
|
1657
1680
|
def _save_settings_to_db(settings):
|
@@ -2100,8 +2123,75 @@ def _results_to_csv(src, df, df_well):
|
|
2100
2123
|
###################################################
|
2101
2124
|
# Classify
|
2102
2125
|
###################################################
|
2126
|
+
|
2127
|
+
def read_plot_model_stats(file_path ,save=False):
|
2128
|
+
|
2129
|
+
def _plot_and_save(train_df, val_df, column='accuracy', save=False, path=None, dpi=600):
|
2130
|
+
|
2131
|
+
pdf_path = os.path.join(path, f'{column}.pdf')
|
2132
|
+
|
2133
|
+
# Create subplots
|
2134
|
+
fig, axes = plt.subplots(1, 2, figsize=(20, 10), sharey=True)
|
2135
|
+
|
2136
|
+
# Plotting
|
2137
|
+
sns.lineplot(ax=axes[0], x='epoch', y=column, data=train_df, marker='o', color='red')
|
2138
|
+
sns.lineplot(ax=axes[1], x='epoch', y=column, data=val_df, marker='o', color='blue')
|
2139
|
+
|
2140
|
+
# Set titles and labels
|
2141
|
+
axes[0].set_title(f'Train {column} vs. Epoch', fontsize=20)
|
2142
|
+
axes[0].set_xlabel('Epoch', fontsize=16)
|
2143
|
+
axes[0].set_ylabel(column, fontsize=16)
|
2144
|
+
axes[0].tick_params(axis='both', which='major', labelsize=12)
|
2145
|
+
|
2146
|
+
axes[1].set_title(f'Validation {column} vs. Epoch', fontsize=20)
|
2147
|
+
axes[1].set_xlabel('Epoch', fontsize=16)
|
2148
|
+
axes[1].tick_params(axis='both', which='major', labelsize=12)
|
2149
|
+
|
2150
|
+
plt.tight_layout()
|
2151
|
+
|
2152
|
+
if save:
|
2153
|
+
plt.savefig(pdf_path, format='pdf', dpi=dpi)
|
2154
|
+
else:
|
2155
|
+
plt.show()
|
2156
|
+
# Read the CSV into a dataframe
|
2157
|
+
df = pd.read_csv(file_path, index_col=0)
|
2158
|
+
|
2159
|
+
# Split the dataframe into train and validation based on the index
|
2160
|
+
train_df = df.filter(like='_train', axis=0).copy()
|
2161
|
+
val_df = df.filter(like='_val', axis=0).copy()
|
2162
|
+
|
2163
|
+
fldr_1 = os.path.dirname(file_path)
|
2103
2164
|
|
2104
|
-
|
2165
|
+
train_csv_path = os.path.join(fldr_1, 'train.csv')
|
2166
|
+
val_csv_path = os.path.join(fldr_1, 'validation.csv')
|
2167
|
+
|
2168
|
+
fldr_2 = os.path.dirname(fldr_1)
|
2169
|
+
fldr_3 = os.path.dirname(fldr_2)
|
2170
|
+
bn_1 = os.path.basename(fldr_1)
|
2171
|
+
bn_2 = os.path.basename(fldr_2)
|
2172
|
+
bn_3 = os.path.basename(fldr_3)
|
2173
|
+
model_name = str(f'{bn_1}_{bn_2}_{bn_3}')
|
2174
|
+
|
2175
|
+
# Extract epochs from index
|
2176
|
+
train_df['epoch'] = [int(idx.split('_')[0]) for idx in train_df.index]
|
2177
|
+
val_df['epoch'] = [int(idx.split('_')[0]) for idx in val_df.index]
|
2178
|
+
|
2179
|
+
# Save dataframes to a CSV file
|
2180
|
+
train_df.to_csv(train_csv_path)
|
2181
|
+
val_df.to_csv(val_csv_path)
|
2182
|
+
|
2183
|
+
if save:
|
2184
|
+
# Setting the style
|
2185
|
+
sns.set(style="whitegrid")
|
2186
|
+
|
2187
|
+
_plot_and_save(train_df, val_df, column='accuracy', save=save, path=fldr_1)
|
2188
|
+
_plot_and_save(train_df, val_df, column='neg_accuracy', save=save, path=fldr_1)
|
2189
|
+
_plot_and_save(train_df, val_df, column='pos_accuracy', save=save, path=fldr_1)
|
2190
|
+
_plot_and_save(train_df, val_df, column='loss', save=save, path=fldr_1)
|
2191
|
+
_plot_and_save(train_df, val_df, column='prauc', save=save, path=fldr_1)
|
2192
|
+
_plot_and_save(train_df, val_df, column='optimal_threshold', save=save, path=fldr_1)
|
2193
|
+
|
2194
|
+
def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=['r','g','b']):
|
2105
2195
|
"""
|
2106
2196
|
Save the model based on certain conditions during training.
|
2107
2197
|
|
@@ -2114,35 +2204,25 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2114
2204
|
epochs (int): The total number of epochs.
|
2115
2205
|
intermedeate_save (list, optional): List of accuracy thresholds to trigger intermediate model saves.
|
2116
2206
|
Defaults to [0.99, 0.98, 0.95, 0.94].
|
2207
|
+
channels (list, optional): List of channels used. Defaults to ['r', 'g', 'b'].
|
2117
2208
|
"""
|
2118
|
-
|
2119
|
-
if epoch % 100 == 0:
|
2120
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
|
2121
|
-
|
2122
|
-
if epoch == epochs:
|
2123
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
|
2124
|
-
|
2125
|
-
if results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[0] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[0]:
|
2126
|
-
percentile = str(intermedeate_save[0]*100)
|
2127
|
-
print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
|
2128
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
|
2129
2209
|
|
2130
|
-
|
2131
|
-
percentile = str(intermedeate_save[1]*100)
|
2132
|
-
print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
|
2133
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
|
2134
|
-
|
2135
|
-
elif results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[2] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[2]:
|
2136
|
-
percentile = str(intermedeate_save[2]*100)
|
2137
|
-
print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
|
2138
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
|
2210
|
+
channels_str = ''.join(channels)
|
2139
2211
|
|
2140
|
-
|
2141
|
-
percentile = str(
|
2212
|
+
def save_model_at_threshold(threshold, epoch, suffix=""):
|
2213
|
+
percentile = str(threshold * 100)
|
2142
2214
|
print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
|
2143
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{
|
2215
|
+
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth')
|
2144
2216
|
|
2145
|
-
|
2217
|
+
if epoch % 100 == 0 or epoch == epochs:
|
2218
|
+
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth')
|
2219
|
+
|
2220
|
+
for threshold in intermedeate_save:
|
2221
|
+
if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
|
2222
|
+
save_model_at_threshold(threshold, epoch)
|
2223
|
+
break # Ensure we only save for the highest matching threshold
|
2224
|
+
|
2225
|
+
def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
|
2146
2226
|
"""
|
2147
2227
|
Save the progress of the classification model.
|
2148
2228
|
|
@@ -2161,11 +2241,14 @@ def _save_progress(dst, results_df, train_metrics_df):
|
|
2161
2241
|
results_df.to_csv(results_path, index=True, header=True, mode='w')
|
2162
2242
|
else:
|
2163
2243
|
results_df.to_csv(results_path, index=True, header=False, mode='a')
|
2244
|
+
|
2164
2245
|
training_metrics_path = os.path.join(dst, 'training_metrics.csv')
|
2165
2246
|
if not os.path.exists(training_metrics_path):
|
2166
2247
|
train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
|
2167
2248
|
else:
|
2168
2249
|
train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
|
2250
|
+
if epoch == epochs:
|
2251
|
+
read_plot_model_stats(results_path, save=True)
|
2169
2252
|
return
|
2170
2253
|
|
2171
2254
|
def _save_settings(settings, src):
|
spacr/measure.py
CHANGED
@@ -920,7 +920,8 @@ def measure_crop(settings):
|
|
920
920
|
from .io import _save_settings_to_db
|
921
921
|
from .timelapse import _timelapse_masks_to_gif, _scmovie
|
922
922
|
from .plot import _save_scimg_plot
|
923
|
-
from .utils import _list_endpoint_subdirectories, _generate_representative_images,
|
923
|
+
from .utils import _list_endpoint_subdirectories, _generate_representative_images, measure_test_mode
|
924
|
+
from .settings import get_measure_crop_settings
|
924
925
|
|
925
926
|
settings = get_measure_crop_settings(settings)
|
926
927
|
settings = measure_test_mode(settings)
|
Binary file
|
@@ -0,0 +1,23 @@
|
|
1
|
+
Key,Value
|
2
|
+
img_src,/nas_mnt/carruthers/patrick/Plaque_assay_training/train
|
3
|
+
model_name,toxo_plaque
|
4
|
+
model_type,cyto
|
5
|
+
Signal_to_noise,10
|
6
|
+
background,200
|
7
|
+
remove_background,False
|
8
|
+
learning_rate,0.2
|
9
|
+
weight_decay,1e-05
|
10
|
+
batch_size,8
|
11
|
+
n_epochs,25000
|
12
|
+
from_scratch,False
|
13
|
+
diameter,30
|
14
|
+
resize,True
|
15
|
+
width_height,"[1120, 1120]"
|
16
|
+
verbose,True
|
17
|
+
channels,"[0, 0]"
|
18
|
+
normalize,True
|
19
|
+
percentiles,
|
20
|
+
circular,False
|
21
|
+
invert,False
|
22
|
+
grayscale,True
|
23
|
+
test,False
|