spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +803 -14
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/core.py +1544 -533
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +297 -253
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +70 -80
- spacr/gui_mask_app.py +114 -91
- spacr/gui_measure_app.py +109 -88
- spacr/gui_utils.py +376 -32
- spacr/io.py +441 -438
- spacr/mask_app.py +116 -9
- spacr/measure.py +169 -69
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +70 -2
- spacr/plot.py +173 -17
- spacr/sequencing.py +1130 -0
- spacr/sim.py +630 -125
- spacr/timelapse.py +139 -10
- spacr/train.py +188 -21
- spacr/umap.py +0 -689
- spacr/utils.py +1360 -119
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/METADATA +17 -29
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.2.dist-info/RECORD +0 -31
- spacr-0.0.2.dist-info/entry_points.txt +0 -7
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/utils.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1
|
-
import os, re, sqlite3,
|
1
|
+
import sys, os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
from cellpose import models as cp_models
|
5
5
|
from cellpose import denoise
|
6
|
+
|
6
7
|
from skimage import morphology
|
7
8
|
from skimage.measure import label, regionprops_table, regionprops
|
8
9
|
import skimage.measure as measure
|
9
|
-
from
|
10
|
+
from skimage.transform import resize as resizescikit
|
11
|
+
from skimage.morphology import dilation, square
|
12
|
+
from skimage.measure import find_contours
|
13
|
+
from skimage.segmentation import clear_border
|
14
|
+
|
15
|
+
from collections import defaultdict, OrderedDict
|
10
16
|
from PIL import Image
|
11
17
|
import pandas as pd
|
12
18
|
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
@@ -15,54 +21,225 @@ import statsmodels.formula.api as smf
|
|
15
21
|
import statsmodels.api as sm
|
16
22
|
from statsmodels.stats.multitest import multipletests
|
17
23
|
from itertools import combinations
|
18
|
-
from collections import OrderedDict
|
19
24
|
from functools import reduce
|
20
|
-
from IPython.display import display
|
25
|
+
from IPython.display import display
|
26
|
+
|
21
27
|
from multiprocessing import Pool, cpu_count
|
22
|
-
from
|
23
|
-
|
24
|
-
from skimage.measure import find_contours
|
28
|
+
from concurrent.futures import ThreadPoolExecutor
|
29
|
+
|
25
30
|
import torch.nn as nn
|
26
31
|
import torch.nn.functional as F
|
27
|
-
#from torchsummary import summary
|
28
32
|
from torch.utils.checkpoint import checkpoint
|
29
33
|
from torch.utils.data import Subset
|
30
34
|
from torch.autograd import grad
|
31
|
-
|
32
|
-
from skimage.segmentation import clear_border
|
35
|
+
|
33
36
|
import seaborn as sns
|
34
37
|
import matplotlib.pyplot as plt
|
38
|
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
39
|
+
|
35
40
|
import scipy.ndimage as ndi
|
36
41
|
from scipy.spatial import distance
|
37
42
|
from scipy.stats import fisher_exact
|
38
|
-
from scipy.ndimage import
|
43
|
+
from scipy.ndimage.filters import gaussian_filter
|
44
|
+
from scipy.spatial import ConvexHull
|
45
|
+
from scipy.interpolate import splprep, splev
|
46
|
+
|
47
|
+
from sklearn.preprocessing import StandardScaler
|
39
48
|
from skimage.exposure import rescale_intensity
|
40
49
|
from sklearn.metrics import auc, precision_recall_curve
|
41
50
|
from sklearn.model_selection import train_test_split
|
42
51
|
from sklearn.linear_model import Lasso, Ridge
|
43
52
|
from sklearn.preprocessing import OneHotEncoder
|
44
53
|
from sklearn.cluster import KMeans
|
54
|
+
from sklearn.preprocessing import StandardScaler
|
55
|
+
from sklearn.cluster import DBSCAN
|
56
|
+
from sklearn.cluster import KMeans
|
57
|
+
from sklearn.manifold import TSNE
|
58
|
+
|
59
|
+
import umap.umap_ as umap
|
60
|
+
|
61
|
+
from torchvision import models
|
45
62
|
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
|
63
|
+
import torchvision.transforms as transforms
|
46
64
|
|
47
65
|
from .logger import log_function_call
|
48
66
|
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
67
|
+
def check_mask_folder(src,mask_fldr):
|
68
|
+
|
69
|
+
mask_folder = os.path.join(src,'norm_channel_stack',mask_fldr)
|
70
|
+
stack_folder = os.path.join(src,'stack')
|
71
|
+
|
72
|
+
if not os.path.exists(mask_folder):
|
73
|
+
return True
|
74
|
+
|
75
|
+
mask_count = sum(1 for file in os.listdir(mask_folder) if file.endswith('.npy'))
|
76
|
+
stack_count = sum(1 for file in os.listdir(stack_folder) if file.endswith('.npy'))
|
77
|
+
|
78
|
+
if mask_count == stack_count:
|
79
|
+
print(f'All masks have been generated for {mask_fldr}')
|
80
|
+
return False
|
81
|
+
else:
|
82
|
+
return True
|
83
|
+
|
84
|
+
def set_default_plot_merge_settings():
|
85
|
+
settings = {}
|
86
|
+
settings.setdefault('include_noninfected', True)
|
87
|
+
settings.setdefault('include_multiinfected', True)
|
88
|
+
settings.setdefault('include_multinucleated', True)
|
89
|
+
settings.setdefault('remove_background', False)
|
90
|
+
settings.setdefault('filter_min_max', None)
|
91
|
+
settings.setdefault('channel_dims', [0,1,2,3])
|
92
|
+
settings.setdefault('backgrounds', [100,100,100,100])
|
93
|
+
settings.setdefault('cell_mask_dim', 4)
|
94
|
+
settings.setdefault('nucleus_mask_dim', 5)
|
95
|
+
settings.setdefault('pathogen_mask_dim', 6)
|
96
|
+
settings.setdefault('outline_thickness', 3)
|
97
|
+
settings.setdefault('outline_color', 'gbr')
|
98
|
+
settings.setdefault('overlay_chans', [1,2,3])
|
99
|
+
settings.setdefault('overlay', True)
|
100
|
+
settings.setdefault('normalization_percentiles', [2,98])
|
101
|
+
settings.setdefault('normalize', True)
|
102
|
+
settings.setdefault('print_object_number', True)
|
103
|
+
settings.setdefault('nr', 1)
|
104
|
+
settings.setdefault('figuresize', 50)
|
105
|
+
settings.setdefault('cmap', 'inferno')
|
106
|
+
settings.setdefault('verbose', True)
|
107
|
+
|
108
|
+
return settings
|
109
|
+
|
110
|
+
def set_default_settings_preprocess_generate_masks(src, settings={}):
|
111
|
+
# Main settings
|
112
|
+
settings['src'] = src
|
113
|
+
settings.setdefault('preprocess', True)
|
114
|
+
settings.setdefault('masks', True)
|
115
|
+
settings.setdefault('save', True)
|
116
|
+
settings.setdefault('batch_size', 50)
|
117
|
+
settings.setdefault('test_mode', False)
|
118
|
+
settings.setdefault('test_images', 10)
|
119
|
+
settings.setdefault('magnification', 20)
|
120
|
+
settings.setdefault('custom_regex', None)
|
121
|
+
settings.setdefault('metadata_type', 'cellvoyager')
|
122
|
+
settings.setdefault('workers', os.cpu_count()-4)
|
123
|
+
settings.setdefault('randomize', True)
|
124
|
+
settings.setdefault('verbose', True)
|
125
|
+
|
126
|
+
settings.setdefault('remove_background_cell', False)
|
127
|
+
settings.setdefault('remove_background_nucleus', False)
|
128
|
+
settings.setdefault('remove_background_pathogen', False)
|
129
|
+
|
130
|
+
# Channel settings
|
131
|
+
settings.setdefault('cell_channel', None)
|
132
|
+
settings.setdefault('nucleus_channel', None)
|
133
|
+
settings.setdefault('pathogen_channel', None)
|
134
|
+
settings.setdefault('channels', [0,1,2,3])
|
135
|
+
settings.setdefault('pathogen_background', 100)
|
136
|
+
settings.setdefault('pathogen_Signal_to_noise', 10)
|
137
|
+
settings.setdefault('pathogen_CP_prob', 0)
|
138
|
+
settings.setdefault('cell_background', 100)
|
139
|
+
settings.setdefault('cell_Signal_to_noise', 10)
|
140
|
+
settings.setdefault('cell_CP_prob', 0)
|
141
|
+
settings.setdefault('nucleus_background', 100)
|
142
|
+
settings.setdefault('nucleus_Signal_to_noise', 10)
|
143
|
+
settings.setdefault('nucleus_CP_prob', 0)
|
144
|
+
|
145
|
+
settings.setdefault('nucleus_FT', 100)
|
146
|
+
settings.setdefault('cell_FT', 100)
|
147
|
+
settings.setdefault('pathogen_FT', 100)
|
148
|
+
|
149
|
+
# Plot settings
|
150
|
+
settings.setdefault('plot', False)
|
151
|
+
settings.setdefault('figuresize', 50)
|
152
|
+
settings.setdefault('cmap', 'inferno')
|
153
|
+
settings.setdefault('normalize', True)
|
154
|
+
settings.setdefault('normalize_plots', True)
|
155
|
+
settings.setdefault('examples_to_plot', 1)
|
156
|
+
|
157
|
+
# Analasys settings
|
158
|
+
settings.setdefault('pathogen_model', None)
|
159
|
+
settings.setdefault('merge_pathogens', False)
|
160
|
+
settings.setdefault('filter', False)
|
161
|
+
settings.setdefault('lower_percentile', 2)
|
162
|
+
|
163
|
+
# Timelapse settings
|
164
|
+
settings.setdefault('timelapse', False)
|
165
|
+
settings.setdefault('fps', 2)
|
166
|
+
settings.setdefault('timelapse_displacement', None)
|
167
|
+
settings.setdefault('timelapse_memory', 3)
|
168
|
+
settings.setdefault('timelapse_frame_limits', None)
|
169
|
+
settings.setdefault('timelapse_remove_transient', False)
|
170
|
+
settings.setdefault('timelapse_mode', 'trackpy')
|
171
|
+
settings.setdefault('timelapse_objects', 'cells')
|
172
|
+
|
173
|
+
# Misc settings
|
174
|
+
settings.setdefault('all_to_mip', False)
|
175
|
+
settings.setdefault('pick_slice', False)
|
176
|
+
settings.setdefault('skip_mode', '01')
|
177
|
+
settings.setdefault('upscale', False)
|
178
|
+
settings.setdefault('upscale_factor', 2.0)
|
179
|
+
settings.setdefault('adjust_cells', False)
|
180
|
+
|
181
|
+
return settings
|
182
|
+
|
183
|
+
def set_default_settings_preprocess_img_data(settings):
|
184
|
+
|
185
|
+
metadata_type = settings.setdefault('metadata_type', 'cellvoyager')
|
186
|
+
custom_regex = settings.setdefault('custom_regex', None)
|
187
|
+
nr = settings.setdefault('nr', 1)
|
188
|
+
plot = settings.setdefault('plot', True)
|
189
|
+
batch_size = settings.setdefault('batch_size', 50)
|
190
|
+
timelapse = settings.setdefault('timelapse', False)
|
191
|
+
lower_percentile = settings.setdefault('lower_percentile', 2)
|
192
|
+
randomize = settings.setdefault('randomize', True)
|
193
|
+
all_to_mip = settings.setdefault('all_to_mip', False)
|
194
|
+
pick_slice = settings.setdefault('pick_slice', False)
|
195
|
+
skip_mode = settings.setdefault('skip_mode', False)
|
196
|
+
|
197
|
+
cmap = settings.setdefault('cmap', 'inferno')
|
198
|
+
figuresize = settings.setdefault('figuresize', 50)
|
199
|
+
normalize = settings.setdefault('normalize', True)
|
200
|
+
save_dtype = settings.setdefault('save_dtype', 'uint16')
|
201
|
+
|
202
|
+
test_mode = settings.setdefault('test_mode', False)
|
203
|
+
test_images = settings.setdefault('test_images', 10)
|
204
|
+
random_test = settings.setdefault('random_test', True)
|
205
|
+
|
206
|
+
return settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test
|
207
|
+
|
208
|
+
def smooth_hull_lines(cluster_data):
|
209
|
+
hull = ConvexHull(cluster_data)
|
210
|
+
|
211
|
+
# Extract vertices of the hull
|
212
|
+
vertices = hull.points[hull.vertices]
|
213
|
+
|
214
|
+
# Close the loop
|
215
|
+
vertices = np.vstack([vertices, vertices[0, :]])
|
216
|
+
|
217
|
+
# Parameterize the vertices
|
218
|
+
tck, u = splprep(vertices.T, u=None, s=0.0)
|
219
|
+
|
220
|
+
# Evaluate spline at new parameter values
|
221
|
+
new_points = splev(np.linspace(0, 1, 100), tck)
|
222
|
+
|
223
|
+
return new_points[0], new_points[1]
|
224
|
+
|
225
|
+
def _gen_rgb_image(image, channels):
|
226
|
+
"""
|
227
|
+
Generate an RGB image from the specified channels of the input image.
|
53
228
|
|
229
|
+
Args:
|
230
|
+
image (ndarray): The input image.
|
231
|
+
channels (list): List of channel indices to use for RGB.
|
54
232
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
rgb_image
|
59
|
-
|
233
|
+
Returns:
|
234
|
+
rgb_image (ndarray): The generated RGB image.
|
235
|
+
"""
|
236
|
+
rgb_image = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.float32)
|
237
|
+
for i, chan in enumerate(channels):
|
238
|
+
if chan < image.shape[2]:
|
239
|
+
rgb_image[:, :, i] = image[:, :, chan]
|
60
240
|
return rgb_image
|
61
241
|
|
62
242
|
def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
|
63
|
-
from concurrent.futures import ThreadPoolExecutor
|
64
|
-
import cv2
|
65
|
-
|
66
243
|
outlines = []
|
67
244
|
overlayed_image = rgb_image.copy()
|
68
245
|
|
@@ -71,11 +248,13 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
|
|
71
248
|
outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
|
72
249
|
|
73
250
|
# Find and draw contours
|
74
|
-
for j in np.unique(mask)
|
251
|
+
for j in np.unique(mask):
|
252
|
+
if j == 0:
|
253
|
+
continue # Skip background
|
75
254
|
contours = find_contours(mask == j, 0.5)
|
76
255
|
# Convert contours for OpenCV format and draw directly to optimize
|
77
256
|
cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
|
78
|
-
cv2.drawContours(outline, cv_contours, -1, color=
|
257
|
+
cv2.drawContours(outline, cv_contours, -1, color=255, thickness=outline_thickness)
|
79
258
|
|
80
259
|
return dilation(outline, square(outline_thickness))
|
81
260
|
|
@@ -83,19 +262,15 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
|
|
83
262
|
with ThreadPoolExecutor() as executor:
|
84
263
|
outlines = list(executor.map(process_dim, mask_dims))
|
85
264
|
|
86
|
-
# Overlay outlines onto the RGB image
|
265
|
+
# Overlay outlines onto the RGB image
|
87
266
|
for i, outline in enumerate(outlines):
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
267
|
+
color = np.array(outline_colors[i % len(outline_colors)])
|
268
|
+
for j in np.unique(outline):
|
269
|
+
if j == 0:
|
270
|
+
continue # Skip background
|
92
271
|
mask = outline == j
|
93
272
|
overlayed_image[mask] = color # Direct assignment with broadcasting
|
94
273
|
|
95
|
-
# Remove mask_dims from image
|
96
|
-
channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
|
97
|
-
image = np.take(image, channels_to_keep, axis=-1)
|
98
|
-
|
99
274
|
return overlayed_image, outlines, image
|
100
275
|
|
101
276
|
def _convert_cq1_well_id(well_id):
|
@@ -355,43 +530,82 @@ def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pa
|
|
355
530
|
df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
|
356
531
|
df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
|
357
532
|
return df
|
358
|
-
|
359
|
-
def normalize_to_dtype(array,
|
533
|
+
|
534
|
+
def normalize_to_dtype(array, p1=2, p2=98):
|
360
535
|
"""
|
361
|
-
Normalize the
|
536
|
+
Normalize each image in the stack to its own percentiles.
|
362
537
|
|
363
538
|
Parameters:
|
364
539
|
- array: numpy array
|
365
|
-
The input
|
366
|
-
-
|
540
|
+
The input stack to be normalized.
|
541
|
+
- p1: int, optional
|
367
542
|
The lower percentile value for normalization. Default is 2.
|
368
|
-
-
|
543
|
+
- p2: int, optional
|
369
544
|
The upper percentile value for normalization. Default is 98.
|
370
|
-
- percentiles: list of tuples, optional
|
371
|
-
A list of tuples containing the percentile values for each image in the array.
|
372
|
-
If provided, the percentiles for each image will be used instead of q1 and q2.
|
373
545
|
|
374
546
|
Returns:
|
375
547
|
- new_stack: numpy array
|
376
|
-
The normalized
|
548
|
+
The normalized stack with the same shape as the input stack.
|
377
549
|
"""
|
378
550
|
nimg = array.shape[2]
|
379
551
|
new_stack = np.empty_like(array)
|
380
|
-
|
381
|
-
|
552
|
+
|
553
|
+
for i in range(nimg):
|
554
|
+
img = array[:, :, i]
|
382
555
|
non_zero_img = img[img > 0]
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
else:
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
556
|
+
|
557
|
+
if non_zero_img.size > 0:
|
558
|
+
img_min = np.percentile(non_zero_img, p1)
|
559
|
+
img_max = np.percentile(non_zero_img, p2)
|
560
|
+
else:
|
561
|
+
img_min = img.min()
|
562
|
+
img_max = img.max()
|
563
|
+
|
564
|
+
# Determine output range based on dtype
|
565
|
+
if np.issubdtype(array.dtype, np.integer):
|
566
|
+
out_range = (0, np.iinfo(array.dtype).max)
|
567
|
+
else:
|
568
|
+
out_range = (0.0, 1.0)
|
569
|
+
|
570
|
+
img = rescale_intensity(img, in_range=(img_min, img_max), out_range=out_range).astype(array.dtype)
|
571
|
+
new_stack[:, :, i] = img
|
572
|
+
|
573
|
+
return new_stack
|
574
|
+
|
575
|
+
def normalize_to_dtype(array, p1=2, p2=98):
|
576
|
+
"""
|
577
|
+
Normalize each image in the stack to its own percentiles.
|
578
|
+
|
579
|
+
Parameters:
|
580
|
+
- array: numpy array
|
581
|
+
The input stack to be normalized.
|
582
|
+
- p1: int, optional
|
583
|
+
The lower percentile value for normalization. Default is 2.
|
584
|
+
- p2: int, optional
|
585
|
+
The upper percentile value for normalization. Default is 98.
|
586
|
+
|
587
|
+
Returns:
|
588
|
+
- new_stack: numpy array
|
589
|
+
The normalized stack with the same shape as the input stack.
|
590
|
+
"""
|
591
|
+
nimg = array.shape[2]
|
592
|
+
new_stack = np.empty_like(array, dtype=np.float32)
|
593
|
+
|
594
|
+
for i in range(nimg):
|
595
|
+
img = array[:, :, i]
|
596
|
+
non_zero_img = img[img > 0]
|
597
|
+
|
598
|
+
if non_zero_img.size > 0:
|
599
|
+
img_min = np.percentile(non_zero_img, p1)
|
600
|
+
img_max = np.percentile(non_zero_img, p2)
|
601
|
+
else:
|
602
|
+
img_min = img.min()
|
603
|
+
img_max = img.max()
|
604
|
+
|
605
|
+
# Normalize to the range (0, 1) for visualization
|
606
|
+
img = rescale_intensity(img, in_range=(img_min, img_max), out_range=(0.0, 1.0))
|
607
|
+
new_stack[:, :, i] = img
|
608
|
+
|
395
609
|
return new_stack
|
396
610
|
|
397
611
|
def _list_endpoint_subdirectories(base_dir):
|
@@ -749,7 +963,7 @@ def _get_diam(mag, obj):
|
|
749
963
|
elif obj == 'nucleus':
|
750
964
|
diamiter = 60
|
751
965
|
elif obj == 'pathogen':
|
752
|
-
diamiter =
|
966
|
+
diamiter = 20
|
753
967
|
else:
|
754
968
|
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
755
969
|
|
@@ -769,7 +983,7 @@ def _get_diam(mag, obj):
|
|
769
983
|
if obj == 'nucleus':
|
770
984
|
diamiter = 90
|
771
985
|
if obj == 'pathogen':
|
772
|
-
diamiter =
|
986
|
+
diamiter = 60
|
773
987
|
else:
|
774
988
|
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
775
989
|
else:
|
@@ -781,8 +995,8 @@ def _get_object_settings(object_type, settings):
|
|
781
995
|
object_settings = {}
|
782
996
|
|
783
997
|
object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
|
784
|
-
object_settings['minimum_size'] = (object_settings['diameter']**2)/
|
785
|
-
object_settings['maximum_size'] = (object_settings['diameter']**2)*
|
998
|
+
object_settings['minimum_size'] = (object_settings['diameter']**2)/4
|
999
|
+
object_settings['maximum_size'] = (object_settings['diameter']**2)*10
|
786
1000
|
object_settings['merge'] = False
|
787
1001
|
object_settings['resample'] = True
|
788
1002
|
object_settings['remove_border_objects'] = False
|
@@ -793,21 +1007,23 @@ def _get_object_settings(object_type, settings):
|
|
793
1007
|
object_settings['model_name'] = 'cyto'
|
794
1008
|
else:
|
795
1009
|
object_settings['model_name'] = 'cyto2'
|
796
|
-
object_settings['filter_size'] =
|
797
|
-
object_settings['filter_intensity'] =
|
1010
|
+
object_settings['filter_size'] = False
|
1011
|
+
object_settings['filter_intensity'] = False
|
798
1012
|
object_settings['restore_type'] = settings.get('cell_restore_type', None)
|
799
1013
|
|
800
1014
|
elif object_type == 'nucleus':
|
801
1015
|
object_settings['model_name'] = 'nuclei'
|
802
|
-
object_settings['filter_size'] =
|
803
|
-
object_settings['filter_intensity'] =
|
1016
|
+
object_settings['filter_size'] = False
|
1017
|
+
object_settings['filter_intensity'] = False
|
804
1018
|
object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
|
805
1019
|
|
806
1020
|
elif object_type == 'pathogen':
|
807
1021
|
object_settings['model_name'] = 'cyto'
|
808
|
-
object_settings['filter_size'] =
|
809
|
-
object_settings['filter_intensity'] =
|
1022
|
+
object_settings['filter_size'] = False
|
1023
|
+
object_settings['filter_intensity'] = False
|
1024
|
+
object_settings['resample'] = False
|
810
1025
|
object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
|
1026
|
+
object_settings['merge'] = settings['merge_pathogens']
|
811
1027
|
|
812
1028
|
else:
|
813
1029
|
print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
|
@@ -884,17 +1100,15 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
|
|
884
1100
|
|
885
1101
|
if not pathogen_channel is None:
|
886
1102
|
if not nucleus_channel is None:
|
887
|
-
|
1103
|
+
if not pathogen_channel is None:
|
1104
|
+
cellpose_channels['pathogen'] = [0,2]
|
1105
|
+
else:
|
1106
|
+
cellpose_channels['pathogen'] = [0,1]
|
888
1107
|
else:
|
889
1108
|
cellpose_channels['pathogen'] = [0,0]
|
890
1109
|
|
891
1110
|
if not cell_channel is None:
|
892
1111
|
if not nucleus_channel is None:
|
893
|
-
if not pathogen_channel is None:
|
894
|
-
cellpose_channels['cell'] = [0,2]
|
895
|
-
else:
|
896
|
-
cellpose_channels['cell'] = [0,1]
|
897
|
-
elif not pathogen_channel is None:
|
898
1112
|
cellpose_channels['cell'] = [0,1]
|
899
1113
|
else:
|
900
1114
|
cellpose_channels['cell'] = [0,0]
|
@@ -1069,7 +1283,7 @@ class Cache:
|
|
1069
1283
|
cache (OrderedDict): The cache data structure.
|
1070
1284
|
"""
|
1071
1285
|
|
1072
|
-
def
|
1286
|
+
def __init__(self, max_size):
|
1073
1287
|
self.cache = OrderedDict()
|
1074
1288
|
self.max_size = max_size
|
1075
1289
|
|
@@ -1100,7 +1314,7 @@ class ScaledDotProductAttention(nn.Module):
|
|
1100
1314
|
|
1101
1315
|
"""
|
1102
1316
|
|
1103
|
-
def
|
1317
|
+
def __init__(self, d_k):
|
1104
1318
|
super(ScaledDotProductAttention, self).__init__()
|
1105
1319
|
self.d_k = d_k
|
1106
1320
|
|
@@ -1131,7 +1345,7 @@ class SelfAttention(nn.Module):
|
|
1131
1345
|
d_k (int): Dimensionality of the key and query vectors.
|
1132
1346
|
"""
|
1133
1347
|
|
1134
|
-
def
|
1348
|
+
def __init__(self, in_channels, d_k):
|
1135
1349
|
super(SelfAttention, self).__init__()
|
1136
1350
|
self.W_q = nn.Linear(in_channels, d_k)
|
1137
1351
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1155,7 +1369,7 @@ class SelfAttention(nn.Module):
|
|
1155
1369
|
return output
|
1156
1370
|
|
1157
1371
|
class ScaledDotProductAttention(nn.Module):
|
1158
|
-
def
|
1372
|
+
def __init__(self, d_k):
|
1159
1373
|
"""
|
1160
1374
|
Initializes the ScaledDotProductAttention module.
|
1161
1375
|
|
@@ -1192,7 +1406,7 @@ class SelfAttention(nn.Module):
|
|
1192
1406
|
in_channels (int): Number of input channels.
|
1193
1407
|
d_k (int): Dimensionality of the key and query vectors.
|
1194
1408
|
"""
|
1195
|
-
def
|
1409
|
+
def __init__(self, in_channels, d_k):
|
1196
1410
|
super(SelfAttention, self).__init__()
|
1197
1411
|
self.W_q = nn.Linear(in_channels, d_k)
|
1198
1412
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1223,7 +1437,7 @@ class EarlyFusion(nn.Module):
|
|
1223
1437
|
Args:
|
1224
1438
|
in_channels (int): Number of input channels.
|
1225
1439
|
"""
|
1226
|
-
def
|
1440
|
+
def __init__(self, in_channels):
|
1227
1441
|
super(EarlyFusion, self).__init__()
|
1228
1442
|
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
|
1229
1443
|
|
@@ -1242,7 +1456,7 @@ class EarlyFusion(nn.Module):
|
|
1242
1456
|
|
1243
1457
|
# Spatial Attention Mechanism
|
1244
1458
|
class SpatialAttention(nn.Module):
|
1245
|
-
def
|
1459
|
+
def __init__(self, kernel_size=7):
|
1246
1460
|
"""
|
1247
1461
|
Initializes the SpatialAttention module.
|
1248
1462
|
|
@@ -1287,7 +1501,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1287
1501
|
forward: Forward method for the module.
|
1288
1502
|
"""
|
1289
1503
|
|
1290
|
-
def
|
1504
|
+
def __init__(self, in_channels, out_channels):
|
1291
1505
|
super(MultiScaleBlockWithAttention, self).__init__()
|
1292
1506
|
self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
|
1293
1507
|
self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
|
@@ -1320,7 +1534,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1320
1534
|
|
1321
1535
|
# Final Classifier
|
1322
1536
|
class CustomCellClassifier(nn.Module):
|
1323
|
-
def
|
1537
|
+
def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
|
1324
1538
|
super(CustomCellClassifier, self).__init__()
|
1325
1539
|
self.early_fusion = EarlyFusion(in_channels=3)
|
1326
1540
|
|
@@ -1349,7 +1563,7 @@ class CustomCellClassifier(nn.Module):
|
|
1349
1563
|
|
1350
1564
|
#CNN and Transformer class, pick any Torch model.
|
1351
1565
|
class TorchModel(nn.Module):
|
1352
|
-
def
|
1566
|
+
def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
|
1353
1567
|
super(TorchModel, self).__init__()
|
1354
1568
|
self.model_name = model_name
|
1355
1569
|
self.use_checkpoint = use_checkpoint
|
@@ -1423,7 +1637,7 @@ class TorchModel(nn.Module):
|
|
1423
1637
|
return logits
|
1424
1638
|
|
1425
1639
|
class FocalLossWithLogits(nn.Module):
|
1426
|
-
def
|
1640
|
+
def __init__(self, alpha=1, gamma=2):
|
1427
1641
|
super(FocalLossWithLogits, self).__init__()
|
1428
1642
|
self.alpha = alpha
|
1429
1643
|
self.gamma = gamma
|
@@ -1435,7 +1649,7 @@ class FocalLossWithLogits(nn.Module):
|
|
1435
1649
|
return focal_loss.mean()
|
1436
1650
|
|
1437
1651
|
class ResNet(nn.Module):
|
1438
|
-
def
|
1652
|
+
def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
|
1439
1653
|
super(ResNet, self).__init__()
|
1440
1654
|
|
1441
1655
|
resnet_map = {
|
@@ -1788,25 +2002,24 @@ def annotate_predictions(csv_loc):
|
|
1788
2002
|
df['cond'] = df.apply(assign_condition, axis=1)
|
1789
2003
|
return df
|
1790
2004
|
|
1791
|
-
def
|
2005
|
+
def initiate_counter(counter_, lock_):
|
1792
2006
|
global counter, lock
|
1793
2007
|
counter = counter_
|
1794
2008
|
lock = lock_
|
1795
2009
|
|
1796
|
-
def add_images_to_tar(
|
1797
|
-
global counter, lock, total_images
|
1798
|
-
paths_chunk, tar_path = args
|
2010
|
+
def add_images_to_tar(paths_chunk, tar_path, total_images):
|
1799
2011
|
with tarfile.open(tar_path, 'w') as tar:
|
1800
|
-
for img_path in paths_chunk:
|
2012
|
+
for i, img_path in enumerate(paths_chunk):
|
1801
2013
|
arcname = os.path.basename(img_path)
|
1802
2014
|
try:
|
1803
2015
|
tar.add(img_path, arcname=arcname)
|
1804
2016
|
with lock:
|
1805
2017
|
counter.value += 1
|
1806
|
-
|
2018
|
+
if counter.value % 100 == 0: # Print every 100 updates
|
2019
|
+
progress = (counter.value / total_images) * 100
|
2020
|
+
print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
|
1807
2021
|
except FileNotFoundError:
|
1808
2022
|
print(f"File not found: {img_path}")
|
1809
|
-
return tar_path
|
1810
2023
|
|
1811
2024
|
def generate_fraction_map(df, gene_column, min_frequency=0.0):
|
1812
2025
|
df['fraction'] = df['count']/df['well_read_sum']
|
@@ -2255,8 +2468,8 @@ def dice_coefficient(mask1, mask2):
|
|
2255
2468
|
def extract_boundaries(mask, dilation_radius=1):
|
2256
2469
|
binary_mask = (mask > 0).astype(np.uint8)
|
2257
2470
|
struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
|
2258
|
-
dilated = binary_dilation(binary_mask, footprint=struct_elem)
|
2259
|
-
eroded = binary_erosion(binary_mask, footprint=struct_elem)
|
2471
|
+
dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
|
2472
|
+
eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
|
2260
2473
|
boundary = dilated ^ eroded
|
2261
2474
|
return boundary
|
2262
2475
|
|
@@ -2669,6 +2882,13 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
|
|
2669
2882
|
print(f'Number of objects before filtration: {num_objects}')
|
2670
2883
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2671
2884
|
|
2885
|
+
if merge:
|
2886
|
+
mask = merge_touching_objects(mask, threshold=0.66)
|
2887
|
+
if plot and idx == 0:
|
2888
|
+
num_objects = mask_object_count(mask)
|
2889
|
+
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2890
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2891
|
+
|
2672
2892
|
if filter_size:
|
2673
2893
|
props = measure.regionprops_table(mask, properties=['label', 'area'])
|
2674
2894
|
valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
|
@@ -2714,13 +2934,6 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
|
|
2714
2934
|
print(f'Number of objects after removing border objects, : {num_objects}')
|
2715
2935
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2716
2936
|
|
2717
|
-
if merge:
|
2718
|
-
mask = merge_touching_objects(mask, threshold=0.25)
|
2719
|
-
if plot and idx == 0:
|
2720
|
-
num_objects = mask_object_count(mask)
|
2721
|
-
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2722
|
-
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2723
|
-
|
2724
2937
|
mask_stack.append(mask)
|
2725
2938
|
|
2726
2939
|
return mask_stack
|
@@ -2758,15 +2971,37 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
|
|
2758
2971
|
print(f'After {object_type} maximum mean intensity filter: {len(df)}')
|
2759
2972
|
return df
|
2760
2973
|
|
2761
|
-
def
|
2974
|
+
def _get_regex(metadata_type, img_format, custom_regex=None):
|
2975
|
+
|
2976
|
+
if img_format == None:
|
2977
|
+
img_format == '.tif'
|
2978
|
+
if metadata_type == 'cellvoyager':
|
2979
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2980
|
+
elif metadata_type == 'cq1':
|
2981
|
+
regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2982
|
+
elif metadata_type == 'nikon':
|
2983
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2984
|
+
elif metadata_type == 'zeis':
|
2985
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2986
|
+
elif metadata_type == 'leica':
|
2987
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2988
|
+
elif metadata_type == 'custom':
|
2989
|
+
regex = f'({custom_regex}){img_format}'
|
2990
|
+
|
2991
|
+
print(f'regex mode:{metadata_type} regex:{regex}')
|
2992
|
+
return regex
|
2993
|
+
|
2994
|
+
def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True):
|
2995
|
+
|
2762
2996
|
if timelapse:
|
2763
2997
|
test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
|
2764
|
-
|
2765
|
-
test_images = 10 # Use 10 sets for non-timelapse scenarios
|
2766
|
-
|
2998
|
+
|
2767
2999
|
test_folder_path = os.path.join(src, 'test')
|
2768
3000
|
os.makedirs(test_folder_path, exist_ok=True)
|
2769
3001
|
regular_expression = re.compile(regex)
|
3002
|
+
|
3003
|
+
if os.path.exists(os.path.join(src, 'orig')):
|
3004
|
+
src = os.path.join(src, 'orig')
|
2770
3005
|
|
2771
3006
|
all_filenames = [filename for filename in os.listdir(src) if regular_expression.match(filename)]
|
2772
3007
|
print(f'Found {len(all_filenames)} files')
|
@@ -2778,24 +3013,20 @@ def _run_test_mode(src, regex, timelapse=False):
|
|
2778
3013
|
plate = match.group('plateID') if 'plateID' in match.groupdict() else os.path.basename(src)
|
2779
3014
|
well = match.group('wellID')
|
2780
3015
|
field = match.group('fieldID')
|
2781
|
-
|
2782
|
-
if timelapse:
|
2783
|
-
set_identifier = (plate, well, field)
|
2784
|
-
else:
|
2785
|
-
# For non-timelapse, you might want to distinguish sets more granularly
|
2786
|
-
# Here, assuming you're grouping by plate, well, and field for simplicity
|
2787
|
-
set_identifier = (plate, well, field)
|
3016
|
+
set_identifier = (plate, well, field)
|
2788
3017
|
images_by_set[set_identifier].append(filename)
|
2789
3018
|
|
2790
3019
|
# Prepare for random selection
|
2791
3020
|
set_identifiers = list(images_by_set.keys())
|
3021
|
+
if random_test:
|
3022
|
+
random.seed(42)
|
2792
3023
|
random.shuffle(set_identifiers) # Randomize the order
|
2793
3024
|
|
2794
3025
|
# Select a subset based on the test_images count
|
2795
3026
|
selected_sets = set_identifiers[:test_images]
|
2796
3027
|
|
2797
3028
|
# Print information about the number of sets used
|
2798
|
-
print(f'Using {
|
3029
|
+
print(f'Using {len(selected_sets)} random image set(s) for test model')
|
2799
3030
|
|
2800
3031
|
# Copy files for selected sets to the test folder
|
2801
3032
|
for set_identifier in selected_sets:
|
@@ -2804,24 +3035,1034 @@ def _run_test_mode(src, regex, timelapse=False):
|
|
2804
3035
|
|
2805
3036
|
return test_folder_path
|
2806
3037
|
|
2807
|
-
def _choose_model(model_name, device, object_type='cell', restore_type=None):
|
3038
|
+
def _choose_model(model_name, device, object_type='cell', restore_type=None, object_settings={}):
|
3039
|
+
|
3040
|
+
if object_type == 'pathogen':
|
3041
|
+
if model_name == 'toxo_pv_lumen':
|
3042
|
+
diameter = object_settings['diameter']
|
3043
|
+
current_dir = os.path.dirname(__file__)
|
3044
|
+
model_path = os.path.join(current_dir, 'models', 'cp', 'toxo_pv_lumen.CP_model')
|
3045
|
+
print(model_path)
|
3046
|
+
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=model_path, diam_mean=diameter, device=device)
|
3047
|
+
#model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type='cyto', device=device)
|
3048
|
+
print(f'Using Toxoplasma PV lumen model to generate pathogen masks')
|
3049
|
+
return model
|
3050
|
+
|
2808
3051
|
restore_list = ['denoise', 'deblur', 'upsample', None]
|
2809
3052
|
if restore_type not in restore_list:
|
2810
3053
|
print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
|
2811
3054
|
restore_type = None
|
2812
3055
|
|
2813
3056
|
if restore_type == None:
|
2814
|
-
|
3057
|
+
if model_name in ['cyto', 'cyto2', 'cyto3', 'nuclei']:
|
3058
|
+
model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type=model_name, device=device)
|
3059
|
+
|
2815
3060
|
else:
|
2816
3061
|
if object_type == 'nucleus':
|
2817
3062
|
restore = f'{type}_nuclei'
|
2818
|
-
model = denoise.CellposeDenoiseModel(gpu=
|
3063
|
+
model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
|
2819
3064
|
else:
|
2820
3065
|
restore = f'{type}_cyto3'
|
2821
3066
|
if model_name =='cyto2':
|
2822
3067
|
chan2_restore = True
|
2823
3068
|
if model_name =='cyto':
|
2824
3069
|
chan2_restore = False
|
2825
|
-
model = denoise.CellposeDenoiseModel(gpu=
|
3070
|
+
model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
|
3071
|
+
|
3072
|
+
return model
|
3073
|
+
|
3074
|
+
class SelectChannels:
|
3075
|
+
def __init__(self, channels):
|
3076
|
+
self.channels = channels
|
3077
|
+
|
3078
|
+
def __call__(self, img):
|
3079
|
+
img = img.clone()
|
3080
|
+
if 1 not in self.channels:
|
3081
|
+
img[0, :, :] = 0 # Zero out the red channel
|
3082
|
+
if 2 not in self.channels:
|
3083
|
+
img[1, :, :] = 0 # Zero out the green channel
|
3084
|
+
if 3 not in self.channels:
|
3085
|
+
img[2, :, :] = 0 # Zero out the blue channel
|
3086
|
+
return img
|
3087
|
+
|
3088
|
+
def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
|
3089
|
+
|
3090
|
+
if normalize:
|
3091
|
+
transform = transforms.Compose([
|
3092
|
+
transforms.ToTensor(),
|
3093
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
3094
|
+
SelectChannels(channels),
|
3095
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
3096
|
+
else:
|
3097
|
+
transform = transforms.Compose([
|
3098
|
+
transforms.ToTensor(),
|
3099
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
3100
|
+
SelectChannels(channels)])
|
3101
|
+
|
3102
|
+
image = Image.open(image_path).convert('RGB')
|
3103
|
+
input_tensor = transform(image).unsqueeze(0)
|
3104
|
+
return image, input_tensor
|
3105
|
+
|
3106
|
+
|
3107
|
+
class SaliencyMapGenerator:
|
3108
|
+
def __init__(self, model):
|
3109
|
+
self.model = model
|
3110
|
+
|
3111
|
+
def compute_saliency_maps(self, X, y):
|
3112
|
+
self.model.eval()
|
3113
|
+
X.requires_grad_()
|
3114
|
+
|
3115
|
+
# Forward pass
|
3116
|
+
scores = self.model(X).squeeze()
|
3117
|
+
|
3118
|
+
# For binary classification, target scores can be the single output
|
3119
|
+
target_scores = scores * (2 * y - 1)
|
3120
|
+
|
3121
|
+
self.model.zero_grad()
|
3122
|
+
target_scores.backward(torch.ones_like(target_scores))
|
3123
|
+
|
3124
|
+
saliency = X.grad.abs()
|
3125
|
+
return saliency
|
3126
|
+
|
3127
|
+
def plot_saliency_maps(self, X, y, saliency, class_names):
|
3128
|
+
N = X.shape[0]
|
3129
|
+
for i in range(N):
|
3130
|
+
plt.subplot(2, N, i + 1)
|
3131
|
+
plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
|
3132
|
+
plt.axis('off')
|
3133
|
+
plt.title(class_names[y[i]])
|
3134
|
+
plt.subplot(2, N, N + i + 1)
|
3135
|
+
plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
|
3136
|
+
plt.axis('off')
|
3137
|
+
plt.gcf().set_size_inches(12, 5)
|
3138
|
+
plt.show()
|
3139
|
+
|
3140
|
+
def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
|
3141
|
+
preprocess = transforms.Compose([
|
3142
|
+
transforms.Resize((image_size, image_size)),
|
3143
|
+
transforms.ToTensor(),
|
3144
|
+
])
|
3145
|
+
|
3146
|
+
image = Image.open(image_path).convert('RGB')
|
3147
|
+
input_tensor = preprocess(image)
|
3148
|
+
if normalize:
|
3149
|
+
input_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_tensor)
|
3150
|
+
input_tensor = input_tensor.unsqueeze(0)
|
3151
|
+
|
3152
|
+
return image, input_tensor
|
3153
|
+
|
3154
|
+
def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names = ['nc', 'pc']):
|
3155
|
+
|
3156
|
+
def jitter(img, ox, oy):
|
3157
|
+
# Randomly jitter the image
|
3158
|
+
return torch.roll(torch.roll(img, ox, dims=2), oy, dims=3)
|
3159
|
+
|
3160
|
+
def blur_image(img, sigma=1):
|
3161
|
+
# Apply Gaussian blur to the image
|
3162
|
+
img_np = img.cpu().numpy()
|
3163
|
+
for i in range(img_np.shape[1]):
|
3164
|
+
img_np[:, i] = gaussian_filter(img_np[:, i], sigma=sigma)
|
3165
|
+
img.copy_(torch.tensor(img_np).to(img.device))
|
3166
|
+
|
3167
|
+
def deprocess(img_tensor):
|
3168
|
+
# Convert the tensor image to a numpy array for visualization
|
3169
|
+
img_tensor = img_tensor.clone()
|
3170
|
+
for c in range(3):
|
3171
|
+
img_tensor[:, c] = img_tensor[:, c] * SQUEEZENET_STD[c] + SQUEEZENET_MEAN[c]
|
3172
|
+
img_tensor = img_tensor.clamp(0, 1)
|
3173
|
+
return img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
|
3174
|
+
|
3175
|
+
# Assuming these are defined somewhere in your codebase
|
3176
|
+
SQUEEZENET_MEAN = [0.485, 0.456, 0.406]
|
3177
|
+
SQUEEZENET_STD = [0.229, 0.224, 0.225]
|
3178
|
+
|
3179
|
+
model = torch.load(model_path)
|
3180
|
+
|
3181
|
+
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
|
3182
|
+
len_chans = len(channels)
|
3183
|
+
model.type(dtype)
|
3184
|
+
|
3185
|
+
# Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
|
3186
|
+
img = torch.randn(1, len_chans, img_size, img_size).mul_(1.0).type(dtype).requires_grad_()
|
3187
|
+
|
3188
|
+
for t in range(num_iterations):
|
3189
|
+
# Randomly jitter the image a bit; this gives slightly nicer results
|
3190
|
+
ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
|
3191
|
+
img.data.copy_(jitter(img.data, ox, oy))
|
3192
|
+
|
3193
|
+
# Forward pass
|
3194
|
+
score = model(img)
|
3195
|
+
|
3196
|
+
if target_y == 0:
|
3197
|
+
target_score = -score
|
3198
|
+
else:
|
3199
|
+
target_score = score
|
3200
|
+
|
3201
|
+
# Add regularization
|
3202
|
+
target_score = target_score - l2_reg * torch.norm(img)
|
3203
|
+
|
3204
|
+
# Backward pass
|
3205
|
+
target_score.backward()
|
3206
|
+
|
3207
|
+
# Gradient ascent step
|
3208
|
+
with torch.no_grad():
|
3209
|
+
img += learning_rate * img.grad / torch.norm(img.grad)
|
3210
|
+
img.grad.zero_()
|
3211
|
+
|
3212
|
+
# Undo the random jitter
|
3213
|
+
img.data.copy_(jitter(img.data, -ox, -oy))
|
3214
|
+
|
3215
|
+
# As regularizer, clamp and periodically blur the image
|
3216
|
+
for c in range(3):
|
3217
|
+
lo = float(-SQUEEZENET_MEAN[c] / SQUEEZENET_STD[c])
|
3218
|
+
hi = float((1.0 - SQUEEZENET_MEAN[c]) / SQUEEZENET_STD[c])
|
3219
|
+
img.data[:, c].clamp_(min=lo, max=hi)
|
3220
|
+
if t % blur_every == 0:
|
3221
|
+
blur_image(img.data, sigma=0.5)
|
3222
|
+
|
3223
|
+
# Periodically show the image
|
3224
|
+
if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
|
3225
|
+
plt.imshow(deprocess(img.data.clone().cpu()))
|
3226
|
+
class_name = class_names[target_y]
|
3227
|
+
plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
|
3228
|
+
plt.gcf().set_size_inches(4, 4)
|
3229
|
+
plt.axis('off')
|
3230
|
+
plt.show()
|
3231
|
+
|
3232
|
+
return deprocess(img.data.cpu())
|
3233
|
+
|
3234
|
+
def get_submodules(model, prefix=''):
|
3235
|
+
submodules = []
|
3236
|
+
for name, module in model.named_children():
|
3237
|
+
full_name = prefix + ('.' if prefix else '') + name
|
3238
|
+
submodules.append(full_name)
|
3239
|
+
submodules.extend(get_submodules(module, full_name))
|
3240
|
+
return submodules
|
3241
|
+
|
3242
|
+
class GradCAM:
|
3243
|
+
def __init__(self, model, target_layers=None, use_cuda=True):
|
3244
|
+
self.model = model
|
3245
|
+
self.model.eval()
|
3246
|
+
self.target_layers = target_layers
|
3247
|
+
self.cuda = use_cuda
|
3248
|
+
if self.cuda:
|
3249
|
+
self.model = model.cuda()
|
3250
|
+
|
3251
|
+
def forward(self, input):
|
3252
|
+
return self.model(input)
|
3253
|
+
|
3254
|
+
def __call__(self, x, index=None):
|
3255
|
+
if self.cuda:
|
3256
|
+
x = x.cuda()
|
3257
|
+
|
3258
|
+
features = []
|
3259
|
+
def hook(module, input, output):
|
3260
|
+
features.append(output)
|
3261
|
+
|
3262
|
+
handles = []
|
3263
|
+
for name, module in self.model.named_modules():
|
3264
|
+
if name in self.target_layers:
|
3265
|
+
handles.append(module.register_forward_hook(hook))
|
3266
|
+
|
3267
|
+
output = self.forward(x)
|
3268
|
+
if index is None:
|
3269
|
+
index = np.argmax(output.data.cpu().numpy())
|
3270
|
+
|
3271
|
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
3272
|
+
one_hot[0][index] = 1
|
3273
|
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
3274
|
+
if self.cuda:
|
3275
|
+
one_hot = one_hot.cuda()
|
3276
|
+
|
3277
|
+
one_hot = torch.sum(one_hot * output)
|
3278
|
+
self.model.zero_grad()
|
3279
|
+
one_hot.backward(retain_graph=True)
|
3280
|
+
|
3281
|
+
grads_val = features[0].grad.cpu().data.numpy()
|
3282
|
+
target = features[0].cpu().data.numpy()[0, :]
|
3283
|
+
|
3284
|
+
weights = np.mean(grads_val, axis=(2, 3))[0, :]
|
3285
|
+
cam = np.zeros(target.shape[1:], dtype=np.float32)
|
3286
|
+
|
3287
|
+
for i, w in enumerate(weights):
|
3288
|
+
cam += w * target[i, :, :]
|
3289
|
+
|
3290
|
+
cam = np.maximum(cam, 0)
|
3291
|
+
cam = cv2.resize(cam, (x.size(2), x.size(3)))
|
3292
|
+
cam = cam - np.min(cam)
|
3293
|
+
cam = cam / np.max(cam)
|
3294
|
+
|
3295
|
+
for handle in handles:
|
3296
|
+
handle.remove()
|
3297
|
+
|
3298
|
+
return cam
|
3299
|
+
|
3300
|
+
def show_cam_on_image(img, mask):
|
3301
|
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
3302
|
+
heatmap = np.float32(heatmap) / 255
|
3303
|
+
cam = heatmap + np.float32(img)
|
3304
|
+
cam = cam / np.max(cam)
|
3305
|
+
return np.uint8(255 * cam)
|
3306
|
+
|
3307
|
+
def recommend_target_layers(model):
|
3308
|
+
target_layers = []
|
3309
|
+
for name, module in model.named_modules():
|
3310
|
+
if isinstance(module, torch.nn.Conv2d):
|
3311
|
+
target_layers.append(name)
|
3312
|
+
# Choose the last conv layer as the recommended target layer
|
3313
|
+
if target_layers:
|
3314
|
+
return [target_layers[-1]], target_layers
|
3315
|
+
else:
|
3316
|
+
raise ValueError("No convolutional layers found in the model.")
|
3317
|
+
|
3318
|
+
class IntegratedGradients:
|
3319
|
+
def __init__(self, model):
|
3320
|
+
self.model = model
|
3321
|
+
self.model.eval()
|
3322
|
+
|
3323
|
+
def generate_integrated_gradients(self, input_tensor, target_label_idx, baseline=None, num_steps=50):
|
3324
|
+
if baseline is None:
|
3325
|
+
baseline = torch.zeros_like(input_tensor)
|
3326
|
+
|
3327
|
+
assert baseline.shape == input_tensor.shape
|
3328
|
+
|
3329
|
+
# Scale input and compute gradients
|
3330
|
+
scaled_inputs = [(baseline + (float(i) / num_steps) * (input_tensor - baseline)).requires_grad_(True) for i in range(0, num_steps + 1)]
|
3331
|
+
grads = []
|
3332
|
+
for scaled_input in scaled_inputs:
|
3333
|
+
out = self.model(scaled_input)
|
3334
|
+
self.model.zero_grad()
|
3335
|
+
out[0, target_label_idx].backward(retain_graph=True)
|
3336
|
+
grads.append(scaled_input.grad.data.cpu().numpy())
|
3337
|
+
|
3338
|
+
avg_grads = np.mean(grads[:-1], axis=0)
|
3339
|
+
integrated_grads = (input_tensor.cpu().data.numpy() - baseline.cpu().data.numpy()) * avg_grads
|
3340
|
+
return integrated_grads
|
3341
|
+
|
3342
|
+
def get_db_paths(src):
|
3343
|
+
if isinstance(src, str):
|
3344
|
+
src = [src]
|
3345
|
+
db_paths = [os.path.join(source, 'measurements/measurements.db') for source in src]
|
3346
|
+
return db_paths
|
3347
|
+
|
3348
|
+
def get_sequencing_paths(src):
|
3349
|
+
if isinstance(src, str):
|
3350
|
+
src = [src]
|
3351
|
+
seq_paths = [os.path.join(source, 'sequencing/sequencing_data.csv') for source in src]
|
3352
|
+
return seq_paths
|
3353
|
+
|
3354
|
+
def load_image_paths(c, visualize):
|
3355
|
+
c.execute(f'SELECT * FROM png_list')
|
3356
|
+
data = c.fetchall()
|
3357
|
+
columns_info = c.execute(f'PRAGMA table_info(png_list)').fetchall()
|
3358
|
+
column_names = [col_info[1] for col_info in columns_info]
|
3359
|
+
image_paths_df = pd.DataFrame(data, columns=column_names)
|
3360
|
+
if visualize:
|
3361
|
+
object_visualize = visualize + '_png'
|
3362
|
+
image_paths_df = image_paths_df[image_paths_df['png_path'].str.contains(object_visualize)]
|
3363
|
+
image_paths_df = image_paths_df.set_index('prcfo')
|
3364
|
+
return image_paths_df
|
3365
|
+
|
3366
|
+
def merge_dataframes(df, image_paths_df, verbose):
|
3367
|
+
df.set_index('prcfo', inplace=True)
|
3368
|
+
df = image_paths_df.merge(df, left_index=True, right_index=True)
|
3369
|
+
if verbose:
|
3370
|
+
display(df)
|
3371
|
+
return df
|
3372
|
+
|
3373
|
+
def remove_highly_correlated_columns(df, threshold):
|
3374
|
+
corr_matrix = df.corr().abs()
|
3375
|
+
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
3376
|
+
to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > threshold)]
|
3377
|
+
return df.drop(to_drop, axis=1)
|
3378
|
+
|
3379
|
+
def filter_columns(df, filter_by):
|
3380
|
+
if filter_by != 'morphology':
|
3381
|
+
cols_to_include = [col for col in df.columns if filter_by in str(col)]
|
3382
|
+
else:
|
3383
|
+
cols_to_include = [col for col in df.columns if 'channel' not in str(col)]
|
3384
|
+
df = df[cols_to_include]
|
3385
|
+
return df
|
3386
|
+
|
3387
|
+
def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1, mode='fit', model=False):
|
3388
|
+
"""
|
3389
|
+
Perform dimensionality reduction and clustering on the given data.
|
2826
3390
|
|
2827
|
-
|
3391
|
+
Parameters:
|
3392
|
+
numeric_data (np.ndarray): Numeric data for embedding and clustering.
|
3393
|
+
n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
|
3394
|
+
min_dist (float): Minimum distance for UMAP.
|
3395
|
+
metric (str): Metric for UMAP and DBSCAN.
|
3396
|
+
eps (float): Epsilon for DBSCAN.
|
3397
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
3398
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3399
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3400
|
+
verbose (bool): Whether to print verbose output.
|
3401
|
+
embedding (np.ndarray, optional): Precomputed embedding. Default is None.
|
3402
|
+
return_model (bool): Whether to return the reducer model. Default is False.
|
3403
|
+
|
3404
|
+
Returns:
|
3405
|
+
tuple: embedding, labels (and optionally the reducer model)
|
3406
|
+
"""
|
3407
|
+
|
3408
|
+
if verbose:
|
3409
|
+
v = 1
|
3410
|
+
else:
|
3411
|
+
v = 0
|
3412
|
+
|
3413
|
+
if isinstance(n_neighbors, float):
|
3414
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3415
|
+
|
3416
|
+
if n_neighbors <= 2:
|
3417
|
+
n_neighbors = 2
|
3418
|
+
|
3419
|
+
if mode == 'fit':
|
3420
|
+
if reduction_method == 'umap':
|
3421
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors,
|
3422
|
+
n_components=2,
|
3423
|
+
metric=metric,
|
3424
|
+
n_epochs=None,
|
3425
|
+
learning_rate=1.0,
|
3426
|
+
init='spectral',
|
3427
|
+
min_dist=min_dist,
|
3428
|
+
spread=1.0,
|
3429
|
+
set_op_mix_ratio=1.0,
|
3430
|
+
local_connectivity=1,
|
3431
|
+
repulsion_strength=1.0,
|
3432
|
+
negative_sample_rate=5,
|
3433
|
+
transform_queue_size=4.0,
|
3434
|
+
a=None,
|
3435
|
+
b=None,
|
3436
|
+
random_state=42,
|
3437
|
+
metric_kwds=None,
|
3438
|
+
angular_rp_forest=False,
|
3439
|
+
target_n_neighbors=-1,
|
3440
|
+
target_metric='categorical',
|
3441
|
+
target_metric_kwds=None,
|
3442
|
+
target_weight=0.5,
|
3443
|
+
transform_seed=42,
|
3444
|
+
n_jobs=n_jobs,
|
3445
|
+
verbose=verbose)
|
3446
|
+
|
3447
|
+
elif reduction_method == 'tsne':
|
3448
|
+
reducer = TSNE(n_components=2,
|
3449
|
+
perplexity=n_neighbors,
|
3450
|
+
early_exaggeration=12.0,
|
3451
|
+
learning_rate=200.0,
|
3452
|
+
n_iter=1000,
|
3453
|
+
n_iter_without_progress=300,
|
3454
|
+
min_grad_norm=1e-7,
|
3455
|
+
metric=metric,
|
3456
|
+
init='random',
|
3457
|
+
verbose=v,
|
3458
|
+
random_state=42,
|
3459
|
+
method='barnes_hut',
|
3460
|
+
angle=0.5,
|
3461
|
+
n_jobs=n_jobs)
|
3462
|
+
|
3463
|
+
else:
|
3464
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
3465
|
+
|
3466
|
+
embedding = reducer.fit_transform(numeric_data)
|
3467
|
+
if verbose:
|
3468
|
+
print(f'Trained and fit reducer')
|
3469
|
+
|
3470
|
+
else:
|
3471
|
+
if not model is None:
|
3472
|
+
embedding = model.transform(numeric_data)
|
3473
|
+
reducer = model
|
3474
|
+
if verbose:
|
3475
|
+
print(f'Fit data to reducer')
|
3476
|
+
else:
|
3477
|
+
raise ValueError(f"Model is None. Please provide a model for transform.")
|
3478
|
+
|
3479
|
+
if clustering == 'dbscan':
|
3480
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
|
3481
|
+
elif clustering == 'kmeans':
|
3482
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
3483
|
+
|
3484
|
+
clustering_model.fit(embedding)
|
3485
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
3486
|
+
|
3487
|
+
if verbose:
|
3488
|
+
print(f'Embedding shape: {embedding.shape}')
|
3489
|
+
|
3490
|
+
return embedding, labels, reducer
|
3491
|
+
|
3492
|
+
def reduction_and_clustering_v1(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1):
|
3493
|
+
"""
|
3494
|
+
Perform dimensionality reduction and clustering on the given data.
|
3495
|
+
|
3496
|
+
Parameters:
|
3497
|
+
numeric_data (np.ndarray): Numeric data for embedding and clustering.
|
3498
|
+
n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
|
3499
|
+
min_dist (float): Minimum distance for UMAP.
|
3500
|
+
metric (str): Metric for UMAP and DBSCAN.
|
3501
|
+
eps (float): Epsilon for DBSCAN.
|
3502
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
3503
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3504
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3505
|
+
verbose (bool): Whether to print verbose output.
|
3506
|
+
embedding (np.ndarray, optional): Precomputed embedding. Default is None.
|
3507
|
+
|
3508
|
+
Returns:
|
3509
|
+
tuple: embedding, labels
|
3510
|
+
"""
|
3511
|
+
|
3512
|
+
if verbose:
|
3513
|
+
v=1
|
3514
|
+
else:
|
3515
|
+
v=0
|
3516
|
+
|
3517
|
+
if isinstance(n_neighbors, float):
|
3518
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3519
|
+
|
3520
|
+
if n_neighbors <= 2:
|
3521
|
+
n_neighbors = 2
|
3522
|
+
|
3523
|
+
if reduction_method == 'umap':
|
3524
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors,
|
3525
|
+
n_components=2,
|
3526
|
+
metric=metric,
|
3527
|
+
n_epochs=None,
|
3528
|
+
learning_rate=1.0,
|
3529
|
+
init='spectral',
|
3530
|
+
min_dist=min_dist,
|
3531
|
+
spread=1.0,
|
3532
|
+
set_op_mix_ratio=1.0,
|
3533
|
+
local_connectivity=1,
|
3534
|
+
repulsion_strength=1.0,
|
3535
|
+
negative_sample_rate=5,
|
3536
|
+
transform_queue_size=4.0,
|
3537
|
+
a=None,
|
3538
|
+
b=None,
|
3539
|
+
random_state=42,
|
3540
|
+
metric_kwds=None,
|
3541
|
+
angular_rp_forest=False,
|
3542
|
+
target_n_neighbors=-1,
|
3543
|
+
target_metric='categorical',
|
3544
|
+
target_metric_kwds=None,
|
3545
|
+
target_weight=0.5,
|
3546
|
+
transform_seed=42,
|
3547
|
+
n_jobs=n_jobs,
|
3548
|
+
verbose=verbose)
|
3549
|
+
|
3550
|
+
elif reduction_method == 'tsne':
|
3551
|
+
|
3552
|
+
#tsne_params.setdefault('n_components', 2)
|
3553
|
+
#reducer = TSNE(**tsne_params)
|
3554
|
+
|
3555
|
+
reducer = TSNE(n_components=2,
|
3556
|
+
perplexity=n_neighbors,
|
3557
|
+
early_exaggeration=12.0,
|
3558
|
+
learning_rate=200.0,
|
3559
|
+
n_iter=1000,
|
3560
|
+
n_iter_without_progress=300,
|
3561
|
+
min_grad_norm=1e-7,
|
3562
|
+
metric=metric,
|
3563
|
+
init='random',
|
3564
|
+
verbose=v,
|
3565
|
+
random_state=42,
|
3566
|
+
method='barnes_hut',
|
3567
|
+
angle=0.5,
|
3568
|
+
n_jobs=n_jobs)
|
3569
|
+
|
3570
|
+
else:
|
3571
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
3572
|
+
|
3573
|
+
if embedding is None:
|
3574
|
+
embedding = reducer.fit_transform(numeric_data)
|
3575
|
+
|
3576
|
+
if clustering == 'dbscan':
|
3577
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
|
3578
|
+
elif clustering == 'kmeans':
|
3579
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
3580
|
+
else:
|
3581
|
+
raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
|
3582
|
+
|
3583
|
+
clustering_model.fit(embedding)
|
3584
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
3585
|
+
|
3586
|
+
if verbose:
|
3587
|
+
print(f'Embedding shape: {embedding.shape}')
|
3588
|
+
|
3589
|
+
return embedding, labels
|
3590
|
+
|
3591
|
+
def remove_noise(embedding, labels):
|
3592
|
+
non_noise_indices = labels != -1
|
3593
|
+
embedding = embedding[non_noise_indices]
|
3594
|
+
labels = labels[non_noise_indices]
|
3595
|
+
return embedding, labels
|
3596
|
+
|
3597
|
+
def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, plot_by_cluster, plot_outlines, plot_points, plot_images, smooth_lines, black_background, figuresize, dot_size, remove_image_canvas, verbose):
|
3598
|
+
unique_labels = np.unique(labels)
|
3599
|
+
#num_clusters = len(unique_labels[unique_labels != 0])
|
3600
|
+
colors, label_to_color_index = assign_colors(unique_labels, colors)
|
3601
|
+
cluster_centers = [np.mean(embedding[labels == cluster_label], axis=0) for cluster_label in unique_labels]
|
3602
|
+
fig, ax = setup_plot(figuresize, black_background)
|
3603
|
+
plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize, dot_size, verbose)
|
3604
|
+
if not image_paths is None and plot_images:
|
3605
|
+
plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose)
|
3606
|
+
plt.show()
|
3607
|
+
return fig
|
3608
|
+
|
3609
|
+
def generate_colors(num_clusters, black_background):
|
3610
|
+
random_colors = np.random.rand(num_clusters + 1, 4)
|
3611
|
+
random_colors[:, 3] = 1
|
3612
|
+
specific_colors = [
|
3613
|
+
[155 / 255, 55 / 255, 155 / 255, 1],
|
3614
|
+
[55 / 255, 155 / 255, 155 / 255, 1],
|
3615
|
+
[55 / 255, 155 / 255, 255 / 255, 1],
|
3616
|
+
[255 / 255, 55 / 255, 155 / 255, 1]
|
3617
|
+
]
|
3618
|
+
random_colors = np.vstack((specific_colors, random_colors[len(specific_colors):]))
|
3619
|
+
if not black_background:
|
3620
|
+
random_colors = np.vstack(([0, 0, 0, 1], random_colors))
|
3621
|
+
return random_colors
|
3622
|
+
|
3623
|
+
def assign_colors(unique_labels, random_colors):
|
3624
|
+
normalized_colors = random_colors / 255
|
3625
|
+
colors_img = [tuple(color) for color in normalized_colors]
|
3626
|
+
colors = [tuple(color) for color in random_colors]
|
3627
|
+
label_to_color_index = {label: index for index, label in enumerate(unique_labels)}
|
3628
|
+
return colors, label_to_color_index
|
3629
|
+
|
3630
|
+
def setup_plot(figuresize, black_background):
|
3631
|
+
if black_background:
|
3632
|
+
plt.rcParams.update({'figure.facecolor': 'black', 'axes.facecolor': 'black', 'text.color': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'axes.labelcolor': 'white'})
|
3633
|
+
else:
|
3634
|
+
plt.rcParams.update({'figure.facecolor': 'white', 'axes.facecolor': 'white', 'text.color': 'black', 'xtick.color': 'black', 'ytick.color': 'black', 'axes.labelcolor': 'black'})
|
3635
|
+
fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
|
3636
|
+
return fig, ax
|
3637
|
+
|
3638
|
+
def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize=50, dot_size=50, verbose=False):
|
3639
|
+
unique_labels = np.unique(labels)
|
3640
|
+
for cluster_label, color, center in zip(unique_labels, colors, cluster_centers):
|
3641
|
+
cluster_data = embedding[labels == cluster_label]
|
3642
|
+
if smooth_lines:
|
3643
|
+
if cluster_data.shape[0] > 2:
|
3644
|
+
x_smooth, y_smooth = smooth_hull_lines(cluster_data)
|
3645
|
+
if plot_outlines:
|
3646
|
+
plt.plot(x_smooth, y_smooth, color=color, linewidth=2)
|
3647
|
+
else:
|
3648
|
+
if cluster_data.shape[0] > 2:
|
3649
|
+
hull = ConvexHull(cluster_data)
|
3650
|
+
for simplex in hull.simplices:
|
3651
|
+
if plot_outlines:
|
3652
|
+
plt.plot(hull.points[simplex, 0], hull.points[simplex, 1], color=color, linewidth=4)
|
3653
|
+
if plot_points:
|
3654
|
+
scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0.5, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
|
3655
|
+
else:
|
3656
|
+
scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
|
3657
|
+
ax.text(center[0], center[1], str(cluster_label), fontsize=12, ha='center', va='center')
|
3658
|
+
plt.legend(loc='best', fontsize=int(figuresize * 0.75))
|
3659
|
+
plt.xlabel('UMAP Dimension 1', fontsize=int(figuresize * 0.75))
|
3660
|
+
plt.ylabel('UMAP Dimension 2', fontsize=int(figuresize * 0.75))
|
3661
|
+
plt.tick_params(axis='both', which='major', labelsize=int(figuresize * 0.75))
|
3662
|
+
|
3663
|
+
def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose):
|
3664
|
+
if plot_by_cluster:
|
3665
|
+
cluster_indices = {label: np.where(labels == label)[0] for label in np.unique(labels) if label != -1}
|
3666
|
+
plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose)
|
3667
|
+
else:
|
3668
|
+
indices = random.sample(range(len(embedding)), image_nr)
|
3669
|
+
for i, index in enumerate(indices):
|
3670
|
+
x, y = embedding[index]
|
3671
|
+
img = Image.open(image_paths[index])
|
3672
|
+
plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
|
3673
|
+
|
3674
|
+
def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose):
|
3675
|
+
for cluster_label, color in zip(np.unique(labels), colors):
|
3676
|
+
if cluster_label == -1:
|
3677
|
+
continue
|
3678
|
+
indices = cluster_indices.get(cluster_label, [])
|
3679
|
+
if len(indices) > image_nr:
|
3680
|
+
indices = random.sample(list(indices), image_nr)
|
3681
|
+
for index in indices:
|
3682
|
+
x, y = embedding[index]
|
3683
|
+
img = Image.open(image_paths[index])
|
3684
|
+
plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
|
3685
|
+
|
3686
|
+
def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
|
3687
|
+
img = np.array(img)
|
3688
|
+
if remove_image_canvas:
|
3689
|
+
img = remove_canvas(img)
|
3690
|
+
imagebox = OffsetImage(img, zoom=img_zoom)
|
3691
|
+
ab = AnnotationBbox(imagebox, (x, y), frameon=False)
|
3692
|
+
ax.add_artist(ab)
|
3693
|
+
|
3694
|
+
def remove_canvas(img):
|
3695
|
+
if img.mode in ['L', 'I']:
|
3696
|
+
img_data = np.array(img)
|
3697
|
+
img_data = img_data / np.max(img_data)
|
3698
|
+
alpha_channel = (img_data > 0).astype(float)
|
3699
|
+
img_data_rgb = np.stack([img_data] * 3, axis=-1)
|
3700
|
+
img_data_with_alpha = np.dstack([img_data_rgb, alpha_channel])
|
3701
|
+
elif img.mode == 'RGB':
|
3702
|
+
img_data = np.array(img)
|
3703
|
+
img_data = img_data / 255.0
|
3704
|
+
alpha_channel = (np.sum(img_data, axis=-1) > 0).astype(float)
|
3705
|
+
img_data_with_alpha = np.dstack([img_data, alpha_channel])
|
3706
|
+
else:
|
3707
|
+
raise ValueError(f"Unsupported image mode: {img.mode}")
|
3708
|
+
return img_data_with_alpha
|
3709
|
+
|
3710
|
+
def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figuresize, black_background, verbose):
|
3711
|
+
unique_labels = np.unique(labels)
|
3712
|
+
num_clusters = len(unique_labels[unique_labels != -1])
|
3713
|
+
if num_clusters == 0:
|
3714
|
+
print("No clusters found.")
|
3715
|
+
return
|
3716
|
+
cluster_images = {label: [] for label in unique_labels if label != -1}
|
3717
|
+
cluster_indices = {label: np.where(labels == label)[0] for label in unique_labels if label != -1}
|
3718
|
+
for cluster_label, indices in cluster_indices.items():
|
3719
|
+
if cluster_label == -1:
|
3720
|
+
continue
|
3721
|
+
if len(indices) > image_nr:
|
3722
|
+
indices = random.sample(list(indices), image_nr)
|
3723
|
+
for index in indices:
|
3724
|
+
img_path = image_paths[index]
|
3725
|
+
img_array = Image.open(img_path)
|
3726
|
+
img = np.array(img_array)
|
3727
|
+
cluster_images[cluster_label].append(img)
|
3728
|
+
fig = plot_grid(cluster_images, colors, figuresize, black_background, verbose)
|
3729
|
+
return fig
|
3730
|
+
|
3731
|
+
def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
|
3732
|
+
num_clusters = len(cluster_images)
|
3733
|
+
max_figsize = 200 # Set a maximum figure size
|
3734
|
+
if figuresize * num_clusters > max_figsize:
|
3735
|
+
figuresize = max_figsize / num_clusters
|
3736
|
+
|
3737
|
+
grid_fig, grid_axes = plt.subplots(1, num_clusters, figsize=(figuresize * num_clusters, figuresize), gridspec_kw={'wspace': 0.2, 'hspace': 0})
|
3738
|
+
if num_clusters == 1:
|
3739
|
+
grid_axes = [grid_axes] # Ensure grid_axes is always iterable
|
3740
|
+
for cluster_label, axes in zip(cluster_images.keys(), grid_axes):
|
3741
|
+
images = cluster_images[cluster_label]
|
3742
|
+
num_images = len(images)
|
3743
|
+
grid_size = int(np.ceil(np.sqrt(num_images)))
|
3744
|
+
image_size = 0.9 / grid_size
|
3745
|
+
whitespace = (1 - grid_size * image_size) / (grid_size + 1)
|
3746
|
+
|
3747
|
+
if isinstance(cluster_label, str):
|
3748
|
+
idx = list(cluster_images.keys()).index(cluster_label)
|
3749
|
+
color = colors[idx]
|
3750
|
+
if verbose:
|
3751
|
+
print(f'Lable: {cluster_label} index: {idx}')
|
3752
|
+
else:
|
3753
|
+
color = colors[cluster_label]
|
3754
|
+
|
3755
|
+
axes.add_patch(plt.Rectangle((0, 0), 1, 1, transform=axes.transAxes, color=color[:3]))
|
3756
|
+
axes.axis('off')
|
3757
|
+
for i, img in enumerate(images):
|
3758
|
+
row = i // grid_size
|
3759
|
+
col = i % grid_size
|
3760
|
+
x_pos = (col + 1) * whitespace + col * image_size
|
3761
|
+
y_pos = 1 - ((row + 1) * whitespace + (row + 1) * image_size)
|
3762
|
+
ax_img = axes.inset_axes([x_pos, y_pos, image_size, image_size], transform=axes.transAxes)
|
3763
|
+
ax_img.imshow(img, cmap='gray', aspect='auto')
|
3764
|
+
ax_img.axis('off')
|
3765
|
+
ax_img.set_aspect('equal')
|
3766
|
+
ax_img.set_facecolor(color[:3])
|
3767
|
+
|
3768
|
+
# Add cluster labels beside the UMAP plot
|
3769
|
+
spacing_factor = 0.5 # Adjust this value to control the spacing between labels
|
3770
|
+
for i, (cluster_label, color) in enumerate(zip(cluster_images.keys(), colors)):
|
3771
|
+
label_y = 1 - (i + 1) * (spacing_factor / num_clusters) # Adjust y position for each label
|
3772
|
+
grid_fig.text(1.05, label_y, f'Cluster {cluster_label}', verticalalignment='center', fontsize=figuresize, color='black' if not black_background else 'white')
|
3773
|
+
grid_fig.patches.append(plt.Rectangle((1, label_y - 0.02), 0.03, 0.03, transform=grid_fig.transFigure, color=color[:3], clip_on=False))
|
3774
|
+
|
3775
|
+
plt.show()
|
3776
|
+
return grid_fig
|
3777
|
+
|
3778
|
+
def correct_paths(df, base_path):
|
3779
|
+
|
3780
|
+
if 'png_path' not in df.columns:
|
3781
|
+
print("No 'png_path' column found in the dataframe.")
|
3782
|
+
return df, None
|
3783
|
+
|
3784
|
+
image_paths = df['png_path'].to_list()
|
3785
|
+
|
3786
|
+
adjusted_image_paths = []
|
3787
|
+
for path in image_paths:
|
3788
|
+
if base_path not in path:
|
3789
|
+
parts = path.split('/data/')
|
3790
|
+
if len(parts) > 1:
|
3791
|
+
new_path = os.path.join(base_path, 'data', parts[1])
|
3792
|
+
adjusted_image_paths.append(new_path)
|
3793
|
+
else:
|
3794
|
+
adjusted_image_paths.append(path)
|
3795
|
+
else:
|
3796
|
+
adjusted_image_paths.append(path)
|
3797
|
+
|
3798
|
+
df['png_path'] = adjusted_image_paths
|
3799
|
+
image_paths = df['png_path'].to_list()
|
3800
|
+
return df, image_paths
|
3801
|
+
|
3802
|
+
def correct_paths_v1(df, base_path):
|
3803
|
+
if 'png_path' not in df.columns:
|
3804
|
+
print("No 'png_path' column found in the dataframe.")
|
3805
|
+
return df, None
|
3806
|
+
|
3807
|
+
image_paths = df['png_path'].to_list()
|
3808
|
+
|
3809
|
+
adjusted_image_paths = []
|
3810
|
+
for path in image_paths:
|
3811
|
+
if base_path not in path:
|
3812
|
+
print(f"Adjusting path: {path}")
|
3813
|
+
parts = path.split('data/')
|
3814
|
+
if len(parts) > 1:
|
3815
|
+
new_path = os.path.join(base_path, 'data', parts[1])
|
3816
|
+
adjusted_image_paths.append(new_path)
|
3817
|
+
else:
|
3818
|
+
adjusted_image_paths.append(path)
|
3819
|
+
else:
|
3820
|
+
adjusted_image_paths.append(path)
|
3821
|
+
|
3822
|
+
df['png_path'] = adjusted_image_paths
|
3823
|
+
image_paths = df['png_path'].to_list()
|
3824
|
+
return df, image_paths
|
3825
|
+
|
3826
|
+
def get_umap_image_settings(settings={}):
|
3827
|
+
settings.setdefault('src', 'path')
|
3828
|
+
settings.setdefault('row_limit', 1000)
|
3829
|
+
settings.setdefault('tables', ['cell', 'cytoplasm', 'nucleus', 'pathogen'])
|
3830
|
+
settings.setdefault('visualize', 'cell')
|
3831
|
+
settings.setdefault('image_nr', 16)
|
3832
|
+
settings.setdefault('dot_size', 50)
|
3833
|
+
settings.setdefault('n_neighbors', 1000)
|
3834
|
+
settings.setdefault('min_dist', 0.1)
|
3835
|
+
settings.setdefault('metric', 'euclidean')
|
3836
|
+
settings.setdefault('eps', 0.5)
|
3837
|
+
settings.setdefault('min_samples', 1000)
|
3838
|
+
settings.setdefault('filter_by', 'channel_0')
|
3839
|
+
settings.setdefault('img_zoom', 0.5)
|
3840
|
+
settings.setdefault('plot_by_cluster', True)
|
3841
|
+
settings.setdefault('plot_cluster_grids', True)
|
3842
|
+
settings.setdefault('remove_cluster_noise', True)
|
3843
|
+
settings.setdefault('remove_highly_correlated', True)
|
3844
|
+
settings.setdefault('log_data', False)
|
3845
|
+
settings.setdefault('figuresize', 60)
|
3846
|
+
settings.setdefault('black_background', True)
|
3847
|
+
settings.setdefault('remove_image_canvas', False)
|
3848
|
+
settings.setdefault('plot_outlines', True)
|
3849
|
+
settings.setdefault('plot_points', True)
|
3850
|
+
settings.setdefault('smooth_lines', True)
|
3851
|
+
settings.setdefault('clustering', 'dbscan')
|
3852
|
+
settings.setdefault('exclude', None)
|
3853
|
+
settings.setdefault('col_to_compare', 'col')
|
3854
|
+
settings.setdefault('pos', 'c1')
|
3855
|
+
settings.setdefault('neg', 'c2')
|
3856
|
+
settings.setdefault('embedding_by_controls', False)
|
3857
|
+
settings.setdefault('plot_images', True)
|
3858
|
+
settings.setdefault('reduction_method','umap')
|
3859
|
+
settings.setdefault('save_figure', False)
|
3860
|
+
settings.setdefault('n_jobs', -1)
|
3861
|
+
settings.setdefault('color_by', None)
|
3862
|
+
settings.setdefault('neg', 'c1')
|
3863
|
+
settings.setdefault('pos', 'c2')
|
3864
|
+
settings.setdefault('mix', 'c3')
|
3865
|
+
settings.setdefault('mix', 'c3')
|
3866
|
+
settings.setdefault('exclude_conditions', None)
|
3867
|
+
settings.setdefault('analyze_clusters', False)
|
3868
|
+
settings.setdefault('resnet_features', False)
|
3869
|
+
settings.setdefault('verbose',True)
|
3870
|
+
return settings
|
3871
|
+
|
3872
|
+
def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude):
|
3873
|
+
"""
|
3874
|
+
Preprocesses the given dataframe by applying filtering, removing highly correlated columns,
|
3875
|
+
applying log transformation, filling NaN values, and scaling the numeric data.
|
3876
|
+
|
3877
|
+
Args:
|
3878
|
+
df (pandas.DataFrame): The input dataframe.
|
3879
|
+
filter_by (str or None): The channel of interest to filter the dataframe by.
|
3880
|
+
remove_highly_correlated (bool or float): Whether to remove highly correlated columns.
|
3881
|
+
If a float is provided, it represents the correlation threshold.
|
3882
|
+
log_data (bool): Whether to apply log transformation to the numeric data.
|
3883
|
+
exclude (list or None): List of features to exclude from the filtering process.
|
3884
|
+
verbose (bool): Whether to print verbose output during preprocessing.
|
3885
|
+
|
3886
|
+
Returns:
|
3887
|
+
numpy.ndarray: The preprocessed numeric data.
|
3888
|
+
|
3889
|
+
Raises:
|
3890
|
+
ValueError: If no numeric columns are available after filtering.
|
3891
|
+
|
3892
|
+
"""
|
3893
|
+
# Apply filtering based on the `filter_by` parameter
|
3894
|
+
if filter_by is not None:
|
3895
|
+
df, _ = filter_dataframe_features(df, channel_of_interest=filter_by, exclude=exclude)
|
3896
|
+
|
3897
|
+
# Select numerical features
|
3898
|
+
numeric_data = df.select_dtypes(include=['number'])
|
3899
|
+
|
3900
|
+
# Check if numeric_data is empty
|
3901
|
+
if numeric_data.empty:
|
3902
|
+
raise ValueError("No numeric columns available after filtering. Please check the filter_by and exclude parameters.")
|
3903
|
+
|
3904
|
+
# Remove highly correlated columns
|
3905
|
+
if not remove_highly_correlated is False:
|
3906
|
+
if isinstance(remove_highly_correlated, float):
|
3907
|
+
numeric_data = remove_highly_correlated_columns(numeric_data, remove_highly_correlated)
|
3908
|
+
else:
|
3909
|
+
numeric_data = remove_highly_correlated_columns(numeric_data, 0.95)
|
3910
|
+
|
3911
|
+
# Apply log transformation
|
3912
|
+
if log_data:
|
3913
|
+
numeric_data = np.log(numeric_data + 1e-6)
|
3914
|
+
|
3915
|
+
# Fill NaN values with the column mean
|
3916
|
+
numeric_data = numeric_data.fillna(numeric_data.mean())
|
3917
|
+
|
3918
|
+
# Scale the numeric data
|
3919
|
+
scaler = StandardScaler(copy=True, with_mean=True, with_std=True)
|
3920
|
+
numeric_data = scaler.fit_transform(numeric_data)
|
3921
|
+
|
3922
|
+
return numeric_data
|
3923
|
+
|
3924
|
+
def filter_dataframe_features(df, channel_of_interest, exclude=None):
|
3925
|
+
"""
|
3926
|
+
Filter the dataframe `df` based on the specified `channel_of_interest` and `exclude` parameters.
|
3927
|
+
|
3928
|
+
Parameters:
|
3929
|
+
- df (pandas.DataFrame): The input dataframe to be filtered.
|
3930
|
+
- channel_of_interest (str, int, list, None): The channel(s) of interest to filter the dataframe.
|
3931
|
+
If None, no filtering is applied. If 'morphology', only morphology features are included.
|
3932
|
+
If an integer, only the specified channel is included. If a list, only the specified channels are included.
|
3933
|
+
If a string, only the specified channel is included.
|
3934
|
+
- exclude (str, list, None): The feature(s) to exclude from the filtered dataframe.
|
3935
|
+
If None, no features are excluded. If a string, the specified feature is excluded.
|
3936
|
+
If a list, the specified features are excluded.
|
3937
|
+
|
3938
|
+
Returns:
|
3939
|
+
- filtered_df (pandas.DataFrame): The filtered dataframe based on the specified parameters.
|
3940
|
+
- features (list): The list of selected features after filtering.
|
3941
|
+
|
3942
|
+
"""
|
3943
|
+
if channel_of_interest is None:
|
3944
|
+
feature_string = None
|
3945
|
+
elif channel_of_interest == 'morphology':
|
3946
|
+
feature_string = 'morphology'
|
3947
|
+
elif isinstance(channel_of_interest, list):
|
3948
|
+
feature_string = []
|
3949
|
+
for i in channel_of_interest:
|
3950
|
+
feature_string_tmp = f'channel_{i}'
|
3951
|
+
feature_string.append(feature_string_tmp)
|
3952
|
+
elif isinstance(channel_of_interest, int):
|
3953
|
+
feature_string = f'channel_{channel_of_interest}'
|
3954
|
+
elif isinstance(channel_of_interest, str):
|
3955
|
+
feature_string = channel_of_interest
|
3956
|
+
|
3957
|
+
# Remove columns with a single value
|
3958
|
+
df = df.loc[:, df.nunique() > 1]
|
3959
|
+
|
3960
|
+
# Select numerical features
|
3961
|
+
features = df.select_dtypes(include=[np.number]).columns.tolist()
|
3962
|
+
|
3963
|
+
if feature_string is not None:
|
3964
|
+
feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
|
3965
|
+
|
3966
|
+
# Remove feature_string from the list if it exists
|
3967
|
+
if isinstance(feature_string, str):
|
3968
|
+
if feature_string in feature_list:
|
3969
|
+
feature_list.remove(feature_string)
|
3970
|
+
elif isinstance(feature_string, list):
|
3971
|
+
feature_list = [feature for feature in feature_list if feature not in feature_string]
|
3972
|
+
|
3973
|
+
if feature_string != 'morphology':
|
3974
|
+
features = [feature for feature in features if feature_string in feature]
|
3975
|
+
|
3976
|
+
# Iterate through the list and remove columns from df
|
3977
|
+
for feature_ in feature_list:
|
3978
|
+
features = [feature for feature in features if feature_ not in feature]
|
3979
|
+
print(f'After removing {feature_} features: {len(features)}')
|
3980
|
+
|
3981
|
+
if isinstance(exclude, list):
|
3982
|
+
features = [feature for feature in features if feature not in exclude]
|
3983
|
+
elif isinstance(exclude, str):
|
3984
|
+
features.remove(exclude)
|
3985
|
+
|
3986
|
+
filtered_df = df[features]
|
3987
|
+
|
3988
|
+
return filtered_df, features
|
3989
|
+
|
3990
|
+
# Create a function to check if images overlap
|
3991
|
+
def check_overlap(current_position, other_positions, threshold):
|
3992
|
+
for other_position in other_positions:
|
3993
|
+
distance = np.linalg.norm(np.array(current_position) - np.array(other_position))
|
3994
|
+
if distance < threshold:
|
3995
|
+
return True
|
3996
|
+
return False
|
3997
|
+
|
3998
|
+
# Define a function to try random positions around a given point
|
3999
|
+
def find_non_overlapping_position(x, y, image_positions, threshold, max_attempts=100):
|
4000
|
+
offset_range = 10 # Adjust the range for random offsets
|
4001
|
+
attempts = 0
|
4002
|
+
while attempts < max_attempts:
|
4003
|
+
random_offset_x = random.uniform(-offset_range, offset_range)
|
4004
|
+
random_offset_y = random.uniform(-offset_range, offset_range)
|
4005
|
+
new_x = x + random_offset_x
|
4006
|
+
new_y = y + random_offset_y
|
4007
|
+
if not check_overlap((new_x, new_y), image_positions, threshold):
|
4008
|
+
return new_x, new_y
|
4009
|
+
attempts += 1
|
4010
|
+
return x, y # Return the original position if no suitable position found
|
4011
|
+
|
4012
|
+
def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method, verbose, reduction_param=None, embedding=None, n_jobs=-1):
|
4013
|
+
"""
|
4014
|
+
Perform dimensionality reduction and clustering on the given data.
|
4015
|
+
|
4016
|
+
Parameters:
|
4017
|
+
numeric_data (np.array): Numeric data to process.
|
4018
|
+
n_neighbors (int): Number of neighbors for UMAP or perplexity for tSNE.
|
4019
|
+
min_dist (float): Minimum distance for UMAP.
|
4020
|
+
metric (str): Metric for UMAP, tSNE, and DBSCAN.
|
4021
|
+
eps (float): Epsilon for DBSCAN clustering.
|
4022
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
4023
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
4024
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
4025
|
+
verbose (bool): Whether to print verbose output.
|
4026
|
+
reduction_param (dict): Additional parameters for the reduction method.
|
4027
|
+
embedding (np.array): Precomputed embedding (optional).
|
4028
|
+
n_jobs (int): Number of parallel jobs to run.
|
4029
|
+
|
4030
|
+
Returns:
|
4031
|
+
embedding (np.array): Embedding of the data.
|
4032
|
+
labels (np.array): Cluster labels.
|
4033
|
+
"""
|
4034
|
+
|
4035
|
+
if isinstance(n_neighbors, float):
|
4036
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
4037
|
+
if n_neighbors <= 1:
|
4038
|
+
n_neighbors = 2
|
4039
|
+
print(f'n_neighbors cannota be less than 2. Setting n_neighbors to {n_neighbors}')
|
4040
|
+
|
4041
|
+
reduction_param = reduction_param or {}
|
4042
|
+
reduction_param = {k: v for k, v in reduction_param.items() if k not in ['perplexity', 'n_neighbors', 'min_dist', 'metric', 'method']}
|
4043
|
+
|
4044
|
+
if reduction_method == 'umap':
|
4045
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, n_jobs=n_jobs, **reduction_param)
|
4046
|
+
elif reduction_method == 'tsne':
|
4047
|
+
reducer = TSNE(n_components=2, perplexity=n_neighbors, metric=metric, n_jobs=n_jobs, **reduction_param)
|
4048
|
+
else:
|
4049
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
4050
|
+
|
4051
|
+
if embedding is None:
|
4052
|
+
embedding = reducer.fit_transform(numeric_data)
|
4053
|
+
|
4054
|
+
if clustering == 'dbscan':
|
4055
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
|
4056
|
+
elif clustering == 'kmeans':
|
4057
|
+
from sklearn.cluster import KMeans
|
4058
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
4059
|
+
else:
|
4060
|
+
raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
|
4061
|
+
clustering_model.fit(embedding)
|
4062
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
4063
|
+
if verbose:
|
4064
|
+
print(f'Embedding shape: {embedding.shape}')
|
4065
|
+
return embedding, labels
|
4066
|
+
|
4067
|
+
|
4068
|
+
|