spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +6 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +807 -0
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/cli.py +25 -187
- spacr/core.py +1611 -389
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +320 -0
- spacr/graph_learning_lap.py +84 -0
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +187 -0
- spacr/gui_mask_app.py +149 -174
- spacr/gui_measure_app.py +116 -109
- spacr/gui_sim_app.py +0 -0
- spacr/gui_utils.py +679 -139
- spacr/io.py +620 -469
- spacr/mask_app.py +116 -9
- spacr/measure.py +178 -84
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +255 -1
- spacr/plot.py +263 -100
- spacr/sequencing.py +1130 -0
- spacr/sim.py +634 -122
- spacr/timelapse.py +343 -53
- spacr/train.py +195 -22
- spacr/umap.py +0 -689
- spacr/utils.py +1530 -188
- spacr-0.0.6.dist-info/METADATA +118 -0
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.1.dist-info/METADATA +0 -64
- spacr-0.0.1.dist-info/RECORD +0 -26
- spacr-0.0.1.dist-info/entry_points.txt +0 -5
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/utils.py
CHANGED
@@ -1,10 +1,18 @@
|
|
1
|
-
import os, re, sqlite3,
|
1
|
+
import sys, os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
+
from cellpose import models as cp_models
|
5
|
+
from cellpose import denoise
|
6
|
+
|
4
7
|
from skimage import morphology
|
5
8
|
from skimage.measure import label, regionprops_table, regionprops
|
6
9
|
import skimage.measure as measure
|
7
|
-
from
|
10
|
+
from skimage.transform import resize as resizescikit
|
11
|
+
from skimage.morphology import dilation, square
|
12
|
+
from skimage.measure import find_contours
|
13
|
+
from skimage.segmentation import clear_border
|
14
|
+
|
15
|
+
from collections import defaultdict, OrderedDict
|
8
16
|
from PIL import Image
|
9
17
|
import pandas as pd
|
10
18
|
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
@@ -13,37 +21,257 @@ import statsmodels.formula.api as smf
|
|
13
21
|
import statsmodels.api as sm
|
14
22
|
from statsmodels.stats.multitest import multipletests
|
15
23
|
from itertools import combinations
|
16
|
-
from collections import OrderedDict
|
17
24
|
from functools import reduce
|
18
|
-
from IPython.display import display
|
25
|
+
from IPython.display import display
|
26
|
+
|
19
27
|
from multiprocessing import Pool, cpu_count
|
20
|
-
from
|
28
|
+
from concurrent.futures import ThreadPoolExecutor
|
29
|
+
|
21
30
|
import torch.nn as nn
|
22
31
|
import torch.nn.functional as F
|
23
|
-
#from torchsummary import summary
|
24
32
|
from torch.utils.checkpoint import checkpoint
|
25
33
|
from torch.utils.data import Subset
|
26
34
|
from torch.autograd import grad
|
27
|
-
|
28
|
-
from skimage.segmentation import clear_border
|
35
|
+
|
29
36
|
import seaborn as sns
|
30
37
|
import matplotlib.pyplot as plt
|
38
|
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
39
|
+
|
31
40
|
import scipy.ndimage as ndi
|
41
|
+
from scipy.spatial import distance
|
32
42
|
from scipy.stats import fisher_exact
|
33
|
-
from scipy.ndimage import
|
43
|
+
from scipy.ndimage.filters import gaussian_filter
|
44
|
+
from scipy.spatial import ConvexHull
|
45
|
+
from scipy.interpolate import splprep, splev
|
46
|
+
|
47
|
+
from sklearn.preprocessing import StandardScaler
|
34
48
|
from skimage.exposure import rescale_intensity
|
35
49
|
from sklearn.metrics import auc, precision_recall_curve
|
36
50
|
from sklearn.model_selection import train_test_split
|
37
51
|
from sklearn.linear_model import Lasso, Ridge
|
38
52
|
from sklearn.preprocessing import OneHotEncoder
|
53
|
+
from sklearn.cluster import KMeans
|
54
|
+
from sklearn.preprocessing import StandardScaler
|
55
|
+
from sklearn.cluster import DBSCAN
|
56
|
+
from sklearn.cluster import KMeans
|
57
|
+
from sklearn.manifold import TSNE
|
58
|
+
|
59
|
+
import umap.umap_ as umap
|
60
|
+
|
61
|
+
from torchvision import models
|
39
62
|
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
|
63
|
+
import torchvision.transforms as transforms
|
40
64
|
|
41
65
|
from .logger import log_function_call
|
42
66
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
67
|
+
def check_mask_folder(src,mask_fldr):
|
68
|
+
|
69
|
+
mask_folder = os.path.join(src,'norm_channel_stack',mask_fldr)
|
70
|
+
stack_folder = os.path.join(src,'stack')
|
71
|
+
|
72
|
+
if not os.path.exists(mask_folder):
|
73
|
+
return True
|
74
|
+
|
75
|
+
mask_count = sum(1 for file in os.listdir(mask_folder) if file.endswith('.npy'))
|
76
|
+
stack_count = sum(1 for file in os.listdir(stack_folder) if file.endswith('.npy'))
|
77
|
+
|
78
|
+
if mask_count == stack_count:
|
79
|
+
print(f'All masks have been generated for {mask_fldr}')
|
80
|
+
return False
|
81
|
+
else:
|
82
|
+
return True
|
83
|
+
|
84
|
+
def set_default_plot_merge_settings():
|
85
|
+
settings = {}
|
86
|
+
settings.setdefault('include_noninfected', True)
|
87
|
+
settings.setdefault('include_multiinfected', True)
|
88
|
+
settings.setdefault('include_multinucleated', True)
|
89
|
+
settings.setdefault('remove_background', False)
|
90
|
+
settings.setdefault('filter_min_max', None)
|
91
|
+
settings.setdefault('channel_dims', [0,1,2,3])
|
92
|
+
settings.setdefault('backgrounds', [100,100,100,100])
|
93
|
+
settings.setdefault('cell_mask_dim', 4)
|
94
|
+
settings.setdefault('nucleus_mask_dim', 5)
|
95
|
+
settings.setdefault('pathogen_mask_dim', 6)
|
96
|
+
settings.setdefault('outline_thickness', 3)
|
97
|
+
settings.setdefault('outline_color', 'gbr')
|
98
|
+
settings.setdefault('overlay_chans', [1,2,3])
|
99
|
+
settings.setdefault('overlay', True)
|
100
|
+
settings.setdefault('normalization_percentiles', [2,98])
|
101
|
+
settings.setdefault('normalize', True)
|
102
|
+
settings.setdefault('print_object_number', True)
|
103
|
+
settings.setdefault('nr', 1)
|
104
|
+
settings.setdefault('figuresize', 50)
|
105
|
+
settings.setdefault('cmap', 'inferno')
|
106
|
+
settings.setdefault('verbose', True)
|
107
|
+
|
108
|
+
return settings
|
109
|
+
|
110
|
+
def set_default_settings_preprocess_generate_masks(src, settings={}):
|
111
|
+
# Main settings
|
112
|
+
settings['src'] = src
|
113
|
+
settings.setdefault('preprocess', True)
|
114
|
+
settings.setdefault('masks', True)
|
115
|
+
settings.setdefault('save', True)
|
116
|
+
settings.setdefault('batch_size', 50)
|
117
|
+
settings.setdefault('test_mode', False)
|
118
|
+
settings.setdefault('test_images', 10)
|
119
|
+
settings.setdefault('magnification', 20)
|
120
|
+
settings.setdefault('custom_regex', None)
|
121
|
+
settings.setdefault('metadata_type', 'cellvoyager')
|
122
|
+
settings.setdefault('workers', os.cpu_count()-4)
|
123
|
+
settings.setdefault('randomize', True)
|
124
|
+
settings.setdefault('verbose', True)
|
125
|
+
|
126
|
+
settings.setdefault('remove_background_cell', False)
|
127
|
+
settings.setdefault('remove_background_nucleus', False)
|
128
|
+
settings.setdefault('remove_background_pathogen', False)
|
129
|
+
|
130
|
+
# Channel settings
|
131
|
+
settings.setdefault('cell_channel', None)
|
132
|
+
settings.setdefault('nucleus_channel', None)
|
133
|
+
settings.setdefault('pathogen_channel', None)
|
134
|
+
settings.setdefault('channels', [0,1,2,3])
|
135
|
+
settings.setdefault('pathogen_background', 100)
|
136
|
+
settings.setdefault('pathogen_Signal_to_noise', 10)
|
137
|
+
settings.setdefault('pathogen_CP_prob', 0)
|
138
|
+
settings.setdefault('cell_background', 100)
|
139
|
+
settings.setdefault('cell_Signal_to_noise', 10)
|
140
|
+
settings.setdefault('cell_CP_prob', 0)
|
141
|
+
settings.setdefault('nucleus_background', 100)
|
142
|
+
settings.setdefault('nucleus_Signal_to_noise', 10)
|
143
|
+
settings.setdefault('nucleus_CP_prob', 0)
|
144
|
+
|
145
|
+
settings.setdefault('nucleus_FT', 100)
|
146
|
+
settings.setdefault('cell_FT', 100)
|
147
|
+
settings.setdefault('pathogen_FT', 100)
|
148
|
+
|
149
|
+
# Plot settings
|
150
|
+
settings.setdefault('plot', False)
|
151
|
+
settings.setdefault('figuresize', 50)
|
152
|
+
settings.setdefault('cmap', 'inferno')
|
153
|
+
settings.setdefault('normalize', True)
|
154
|
+
settings.setdefault('normalize_plots', True)
|
155
|
+
settings.setdefault('examples_to_plot', 1)
|
156
|
+
|
157
|
+
# Analasys settings
|
158
|
+
settings.setdefault('pathogen_model', None)
|
159
|
+
settings.setdefault('merge_pathogens', False)
|
160
|
+
settings.setdefault('filter', False)
|
161
|
+
settings.setdefault('lower_percentile', 2)
|
162
|
+
|
163
|
+
# Timelapse settings
|
164
|
+
settings.setdefault('timelapse', False)
|
165
|
+
settings.setdefault('fps', 2)
|
166
|
+
settings.setdefault('timelapse_displacement', None)
|
167
|
+
settings.setdefault('timelapse_memory', 3)
|
168
|
+
settings.setdefault('timelapse_frame_limits', None)
|
169
|
+
settings.setdefault('timelapse_remove_transient', False)
|
170
|
+
settings.setdefault('timelapse_mode', 'trackpy')
|
171
|
+
settings.setdefault('timelapse_objects', 'cells')
|
172
|
+
|
173
|
+
# Misc settings
|
174
|
+
settings.setdefault('all_to_mip', False)
|
175
|
+
settings.setdefault('pick_slice', False)
|
176
|
+
settings.setdefault('skip_mode', '01')
|
177
|
+
settings.setdefault('upscale', False)
|
178
|
+
settings.setdefault('upscale_factor', 2.0)
|
179
|
+
settings.setdefault('adjust_cells', False)
|
180
|
+
|
181
|
+
return settings
|
182
|
+
|
183
|
+
def set_default_settings_preprocess_img_data(settings):
|
184
|
+
|
185
|
+
metadata_type = settings.setdefault('metadata_type', 'cellvoyager')
|
186
|
+
custom_regex = settings.setdefault('custom_regex', None)
|
187
|
+
nr = settings.setdefault('nr', 1)
|
188
|
+
plot = settings.setdefault('plot', True)
|
189
|
+
batch_size = settings.setdefault('batch_size', 50)
|
190
|
+
timelapse = settings.setdefault('timelapse', False)
|
191
|
+
lower_percentile = settings.setdefault('lower_percentile', 2)
|
192
|
+
randomize = settings.setdefault('randomize', True)
|
193
|
+
all_to_mip = settings.setdefault('all_to_mip', False)
|
194
|
+
pick_slice = settings.setdefault('pick_slice', False)
|
195
|
+
skip_mode = settings.setdefault('skip_mode', False)
|
196
|
+
|
197
|
+
cmap = settings.setdefault('cmap', 'inferno')
|
198
|
+
figuresize = settings.setdefault('figuresize', 50)
|
199
|
+
normalize = settings.setdefault('normalize', True)
|
200
|
+
save_dtype = settings.setdefault('save_dtype', 'uint16')
|
201
|
+
|
202
|
+
test_mode = settings.setdefault('test_mode', False)
|
203
|
+
test_images = settings.setdefault('test_images', 10)
|
204
|
+
random_test = settings.setdefault('random_test', True)
|
205
|
+
|
206
|
+
return settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test
|
207
|
+
|
208
|
+
def smooth_hull_lines(cluster_data):
|
209
|
+
hull = ConvexHull(cluster_data)
|
210
|
+
|
211
|
+
# Extract vertices of the hull
|
212
|
+
vertices = hull.points[hull.vertices]
|
213
|
+
|
214
|
+
# Close the loop
|
215
|
+
vertices = np.vstack([vertices, vertices[0, :]])
|
216
|
+
|
217
|
+
# Parameterize the vertices
|
218
|
+
tck, u = splprep(vertices.T, u=None, s=0.0)
|
219
|
+
|
220
|
+
# Evaluate spline at new parameter values
|
221
|
+
new_points = splev(np.linspace(0, 1, 100), tck)
|
222
|
+
|
223
|
+
return new_points[0], new_points[1]
|
224
|
+
|
225
|
+
def _gen_rgb_image(image, channels):
|
226
|
+
"""
|
227
|
+
Generate an RGB image from the specified channels of the input image.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
image (ndarray): The input image.
|
231
|
+
channels (list): List of channel indices to use for RGB.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
rgb_image (ndarray): The generated RGB image.
|
235
|
+
"""
|
236
|
+
rgb_image = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.float32)
|
237
|
+
for i, chan in enumerate(channels):
|
238
|
+
if chan < image.shape[2]:
|
239
|
+
rgb_image[:, :, i] = image[:, :, chan]
|
240
|
+
return rgb_image
|
241
|
+
|
242
|
+
def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
|
243
|
+
outlines = []
|
244
|
+
overlayed_image = rgb_image.copy()
|
245
|
+
|
246
|
+
def process_dim(mask_dim):
|
247
|
+
mask = np.take(image, mask_dim, axis=-1)
|
248
|
+
outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
|
249
|
+
|
250
|
+
# Find and draw contours
|
251
|
+
for j in np.unique(mask):
|
252
|
+
if j == 0:
|
253
|
+
continue # Skip background
|
254
|
+
contours = find_contours(mask == j, 0.5)
|
255
|
+
# Convert contours for OpenCV format and draw directly to optimize
|
256
|
+
cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
|
257
|
+
cv2.drawContours(outline, cv_contours, -1, color=255, thickness=outline_thickness)
|
258
|
+
|
259
|
+
return dilation(outline, square(outline_thickness))
|
260
|
+
|
261
|
+
# Parallel processing
|
262
|
+
with ThreadPoolExecutor() as executor:
|
263
|
+
outlines = list(executor.map(process_dim, mask_dims))
|
264
|
+
|
265
|
+
# Overlay outlines onto the RGB image
|
266
|
+
for i, outline in enumerate(outlines):
|
267
|
+
color = np.array(outline_colors[i % len(outline_colors)])
|
268
|
+
for j in np.unique(outline):
|
269
|
+
if j == 0:
|
270
|
+
continue # Skip background
|
271
|
+
mask = outline == j
|
272
|
+
overlayed_image[mask] = color # Direct assignment with broadcasting
|
273
|
+
|
274
|
+
return overlayed_image, outlines, image
|
47
275
|
|
48
276
|
def _convert_cq1_well_id(well_id):
|
49
277
|
"""
|
@@ -114,8 +342,8 @@ def _extract_filename_metadata(filenames, src, images_by_key, regular_expression
|
|
114
342
|
if metadata_type =='cq1':
|
115
343
|
orig_wellID = wellID
|
116
344
|
wellID = _convert_cq1_well_id(wellID)
|
117
|
-
clear_output(wait=True)
|
118
|
-
print(f'
|
345
|
+
#clear_output(wait=True)
|
346
|
+
print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
|
119
347
|
|
120
348
|
if pick_slice:
|
121
349
|
try:
|
@@ -302,43 +530,82 @@ def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pa
|
|
302
530
|
df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
|
303
531
|
df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
|
304
532
|
return df
|
305
|
-
|
306
|
-
def normalize_to_dtype(array,
|
533
|
+
|
534
|
+
def normalize_to_dtype(array, p1=2, p2=98):
|
307
535
|
"""
|
308
|
-
Normalize the
|
536
|
+
Normalize each image in the stack to its own percentiles.
|
309
537
|
|
310
538
|
Parameters:
|
311
539
|
- array: numpy array
|
312
|
-
The input
|
313
|
-
-
|
540
|
+
The input stack to be normalized.
|
541
|
+
- p1: int, optional
|
314
542
|
The lower percentile value for normalization. Default is 2.
|
315
|
-
-
|
543
|
+
- p2: int, optional
|
316
544
|
The upper percentile value for normalization. Default is 98.
|
317
|
-
- percentiles: list of tuples, optional
|
318
|
-
A list of tuples containing the percentile values for each image in the array.
|
319
|
-
If provided, the percentiles for each image will be used instead of q1 and q2.
|
320
545
|
|
321
546
|
Returns:
|
322
547
|
- new_stack: numpy array
|
323
|
-
The normalized
|
548
|
+
The normalized stack with the same shape as the input stack.
|
324
549
|
"""
|
325
550
|
nimg = array.shape[2]
|
326
551
|
new_stack = np.empty_like(array)
|
327
|
-
|
328
|
-
|
552
|
+
|
553
|
+
for i in range(nimg):
|
554
|
+
img = array[:, :, i]
|
329
555
|
non_zero_img = img[img > 0]
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
else:
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
556
|
+
|
557
|
+
if non_zero_img.size > 0:
|
558
|
+
img_min = np.percentile(non_zero_img, p1)
|
559
|
+
img_max = np.percentile(non_zero_img, p2)
|
560
|
+
else:
|
561
|
+
img_min = img.min()
|
562
|
+
img_max = img.max()
|
563
|
+
|
564
|
+
# Determine output range based on dtype
|
565
|
+
if np.issubdtype(array.dtype, np.integer):
|
566
|
+
out_range = (0, np.iinfo(array.dtype).max)
|
567
|
+
else:
|
568
|
+
out_range = (0.0, 1.0)
|
569
|
+
|
570
|
+
img = rescale_intensity(img, in_range=(img_min, img_max), out_range=out_range).astype(array.dtype)
|
571
|
+
new_stack[:, :, i] = img
|
572
|
+
|
573
|
+
return new_stack
|
574
|
+
|
575
|
+
def normalize_to_dtype(array, p1=2, p2=98):
|
576
|
+
"""
|
577
|
+
Normalize each image in the stack to its own percentiles.
|
578
|
+
|
579
|
+
Parameters:
|
580
|
+
- array: numpy array
|
581
|
+
The input stack to be normalized.
|
582
|
+
- p1: int, optional
|
583
|
+
The lower percentile value for normalization. Default is 2.
|
584
|
+
- p2: int, optional
|
585
|
+
The upper percentile value for normalization. Default is 98.
|
586
|
+
|
587
|
+
Returns:
|
588
|
+
- new_stack: numpy array
|
589
|
+
The normalized stack with the same shape as the input stack.
|
590
|
+
"""
|
591
|
+
nimg = array.shape[2]
|
592
|
+
new_stack = np.empty_like(array, dtype=np.float32)
|
593
|
+
|
594
|
+
for i in range(nimg):
|
595
|
+
img = array[:, :, i]
|
596
|
+
non_zero_img = img[img > 0]
|
597
|
+
|
598
|
+
if non_zero_img.size > 0:
|
599
|
+
img_min = np.percentile(non_zero_img, p1)
|
600
|
+
img_max = np.percentile(non_zero_img, p2)
|
601
|
+
else:
|
602
|
+
img_min = img.min()
|
603
|
+
img_max = img.max()
|
604
|
+
|
605
|
+
# Normalize to the range (0, 1) for visualization
|
606
|
+
img = rescale_intensity(img, in_range=(img_min, img_max), out_range=(0.0, 1.0))
|
607
|
+
new_stack[:, :, i] = img
|
608
|
+
|
342
609
|
return new_stack
|
343
610
|
|
344
611
|
def _list_endpoint_subdirectories(base_dir):
|
@@ -673,9 +940,6 @@ def _crop_center(img, cell_mask, new_width, new_height, normalize=(2,98)):
|
|
673
940
|
img = img[start_y:end_y, start_x:end_x, :]
|
674
941
|
return img
|
675
942
|
|
676
|
-
|
677
|
-
|
678
|
-
|
679
943
|
def _masks_to_masks_stack(masks):
|
680
944
|
"""
|
681
945
|
Convert a list of masks into a stack of masks.
|
@@ -692,53 +956,50 @@ def _masks_to_masks_stack(masks):
|
|
692
956
|
return mask_stack
|
693
957
|
|
694
958
|
def _get_diam(mag, obj):
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
if
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
elif
|
717
|
-
if
|
718
|
-
|
719
|
-
if
|
720
|
-
|
721
|
-
if
|
722
|
-
|
959
|
+
|
960
|
+
if mag == 20:
|
961
|
+
if obj == 'cell':
|
962
|
+
diamiter = 120
|
963
|
+
elif obj == 'nucleus':
|
964
|
+
diamiter = 60
|
965
|
+
elif obj == 'pathogen':
|
966
|
+
diamiter = 20
|
967
|
+
else:
|
968
|
+
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
969
|
+
|
970
|
+
elif mag == 40:
|
971
|
+
if obj == 'cell':
|
972
|
+
diamiter = 160
|
973
|
+
elif obj == 'nucleus':
|
974
|
+
diamiter = 80
|
975
|
+
elif obj == 'pathogen':
|
976
|
+
diamiter = 40
|
977
|
+
else:
|
978
|
+
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
979
|
+
|
980
|
+
elif mag == 60:
|
981
|
+
if obj == 'cell':
|
982
|
+
diamiter = 200
|
983
|
+
if obj == 'nucleus':
|
984
|
+
diamiter = 90
|
985
|
+
if obj == 'pathogen':
|
986
|
+
diamiter = 60
|
987
|
+
else:
|
988
|
+
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
723
989
|
else:
|
724
|
-
raise ValueError("Invalid
|
725
|
-
|
990
|
+
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
991
|
+
|
726
992
|
return diamiter
|
727
993
|
|
728
994
|
def _get_object_settings(object_type, settings):
|
729
|
-
|
730
995
|
object_settings = {}
|
731
|
-
|
732
|
-
object_settings['filter_size'] = False
|
733
|
-
object_settings['filter_dimm'] = False
|
734
|
-
print(object_type)
|
996
|
+
|
735
997
|
object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
|
736
|
-
object_settings['
|
737
|
-
object_settings['
|
738
|
-
object_settings['maximum_size'] = object_settings['minimum_size']*50
|
998
|
+
object_settings['minimum_size'] = (object_settings['diameter']**2)/4
|
999
|
+
object_settings['maximum_size'] = (object_settings['diameter']**2)*10
|
739
1000
|
object_settings['merge'] = False
|
740
|
-
object_settings['net_avg'] = True
|
741
1001
|
object_settings['resample'] = True
|
1002
|
+
object_settings['remove_border_objects'] = False
|
742
1003
|
object_settings['model_name'] = 'cyto'
|
743
1004
|
|
744
1005
|
if object_type == 'cell':
|
@@ -746,20 +1007,29 @@ def _get_object_settings(object_type, settings):
|
|
746
1007
|
object_settings['model_name'] = 'cyto'
|
747
1008
|
else:
|
748
1009
|
object_settings['model_name'] = 'cyto2'
|
749
|
-
|
1010
|
+
object_settings['filter_size'] = False
|
1011
|
+
object_settings['filter_intensity'] = False
|
1012
|
+
object_settings['restore_type'] = settings.get('cell_restore_type', None)
|
1013
|
+
|
750
1014
|
elif object_type == 'nucleus':
|
751
1015
|
object_settings['model_name'] = 'nuclei'
|
1016
|
+
object_settings['filter_size'] = False
|
1017
|
+
object_settings['filter_intensity'] = False
|
1018
|
+
object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
|
752
1019
|
|
753
1020
|
elif object_type == 'pathogen':
|
754
|
-
object_settings['model_name'] = 'cyto3'
|
755
|
-
|
756
|
-
elif object_type == 'pathogen_nucleus':
|
757
|
-
object_settings['filter_size'] = True
|
758
1021
|
object_settings['model_name'] = 'cyto'
|
1022
|
+
object_settings['filter_size'] = False
|
1023
|
+
object_settings['filter_intensity'] = False
|
1024
|
+
object_settings['resample'] = False
|
1025
|
+
object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
|
1026
|
+
object_settings['merge'] = settings['merge_pathogens']
|
759
1027
|
|
760
1028
|
else:
|
761
1029
|
print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
|
762
|
-
|
1030
|
+
|
1031
|
+
if settings['verbose']:
|
1032
|
+
print(object_settings)
|
763
1033
|
|
764
1034
|
return object_settings
|
765
1035
|
|
@@ -786,6 +1056,7 @@ def _pivot_counts_table(db_path):
|
|
786
1056
|
return df
|
787
1057
|
|
788
1058
|
def _pivot_dataframe(df):
|
1059
|
+
|
789
1060
|
"""
|
790
1061
|
Pivot the DataFrame.
|
791
1062
|
|
@@ -812,61 +1083,32 @@ def _pivot_counts_table(db_path):
|
|
812
1083
|
pivoted_df.to_sql('pivoted_counts', conn, if_exists='replace', index=False)
|
813
1084
|
conn.close()
|
814
1085
|
|
815
|
-
def
|
816
|
-
cellpose_channels = {}
|
817
|
-
if nucleus_chann_dim in mask_channels:
|
818
|
-
cellpose_channels['nucleus'] = [0, mask_channels.index(nucleus_chann_dim)]
|
819
|
-
if pathogen_chann_dim in mask_channels:
|
820
|
-
cellpose_channels['pathogen'] = [0, mask_channels.index(pathogen_chann_dim)]
|
821
|
-
if cell_chann_dim in mask_channels:
|
822
|
-
cellpose_channels['cell'] = [0, mask_channels.index(cell_chann_dim)]
|
823
|
-
return cellpose_channels
|
1086
|
+
def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel):
|
824
1087
|
|
825
|
-
|
826
|
-
|
827
|
-
|
1088
|
+
cell_mask_path = os.path.join(src, 'norm_channel_stack', 'cell_mask_stack')
|
1089
|
+
nucleus_mask_path = os.path.join(src, 'norm_channel_stack', 'nucleus_mask_stack')
|
1090
|
+
pathogen_mask_path = os.path.join(src, 'norm_channel_stack', 'pathogen_mask_stack')
|
828
1091
|
|
829
|
-
# Initialize a list to keep track of the channels in their new order
|
830
|
-
new_channel_order = []
|
831
|
-
|
832
|
-
# Add each channel to the new order list if it is not None
|
833
|
-
if cell_channel is not None:
|
834
|
-
new_channel_order.append(('cell', cell_channel))
|
835
|
-
if nucleus_channel is not None:
|
836
|
-
new_channel_order.append(('nucleus', nucleus_channel))
|
837
|
-
if pathogen_channel is not None:
|
838
|
-
new_channel_order.append(('pathogen', pathogen_channel))
|
839
|
-
|
840
|
-
# Sort the list based on the original channel indices to maintain the original order
|
841
|
-
new_channel_order.sort(key=lambda x: x[1])
|
842
|
-
print(new_channel_order)
|
843
|
-
# Assign new indices based on the sorted order
|
844
|
-
for new_index, (channel_name, _) in enumerate(new_channel_order):
|
845
|
-
cellpose_channels[channel_name] = [new_index, 0]
|
846
|
-
|
847
|
-
if cell_channel is not None and nucleus_channel is not None:
|
848
|
-
cellpose_channels['cell'][1] = cellpose_channels['nucleus'][0]
|
849
|
-
|
850
|
-
return cellpose_channels
|
851
1092
|
|
852
|
-
|
1093
|
+
if os.path.exists(cell_mask_path) or os.path.exists(nucleus_mask_path) or os.path.exists(pathogen_mask_path):
|
1094
|
+
if nucleus_channel is None or nucleus_channel is None or nucleus_channel is None:
|
1095
|
+
print('Warning: Cellpose masks already exist. Unexpected behaviour when setting any object dimention to None when the object masks have been created.')
|
1096
|
+
|
853
1097
|
cellpose_channels = {}
|
854
1098
|
if not nucleus_channel is None:
|
855
1099
|
cellpose_channels['nucleus'] = [0,0]
|
856
1100
|
|
857
1101
|
if not pathogen_channel is None:
|
858
1102
|
if not nucleus_channel is None:
|
859
|
-
|
1103
|
+
if not pathogen_channel is None:
|
1104
|
+
cellpose_channels['pathogen'] = [0,2]
|
1105
|
+
else:
|
1106
|
+
cellpose_channels['pathogen'] = [0,1]
|
860
1107
|
else:
|
861
1108
|
cellpose_channels['pathogen'] = [0,0]
|
862
1109
|
|
863
1110
|
if not cell_channel is None:
|
864
1111
|
if not nucleus_channel is None:
|
865
|
-
if not pathogen_channel is None:
|
866
|
-
cellpose_channels['cell'] = [0,2]
|
867
|
-
else:
|
868
|
-
cellpose_channels['cell'] = [0,1]
|
869
|
-
elif not pathogen_channel is None:
|
870
1112
|
cellpose_channels['cell'] = [0,1]
|
871
1113
|
else:
|
872
1114
|
cellpose_channels['cell'] = [0,0]
|
@@ -1027,9 +1269,6 @@ def _group_by_well(df):
|
|
1027
1269
|
# Apply mean function to numeric columns and first to non-numeric
|
1028
1270
|
df_grouped = df.groupby(['plate', 'row', 'col']).agg({**{col: np.mean for col in numeric_cols}, **{col: 'first' for col in non_numeric_cols}})
|
1029
1271
|
return df_grouped
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
1272
|
|
1034
1273
|
###################################################
|
1035
1274
|
# Classify
|
@@ -1044,7 +1283,7 @@ class Cache:
|
|
1044
1283
|
cache (OrderedDict): The cache data structure.
|
1045
1284
|
"""
|
1046
1285
|
|
1047
|
-
def
|
1286
|
+
def __init__(self, max_size):
|
1048
1287
|
self.cache = OrderedDict()
|
1049
1288
|
self.max_size = max_size
|
1050
1289
|
|
@@ -1075,7 +1314,7 @@ class ScaledDotProductAttention(nn.Module):
|
|
1075
1314
|
|
1076
1315
|
"""
|
1077
1316
|
|
1078
|
-
def
|
1317
|
+
def __init__(self, d_k):
|
1079
1318
|
super(ScaledDotProductAttention, self).__init__()
|
1080
1319
|
self.d_k = d_k
|
1081
1320
|
|
@@ -1106,7 +1345,7 @@ class SelfAttention(nn.Module):
|
|
1106
1345
|
d_k (int): Dimensionality of the key and query vectors.
|
1107
1346
|
"""
|
1108
1347
|
|
1109
|
-
def
|
1348
|
+
def __init__(self, in_channels, d_k):
|
1110
1349
|
super(SelfAttention, self).__init__()
|
1111
1350
|
self.W_q = nn.Linear(in_channels, d_k)
|
1112
1351
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1130,7 +1369,7 @@ class SelfAttention(nn.Module):
|
|
1130
1369
|
return output
|
1131
1370
|
|
1132
1371
|
class ScaledDotProductAttention(nn.Module):
|
1133
|
-
def
|
1372
|
+
def __init__(self, d_k):
|
1134
1373
|
"""
|
1135
1374
|
Initializes the ScaledDotProductAttention module.
|
1136
1375
|
|
@@ -1167,7 +1406,7 @@ class SelfAttention(nn.Module):
|
|
1167
1406
|
in_channels (int): Number of input channels.
|
1168
1407
|
d_k (int): Dimensionality of the key and query vectors.
|
1169
1408
|
"""
|
1170
|
-
def
|
1409
|
+
def __init__(self, in_channels, d_k):
|
1171
1410
|
super(SelfAttention, self).__init__()
|
1172
1411
|
self.W_q = nn.Linear(in_channels, d_k)
|
1173
1412
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1198,7 +1437,7 @@ class EarlyFusion(nn.Module):
|
|
1198
1437
|
Args:
|
1199
1438
|
in_channels (int): Number of input channels.
|
1200
1439
|
"""
|
1201
|
-
def
|
1440
|
+
def __init__(self, in_channels):
|
1202
1441
|
super(EarlyFusion, self).__init__()
|
1203
1442
|
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
|
1204
1443
|
|
@@ -1217,7 +1456,7 @@ class EarlyFusion(nn.Module):
|
|
1217
1456
|
|
1218
1457
|
# Spatial Attention Mechanism
|
1219
1458
|
class SpatialAttention(nn.Module):
|
1220
|
-
def
|
1459
|
+
def __init__(self, kernel_size=7):
|
1221
1460
|
"""
|
1222
1461
|
Initializes the SpatialAttention module.
|
1223
1462
|
|
@@ -1262,7 +1501,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1262
1501
|
forward: Forward method for the module.
|
1263
1502
|
"""
|
1264
1503
|
|
1265
|
-
def
|
1504
|
+
def __init__(self, in_channels, out_channels):
|
1266
1505
|
super(MultiScaleBlockWithAttention, self).__init__()
|
1267
1506
|
self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
|
1268
1507
|
self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
|
@@ -1295,7 +1534,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1295
1534
|
|
1296
1535
|
# Final Classifier
|
1297
1536
|
class CustomCellClassifier(nn.Module):
|
1298
|
-
def
|
1537
|
+
def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
|
1299
1538
|
super(CustomCellClassifier, self).__init__()
|
1300
1539
|
self.early_fusion = EarlyFusion(in_channels=3)
|
1301
1540
|
|
@@ -1324,7 +1563,7 @@ class CustomCellClassifier(nn.Module):
|
|
1324
1563
|
|
1325
1564
|
#CNN and Transformer class, pick any Torch model.
|
1326
1565
|
class TorchModel(nn.Module):
|
1327
|
-
def
|
1566
|
+
def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
|
1328
1567
|
super(TorchModel, self).__init__()
|
1329
1568
|
self.model_name = model_name
|
1330
1569
|
self.use_checkpoint = use_checkpoint
|
@@ -1398,7 +1637,7 @@ class TorchModel(nn.Module):
|
|
1398
1637
|
return logits
|
1399
1638
|
|
1400
1639
|
class FocalLossWithLogits(nn.Module):
|
1401
|
-
def
|
1640
|
+
def __init__(self, alpha=1, gamma=2):
|
1402
1641
|
super(FocalLossWithLogits, self).__init__()
|
1403
1642
|
self.alpha = alpha
|
1404
1643
|
self.gamma = gamma
|
@@ -1410,7 +1649,7 @@ class FocalLossWithLogits(nn.Module):
|
|
1410
1649
|
return focal_loss.mean()
|
1411
1650
|
|
1412
1651
|
class ResNet(nn.Module):
|
1413
|
-
def
|
1652
|
+
def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
|
1414
1653
|
super(ResNet, self).__init__()
|
1415
1654
|
|
1416
1655
|
resnet_map = {
|
@@ -1763,25 +2002,24 @@ def annotate_predictions(csv_loc):
|
|
1763
2002
|
df['cond'] = df.apply(assign_condition, axis=1)
|
1764
2003
|
return df
|
1765
2004
|
|
1766
|
-
def
|
2005
|
+
def initiate_counter(counter_, lock_):
|
1767
2006
|
global counter, lock
|
1768
2007
|
counter = counter_
|
1769
2008
|
lock = lock_
|
1770
2009
|
|
1771
|
-
def add_images_to_tar(
|
1772
|
-
global counter, lock, total_images
|
1773
|
-
paths_chunk, tar_path = args
|
2010
|
+
def add_images_to_tar(paths_chunk, tar_path, total_images):
|
1774
2011
|
with tarfile.open(tar_path, 'w') as tar:
|
1775
|
-
for img_path in paths_chunk:
|
2012
|
+
for i, img_path in enumerate(paths_chunk):
|
1776
2013
|
arcname = os.path.basename(img_path)
|
1777
2014
|
try:
|
1778
2015
|
tar.add(img_path, arcname=arcname)
|
1779
2016
|
with lock:
|
1780
2017
|
counter.value += 1
|
1781
|
-
|
2018
|
+
if counter.value % 100 == 0: # Print every 100 updates
|
2019
|
+
progress = (counter.value / total_images) * 100
|
2020
|
+
print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
|
1782
2021
|
except FileNotFoundError:
|
1783
2022
|
print(f"File not found: {img_path}")
|
1784
|
-
return tar_path
|
1785
2023
|
|
1786
2024
|
def generate_fraction_map(df, gene_column, min_frequency=0.0):
|
1787
2025
|
df['fraction'] = df['count']/df['well_read_sum']
|
@@ -2230,8 +2468,8 @@ def dice_coefficient(mask1, mask2):
|
|
2230
2468
|
def extract_boundaries(mask, dilation_radius=1):
|
2231
2469
|
binary_mask = (mask > 0).astype(np.uint8)
|
2232
2470
|
struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
|
2233
|
-
dilated = binary_dilation(binary_mask, footprint=struct_elem)
|
2234
|
-
eroded = binary_erosion(binary_mask, footprint=struct_elem)
|
2471
|
+
dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
|
2472
|
+
eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
|
2235
2473
|
boundary = dilated ^ eroded
|
2236
2474
|
return boundary
|
2237
2475
|
|
@@ -2612,24 +2850,21 @@ def _filter_object(mask, min_value):
|
|
2612
2850
|
mask[np.isin(mask, to_remove)] = 0
|
2613
2851
|
return mask
|
2614
2852
|
|
2615
|
-
def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remove_border_objects, merge,
|
2853
|
+
def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize):
|
2854
|
+
|
2616
2855
|
"""
|
2617
2856
|
Filter the masks based on various criteria such as size, border objects, merging, and intensity.
|
2618
2857
|
|
2619
2858
|
Args:
|
2620
2859
|
masks (list): List of masks.
|
2621
2860
|
flows (list): List of flows.
|
2622
|
-
refine_masks (bool): Flag indicating whether to refine masks.
|
2623
2861
|
filter_size (bool): Flag indicating whether to filter based on size.
|
2862
|
+
filter_intensity (bool): Flag indicating whether to filter based on intensity.
|
2624
2863
|
minimum_size (int): Minimum size of objects to keep.
|
2625
2864
|
maximum_size (int): Maximum size of objects to keep.
|
2626
2865
|
remove_border_objects (bool): Flag indicating whether to remove border objects.
|
2627
2866
|
merge (bool): Flag indicating whether to merge adjacent objects.
|
2628
|
-
filter_dimm (bool): Flag indicating whether to filter based on intensity.
|
2629
2867
|
batch (ndarray): Batch of images.
|
2630
|
-
moving_avg_q1 (float): Moving average of the first quartile of object intensities.
|
2631
|
-
moving_avg_q3 (float): Moving average of the third quartile of object intensities.
|
2632
|
-
moving_count (int): Count of moving averages.
|
2633
2868
|
plot (bool): Flag indicating whether to plot the masks.
|
2634
2869
|
figuresize (tuple): Size of the figure.
|
2635
2870
|
|
@@ -2641,51 +2876,66 @@ def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remo
|
|
2641
2876
|
|
2642
2877
|
mask_stack = []
|
2643
2878
|
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2879
|
+
|
2644
2880
|
if plot and idx == 0:
|
2645
2881
|
num_objects = mask_object_count(mask)
|
2646
2882
|
print(f'Number of objects before filtration: {num_objects}')
|
2647
2883
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2648
2884
|
|
2649
|
-
if
|
2650
|
-
|
2651
|
-
valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)] # Select labels of valid size.
|
2652
|
-
masks[idx] = np.isin(mask, valid_labels) * mask # Keep only valid objects.
|
2885
|
+
if merge:
|
2886
|
+
mask = merge_touching_objects(mask, threshold=0.66)
|
2653
2887
|
if plot and idx == 0:
|
2654
2888
|
num_objects = mask_object_count(mask)
|
2655
|
-
print(f'Number of objects after
|
2889
|
+
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2656
2890
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2657
|
-
|
2658
|
-
|
2891
|
+
|
2892
|
+
if filter_size:
|
2893
|
+
props = measure.regionprops_table(mask, properties=['label', 'area'])
|
2894
|
+
valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
|
2895
|
+
mask = np.isin(mask, valid_labels) * mask
|
2659
2896
|
if plot and idx == 0:
|
2660
2897
|
num_objects = mask_object_count(mask)
|
2661
|
-
print(f'Number of objects after
|
2898
|
+
print(f'Number of objects after size filtration >{minimum_size} and <{maximum_size} : {num_objects}')
|
2662
2899
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2663
|
-
|
2664
|
-
|
2900
|
+
|
2901
|
+
if filter_intensity:
|
2902
|
+
intensity_image = image[:, :, 1]
|
2903
|
+
props = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'mean_intensity'])
|
2904
|
+
mean_intensities = np.array(props['mean_intensity']).reshape(-1, 1)
|
2905
|
+
|
2906
|
+
if mean_intensities.shape[0] >= 2:
|
2907
|
+
kmeans = KMeans(n_clusters=2, random_state=0).fit(mean_intensities)
|
2908
|
+
centroids = kmeans.cluster_centers_
|
2909
|
+
|
2910
|
+
# Calculate the Euclidean distance between the two centroids
|
2911
|
+
dist_between_centroids = distance.euclidean(centroids[0], centroids[1])
|
2912
|
+
|
2913
|
+
# Set a threshold for the minimum distance to consider clusters distinct
|
2914
|
+
distance_threshold = 0.25
|
2915
|
+
|
2916
|
+
if dist_between_centroids > distance_threshold:
|
2917
|
+
high_intensity_cluster = np.argmax(centroids)
|
2918
|
+
valid_labels = np.array(props['label'])[kmeans.labels_ == high_intensity_cluster]
|
2919
|
+
mask = np.isin(mask, valid_labels) * mask
|
2920
|
+
|
2665
2921
|
if plot and idx == 0:
|
2666
2922
|
num_objects = mask_object_count(mask)
|
2667
|
-
|
2923
|
+
props_after = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'mean_intensity'])
|
2924
|
+
mean_intensities_after = np.mean(np.array(props_after['mean_intensity']))
|
2925
|
+
average_intensity_before = np.mean(mean_intensities)
|
2926
|
+
print(f'Number of objects after potential intensity clustering: {num_objects}. Mean intensity before:{average_intensity_before:.4f}. After:{mean_intensities_after:.4f}.')
|
2668
2927
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2669
|
-
|
2670
|
-
|
2671
|
-
|
2672
|
-
|
2673
|
-
object_intensities = [np.mean(batch[idx, :, :, 1][mask == label]) for label in unique_labels if label != 0]
|
2674
|
-
object_q1s = [np.percentile(intensities, 25) for intensities in object_intensities if intensities.size > 0]
|
2675
|
-
object_q3s = [np.percentile(intensities, 75) for intensities in object_intensities if intensities.size > 0]
|
2676
|
-
if object_q1s:
|
2677
|
-
object_q1_mean = np.mean(object_q1s)
|
2678
|
-
object_q3_mean = np.mean(object_q3s)
|
2679
|
-
moving_avg_q1 = (moving_avg_q1 * moving_count + object_q1_mean) / (moving_count + 1)
|
2680
|
-
moving_avg_q3 = (moving_avg_q3 * moving_count + object_q3_mean) / (moving_count + 1)
|
2681
|
-
moving_count += 1
|
2682
|
-
mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q1, mode='low')
|
2683
|
-
mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q3, mode='high')
|
2928
|
+
|
2929
|
+
|
2930
|
+
if remove_border_objects:
|
2931
|
+
mask = clear_border(mask)
|
2684
2932
|
if plot and idx == 0:
|
2685
2933
|
num_objects = mask_object_count(mask)
|
2686
|
-
print(f'
|
2934
|
+
print(f'Number of objects after removing border objects, : {num_objects}')
|
2687
2935
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2936
|
+
|
2688
2937
|
mask_stack.append(mask)
|
2938
|
+
|
2689
2939
|
return mask_stack
|
2690
2940
|
|
2691
2941
|
def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mask_chan):
|
@@ -2721,6 +2971,1098 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
|
|
2721
2971
|
print(f'After {object_type} maximum mean intensity filter: {len(df)}')
|
2722
2972
|
return df
|
2723
2973
|
|
2724
|
-
|
2725
|
-
|
2726
|
-
|
2974
|
+
def _get_regex(metadata_type, img_format, custom_regex=None):
|
2975
|
+
|
2976
|
+
if img_format == None:
|
2977
|
+
img_format == '.tif'
|
2978
|
+
if metadata_type == 'cellvoyager':
|
2979
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2980
|
+
elif metadata_type == 'cq1':
|
2981
|
+
regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2982
|
+
elif metadata_type == 'nikon':
|
2983
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2984
|
+
elif metadata_type == 'zeis':
|
2985
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2986
|
+
elif metadata_type == 'leica':
|
2987
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2988
|
+
elif metadata_type == 'custom':
|
2989
|
+
regex = f'({custom_regex}){img_format}'
|
2990
|
+
|
2991
|
+
print(f'regex mode:{metadata_type} regex:{regex}')
|
2992
|
+
return regex
|
2993
|
+
|
2994
|
+
def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True):
|
2995
|
+
|
2996
|
+
if timelapse:
|
2997
|
+
test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
|
2998
|
+
|
2999
|
+
test_folder_path = os.path.join(src, 'test')
|
3000
|
+
os.makedirs(test_folder_path, exist_ok=True)
|
3001
|
+
regular_expression = re.compile(regex)
|
3002
|
+
|
3003
|
+
if os.path.exists(os.path.join(src, 'orig')):
|
3004
|
+
src = os.path.join(src, 'orig')
|
3005
|
+
|
3006
|
+
all_filenames = [filename for filename in os.listdir(src) if regular_expression.match(filename)]
|
3007
|
+
print(f'Found {len(all_filenames)} files')
|
3008
|
+
images_by_set = defaultdict(list)
|
3009
|
+
|
3010
|
+
for filename in all_filenames:
|
3011
|
+
match = regular_expression.match(filename)
|
3012
|
+
if match:
|
3013
|
+
plate = match.group('plateID') if 'plateID' in match.groupdict() else os.path.basename(src)
|
3014
|
+
well = match.group('wellID')
|
3015
|
+
field = match.group('fieldID')
|
3016
|
+
set_identifier = (plate, well, field)
|
3017
|
+
images_by_set[set_identifier].append(filename)
|
3018
|
+
|
3019
|
+
# Prepare for random selection
|
3020
|
+
set_identifiers = list(images_by_set.keys())
|
3021
|
+
if random_test:
|
3022
|
+
random.seed(42)
|
3023
|
+
random.shuffle(set_identifiers) # Randomize the order
|
3024
|
+
|
3025
|
+
# Select a subset based on the test_images count
|
3026
|
+
selected_sets = set_identifiers[:test_images]
|
3027
|
+
|
3028
|
+
# Print information about the number of sets used
|
3029
|
+
print(f'Using {len(selected_sets)} random image set(s) for test model')
|
3030
|
+
|
3031
|
+
# Copy files for selected sets to the test folder
|
3032
|
+
for set_identifier in selected_sets:
|
3033
|
+
for filename in images_by_set[set_identifier]:
|
3034
|
+
shutil.copy(os.path.join(src, filename), test_folder_path)
|
3035
|
+
|
3036
|
+
return test_folder_path
|
3037
|
+
|
3038
|
+
def _choose_model(model_name, device, object_type='cell', restore_type=None, object_settings={}):
|
3039
|
+
|
3040
|
+
if object_type == 'pathogen':
|
3041
|
+
if model_name == 'toxo_pv_lumen':
|
3042
|
+
diameter = object_settings['diameter']
|
3043
|
+
current_dir = os.path.dirname(__file__)
|
3044
|
+
model_path = os.path.join(current_dir, 'models', 'cp', 'toxo_pv_lumen.CP_model')
|
3045
|
+
print(model_path)
|
3046
|
+
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=model_path, diam_mean=diameter, device=device)
|
3047
|
+
#model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type='cyto', device=device)
|
3048
|
+
print(f'Using Toxoplasma PV lumen model to generate pathogen masks')
|
3049
|
+
return model
|
3050
|
+
|
3051
|
+
restore_list = ['denoise', 'deblur', 'upsample', None]
|
3052
|
+
if restore_type not in restore_list:
|
3053
|
+
print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
|
3054
|
+
restore_type = None
|
3055
|
+
|
3056
|
+
if restore_type == None:
|
3057
|
+
if model_name in ['cyto', 'cyto2', 'cyto3', 'nuclei']:
|
3058
|
+
model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type=model_name, device=device)
|
3059
|
+
|
3060
|
+
else:
|
3061
|
+
if object_type == 'nucleus':
|
3062
|
+
restore = f'{type}_nuclei'
|
3063
|
+
model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
|
3064
|
+
else:
|
3065
|
+
restore = f'{type}_cyto3'
|
3066
|
+
if model_name =='cyto2':
|
3067
|
+
chan2_restore = True
|
3068
|
+
if model_name =='cyto':
|
3069
|
+
chan2_restore = False
|
3070
|
+
model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
|
3071
|
+
|
3072
|
+
return model
|
3073
|
+
|
3074
|
+
class SelectChannels:
|
3075
|
+
def __init__(self, channels):
|
3076
|
+
self.channels = channels
|
3077
|
+
|
3078
|
+
def __call__(self, img):
|
3079
|
+
img = img.clone()
|
3080
|
+
if 1 not in self.channels:
|
3081
|
+
img[0, :, :] = 0 # Zero out the red channel
|
3082
|
+
if 2 not in self.channels:
|
3083
|
+
img[1, :, :] = 0 # Zero out the green channel
|
3084
|
+
if 3 not in self.channels:
|
3085
|
+
img[2, :, :] = 0 # Zero out the blue channel
|
3086
|
+
return img
|
3087
|
+
|
3088
|
+
def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
|
3089
|
+
|
3090
|
+
if normalize:
|
3091
|
+
transform = transforms.Compose([
|
3092
|
+
transforms.ToTensor(),
|
3093
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
3094
|
+
SelectChannels(channels),
|
3095
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
3096
|
+
else:
|
3097
|
+
transform = transforms.Compose([
|
3098
|
+
transforms.ToTensor(),
|
3099
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
3100
|
+
SelectChannels(channels)])
|
3101
|
+
|
3102
|
+
image = Image.open(image_path).convert('RGB')
|
3103
|
+
input_tensor = transform(image).unsqueeze(0)
|
3104
|
+
return image, input_tensor
|
3105
|
+
|
3106
|
+
|
3107
|
+
class SaliencyMapGenerator:
|
3108
|
+
def __init__(self, model):
|
3109
|
+
self.model = model
|
3110
|
+
|
3111
|
+
def compute_saliency_maps(self, X, y):
|
3112
|
+
self.model.eval()
|
3113
|
+
X.requires_grad_()
|
3114
|
+
|
3115
|
+
# Forward pass
|
3116
|
+
scores = self.model(X).squeeze()
|
3117
|
+
|
3118
|
+
# For binary classification, target scores can be the single output
|
3119
|
+
target_scores = scores * (2 * y - 1)
|
3120
|
+
|
3121
|
+
self.model.zero_grad()
|
3122
|
+
target_scores.backward(torch.ones_like(target_scores))
|
3123
|
+
|
3124
|
+
saliency = X.grad.abs()
|
3125
|
+
return saliency
|
3126
|
+
|
3127
|
+
def plot_saliency_maps(self, X, y, saliency, class_names):
|
3128
|
+
N = X.shape[0]
|
3129
|
+
for i in range(N):
|
3130
|
+
plt.subplot(2, N, i + 1)
|
3131
|
+
plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
|
3132
|
+
plt.axis('off')
|
3133
|
+
plt.title(class_names[y[i]])
|
3134
|
+
plt.subplot(2, N, N + i + 1)
|
3135
|
+
plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
|
3136
|
+
plt.axis('off')
|
3137
|
+
plt.gcf().set_size_inches(12, 5)
|
3138
|
+
plt.show()
|
3139
|
+
|
3140
|
+
def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
|
3141
|
+
preprocess = transforms.Compose([
|
3142
|
+
transforms.Resize((image_size, image_size)),
|
3143
|
+
transforms.ToTensor(),
|
3144
|
+
])
|
3145
|
+
|
3146
|
+
image = Image.open(image_path).convert('RGB')
|
3147
|
+
input_tensor = preprocess(image)
|
3148
|
+
if normalize:
|
3149
|
+
input_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_tensor)
|
3150
|
+
input_tensor = input_tensor.unsqueeze(0)
|
3151
|
+
|
3152
|
+
return image, input_tensor
|
3153
|
+
|
3154
|
+
def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names = ['nc', 'pc']):
|
3155
|
+
|
3156
|
+
def jitter(img, ox, oy):
|
3157
|
+
# Randomly jitter the image
|
3158
|
+
return torch.roll(torch.roll(img, ox, dims=2), oy, dims=3)
|
3159
|
+
|
3160
|
+
def blur_image(img, sigma=1):
|
3161
|
+
# Apply Gaussian blur to the image
|
3162
|
+
img_np = img.cpu().numpy()
|
3163
|
+
for i in range(img_np.shape[1]):
|
3164
|
+
img_np[:, i] = gaussian_filter(img_np[:, i], sigma=sigma)
|
3165
|
+
img.copy_(torch.tensor(img_np).to(img.device))
|
3166
|
+
|
3167
|
+
def deprocess(img_tensor):
|
3168
|
+
# Convert the tensor image to a numpy array for visualization
|
3169
|
+
img_tensor = img_tensor.clone()
|
3170
|
+
for c in range(3):
|
3171
|
+
img_tensor[:, c] = img_tensor[:, c] * SQUEEZENET_STD[c] + SQUEEZENET_MEAN[c]
|
3172
|
+
img_tensor = img_tensor.clamp(0, 1)
|
3173
|
+
return img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
|
3174
|
+
|
3175
|
+
# Assuming these are defined somewhere in your codebase
|
3176
|
+
SQUEEZENET_MEAN = [0.485, 0.456, 0.406]
|
3177
|
+
SQUEEZENET_STD = [0.229, 0.224, 0.225]
|
3178
|
+
|
3179
|
+
model = torch.load(model_path)
|
3180
|
+
|
3181
|
+
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
|
3182
|
+
len_chans = len(channels)
|
3183
|
+
model.type(dtype)
|
3184
|
+
|
3185
|
+
# Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
|
3186
|
+
img = torch.randn(1, len_chans, img_size, img_size).mul_(1.0).type(dtype).requires_grad_()
|
3187
|
+
|
3188
|
+
for t in range(num_iterations):
|
3189
|
+
# Randomly jitter the image a bit; this gives slightly nicer results
|
3190
|
+
ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
|
3191
|
+
img.data.copy_(jitter(img.data, ox, oy))
|
3192
|
+
|
3193
|
+
# Forward pass
|
3194
|
+
score = model(img)
|
3195
|
+
|
3196
|
+
if target_y == 0:
|
3197
|
+
target_score = -score
|
3198
|
+
else:
|
3199
|
+
target_score = score
|
3200
|
+
|
3201
|
+
# Add regularization
|
3202
|
+
target_score = target_score - l2_reg * torch.norm(img)
|
3203
|
+
|
3204
|
+
# Backward pass
|
3205
|
+
target_score.backward()
|
3206
|
+
|
3207
|
+
# Gradient ascent step
|
3208
|
+
with torch.no_grad():
|
3209
|
+
img += learning_rate * img.grad / torch.norm(img.grad)
|
3210
|
+
img.grad.zero_()
|
3211
|
+
|
3212
|
+
# Undo the random jitter
|
3213
|
+
img.data.copy_(jitter(img.data, -ox, -oy))
|
3214
|
+
|
3215
|
+
# As regularizer, clamp and periodically blur the image
|
3216
|
+
for c in range(3):
|
3217
|
+
lo = float(-SQUEEZENET_MEAN[c] / SQUEEZENET_STD[c])
|
3218
|
+
hi = float((1.0 - SQUEEZENET_MEAN[c]) / SQUEEZENET_STD[c])
|
3219
|
+
img.data[:, c].clamp_(min=lo, max=hi)
|
3220
|
+
if t % blur_every == 0:
|
3221
|
+
blur_image(img.data, sigma=0.5)
|
3222
|
+
|
3223
|
+
# Periodically show the image
|
3224
|
+
if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
|
3225
|
+
plt.imshow(deprocess(img.data.clone().cpu()))
|
3226
|
+
class_name = class_names[target_y]
|
3227
|
+
plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
|
3228
|
+
plt.gcf().set_size_inches(4, 4)
|
3229
|
+
plt.axis('off')
|
3230
|
+
plt.show()
|
3231
|
+
|
3232
|
+
return deprocess(img.data.cpu())
|
3233
|
+
|
3234
|
+
def get_submodules(model, prefix=''):
|
3235
|
+
submodules = []
|
3236
|
+
for name, module in model.named_children():
|
3237
|
+
full_name = prefix + ('.' if prefix else '') + name
|
3238
|
+
submodules.append(full_name)
|
3239
|
+
submodules.extend(get_submodules(module, full_name))
|
3240
|
+
return submodules
|
3241
|
+
|
3242
|
+
class GradCAM:
|
3243
|
+
def __init__(self, model, target_layers=None, use_cuda=True):
|
3244
|
+
self.model = model
|
3245
|
+
self.model.eval()
|
3246
|
+
self.target_layers = target_layers
|
3247
|
+
self.cuda = use_cuda
|
3248
|
+
if self.cuda:
|
3249
|
+
self.model = model.cuda()
|
3250
|
+
|
3251
|
+
def forward(self, input):
|
3252
|
+
return self.model(input)
|
3253
|
+
|
3254
|
+
def __call__(self, x, index=None):
|
3255
|
+
if self.cuda:
|
3256
|
+
x = x.cuda()
|
3257
|
+
|
3258
|
+
features = []
|
3259
|
+
def hook(module, input, output):
|
3260
|
+
features.append(output)
|
3261
|
+
|
3262
|
+
handles = []
|
3263
|
+
for name, module in self.model.named_modules():
|
3264
|
+
if name in self.target_layers:
|
3265
|
+
handles.append(module.register_forward_hook(hook))
|
3266
|
+
|
3267
|
+
output = self.forward(x)
|
3268
|
+
if index is None:
|
3269
|
+
index = np.argmax(output.data.cpu().numpy())
|
3270
|
+
|
3271
|
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
3272
|
+
one_hot[0][index] = 1
|
3273
|
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
3274
|
+
if self.cuda:
|
3275
|
+
one_hot = one_hot.cuda()
|
3276
|
+
|
3277
|
+
one_hot = torch.sum(one_hot * output)
|
3278
|
+
self.model.zero_grad()
|
3279
|
+
one_hot.backward(retain_graph=True)
|
3280
|
+
|
3281
|
+
grads_val = features[0].grad.cpu().data.numpy()
|
3282
|
+
target = features[0].cpu().data.numpy()[0, :]
|
3283
|
+
|
3284
|
+
weights = np.mean(grads_val, axis=(2, 3))[0, :]
|
3285
|
+
cam = np.zeros(target.shape[1:], dtype=np.float32)
|
3286
|
+
|
3287
|
+
for i, w in enumerate(weights):
|
3288
|
+
cam += w * target[i, :, :]
|
3289
|
+
|
3290
|
+
cam = np.maximum(cam, 0)
|
3291
|
+
cam = cv2.resize(cam, (x.size(2), x.size(3)))
|
3292
|
+
cam = cam - np.min(cam)
|
3293
|
+
cam = cam / np.max(cam)
|
3294
|
+
|
3295
|
+
for handle in handles:
|
3296
|
+
handle.remove()
|
3297
|
+
|
3298
|
+
return cam
|
3299
|
+
|
3300
|
+
def show_cam_on_image(img, mask):
|
3301
|
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
3302
|
+
heatmap = np.float32(heatmap) / 255
|
3303
|
+
cam = heatmap + np.float32(img)
|
3304
|
+
cam = cam / np.max(cam)
|
3305
|
+
return np.uint8(255 * cam)
|
3306
|
+
|
3307
|
+
def recommend_target_layers(model):
|
3308
|
+
target_layers = []
|
3309
|
+
for name, module in model.named_modules():
|
3310
|
+
if isinstance(module, torch.nn.Conv2d):
|
3311
|
+
target_layers.append(name)
|
3312
|
+
# Choose the last conv layer as the recommended target layer
|
3313
|
+
if target_layers:
|
3314
|
+
return [target_layers[-1]], target_layers
|
3315
|
+
else:
|
3316
|
+
raise ValueError("No convolutional layers found in the model.")
|
3317
|
+
|
3318
|
+
class IntegratedGradients:
|
3319
|
+
def __init__(self, model):
|
3320
|
+
self.model = model
|
3321
|
+
self.model.eval()
|
3322
|
+
|
3323
|
+
def generate_integrated_gradients(self, input_tensor, target_label_idx, baseline=None, num_steps=50):
|
3324
|
+
if baseline is None:
|
3325
|
+
baseline = torch.zeros_like(input_tensor)
|
3326
|
+
|
3327
|
+
assert baseline.shape == input_tensor.shape
|
3328
|
+
|
3329
|
+
# Scale input and compute gradients
|
3330
|
+
scaled_inputs = [(baseline + (float(i) / num_steps) * (input_tensor - baseline)).requires_grad_(True) for i in range(0, num_steps + 1)]
|
3331
|
+
grads = []
|
3332
|
+
for scaled_input in scaled_inputs:
|
3333
|
+
out = self.model(scaled_input)
|
3334
|
+
self.model.zero_grad()
|
3335
|
+
out[0, target_label_idx].backward(retain_graph=True)
|
3336
|
+
grads.append(scaled_input.grad.data.cpu().numpy())
|
3337
|
+
|
3338
|
+
avg_grads = np.mean(grads[:-1], axis=0)
|
3339
|
+
integrated_grads = (input_tensor.cpu().data.numpy() - baseline.cpu().data.numpy()) * avg_grads
|
3340
|
+
return integrated_grads
|
3341
|
+
|
3342
|
+
def get_db_paths(src):
|
3343
|
+
if isinstance(src, str):
|
3344
|
+
src = [src]
|
3345
|
+
db_paths = [os.path.join(source, 'measurements/measurements.db') for source in src]
|
3346
|
+
return db_paths
|
3347
|
+
|
3348
|
+
def get_sequencing_paths(src):
|
3349
|
+
if isinstance(src, str):
|
3350
|
+
src = [src]
|
3351
|
+
seq_paths = [os.path.join(source, 'sequencing/sequencing_data.csv') for source in src]
|
3352
|
+
return seq_paths
|
3353
|
+
|
3354
|
+
def load_image_paths(c, visualize):
|
3355
|
+
c.execute(f'SELECT * FROM png_list')
|
3356
|
+
data = c.fetchall()
|
3357
|
+
columns_info = c.execute(f'PRAGMA table_info(png_list)').fetchall()
|
3358
|
+
column_names = [col_info[1] for col_info in columns_info]
|
3359
|
+
image_paths_df = pd.DataFrame(data, columns=column_names)
|
3360
|
+
if visualize:
|
3361
|
+
object_visualize = visualize + '_png'
|
3362
|
+
image_paths_df = image_paths_df[image_paths_df['png_path'].str.contains(object_visualize)]
|
3363
|
+
image_paths_df = image_paths_df.set_index('prcfo')
|
3364
|
+
return image_paths_df
|
3365
|
+
|
3366
|
+
def merge_dataframes(df, image_paths_df, verbose):
|
3367
|
+
df.set_index('prcfo', inplace=True)
|
3368
|
+
df = image_paths_df.merge(df, left_index=True, right_index=True)
|
3369
|
+
if verbose:
|
3370
|
+
display(df)
|
3371
|
+
return df
|
3372
|
+
|
3373
|
+
def remove_highly_correlated_columns(df, threshold):
|
3374
|
+
corr_matrix = df.corr().abs()
|
3375
|
+
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
3376
|
+
to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > threshold)]
|
3377
|
+
return df.drop(to_drop, axis=1)
|
3378
|
+
|
3379
|
+
def filter_columns(df, filter_by):
|
3380
|
+
if filter_by != 'morphology':
|
3381
|
+
cols_to_include = [col for col in df.columns if filter_by in str(col)]
|
3382
|
+
else:
|
3383
|
+
cols_to_include = [col for col in df.columns if 'channel' not in str(col)]
|
3384
|
+
df = df[cols_to_include]
|
3385
|
+
return df
|
3386
|
+
|
3387
|
+
def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1, mode='fit', model=False):
|
3388
|
+
"""
|
3389
|
+
Perform dimensionality reduction and clustering on the given data.
|
3390
|
+
|
3391
|
+
Parameters:
|
3392
|
+
numeric_data (np.ndarray): Numeric data for embedding and clustering.
|
3393
|
+
n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
|
3394
|
+
min_dist (float): Minimum distance for UMAP.
|
3395
|
+
metric (str): Metric for UMAP and DBSCAN.
|
3396
|
+
eps (float): Epsilon for DBSCAN.
|
3397
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
3398
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3399
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3400
|
+
verbose (bool): Whether to print verbose output.
|
3401
|
+
embedding (np.ndarray, optional): Precomputed embedding. Default is None.
|
3402
|
+
return_model (bool): Whether to return the reducer model. Default is False.
|
3403
|
+
|
3404
|
+
Returns:
|
3405
|
+
tuple: embedding, labels (and optionally the reducer model)
|
3406
|
+
"""
|
3407
|
+
|
3408
|
+
if verbose:
|
3409
|
+
v = 1
|
3410
|
+
else:
|
3411
|
+
v = 0
|
3412
|
+
|
3413
|
+
if isinstance(n_neighbors, float):
|
3414
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3415
|
+
|
3416
|
+
if n_neighbors <= 2:
|
3417
|
+
n_neighbors = 2
|
3418
|
+
|
3419
|
+
if mode == 'fit':
|
3420
|
+
if reduction_method == 'umap':
|
3421
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors,
|
3422
|
+
n_components=2,
|
3423
|
+
metric=metric,
|
3424
|
+
n_epochs=None,
|
3425
|
+
learning_rate=1.0,
|
3426
|
+
init='spectral',
|
3427
|
+
min_dist=min_dist,
|
3428
|
+
spread=1.0,
|
3429
|
+
set_op_mix_ratio=1.0,
|
3430
|
+
local_connectivity=1,
|
3431
|
+
repulsion_strength=1.0,
|
3432
|
+
negative_sample_rate=5,
|
3433
|
+
transform_queue_size=4.0,
|
3434
|
+
a=None,
|
3435
|
+
b=None,
|
3436
|
+
random_state=42,
|
3437
|
+
metric_kwds=None,
|
3438
|
+
angular_rp_forest=False,
|
3439
|
+
target_n_neighbors=-1,
|
3440
|
+
target_metric='categorical',
|
3441
|
+
target_metric_kwds=None,
|
3442
|
+
target_weight=0.5,
|
3443
|
+
transform_seed=42,
|
3444
|
+
n_jobs=n_jobs,
|
3445
|
+
verbose=verbose)
|
3446
|
+
|
3447
|
+
elif reduction_method == 'tsne':
|
3448
|
+
reducer = TSNE(n_components=2,
|
3449
|
+
perplexity=n_neighbors,
|
3450
|
+
early_exaggeration=12.0,
|
3451
|
+
learning_rate=200.0,
|
3452
|
+
n_iter=1000,
|
3453
|
+
n_iter_without_progress=300,
|
3454
|
+
min_grad_norm=1e-7,
|
3455
|
+
metric=metric,
|
3456
|
+
init='random',
|
3457
|
+
verbose=v,
|
3458
|
+
random_state=42,
|
3459
|
+
method='barnes_hut',
|
3460
|
+
angle=0.5,
|
3461
|
+
n_jobs=n_jobs)
|
3462
|
+
|
3463
|
+
else:
|
3464
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
3465
|
+
|
3466
|
+
embedding = reducer.fit_transform(numeric_data)
|
3467
|
+
if verbose:
|
3468
|
+
print(f'Trained and fit reducer')
|
3469
|
+
|
3470
|
+
else:
|
3471
|
+
if not model is None:
|
3472
|
+
embedding = model.transform(numeric_data)
|
3473
|
+
reducer = model
|
3474
|
+
if verbose:
|
3475
|
+
print(f'Fit data to reducer')
|
3476
|
+
else:
|
3477
|
+
raise ValueError(f"Model is None. Please provide a model for transform.")
|
3478
|
+
|
3479
|
+
if clustering == 'dbscan':
|
3480
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
|
3481
|
+
elif clustering == 'kmeans':
|
3482
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
3483
|
+
|
3484
|
+
clustering_model.fit(embedding)
|
3485
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
3486
|
+
|
3487
|
+
if verbose:
|
3488
|
+
print(f'Embedding shape: {embedding.shape}')
|
3489
|
+
|
3490
|
+
return embedding, labels, reducer
|
3491
|
+
|
3492
|
+
def reduction_and_clustering_v1(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1):
|
3493
|
+
"""
|
3494
|
+
Perform dimensionality reduction and clustering on the given data.
|
3495
|
+
|
3496
|
+
Parameters:
|
3497
|
+
numeric_data (np.ndarray): Numeric data for embedding and clustering.
|
3498
|
+
n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
|
3499
|
+
min_dist (float): Minimum distance for UMAP.
|
3500
|
+
metric (str): Metric for UMAP and DBSCAN.
|
3501
|
+
eps (float): Epsilon for DBSCAN.
|
3502
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
3503
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3504
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3505
|
+
verbose (bool): Whether to print verbose output.
|
3506
|
+
embedding (np.ndarray, optional): Precomputed embedding. Default is None.
|
3507
|
+
|
3508
|
+
Returns:
|
3509
|
+
tuple: embedding, labels
|
3510
|
+
"""
|
3511
|
+
|
3512
|
+
if verbose:
|
3513
|
+
v=1
|
3514
|
+
else:
|
3515
|
+
v=0
|
3516
|
+
|
3517
|
+
if isinstance(n_neighbors, float):
|
3518
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3519
|
+
|
3520
|
+
if n_neighbors <= 2:
|
3521
|
+
n_neighbors = 2
|
3522
|
+
|
3523
|
+
if reduction_method == 'umap':
|
3524
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors,
|
3525
|
+
n_components=2,
|
3526
|
+
metric=metric,
|
3527
|
+
n_epochs=None,
|
3528
|
+
learning_rate=1.0,
|
3529
|
+
init='spectral',
|
3530
|
+
min_dist=min_dist,
|
3531
|
+
spread=1.0,
|
3532
|
+
set_op_mix_ratio=1.0,
|
3533
|
+
local_connectivity=1,
|
3534
|
+
repulsion_strength=1.0,
|
3535
|
+
negative_sample_rate=5,
|
3536
|
+
transform_queue_size=4.0,
|
3537
|
+
a=None,
|
3538
|
+
b=None,
|
3539
|
+
random_state=42,
|
3540
|
+
metric_kwds=None,
|
3541
|
+
angular_rp_forest=False,
|
3542
|
+
target_n_neighbors=-1,
|
3543
|
+
target_metric='categorical',
|
3544
|
+
target_metric_kwds=None,
|
3545
|
+
target_weight=0.5,
|
3546
|
+
transform_seed=42,
|
3547
|
+
n_jobs=n_jobs,
|
3548
|
+
verbose=verbose)
|
3549
|
+
|
3550
|
+
elif reduction_method == 'tsne':
|
3551
|
+
|
3552
|
+
#tsne_params.setdefault('n_components', 2)
|
3553
|
+
#reducer = TSNE(**tsne_params)
|
3554
|
+
|
3555
|
+
reducer = TSNE(n_components=2,
|
3556
|
+
perplexity=n_neighbors,
|
3557
|
+
early_exaggeration=12.0,
|
3558
|
+
learning_rate=200.0,
|
3559
|
+
n_iter=1000,
|
3560
|
+
n_iter_without_progress=300,
|
3561
|
+
min_grad_norm=1e-7,
|
3562
|
+
metric=metric,
|
3563
|
+
init='random',
|
3564
|
+
verbose=v,
|
3565
|
+
random_state=42,
|
3566
|
+
method='barnes_hut',
|
3567
|
+
angle=0.5,
|
3568
|
+
n_jobs=n_jobs)
|
3569
|
+
|
3570
|
+
else:
|
3571
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
3572
|
+
|
3573
|
+
if embedding is None:
|
3574
|
+
embedding = reducer.fit_transform(numeric_data)
|
3575
|
+
|
3576
|
+
if clustering == 'dbscan':
|
3577
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
|
3578
|
+
elif clustering == 'kmeans':
|
3579
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
3580
|
+
else:
|
3581
|
+
raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
|
3582
|
+
|
3583
|
+
clustering_model.fit(embedding)
|
3584
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
3585
|
+
|
3586
|
+
if verbose:
|
3587
|
+
print(f'Embedding shape: {embedding.shape}')
|
3588
|
+
|
3589
|
+
return embedding, labels
|
3590
|
+
|
3591
|
+
def remove_noise(embedding, labels):
|
3592
|
+
non_noise_indices = labels != -1
|
3593
|
+
embedding = embedding[non_noise_indices]
|
3594
|
+
labels = labels[non_noise_indices]
|
3595
|
+
return embedding, labels
|
3596
|
+
|
3597
|
+
def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, plot_by_cluster, plot_outlines, plot_points, plot_images, smooth_lines, black_background, figuresize, dot_size, remove_image_canvas, verbose):
|
3598
|
+
unique_labels = np.unique(labels)
|
3599
|
+
#num_clusters = len(unique_labels[unique_labels != 0])
|
3600
|
+
colors, label_to_color_index = assign_colors(unique_labels, colors)
|
3601
|
+
cluster_centers = [np.mean(embedding[labels == cluster_label], axis=0) for cluster_label in unique_labels]
|
3602
|
+
fig, ax = setup_plot(figuresize, black_background)
|
3603
|
+
plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize, dot_size, verbose)
|
3604
|
+
if not image_paths is None and plot_images:
|
3605
|
+
plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose)
|
3606
|
+
plt.show()
|
3607
|
+
return fig
|
3608
|
+
|
3609
|
+
def generate_colors(num_clusters, black_background):
|
3610
|
+
random_colors = np.random.rand(num_clusters + 1, 4)
|
3611
|
+
random_colors[:, 3] = 1
|
3612
|
+
specific_colors = [
|
3613
|
+
[155 / 255, 55 / 255, 155 / 255, 1],
|
3614
|
+
[55 / 255, 155 / 255, 155 / 255, 1],
|
3615
|
+
[55 / 255, 155 / 255, 255 / 255, 1],
|
3616
|
+
[255 / 255, 55 / 255, 155 / 255, 1]
|
3617
|
+
]
|
3618
|
+
random_colors = np.vstack((specific_colors, random_colors[len(specific_colors):]))
|
3619
|
+
if not black_background:
|
3620
|
+
random_colors = np.vstack(([0, 0, 0, 1], random_colors))
|
3621
|
+
return random_colors
|
3622
|
+
|
3623
|
+
def assign_colors(unique_labels, random_colors):
|
3624
|
+
normalized_colors = random_colors / 255
|
3625
|
+
colors_img = [tuple(color) for color in normalized_colors]
|
3626
|
+
colors = [tuple(color) for color in random_colors]
|
3627
|
+
label_to_color_index = {label: index for index, label in enumerate(unique_labels)}
|
3628
|
+
return colors, label_to_color_index
|
3629
|
+
|
3630
|
+
def setup_plot(figuresize, black_background):
|
3631
|
+
if black_background:
|
3632
|
+
plt.rcParams.update({'figure.facecolor': 'black', 'axes.facecolor': 'black', 'text.color': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'axes.labelcolor': 'white'})
|
3633
|
+
else:
|
3634
|
+
plt.rcParams.update({'figure.facecolor': 'white', 'axes.facecolor': 'white', 'text.color': 'black', 'xtick.color': 'black', 'ytick.color': 'black', 'axes.labelcolor': 'black'})
|
3635
|
+
fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
|
3636
|
+
return fig, ax
|
3637
|
+
|
3638
|
+
def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize=50, dot_size=50, verbose=False):
|
3639
|
+
unique_labels = np.unique(labels)
|
3640
|
+
for cluster_label, color, center in zip(unique_labels, colors, cluster_centers):
|
3641
|
+
cluster_data = embedding[labels == cluster_label]
|
3642
|
+
if smooth_lines:
|
3643
|
+
if cluster_data.shape[0] > 2:
|
3644
|
+
x_smooth, y_smooth = smooth_hull_lines(cluster_data)
|
3645
|
+
if plot_outlines:
|
3646
|
+
plt.plot(x_smooth, y_smooth, color=color, linewidth=2)
|
3647
|
+
else:
|
3648
|
+
if cluster_data.shape[0] > 2:
|
3649
|
+
hull = ConvexHull(cluster_data)
|
3650
|
+
for simplex in hull.simplices:
|
3651
|
+
if plot_outlines:
|
3652
|
+
plt.plot(hull.points[simplex, 0], hull.points[simplex, 1], color=color, linewidth=4)
|
3653
|
+
if plot_points:
|
3654
|
+
scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0.5, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
|
3655
|
+
else:
|
3656
|
+
scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
|
3657
|
+
ax.text(center[0], center[1], str(cluster_label), fontsize=12, ha='center', va='center')
|
3658
|
+
plt.legend(loc='best', fontsize=int(figuresize * 0.75))
|
3659
|
+
plt.xlabel('UMAP Dimension 1', fontsize=int(figuresize * 0.75))
|
3660
|
+
plt.ylabel('UMAP Dimension 2', fontsize=int(figuresize * 0.75))
|
3661
|
+
plt.tick_params(axis='both', which='major', labelsize=int(figuresize * 0.75))
|
3662
|
+
|
3663
|
+
def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose):
|
3664
|
+
if plot_by_cluster:
|
3665
|
+
cluster_indices = {label: np.where(labels == label)[0] for label in np.unique(labels) if label != -1}
|
3666
|
+
plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose)
|
3667
|
+
else:
|
3668
|
+
indices = random.sample(range(len(embedding)), image_nr)
|
3669
|
+
for i, index in enumerate(indices):
|
3670
|
+
x, y = embedding[index]
|
3671
|
+
img = Image.open(image_paths[index])
|
3672
|
+
plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
|
3673
|
+
|
3674
|
+
def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose):
|
3675
|
+
for cluster_label, color in zip(np.unique(labels), colors):
|
3676
|
+
if cluster_label == -1:
|
3677
|
+
continue
|
3678
|
+
indices = cluster_indices.get(cluster_label, [])
|
3679
|
+
if len(indices) > image_nr:
|
3680
|
+
indices = random.sample(list(indices), image_nr)
|
3681
|
+
for index in indices:
|
3682
|
+
x, y = embedding[index]
|
3683
|
+
img = Image.open(image_paths[index])
|
3684
|
+
plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
|
3685
|
+
|
3686
|
+
def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
|
3687
|
+
img = np.array(img)
|
3688
|
+
if remove_image_canvas:
|
3689
|
+
img = remove_canvas(img)
|
3690
|
+
imagebox = OffsetImage(img, zoom=img_zoom)
|
3691
|
+
ab = AnnotationBbox(imagebox, (x, y), frameon=False)
|
3692
|
+
ax.add_artist(ab)
|
3693
|
+
|
3694
|
+
def remove_canvas(img):
|
3695
|
+
if img.mode in ['L', 'I']:
|
3696
|
+
img_data = np.array(img)
|
3697
|
+
img_data = img_data / np.max(img_data)
|
3698
|
+
alpha_channel = (img_data > 0).astype(float)
|
3699
|
+
img_data_rgb = np.stack([img_data] * 3, axis=-1)
|
3700
|
+
img_data_with_alpha = np.dstack([img_data_rgb, alpha_channel])
|
3701
|
+
elif img.mode == 'RGB':
|
3702
|
+
img_data = np.array(img)
|
3703
|
+
img_data = img_data / 255.0
|
3704
|
+
alpha_channel = (np.sum(img_data, axis=-1) > 0).astype(float)
|
3705
|
+
img_data_with_alpha = np.dstack([img_data, alpha_channel])
|
3706
|
+
else:
|
3707
|
+
raise ValueError(f"Unsupported image mode: {img.mode}")
|
3708
|
+
return img_data_with_alpha
|
3709
|
+
|
3710
|
+
def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figuresize, black_background, verbose):
|
3711
|
+
unique_labels = np.unique(labels)
|
3712
|
+
num_clusters = len(unique_labels[unique_labels != -1])
|
3713
|
+
if num_clusters == 0:
|
3714
|
+
print("No clusters found.")
|
3715
|
+
return
|
3716
|
+
cluster_images = {label: [] for label in unique_labels if label != -1}
|
3717
|
+
cluster_indices = {label: np.where(labels == label)[0] for label in unique_labels if label != -1}
|
3718
|
+
for cluster_label, indices in cluster_indices.items():
|
3719
|
+
if cluster_label == -1:
|
3720
|
+
continue
|
3721
|
+
if len(indices) > image_nr:
|
3722
|
+
indices = random.sample(list(indices), image_nr)
|
3723
|
+
for index in indices:
|
3724
|
+
img_path = image_paths[index]
|
3725
|
+
img_array = Image.open(img_path)
|
3726
|
+
img = np.array(img_array)
|
3727
|
+
cluster_images[cluster_label].append(img)
|
3728
|
+
fig = plot_grid(cluster_images, colors, figuresize, black_background, verbose)
|
3729
|
+
return fig
|
3730
|
+
|
3731
|
+
def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
|
3732
|
+
num_clusters = len(cluster_images)
|
3733
|
+
max_figsize = 200 # Set a maximum figure size
|
3734
|
+
if figuresize * num_clusters > max_figsize:
|
3735
|
+
figuresize = max_figsize / num_clusters
|
3736
|
+
|
3737
|
+
grid_fig, grid_axes = plt.subplots(1, num_clusters, figsize=(figuresize * num_clusters, figuresize), gridspec_kw={'wspace': 0.2, 'hspace': 0})
|
3738
|
+
if num_clusters == 1:
|
3739
|
+
grid_axes = [grid_axes] # Ensure grid_axes is always iterable
|
3740
|
+
for cluster_label, axes in zip(cluster_images.keys(), grid_axes):
|
3741
|
+
images = cluster_images[cluster_label]
|
3742
|
+
num_images = len(images)
|
3743
|
+
grid_size = int(np.ceil(np.sqrt(num_images)))
|
3744
|
+
image_size = 0.9 / grid_size
|
3745
|
+
whitespace = (1 - grid_size * image_size) / (grid_size + 1)
|
3746
|
+
|
3747
|
+
if isinstance(cluster_label, str):
|
3748
|
+
idx = list(cluster_images.keys()).index(cluster_label)
|
3749
|
+
color = colors[idx]
|
3750
|
+
if verbose:
|
3751
|
+
print(f'Lable: {cluster_label} index: {idx}')
|
3752
|
+
else:
|
3753
|
+
color = colors[cluster_label]
|
3754
|
+
|
3755
|
+
axes.add_patch(plt.Rectangle((0, 0), 1, 1, transform=axes.transAxes, color=color[:3]))
|
3756
|
+
axes.axis('off')
|
3757
|
+
for i, img in enumerate(images):
|
3758
|
+
row = i // grid_size
|
3759
|
+
col = i % grid_size
|
3760
|
+
x_pos = (col + 1) * whitespace + col * image_size
|
3761
|
+
y_pos = 1 - ((row + 1) * whitespace + (row + 1) * image_size)
|
3762
|
+
ax_img = axes.inset_axes([x_pos, y_pos, image_size, image_size], transform=axes.transAxes)
|
3763
|
+
ax_img.imshow(img, cmap='gray', aspect='auto')
|
3764
|
+
ax_img.axis('off')
|
3765
|
+
ax_img.set_aspect('equal')
|
3766
|
+
ax_img.set_facecolor(color[:3])
|
3767
|
+
|
3768
|
+
# Add cluster labels beside the UMAP plot
|
3769
|
+
spacing_factor = 0.5 # Adjust this value to control the spacing between labels
|
3770
|
+
for i, (cluster_label, color) in enumerate(zip(cluster_images.keys(), colors)):
|
3771
|
+
label_y = 1 - (i + 1) * (spacing_factor / num_clusters) # Adjust y position for each label
|
3772
|
+
grid_fig.text(1.05, label_y, f'Cluster {cluster_label}', verticalalignment='center', fontsize=figuresize, color='black' if not black_background else 'white')
|
3773
|
+
grid_fig.patches.append(plt.Rectangle((1, label_y - 0.02), 0.03, 0.03, transform=grid_fig.transFigure, color=color[:3], clip_on=False))
|
3774
|
+
|
3775
|
+
plt.show()
|
3776
|
+
return grid_fig
|
3777
|
+
|
3778
|
+
def correct_paths(df, base_path):
|
3779
|
+
|
3780
|
+
if 'png_path' not in df.columns:
|
3781
|
+
print("No 'png_path' column found in the dataframe.")
|
3782
|
+
return df, None
|
3783
|
+
|
3784
|
+
image_paths = df['png_path'].to_list()
|
3785
|
+
|
3786
|
+
adjusted_image_paths = []
|
3787
|
+
for path in image_paths:
|
3788
|
+
if base_path not in path:
|
3789
|
+
parts = path.split('/data/')
|
3790
|
+
if len(parts) > 1:
|
3791
|
+
new_path = os.path.join(base_path, 'data', parts[1])
|
3792
|
+
adjusted_image_paths.append(new_path)
|
3793
|
+
else:
|
3794
|
+
adjusted_image_paths.append(path)
|
3795
|
+
else:
|
3796
|
+
adjusted_image_paths.append(path)
|
3797
|
+
|
3798
|
+
df['png_path'] = adjusted_image_paths
|
3799
|
+
image_paths = df['png_path'].to_list()
|
3800
|
+
return df, image_paths
|
3801
|
+
|
3802
|
+
def correct_paths_v1(df, base_path):
|
3803
|
+
if 'png_path' not in df.columns:
|
3804
|
+
print("No 'png_path' column found in the dataframe.")
|
3805
|
+
return df, None
|
3806
|
+
|
3807
|
+
image_paths = df['png_path'].to_list()
|
3808
|
+
|
3809
|
+
adjusted_image_paths = []
|
3810
|
+
for path in image_paths:
|
3811
|
+
if base_path not in path:
|
3812
|
+
print(f"Adjusting path: {path}")
|
3813
|
+
parts = path.split('data/')
|
3814
|
+
if len(parts) > 1:
|
3815
|
+
new_path = os.path.join(base_path, 'data', parts[1])
|
3816
|
+
adjusted_image_paths.append(new_path)
|
3817
|
+
else:
|
3818
|
+
adjusted_image_paths.append(path)
|
3819
|
+
else:
|
3820
|
+
adjusted_image_paths.append(path)
|
3821
|
+
|
3822
|
+
df['png_path'] = adjusted_image_paths
|
3823
|
+
image_paths = df['png_path'].to_list()
|
3824
|
+
return df, image_paths
|
3825
|
+
|
3826
|
+
def get_umap_image_settings(settings={}):
|
3827
|
+
settings.setdefault('src', 'path')
|
3828
|
+
settings.setdefault('row_limit', 1000)
|
3829
|
+
settings.setdefault('tables', ['cell', 'cytoplasm', 'nucleus', 'pathogen'])
|
3830
|
+
settings.setdefault('visualize', 'cell')
|
3831
|
+
settings.setdefault('image_nr', 16)
|
3832
|
+
settings.setdefault('dot_size', 50)
|
3833
|
+
settings.setdefault('n_neighbors', 1000)
|
3834
|
+
settings.setdefault('min_dist', 0.1)
|
3835
|
+
settings.setdefault('metric', 'euclidean')
|
3836
|
+
settings.setdefault('eps', 0.5)
|
3837
|
+
settings.setdefault('min_samples', 1000)
|
3838
|
+
settings.setdefault('filter_by', 'channel_0')
|
3839
|
+
settings.setdefault('img_zoom', 0.5)
|
3840
|
+
settings.setdefault('plot_by_cluster', True)
|
3841
|
+
settings.setdefault('plot_cluster_grids', True)
|
3842
|
+
settings.setdefault('remove_cluster_noise', True)
|
3843
|
+
settings.setdefault('remove_highly_correlated', True)
|
3844
|
+
settings.setdefault('log_data', False)
|
3845
|
+
settings.setdefault('figuresize', 60)
|
3846
|
+
settings.setdefault('black_background', True)
|
3847
|
+
settings.setdefault('remove_image_canvas', False)
|
3848
|
+
settings.setdefault('plot_outlines', True)
|
3849
|
+
settings.setdefault('plot_points', True)
|
3850
|
+
settings.setdefault('smooth_lines', True)
|
3851
|
+
settings.setdefault('clustering', 'dbscan')
|
3852
|
+
settings.setdefault('exclude', None)
|
3853
|
+
settings.setdefault('col_to_compare', 'col')
|
3854
|
+
settings.setdefault('pos', 'c1')
|
3855
|
+
settings.setdefault('neg', 'c2')
|
3856
|
+
settings.setdefault('embedding_by_controls', False)
|
3857
|
+
settings.setdefault('plot_images', True)
|
3858
|
+
settings.setdefault('reduction_method','umap')
|
3859
|
+
settings.setdefault('save_figure', False)
|
3860
|
+
settings.setdefault('n_jobs', -1)
|
3861
|
+
settings.setdefault('color_by', None)
|
3862
|
+
settings.setdefault('neg', 'c1')
|
3863
|
+
settings.setdefault('pos', 'c2')
|
3864
|
+
settings.setdefault('mix', 'c3')
|
3865
|
+
settings.setdefault('mix', 'c3')
|
3866
|
+
settings.setdefault('exclude_conditions', None)
|
3867
|
+
settings.setdefault('analyze_clusters', False)
|
3868
|
+
settings.setdefault('resnet_features', False)
|
3869
|
+
settings.setdefault('verbose',True)
|
3870
|
+
return settings
|
3871
|
+
|
3872
|
+
def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude):
|
3873
|
+
"""
|
3874
|
+
Preprocesses the given dataframe by applying filtering, removing highly correlated columns,
|
3875
|
+
applying log transformation, filling NaN values, and scaling the numeric data.
|
3876
|
+
|
3877
|
+
Args:
|
3878
|
+
df (pandas.DataFrame): The input dataframe.
|
3879
|
+
filter_by (str or None): The channel of interest to filter the dataframe by.
|
3880
|
+
remove_highly_correlated (bool or float): Whether to remove highly correlated columns.
|
3881
|
+
If a float is provided, it represents the correlation threshold.
|
3882
|
+
log_data (bool): Whether to apply log transformation to the numeric data.
|
3883
|
+
exclude (list or None): List of features to exclude from the filtering process.
|
3884
|
+
verbose (bool): Whether to print verbose output during preprocessing.
|
3885
|
+
|
3886
|
+
Returns:
|
3887
|
+
numpy.ndarray: The preprocessed numeric data.
|
3888
|
+
|
3889
|
+
Raises:
|
3890
|
+
ValueError: If no numeric columns are available after filtering.
|
3891
|
+
|
3892
|
+
"""
|
3893
|
+
# Apply filtering based on the `filter_by` parameter
|
3894
|
+
if filter_by is not None:
|
3895
|
+
df, _ = filter_dataframe_features(df, channel_of_interest=filter_by, exclude=exclude)
|
3896
|
+
|
3897
|
+
# Select numerical features
|
3898
|
+
numeric_data = df.select_dtypes(include=['number'])
|
3899
|
+
|
3900
|
+
# Check if numeric_data is empty
|
3901
|
+
if numeric_data.empty:
|
3902
|
+
raise ValueError("No numeric columns available after filtering. Please check the filter_by and exclude parameters.")
|
3903
|
+
|
3904
|
+
# Remove highly correlated columns
|
3905
|
+
if not remove_highly_correlated is False:
|
3906
|
+
if isinstance(remove_highly_correlated, float):
|
3907
|
+
numeric_data = remove_highly_correlated_columns(numeric_data, remove_highly_correlated)
|
3908
|
+
else:
|
3909
|
+
numeric_data = remove_highly_correlated_columns(numeric_data, 0.95)
|
3910
|
+
|
3911
|
+
# Apply log transformation
|
3912
|
+
if log_data:
|
3913
|
+
numeric_data = np.log(numeric_data + 1e-6)
|
3914
|
+
|
3915
|
+
# Fill NaN values with the column mean
|
3916
|
+
numeric_data = numeric_data.fillna(numeric_data.mean())
|
3917
|
+
|
3918
|
+
# Scale the numeric data
|
3919
|
+
scaler = StandardScaler(copy=True, with_mean=True, with_std=True)
|
3920
|
+
numeric_data = scaler.fit_transform(numeric_data)
|
3921
|
+
|
3922
|
+
return numeric_data
|
3923
|
+
|
3924
|
+
def filter_dataframe_features(df, channel_of_interest, exclude=None):
|
3925
|
+
"""
|
3926
|
+
Filter the dataframe `df` based on the specified `channel_of_interest` and `exclude` parameters.
|
3927
|
+
|
3928
|
+
Parameters:
|
3929
|
+
- df (pandas.DataFrame): The input dataframe to be filtered.
|
3930
|
+
- channel_of_interest (str, int, list, None): The channel(s) of interest to filter the dataframe.
|
3931
|
+
If None, no filtering is applied. If 'morphology', only morphology features are included.
|
3932
|
+
If an integer, only the specified channel is included. If a list, only the specified channels are included.
|
3933
|
+
If a string, only the specified channel is included.
|
3934
|
+
- exclude (str, list, None): The feature(s) to exclude from the filtered dataframe.
|
3935
|
+
If None, no features are excluded. If a string, the specified feature is excluded.
|
3936
|
+
If a list, the specified features are excluded.
|
3937
|
+
|
3938
|
+
Returns:
|
3939
|
+
- filtered_df (pandas.DataFrame): The filtered dataframe based on the specified parameters.
|
3940
|
+
- features (list): The list of selected features after filtering.
|
3941
|
+
|
3942
|
+
"""
|
3943
|
+
if channel_of_interest is None:
|
3944
|
+
feature_string = None
|
3945
|
+
elif channel_of_interest == 'morphology':
|
3946
|
+
feature_string = 'morphology'
|
3947
|
+
elif isinstance(channel_of_interest, list):
|
3948
|
+
feature_string = []
|
3949
|
+
for i in channel_of_interest:
|
3950
|
+
feature_string_tmp = f'channel_{i}'
|
3951
|
+
feature_string.append(feature_string_tmp)
|
3952
|
+
elif isinstance(channel_of_interest, int):
|
3953
|
+
feature_string = f'channel_{channel_of_interest}'
|
3954
|
+
elif isinstance(channel_of_interest, str):
|
3955
|
+
feature_string = channel_of_interest
|
3956
|
+
|
3957
|
+
# Remove columns with a single value
|
3958
|
+
df = df.loc[:, df.nunique() > 1]
|
3959
|
+
|
3960
|
+
# Select numerical features
|
3961
|
+
features = df.select_dtypes(include=[np.number]).columns.tolist()
|
3962
|
+
|
3963
|
+
if feature_string is not None:
|
3964
|
+
feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
|
3965
|
+
|
3966
|
+
# Remove feature_string from the list if it exists
|
3967
|
+
if isinstance(feature_string, str):
|
3968
|
+
if feature_string in feature_list:
|
3969
|
+
feature_list.remove(feature_string)
|
3970
|
+
elif isinstance(feature_string, list):
|
3971
|
+
feature_list = [feature for feature in feature_list if feature not in feature_string]
|
3972
|
+
|
3973
|
+
if feature_string != 'morphology':
|
3974
|
+
features = [feature for feature in features if feature_string in feature]
|
3975
|
+
|
3976
|
+
# Iterate through the list and remove columns from df
|
3977
|
+
for feature_ in feature_list:
|
3978
|
+
features = [feature for feature in features if feature_ not in feature]
|
3979
|
+
print(f'After removing {feature_} features: {len(features)}')
|
3980
|
+
|
3981
|
+
if isinstance(exclude, list):
|
3982
|
+
features = [feature for feature in features if feature not in exclude]
|
3983
|
+
elif isinstance(exclude, str):
|
3984
|
+
features.remove(exclude)
|
3985
|
+
|
3986
|
+
filtered_df = df[features]
|
3987
|
+
|
3988
|
+
return filtered_df, features
|
3989
|
+
|
3990
|
+
# Create a function to check if images overlap
|
3991
|
+
def check_overlap(current_position, other_positions, threshold):
|
3992
|
+
for other_position in other_positions:
|
3993
|
+
distance = np.linalg.norm(np.array(current_position) - np.array(other_position))
|
3994
|
+
if distance < threshold:
|
3995
|
+
return True
|
3996
|
+
return False
|
3997
|
+
|
3998
|
+
# Define a function to try random positions around a given point
|
3999
|
+
def find_non_overlapping_position(x, y, image_positions, threshold, max_attempts=100):
|
4000
|
+
offset_range = 10 # Adjust the range for random offsets
|
4001
|
+
attempts = 0
|
4002
|
+
while attempts < max_attempts:
|
4003
|
+
random_offset_x = random.uniform(-offset_range, offset_range)
|
4004
|
+
random_offset_y = random.uniform(-offset_range, offset_range)
|
4005
|
+
new_x = x + random_offset_x
|
4006
|
+
new_y = y + random_offset_y
|
4007
|
+
if not check_overlap((new_x, new_y), image_positions, threshold):
|
4008
|
+
return new_x, new_y
|
4009
|
+
attempts += 1
|
4010
|
+
return x, y # Return the original position if no suitable position found
|
4011
|
+
|
4012
|
+
def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method, verbose, reduction_param=None, embedding=None, n_jobs=-1):
|
4013
|
+
"""
|
4014
|
+
Perform dimensionality reduction and clustering on the given data.
|
4015
|
+
|
4016
|
+
Parameters:
|
4017
|
+
numeric_data (np.array): Numeric data to process.
|
4018
|
+
n_neighbors (int): Number of neighbors for UMAP or perplexity for tSNE.
|
4019
|
+
min_dist (float): Minimum distance for UMAP.
|
4020
|
+
metric (str): Metric for UMAP, tSNE, and DBSCAN.
|
4021
|
+
eps (float): Epsilon for DBSCAN clustering.
|
4022
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
4023
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
4024
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
4025
|
+
verbose (bool): Whether to print verbose output.
|
4026
|
+
reduction_param (dict): Additional parameters for the reduction method.
|
4027
|
+
embedding (np.array): Precomputed embedding (optional).
|
4028
|
+
n_jobs (int): Number of parallel jobs to run.
|
4029
|
+
|
4030
|
+
Returns:
|
4031
|
+
embedding (np.array): Embedding of the data.
|
4032
|
+
labels (np.array): Cluster labels.
|
4033
|
+
"""
|
4034
|
+
|
4035
|
+
if isinstance(n_neighbors, float):
|
4036
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
4037
|
+
if n_neighbors <= 1:
|
4038
|
+
n_neighbors = 2
|
4039
|
+
print(f'n_neighbors cannota be less than 2. Setting n_neighbors to {n_neighbors}')
|
4040
|
+
|
4041
|
+
reduction_param = reduction_param or {}
|
4042
|
+
reduction_param = {k: v for k, v in reduction_param.items() if k not in ['perplexity', 'n_neighbors', 'min_dist', 'metric', 'method']}
|
4043
|
+
|
4044
|
+
if reduction_method == 'umap':
|
4045
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, n_jobs=n_jobs, **reduction_param)
|
4046
|
+
elif reduction_method == 'tsne':
|
4047
|
+
reducer = TSNE(n_components=2, perplexity=n_neighbors, metric=metric, n_jobs=n_jobs, **reduction_param)
|
4048
|
+
else:
|
4049
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
4050
|
+
|
4051
|
+
if embedding is None:
|
4052
|
+
embedding = reducer.fit_transform(numeric_data)
|
4053
|
+
|
4054
|
+
if clustering == 'dbscan':
|
4055
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
|
4056
|
+
elif clustering == 'kmeans':
|
4057
|
+
from sklearn.cluster import KMeans
|
4058
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
4059
|
+
else:
|
4060
|
+
raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
|
4061
|
+
clustering_model.fit(embedding)
|
4062
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
4063
|
+
if verbose:
|
4064
|
+
print(f'Embedding shape: {embedding.shape}')
|
4065
|
+
return embedding, labels
|
4066
|
+
|
4067
|
+
|
4068
|
+
|