spacr 0.0.81__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +4 -0
- spacr/annotate_app.py +4 -0
- spacr/annotate_app_v2.py +511 -0
- spacr/core.py +258 -177
- spacr/deep_spacr.py +137 -50
- spacr/graph_learning.py +28 -8
- spacr/io.py +332 -142
- spacr/measure.py +2 -1
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +23 -0
- spacr/plot.py +102 -6
- spacr/sequencing.py +849 -129
- spacr/settings.py +477 -0
- spacr/timelapse.py +0 -3
- spacr/utils.py +312 -275
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/METADATA +1 -1
- spacr-0.1.0.dist-info/RECORD +40 -0
- spacr-0.0.81.dist-info/RECORD +0 -36
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/LICENSE +0 -0
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/WHEEL +0 -0
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/top_level.txt +0 -0
spacr/io.py
CHANGED
@@ -21,6 +21,7 @@ from multiprocessing import Pool, cpu_count
|
|
21
21
|
from torch.utils.data import Dataset
|
22
22
|
import matplotlib.pyplot as plt
|
23
23
|
from torchvision.transforms import ToTensor
|
24
|
+
import seaborn as sns
|
24
25
|
|
25
26
|
|
26
27
|
from .logger import log_function_call
|
@@ -87,7 +88,7 @@ def _load_images_and_labels(image_files, label_files, circular=False, invert=Fal
|
|
87
88
|
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {labels[0].shape}, image type: labels[0].shape')
|
88
89
|
return images, labels, image_names, label_names
|
89
90
|
|
90
|
-
def
|
91
|
+
def _load_normalized_images_and_labels_v1(image_files, label_files, channels=None, percentiles=None, circular=False, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10):
|
91
92
|
|
92
93
|
from .plot import normalize_and_visualize
|
93
94
|
from .utils import invert_image, apply_mask
|
@@ -182,6 +183,115 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
|
|
182
183
|
|
183
184
|
return normalized_images, labels, image_names, label_names
|
184
185
|
|
186
|
+
def _load_normalized_images_and_labels(image_files, label_files, channels=None, percentiles=None, circular=False, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10, target_height=None, target_width=None):
|
187
|
+
|
188
|
+
from .plot import normalize_and_visualize, plot_resize
|
189
|
+
from .utils import invert_image, apply_mask
|
190
|
+
from skimage.transform import resize as resizescikit
|
191
|
+
|
192
|
+
signal_thresholds = background * Signal_to_noise
|
193
|
+
lower_percentile = 2
|
194
|
+
|
195
|
+
images = []
|
196
|
+
labels = []
|
197
|
+
orig_dims = []
|
198
|
+
|
199
|
+
num_channels = 4
|
200
|
+
percentiles_1 = [[] for _ in range(num_channels)]
|
201
|
+
percentiles_99 = [[] for _ in range(num_channels)]
|
202
|
+
|
203
|
+
image_names = [os.path.basename(f) for f in image_files]
|
204
|
+
image_dir = os.path.dirname(image_files[0])
|
205
|
+
|
206
|
+
if label_files is not None:
|
207
|
+
label_names = [os.path.basename(f) for f in label_files]
|
208
|
+
label_dir = os.path.dirname(label_files[0])
|
209
|
+
|
210
|
+
# Load, normalize, and resize images
|
211
|
+
for i, img_file in enumerate(image_files):
|
212
|
+
image = cellpose.io.imread(img_file)
|
213
|
+
orig_dims.append((image.shape[0], image.shape[1]))
|
214
|
+
if invert:
|
215
|
+
image = invert_image(image)
|
216
|
+
if circular:
|
217
|
+
image = apply_mask(image, output_value=0)
|
218
|
+
|
219
|
+
# If specific channels are specified, select them
|
220
|
+
if channels is not None and image.ndim == 3:
|
221
|
+
image = image[..., channels]
|
222
|
+
|
223
|
+
if remove_background:
|
224
|
+
image[image < background] = 0
|
225
|
+
|
226
|
+
if image.ndim < 3:
|
227
|
+
image = np.expand_dims(image, axis=-1)
|
228
|
+
|
229
|
+
if percentiles is None:
|
230
|
+
for c in range(image.shape[-1]):
|
231
|
+
p1 = np.percentile(image[..., c], lower_percentile)
|
232
|
+
percentiles_1[c].append(p1)
|
233
|
+
for percentile in [98, 99, 99.9, 99.99, 99.999]:
|
234
|
+
p = np.percentile(image[..., c], percentile)
|
235
|
+
if p > signal_thresholds:
|
236
|
+
percentiles_99[c].append(p)
|
237
|
+
break
|
238
|
+
|
239
|
+
# Resize image
|
240
|
+
if target_height is not None and target_width is not None:
|
241
|
+
if image.ndim == 2:
|
242
|
+
image_shape = (target_height, target_width)
|
243
|
+
elif image.ndim == 3:
|
244
|
+
image_shape = (target_height, target_width, image.shape[-1])
|
245
|
+
|
246
|
+
image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
|
247
|
+
|
248
|
+
images.append(image)
|
249
|
+
|
250
|
+
if percentiles is None:
|
251
|
+
# Calculate average percentiles for normalization
|
252
|
+
avg_p1 = [np.mean(p) for p in percentiles_1]
|
253
|
+
avg_p99 = [np.mean(p) if len(p) > 0 else np.mean(percentiles_1[i]) for i, p in enumerate(percentiles_99)]
|
254
|
+
|
255
|
+
print(f'Average 1st percentiles: {avg_p1}, Average 99th percentiles: {avg_p99}')
|
256
|
+
|
257
|
+
normalized_images = []
|
258
|
+
for image in images:
|
259
|
+
normalized_image = np.zeros_like(image, dtype=np.float32)
|
260
|
+
for c in range(image.shape[-1]):
|
261
|
+
normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
|
262
|
+
normalized_images.append(normalized_image)
|
263
|
+
if visualize:
|
264
|
+
normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
|
265
|
+
else:
|
266
|
+
normalized_images = []
|
267
|
+
for image in images:
|
268
|
+
normalized_image = np.zeros_like(image, dtype=np.float32)
|
269
|
+
for c in range(image.shape[-1]):
|
270
|
+
low_p = np.percentile(image[..., c], percentiles[0])
|
271
|
+
high_p = np.percentile(image[..., c], percentiles[1])
|
272
|
+
normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(low_p, high_p), out_range=(0, 1))
|
273
|
+
normalized_images.append(normalized_image)
|
274
|
+
if visualize:
|
275
|
+
normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
|
276
|
+
|
277
|
+
if label_files is not None:
|
278
|
+
for lbl_file in label_files:
|
279
|
+
label = cellpose.io.imread(lbl_file)
|
280
|
+
# Resize label
|
281
|
+
if target_height is not None and target_width is not None:
|
282
|
+
label = resizescikit(label, (target_height, target_width), order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
|
283
|
+
labels.append(label)
|
284
|
+
else:
|
285
|
+
label_names = []
|
286
|
+
label_dir = None
|
287
|
+
|
288
|
+
print(f'Loaded and normalized {len(normalized_images)} images and {len(labels)} labels from {image_dir} and {label_dir}')
|
289
|
+
|
290
|
+
if visualize and images and labels:
|
291
|
+
plot_resize(images, normalized_images, labels, labels)
|
292
|
+
|
293
|
+
return normalized_images, labels, image_names, label_names, orig_dims
|
294
|
+
|
185
295
|
class CombineLoaders:
|
186
296
|
|
187
297
|
"""
|
@@ -203,14 +313,14 @@ class CombineLoaders:
|
|
203
313
|
|
204
314
|
"""
|
205
315
|
|
206
|
-
def
|
316
|
+
def __init__(self, train_loaders):
|
207
317
|
self.train_loaders = train_loaders
|
208
318
|
self.loader_iters = [iter(loader) for loader in train_loaders]
|
209
319
|
|
210
|
-
def
|
320
|
+
def __iter__(self):
|
211
321
|
return self
|
212
322
|
|
213
|
-
def
|
323
|
+
def __next__(self):
|
214
324
|
while self.loader_iters:
|
215
325
|
random.shuffle(self.loader_iters) # Shuffle the loader_iters list
|
216
326
|
for i, loader_iter in enumerate(self.loader_iters):
|
@@ -233,7 +343,7 @@ class CombinedDataset(Dataset):
|
|
233
343
|
shuffle (bool, optional): Whether to shuffle the combined dataset. Defaults to True.
|
234
344
|
"""
|
235
345
|
|
236
|
-
def
|
346
|
+
def __init__(self, datasets, shuffle=True):
|
237
347
|
self.datasets = datasets
|
238
348
|
self.lengths = [len(dataset) for dataset in datasets]
|
239
349
|
self.total_length = sum(self.lengths)
|
@@ -243,14 +353,14 @@ class CombinedDataset(Dataset):
|
|
243
353
|
random.shuffle(self.indices)
|
244
354
|
else:
|
245
355
|
self.indices = None
|
246
|
-
def
|
356
|
+
def __getitem__(self, index):
|
247
357
|
if self.shuffle:
|
248
358
|
index = self.indices[index]
|
249
359
|
for dataset, length in zip(self.datasets, self.lengths):
|
250
360
|
if index < length:
|
251
361
|
return dataset[index]
|
252
362
|
index -= length
|
253
|
-
def
|
363
|
+
def __len__(self):
|
254
364
|
return self.total_length
|
255
365
|
|
256
366
|
class NoClassDataset(Dataset):
|
@@ -434,7 +544,7 @@ class NoClassDataset(Dataset):
|
|
434
544
|
|
435
545
|
|
436
546
|
class TarImageDataset(Dataset):
|
437
|
-
def
|
547
|
+
def __init__(self, tar_path, transform=None):
|
438
548
|
self.tar_path = tar_path
|
439
549
|
self.transform = transform
|
440
550
|
|
@@ -442,10 +552,10 @@ class TarImageDataset(Dataset):
|
|
442
552
|
with tarfile.open(self.tar_path, 'r') as f:
|
443
553
|
self.members = [m for m in f.getmembers() if m.isfile()]
|
444
554
|
|
445
|
-
def
|
555
|
+
def __len__(self):
|
446
556
|
return len(self.members)
|
447
557
|
|
448
|
-
def
|
558
|
+
def __getitem__(self, idx):
|
449
559
|
with tarfile.open(self.tar_path, 'r') as f:
|
450
560
|
m = self.members[idx]
|
451
561
|
img_file = f.extractfile(m)
|
@@ -890,7 +1000,75 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
|
|
890
1000
|
print(f'All files concatenated and saved to:{channel_stack_loc}')
|
891
1001
|
return channel_stack_loc
|
892
1002
|
|
893
|
-
def
|
1003
|
+
def _normalize_img_batch(stack, channels, save_dtype, settings):
|
1004
|
+
|
1005
|
+
"""
|
1006
|
+
Normalize the stack of images.
|
1007
|
+
|
1008
|
+
Args:
|
1009
|
+
stack (numpy.ndarray): The stack of images to normalize.
|
1010
|
+
lower_percentile (int): Lower percentile value for normalization.
|
1011
|
+
save_dtype (numpy.dtype): Data type for saving the normalized stack.
|
1012
|
+
settings (dict): keword arguments
|
1013
|
+
|
1014
|
+
Returns:
|
1015
|
+
numpy.ndarray: The normalized stack.
|
1016
|
+
"""
|
1017
|
+
|
1018
|
+
normalized_stack = np.zeros_like(stack, dtype=np.float32)
|
1019
|
+
|
1020
|
+
#for channel in range(stack.shape[-1]):
|
1021
|
+
for channel in channels:
|
1022
|
+
if channel == settings['nucleus_channel']:
|
1023
|
+
background = settings['nucleus_background']
|
1024
|
+
signal_threshold = settings['nucleus_Signal_to_noise']*settings['nucleus_background']
|
1025
|
+
remove_background = settings['remove_background_nucleus']
|
1026
|
+
|
1027
|
+
if channel == settings['cell_channel']:
|
1028
|
+
background = settings['cell_background']
|
1029
|
+
signal_threshold = settings['cell_Signal_to_noise']*settings['cell_background']
|
1030
|
+
remove_background = settings['remove_background_cell']
|
1031
|
+
|
1032
|
+
if channel == settings['pathogen_channel']:
|
1033
|
+
background = settings['pathogen_background']
|
1034
|
+
signal_threshold = settings['pathogen_Signal_to_noise']*settings['pathogen_background']
|
1035
|
+
remove_background = settings['remove_background_pathogen']
|
1036
|
+
|
1037
|
+
single_channel = stack[:, :, :, channel]
|
1038
|
+
|
1039
|
+
print(f'Processing channel {channel}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
|
1040
|
+
|
1041
|
+
# Step 3: Remove background if required
|
1042
|
+
if remove_background:
|
1043
|
+
single_channel[single_channel < background] = 0
|
1044
|
+
|
1045
|
+
# Step 4: Calculate global lower percentile for the channel
|
1046
|
+
non_zero_single_channel = single_channel[single_channel != 0]
|
1047
|
+
global_lower = np.percentile(non_zero_single_channel, settings['lower_percentile'])
|
1048
|
+
|
1049
|
+
# Step 5: Calculate global upper percentile for the channel
|
1050
|
+
global_upper = None
|
1051
|
+
for upper_p in np.linspace(98, 99.5, num=16):
|
1052
|
+
upper_value = np.percentile(non_zero_single_channel, upper_p)
|
1053
|
+
if upper_value >= signal_threshold:
|
1054
|
+
global_upper = upper_value
|
1055
|
+
break
|
1056
|
+
|
1057
|
+
if global_upper is None:
|
1058
|
+
global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
|
1059
|
+
|
1060
|
+
print(f'Channel {channel}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
|
1061
|
+
|
1062
|
+
# Step 6: Normalize each array from global_lower to global_upper between 0 and 1
|
1063
|
+
for array_index in range(single_channel.shape[0]):
|
1064
|
+
arr_2d = single_channel[array_index, :, :]
|
1065
|
+
arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
|
1066
|
+
normalized_stack[array_index, :, :, channel] = arr_2d_normalized
|
1067
|
+
|
1068
|
+
return normalized_stack.astype(save_dtype)
|
1069
|
+
|
1070
|
+
def concatenate_and_normalize(src, channels, save_dtype=np.float32, settings={}):
|
1071
|
+
|
894
1072
|
"""
|
895
1073
|
Concatenates and normalizes channel data from multiple files and saves the normalized data.
|
896
1074
|
|
@@ -910,12 +1088,14 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
910
1088
|
Returns:
|
911
1089
|
str: The directory path where the concatenated and normalized channel data is saved.
|
912
1090
|
"""
|
1091
|
+
# n c p
|
913
1092
|
channels = [item for item in channels if item is not None]
|
1093
|
+
|
914
1094
|
paths = []
|
915
1095
|
output_fldr = os.path.join(os.path.dirname(src), 'norm_channel_stack')
|
916
1096
|
os.makedirs(output_fldr, exist_ok=True)
|
917
1097
|
|
918
|
-
if timelapse:
|
1098
|
+
if settings['timelapse']:
|
919
1099
|
try:
|
920
1100
|
time_stack_path_lists = _generate_time_lists(os.listdir(src))
|
921
1101
|
for i, time_stack_list in enumerate(time_stack_path_lists):
|
@@ -927,12 +1107,19 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
927
1107
|
parts = file.split('_')
|
928
1108
|
name = parts[0] + '_' + parts[1] + '_' + parts[2]
|
929
1109
|
array = np.load(path)
|
930
|
-
array = np.take(array, channels, axis=2)
|
1110
|
+
#array = np.take(array, channels, axis=2)
|
931
1111
|
stack_region.append(array)
|
932
1112
|
filenames_region.append(os.path.basename(path))
|
933
1113
|
print(f'Region {i + 1}/ {len(time_stack_path_lists)}', end='\r', flush=True)
|
934
1114
|
stack = np.stack(stack_region)
|
935
|
-
|
1115
|
+
|
1116
|
+
normalized_stack = _normalize_img_batch(stack=stack,
|
1117
|
+
channels=channels,
|
1118
|
+
save_dtype=save_dtype,
|
1119
|
+
settings=settings)
|
1120
|
+
|
1121
|
+
normalized_stack = normalized_stack[..., channels]
|
1122
|
+
|
936
1123
|
save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
|
937
1124
|
np.savez(save_loc, data=normalized_stack, filenames=filenames_region)
|
938
1125
|
print(save_loc)
|
@@ -945,7 +1132,7 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
945
1132
|
if file.endswith('.npy'):
|
946
1133
|
path = os.path.join(src, file)
|
947
1134
|
paths.append(path)
|
948
|
-
if randomize:
|
1135
|
+
if settings['randomize']:
|
949
1136
|
random.shuffle(paths)
|
950
1137
|
nr_files = len(paths)
|
951
1138
|
batch_index = 0
|
@@ -954,12 +1141,12 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
954
1141
|
|
955
1142
|
for i, path in enumerate(paths):
|
956
1143
|
array = np.load(path)
|
957
|
-
array = np.take(array, channels, axis=2)
|
1144
|
+
#array = np.take(array, channels, axis=2)
|
958
1145
|
stack_ls.append(array)
|
959
1146
|
filenames_batch.append(os.path.basename(path))
|
960
1147
|
print(f'Concatenated: {i + 1}/{nr_files} files')
|
961
1148
|
|
962
|
-
if (i + 1) % batch_size == 0 or i + 1 == nr_files:
|
1149
|
+
if (i + 1) % settings['batch_size'] == 0 or i + 1 == nr_files:
|
963
1150
|
unique_shapes = {arr.shape[:-1] for arr in stack_ls}
|
964
1151
|
if len(unique_shapes) > 1:
|
965
1152
|
max_dims = np.max(np.array(list(unique_shapes)), axis=0)
|
@@ -973,8 +1160,13 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
973
1160
|
stack = np.stack(padded_stack_ls)
|
974
1161
|
else:
|
975
1162
|
stack = np.stack(stack_ls)
|
976
|
-
|
977
|
-
normalized_stack = _normalize_img_batch(stack,
|
1163
|
+
|
1164
|
+
normalized_stack = _normalize_img_batch(stack=stack,
|
1165
|
+
channels=channels,
|
1166
|
+
save_dtype=save_dtype,
|
1167
|
+
settings=settings)
|
1168
|
+
|
1169
|
+
normalized_stack = normalized_stack[..., channels]
|
978
1170
|
|
979
1171
|
save_loc = os.path.join(output_fldr, f'stack_{batch_index}_norm.npz')
|
980
1172
|
np.savez(save_loc, data=normalized_stack, filenames=filenames_batch)
|
@@ -983,64 +1175,10 @@ def concatenate_and_normalize(src, channels, randomize=True, timelapse=False, ba
|
|
983
1175
|
stack_ls = []
|
984
1176
|
filenames_batch = []
|
985
1177
|
padded_stack_ls = []
|
1178
|
+
|
986
1179
|
print(f'All files concatenated and normalized. Saved to: {output_fldr}')
|
987
1180
|
return output_fldr
|
988
1181
|
|
989
|
-
def _normalize_img_batch(stack, backgrounds, remove_backgrounds, lower_percentile, save_dtype, signal_to_noise, signal_thresholds):
|
990
|
-
"""
|
991
|
-
Normalize the stack of images.
|
992
|
-
|
993
|
-
Args:
|
994
|
-
stack (numpy.ndarray): The stack of images to normalize.
|
995
|
-
backgrounds (list): Background values for each channel.
|
996
|
-
remove_backgrounds (list): Whether to remove background values for each channel.
|
997
|
-
lower_percentile (int): Lower percentile value for normalization.
|
998
|
-
save_dtype (numpy.dtype): Data type for saving the normalized stack.
|
999
|
-
signal_to_noise (list): Signal-to-noise ratio thresholds for each channel.
|
1000
|
-
signal_thresholds (list): Signal thresholds for each channel.
|
1001
|
-
|
1002
|
-
Returns:
|
1003
|
-
numpy.ndarray: The normalized stack.
|
1004
|
-
"""
|
1005
|
-
normalized_stack = np.zeros_like(stack, dtype=np.float32)
|
1006
|
-
|
1007
|
-
for chan_index, channel in enumerate(range(stack.shape[-1])):
|
1008
|
-
single_channel = stack[:, :, :, channel]
|
1009
|
-
background = backgrounds[chan_index]
|
1010
|
-
signal_threshold = signal_thresholds[chan_index]
|
1011
|
-
remove_background = remove_backgrounds[chan_index]
|
1012
|
-
|
1013
|
-
print(f'Processing channel {chan_index}: background={background}, signal_threshold={signal_threshold}, remove_background={remove_background}')
|
1014
|
-
|
1015
|
-
# Step 3: Remove background if required
|
1016
|
-
if remove_background:
|
1017
|
-
single_channel[single_channel < background] = 0
|
1018
|
-
|
1019
|
-
# Step 4: Calculate global lower percentile for the channel
|
1020
|
-
non_zero_single_channel = single_channel[single_channel != 0]
|
1021
|
-
global_lower = np.percentile(non_zero_single_channel, lower_percentile)
|
1022
|
-
|
1023
|
-
# Step 5: Calculate global upper percentile for the channel
|
1024
|
-
global_upper = None
|
1025
|
-
for upper_p in np.linspace(98, 99.5, num=16):
|
1026
|
-
upper_value = np.percentile(non_zero_single_channel, upper_p)
|
1027
|
-
if upper_value >= signal_threshold:
|
1028
|
-
global_upper = upper_value
|
1029
|
-
break
|
1030
|
-
|
1031
|
-
if global_upper is None:
|
1032
|
-
global_upper = np.percentile(non_zero_single_channel, 99.5) # Fallback in case no upper percentile met the threshold
|
1033
|
-
|
1034
|
-
print(f'Channel {chan_index}: global_lower={global_lower}, global_upper={global_upper}, Signal-to-noise={global_upper / global_lower}')
|
1035
|
-
|
1036
|
-
# Step 6: Normalize each array from global_lower to global_upper between 0 and 1
|
1037
|
-
for array_index in range(single_channel.shape[0]):
|
1038
|
-
arr_2d = single_channel[array_index, :, :]
|
1039
|
-
arr_2d_normalized = exposure.rescale_intensity(arr_2d, in_range=(global_lower, global_upper), out_range=(0, 1))
|
1040
|
-
normalized_stack[array_index, :, :, channel] = arr_2d_normalized
|
1041
|
-
|
1042
|
-
return normalized_stack.astype(save_dtype)
|
1043
|
-
|
1044
1182
|
def _get_lists_for_normalization(settings):
|
1045
1183
|
"""
|
1046
1184
|
Get lists for normalization based on the provided settings.
|
@@ -1059,22 +1197,25 @@ def _get_lists_for_normalization(settings):
|
|
1059
1197
|
remove_background = []
|
1060
1198
|
|
1061
1199
|
# Iterate through the channels and append the corresponding values if the channel is not None
|
1062
|
-
for ch in settings['channels']:
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1200
|
+
# for ch in settings['channels']:
|
1201
|
+
for ch in [settings['nucleus_channel'], settings['cell_channel'], settings['pathogen_channel']]:
|
1202
|
+
if not ch is None:
|
1203
|
+
if ch == settings['nucleus_channel']:
|
1204
|
+
backgrounds.append(settings['nucleus_background'])
|
1205
|
+
signal_to_noise.append(settings['nucleus_Signal_to_noise'])
|
1206
|
+
signal_thresholds.append(settings['nucleus_Signal_to_noise']*settings['nucleus_background'])
|
1207
|
+
remove_background.append(settings['remove_background_nucleus'])
|
1208
|
+
elif ch == settings['cell_channel']:
|
1209
|
+
backgrounds.append(settings['cell_background'])
|
1210
|
+
signal_to_noise.append(settings['cell_Signal_to_noise'])
|
1211
|
+
signal_thresholds.append(settings['cell_Signal_to_noise']*settings['cell_background'])
|
1212
|
+
remove_background.append(settings['remove_background_cell'])
|
1213
|
+
elif ch == settings['pathogen_channel']:
|
1214
|
+
backgrounds.append(settings['pathogen_background'])
|
1215
|
+
signal_to_noise.append(settings['pathogen_Signal_to_noise'])
|
1216
|
+
signal_thresholds.append(settings['pathogen_Signal_to_noise']*settings['pathogen_background'])
|
1217
|
+
remove_background.append(settings['remove_background_pathogen'])
|
1218
|
+
|
1078
1219
|
return backgrounds, signal_to_noise, signal_thresholds, remove_background
|
1079
1220
|
|
1080
1221
|
def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
|
@@ -1283,7 +1424,8 @@ def delete_empty_subdirectories(folder_path):
|
|
1283
1424
|
def preprocess_img_data(settings):
|
1284
1425
|
|
1285
1426
|
from .plot import plot_arrays, _plot_4D_arrays
|
1286
|
-
from .utils import _run_test_mode, _get_regex
|
1427
|
+
from .utils import _run_test_mode, _get_regex
|
1428
|
+
from .settings import set_default_settings_preprocess_img_data
|
1287
1429
|
|
1288
1430
|
"""
|
1289
1431
|
Preprocesses image data by converting z-stack images to maximum intensity projection (MIP) images.
|
@@ -1400,19 +1542,10 @@ def preprocess_img_data(settings):
|
|
1400
1542
|
except Exception as e:
|
1401
1543
|
print(f"Error: {e}")
|
1402
1544
|
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
randomize,
|
1408
|
-
timelapse,
|
1409
|
-
batch_size,
|
1410
|
-
backgrounds,
|
1411
|
-
remove_backgrounds,
|
1412
|
-
lower_percentile,
|
1413
|
-
np.float32,
|
1414
|
-
signal_to_noise,
|
1415
|
-
signal_thresholds)
|
1545
|
+
concatenate_and_normalize(src=src+'/stack',
|
1546
|
+
channels=mask_channels,
|
1547
|
+
save_dtype=np.float32,
|
1548
|
+
settings=settings)
|
1416
1549
|
|
1417
1550
|
if plot:
|
1418
1551
|
_plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
|
@@ -1494,13 +1627,13 @@ def _save_figure(fig, src, text, dpi=300, i=1, all_folders=1):
|
|
1494
1627
|
del fig
|
1495
1628
|
gc.collect()
|
1496
1629
|
|
1497
|
-
def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', '
|
1630
|
+
def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list']):
|
1498
1631
|
"""
|
1499
1632
|
Reads and joins tables from a SQLite database.
|
1500
1633
|
|
1501
1634
|
Args:
|
1502
1635
|
db_path (str): The path to the SQLite database file.
|
1503
|
-
table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', '
|
1636
|
+
table_names (list, optional): The names of the tables to read and join. Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list'].
|
1504
1637
|
|
1505
1638
|
Returns:
|
1506
1639
|
pandas.DataFrame: The joined DataFrame containing the data from the specified tables, or None if an error occurs.
|
@@ -1522,9 +1655,9 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
|
|
1522
1655
|
join_cols = ['object_label', 'plate', 'row', 'col']
|
1523
1656
|
dataframes['cell'] = pd.merge(dataframes['cell'], png_list_df, on=join_cols, how='left')
|
1524
1657
|
else:
|
1525
|
-
print("Cell table not found
|
1526
|
-
return
|
1527
|
-
for entity in ['nucleus', 'pathogen'
|
1658
|
+
print("Cell table not found in database tables.")
|
1659
|
+
return png_list_df
|
1660
|
+
for entity in ['nucleus', 'pathogen']:
|
1528
1661
|
if entity in dataframes:
|
1529
1662
|
numeric_cols = dataframes[entity].select_dtypes(include=[np.number]).columns.tolist()
|
1530
1663
|
non_numeric_cols = dataframes[entity].select_dtypes(exclude=[np.number]).columns.tolist()
|
@@ -1537,14 +1670,11 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
|
|
1537
1670
|
joined_df = None
|
1538
1671
|
if 'cell' in dataframes:
|
1539
1672
|
joined_df = dataframes['cell']
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
else:
|
1546
|
-
print("Cell table not found. Cannot proceed with joining.")
|
1547
|
-
return None
|
1673
|
+
if 'cytoplasm' in dataframes:
|
1674
|
+
joined_df = pd.merge(joined_df, dataframes['cytoplasm'], on=['object_label', 'prcf'], how='left', suffixes=('', '_cytoplasm'))
|
1675
|
+
for entity in ['nucleus', 'pathogen']:
|
1676
|
+
if entity in dataframes:
|
1677
|
+
joined_df = pd.merge(joined_df, dataframes[entity], left_on=['object_label', 'prcf'], right_index=True, how='left', suffixes=('', f'_{entity}'))
|
1548
1678
|
return joined_df
|
1549
1679
|
|
1550
1680
|
def _save_settings_to_db(settings):
|
@@ -1993,8 +2123,75 @@ def _results_to_csv(src, df, df_well):
|
|
1993
2123
|
###################################################
|
1994
2124
|
# Classify
|
1995
2125
|
###################################################
|
2126
|
+
|
2127
|
+
def read_plot_model_stats(file_path ,save=False):
|
1996
2128
|
|
1997
|
-
def
|
2129
|
+
def _plot_and_save(train_df, val_df, column='accuracy', save=False, path=None, dpi=600):
|
2130
|
+
|
2131
|
+
pdf_path = os.path.join(path, f'{column}.pdf')
|
2132
|
+
|
2133
|
+
# Create subplots
|
2134
|
+
fig, axes = plt.subplots(1, 2, figsize=(20, 10), sharey=True)
|
2135
|
+
|
2136
|
+
# Plotting
|
2137
|
+
sns.lineplot(ax=axes[0], x='epoch', y=column, data=train_df, marker='o', color='red')
|
2138
|
+
sns.lineplot(ax=axes[1], x='epoch', y=column, data=val_df, marker='o', color='blue')
|
2139
|
+
|
2140
|
+
# Set titles and labels
|
2141
|
+
axes[0].set_title(f'Train {column} vs. Epoch', fontsize=20)
|
2142
|
+
axes[0].set_xlabel('Epoch', fontsize=16)
|
2143
|
+
axes[0].set_ylabel(column, fontsize=16)
|
2144
|
+
axes[0].tick_params(axis='both', which='major', labelsize=12)
|
2145
|
+
|
2146
|
+
axes[1].set_title(f'Validation {column} vs. Epoch', fontsize=20)
|
2147
|
+
axes[1].set_xlabel('Epoch', fontsize=16)
|
2148
|
+
axes[1].tick_params(axis='both', which='major', labelsize=12)
|
2149
|
+
|
2150
|
+
plt.tight_layout()
|
2151
|
+
|
2152
|
+
if save:
|
2153
|
+
plt.savefig(pdf_path, format='pdf', dpi=dpi)
|
2154
|
+
else:
|
2155
|
+
plt.show()
|
2156
|
+
# Read the CSV into a dataframe
|
2157
|
+
df = pd.read_csv(file_path, index_col=0)
|
2158
|
+
|
2159
|
+
# Split the dataframe into train and validation based on the index
|
2160
|
+
train_df = df.filter(like='_train', axis=0).copy()
|
2161
|
+
val_df = df.filter(like='_val', axis=0).copy()
|
2162
|
+
|
2163
|
+
fldr_1 = os.path.dirname(file_path)
|
2164
|
+
|
2165
|
+
train_csv_path = os.path.join(fldr_1, 'train.csv')
|
2166
|
+
val_csv_path = os.path.join(fldr_1, 'validation.csv')
|
2167
|
+
|
2168
|
+
fldr_2 = os.path.dirname(fldr_1)
|
2169
|
+
fldr_3 = os.path.dirname(fldr_2)
|
2170
|
+
bn_1 = os.path.basename(fldr_1)
|
2171
|
+
bn_2 = os.path.basename(fldr_2)
|
2172
|
+
bn_3 = os.path.basename(fldr_3)
|
2173
|
+
model_name = str(f'{bn_1}_{bn_2}_{bn_3}')
|
2174
|
+
|
2175
|
+
# Extract epochs from index
|
2176
|
+
train_df['epoch'] = [int(idx.split('_')[0]) for idx in train_df.index]
|
2177
|
+
val_df['epoch'] = [int(idx.split('_')[0]) for idx in val_df.index]
|
2178
|
+
|
2179
|
+
# Save dataframes to a CSV file
|
2180
|
+
train_df.to_csv(train_csv_path)
|
2181
|
+
val_df.to_csv(val_csv_path)
|
2182
|
+
|
2183
|
+
if save:
|
2184
|
+
# Setting the style
|
2185
|
+
sns.set(style="whitegrid")
|
2186
|
+
|
2187
|
+
_plot_and_save(train_df, val_df, column='accuracy', save=save, path=fldr_1)
|
2188
|
+
_plot_and_save(train_df, val_df, column='neg_accuracy', save=save, path=fldr_1)
|
2189
|
+
_plot_and_save(train_df, val_df, column='pos_accuracy', save=save, path=fldr_1)
|
2190
|
+
_plot_and_save(train_df, val_df, column='loss', save=save, path=fldr_1)
|
2191
|
+
_plot_and_save(train_df, val_df, column='prauc', save=save, path=fldr_1)
|
2192
|
+
_plot_and_save(train_df, val_df, column='optimal_threshold', save=save, path=fldr_1)
|
2193
|
+
|
2194
|
+
def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=['r','g','b']):
|
1998
2195
|
"""
|
1999
2196
|
Save the model based on certain conditions during training.
|
2000
2197
|
|
@@ -2007,35 +2204,25 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2007
2204
|
epochs (int): The total number of epochs.
|
2008
2205
|
intermedeate_save (list, optional): List of accuracy thresholds to trigger intermediate model saves.
|
2009
2206
|
Defaults to [0.99, 0.98, 0.95, 0.94].
|
2207
|
+
channels (list, optional): List of channels used. Defaults to ['r', 'g', 'b'].
|
2010
2208
|
"""
|
2011
|
-
|
2012
|
-
if epoch % 100 == 0:
|
2013
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
|
2014
|
-
|
2015
|
-
if epoch == epochs:
|
2016
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}.pth')
|
2017
|
-
|
2018
|
-
if results_df['neg_accuracy'].dropna().mean() >= intermedeate_save[0] and results_df['pos_accuracy'].dropna().mean() >= intermedeate_save[0]:
|
2019
|
-
percentile = str(intermedeate_save[0]*100)
|
2020
|
-
print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
|
2021
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
|
2022
2209
|
|
2023
|
-
|
2024
|
-
percentile = str(intermedeate_save[1]*100)
|
2025
|
-
print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
|
2026
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
|
2210
|
+
channels_str = ''.join(channels)
|
2027
2211
|
|
2028
|
-
|
2029
|
-
percentile = str(
|
2212
|
+
def save_model_at_threshold(threshold, epoch, suffix=""):
|
2213
|
+
percentile = str(threshold * 100)
|
2030
2214
|
print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
|
2031
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{
|
2032
|
-
|
2033
|
-
|
2034
|
-
|
2035
|
-
print(f'\rfound: {percentile}% accurate model', end='\r', flush=True)
|
2036
|
-
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_acc_{str(percentile)}.pth')
|
2215
|
+
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth')
|
2216
|
+
|
2217
|
+
if epoch % 100 == 0 or epoch == epochs:
|
2218
|
+
torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth')
|
2037
2219
|
|
2038
|
-
|
2220
|
+
for threshold in intermedeate_save:
|
2221
|
+
if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
|
2222
|
+
save_model_at_threshold(threshold, epoch)
|
2223
|
+
break # Ensure we only save for the highest matching threshold
|
2224
|
+
|
2225
|
+
def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
|
2039
2226
|
"""
|
2040
2227
|
Save the progress of the classification model.
|
2041
2228
|
|
@@ -2054,11 +2241,14 @@ def _save_progress(dst, results_df, train_metrics_df):
|
|
2054
2241
|
results_df.to_csv(results_path, index=True, header=True, mode='w')
|
2055
2242
|
else:
|
2056
2243
|
results_df.to_csv(results_path, index=True, header=False, mode='a')
|
2244
|
+
|
2057
2245
|
training_metrics_path = os.path.join(dst, 'training_metrics.csv')
|
2058
2246
|
if not os.path.exists(training_metrics_path):
|
2059
2247
|
train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
|
2060
2248
|
else:
|
2061
2249
|
train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
|
2250
|
+
if epoch == epochs:
|
2251
|
+
read_plot_model_stats(results_path, save=True)
|
2062
2252
|
return
|
2063
2253
|
|
2064
2254
|
def _save_settings(settings, src):
|