spacr 0.0.36__py3-none-any.whl → 0.0.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +11 -4
- spacr/__main__.py +0 -2
- spacr/alpha.py +514 -2
- spacr/annotate_app.py +112 -116
- spacr/core.py +864 -728
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +2 -16
- spacr/graph_learning.py +297 -253
- spacr/gui.py +9 -8
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +3 -4
- spacr/gui_mask_app.py +9 -9
- spacr/gui_measure_app.py +3 -5
- spacr/gui_utils.py +132 -33
- spacr/io.py +308 -464
- spacr/mask_app.py +109 -5
- spacr/measure.py +15 -1
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +69 -1
- spacr/plot.py +23 -6
- spacr/sequencing.py +1130 -0
- spacr/sim.py +0 -42
- spacr/timelapse.py +0 -1
- spacr/train.py +172 -13
- spacr/umap.py +0 -689
- spacr/utils.py +1322 -75
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/METADATA +14 -29
- spacr-0.0.62.dist-info/RECORD +39 -0
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/entry_points.txt +1 -0
- spacr-0.0.36.dist-info/RECORD +0 -35
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/LICENSE +0 -0
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/WHEEL +0 -0
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/top_level.txt +0 -0
spacr/utils.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1
|
-
import sys, os, re, sqlite3,
|
1
|
+
import sys, os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
from cellpose import models as cp_models
|
5
5
|
from cellpose import denoise
|
6
|
+
|
6
7
|
from skimage import morphology
|
7
8
|
from skimage.measure import label, regionprops_table, regionprops
|
8
9
|
import skimage.measure as measure
|
9
|
-
from
|
10
|
+
from skimage.transform import resize as resizescikit
|
11
|
+
from skimage.morphology import dilation, square
|
12
|
+
from skimage.measure import find_contours
|
13
|
+
from skimage.segmentation import clear_border
|
14
|
+
|
15
|
+
from collections import defaultdict, OrderedDict
|
10
16
|
from PIL import Image
|
11
17
|
import pandas as pd
|
12
18
|
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
@@ -15,48 +21,225 @@ import statsmodels.formula.api as smf
|
|
15
21
|
import statsmodels.api as sm
|
16
22
|
from statsmodels.stats.multitest import multipletests
|
17
23
|
from itertools import combinations
|
18
|
-
from collections import OrderedDict
|
19
24
|
from functools import reduce
|
20
|
-
from IPython.display import display
|
25
|
+
from IPython.display import display
|
26
|
+
|
21
27
|
from multiprocessing import Pool, cpu_count
|
22
|
-
from
|
23
|
-
|
24
|
-
from skimage.measure import find_contours
|
28
|
+
from concurrent.futures import ThreadPoolExecutor
|
29
|
+
|
25
30
|
import torch.nn as nn
|
26
31
|
import torch.nn.functional as F
|
27
|
-
#from torchsummary import summary
|
28
32
|
from torch.utils.checkpoint import checkpoint
|
29
33
|
from torch.utils.data import Subset
|
30
34
|
from torch.autograd import grad
|
31
|
-
|
32
|
-
from skimage.segmentation import clear_border
|
35
|
+
|
33
36
|
import seaborn as sns
|
34
37
|
import matplotlib.pyplot as plt
|
38
|
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
39
|
+
|
35
40
|
import scipy.ndimage as ndi
|
36
41
|
from scipy.spatial import distance
|
37
42
|
from scipy.stats import fisher_exact
|
38
|
-
from scipy.ndimage import
|
43
|
+
from scipy.ndimage.filters import gaussian_filter
|
44
|
+
from scipy.spatial import ConvexHull
|
45
|
+
from scipy.interpolate import splprep, splev
|
46
|
+
|
47
|
+
from sklearn.preprocessing import StandardScaler
|
39
48
|
from skimage.exposure import rescale_intensity
|
40
49
|
from sklearn.metrics import auc, precision_recall_curve
|
41
50
|
from sklearn.model_selection import train_test_split
|
42
51
|
from sklearn.linear_model import Lasso, Ridge
|
43
52
|
from sklearn.preprocessing import OneHotEncoder
|
44
53
|
from sklearn.cluster import KMeans
|
54
|
+
from sklearn.preprocessing import StandardScaler
|
55
|
+
from sklearn.cluster import DBSCAN
|
56
|
+
from sklearn.cluster import KMeans
|
57
|
+
from sklearn.manifold import TSNE
|
58
|
+
|
59
|
+
import umap.umap_ as umap
|
60
|
+
|
61
|
+
from torchvision import models
|
45
62
|
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
|
63
|
+
import torchvision.transforms as transforms
|
46
64
|
|
47
65
|
from .logger import log_function_call
|
48
66
|
|
49
|
-
def
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
67
|
+
def check_mask_folder(src,mask_fldr):
|
68
|
+
|
69
|
+
mask_folder = os.path.join(src,'norm_channel_stack',mask_fldr)
|
70
|
+
stack_folder = os.path.join(src,'stack')
|
71
|
+
|
72
|
+
if not os.path.exists(mask_folder):
|
73
|
+
return True
|
74
|
+
|
75
|
+
mask_count = sum(1 for file in os.listdir(mask_folder) if file.endswith('.npy'))
|
76
|
+
stack_count = sum(1 for file in os.listdir(stack_folder) if file.endswith('.npy'))
|
77
|
+
|
78
|
+
if mask_count == stack_count:
|
79
|
+
print(f'All masks have been generated for {mask_fldr}')
|
80
|
+
return False
|
81
|
+
else:
|
82
|
+
return True
|
83
|
+
|
84
|
+
def set_default_plot_merge_settings():
|
85
|
+
settings = {}
|
86
|
+
settings.setdefault('include_noninfected', True)
|
87
|
+
settings.setdefault('include_multiinfected', True)
|
88
|
+
settings.setdefault('include_multinucleated', True)
|
89
|
+
settings.setdefault('remove_background', False)
|
90
|
+
settings.setdefault('filter_min_max', None)
|
91
|
+
settings.setdefault('channel_dims', [0,1,2,3])
|
92
|
+
settings.setdefault('backgrounds', [100,100,100,100])
|
93
|
+
settings.setdefault('cell_mask_dim', 4)
|
94
|
+
settings.setdefault('nucleus_mask_dim', 5)
|
95
|
+
settings.setdefault('pathogen_mask_dim', 6)
|
96
|
+
settings.setdefault('outline_thickness', 3)
|
97
|
+
settings.setdefault('outline_color', 'gbr')
|
98
|
+
settings.setdefault('overlay_chans', [1,2,3])
|
99
|
+
settings.setdefault('overlay', True)
|
100
|
+
settings.setdefault('normalization_percentiles', [2,98])
|
101
|
+
settings.setdefault('normalize', True)
|
102
|
+
settings.setdefault('print_object_number', True)
|
103
|
+
settings.setdefault('nr', 1)
|
104
|
+
settings.setdefault('figuresize', 50)
|
105
|
+
settings.setdefault('cmap', 'inferno')
|
106
|
+
settings.setdefault('verbose', True)
|
107
|
+
|
108
|
+
return settings
|
109
|
+
|
110
|
+
def set_default_settings_preprocess_generate_masks(src, settings={}):
|
111
|
+
# Main settings
|
112
|
+
settings['src'] = src
|
113
|
+
settings.setdefault('preprocess', True)
|
114
|
+
settings.setdefault('masks', True)
|
115
|
+
settings.setdefault('save', True)
|
116
|
+
settings.setdefault('batch_size', 50)
|
117
|
+
settings.setdefault('test_mode', False)
|
118
|
+
settings.setdefault('test_images', 10)
|
119
|
+
settings.setdefault('magnification', 20)
|
120
|
+
settings.setdefault('custom_regex', None)
|
121
|
+
settings.setdefault('metadata_type', 'cellvoyager')
|
122
|
+
settings.setdefault('workers', os.cpu_count()-4)
|
123
|
+
settings.setdefault('randomize', True)
|
124
|
+
settings.setdefault('verbose', True)
|
125
|
+
|
126
|
+
settings.setdefault('remove_background_cell', False)
|
127
|
+
settings.setdefault('remove_background_nucleus', False)
|
128
|
+
settings.setdefault('remove_background_pathogen', False)
|
129
|
+
|
130
|
+
# Channel settings
|
131
|
+
settings.setdefault('cell_channel', None)
|
132
|
+
settings.setdefault('nucleus_channel', None)
|
133
|
+
settings.setdefault('pathogen_channel', None)
|
134
|
+
settings.setdefault('channels', [0,1,2,3])
|
135
|
+
settings.setdefault('pathogen_background', 100)
|
136
|
+
settings.setdefault('pathogen_Signal_to_noise', 10)
|
137
|
+
settings.setdefault('pathogen_CP_prob', 0)
|
138
|
+
settings.setdefault('cell_background', 100)
|
139
|
+
settings.setdefault('cell_Signal_to_noise', 10)
|
140
|
+
settings.setdefault('cell_CP_prob', 0)
|
141
|
+
settings.setdefault('nucleus_background', 100)
|
142
|
+
settings.setdefault('nucleus_Signal_to_noise', 10)
|
143
|
+
settings.setdefault('nucleus_CP_prob', 0)
|
144
|
+
|
145
|
+
settings.setdefault('nucleus_FT', 100)
|
146
|
+
settings.setdefault('cell_FT', 100)
|
147
|
+
settings.setdefault('pathogen_FT', 100)
|
148
|
+
|
149
|
+
# Plot settings
|
150
|
+
settings.setdefault('plot', False)
|
151
|
+
settings.setdefault('figuresize', 50)
|
152
|
+
settings.setdefault('cmap', 'inferno')
|
153
|
+
settings.setdefault('normalize', True)
|
154
|
+
settings.setdefault('normalize_plots', True)
|
155
|
+
settings.setdefault('examples_to_plot', 1)
|
156
|
+
|
157
|
+
# Analasys settings
|
158
|
+
settings.setdefault('pathogen_model', None)
|
159
|
+
settings.setdefault('merge_pathogens', False)
|
160
|
+
settings.setdefault('filter', False)
|
161
|
+
settings.setdefault('lower_percentile', 2)
|
162
|
+
|
163
|
+
# Timelapse settings
|
164
|
+
settings.setdefault('timelapse', False)
|
165
|
+
settings.setdefault('fps', 2)
|
166
|
+
settings.setdefault('timelapse_displacement', None)
|
167
|
+
settings.setdefault('timelapse_memory', 3)
|
168
|
+
settings.setdefault('timelapse_frame_limits', None)
|
169
|
+
settings.setdefault('timelapse_remove_transient', False)
|
170
|
+
settings.setdefault('timelapse_mode', 'trackpy')
|
171
|
+
settings.setdefault('timelapse_objects', 'cells')
|
172
|
+
|
173
|
+
# Misc settings
|
174
|
+
settings.setdefault('all_to_mip', False)
|
175
|
+
settings.setdefault('pick_slice', False)
|
176
|
+
settings.setdefault('skip_mode', '01')
|
177
|
+
settings.setdefault('upscale', False)
|
178
|
+
settings.setdefault('upscale_factor', 2.0)
|
179
|
+
settings.setdefault('adjust_cells', False)
|
180
|
+
|
181
|
+
return settings
|
182
|
+
|
183
|
+
def set_default_settings_preprocess_img_data(settings):
|
184
|
+
|
185
|
+
metadata_type = settings.setdefault('metadata_type', 'cellvoyager')
|
186
|
+
custom_regex = settings.setdefault('custom_regex', None)
|
187
|
+
nr = settings.setdefault('nr', 1)
|
188
|
+
plot = settings.setdefault('plot', True)
|
189
|
+
batch_size = settings.setdefault('batch_size', 50)
|
190
|
+
timelapse = settings.setdefault('timelapse', False)
|
191
|
+
lower_percentile = settings.setdefault('lower_percentile', 2)
|
192
|
+
randomize = settings.setdefault('randomize', True)
|
193
|
+
all_to_mip = settings.setdefault('all_to_mip', False)
|
194
|
+
pick_slice = settings.setdefault('pick_slice', False)
|
195
|
+
skip_mode = settings.setdefault('skip_mode', False)
|
196
|
+
|
197
|
+
cmap = settings.setdefault('cmap', 'inferno')
|
198
|
+
figuresize = settings.setdefault('figuresize', 50)
|
199
|
+
normalize = settings.setdefault('normalize', True)
|
200
|
+
save_dtype = settings.setdefault('save_dtype', 'uint16')
|
201
|
+
|
202
|
+
test_mode = settings.setdefault('test_mode', False)
|
203
|
+
test_images = settings.setdefault('test_images', 10)
|
204
|
+
random_test = settings.setdefault('random_test', True)
|
205
|
+
|
206
|
+
return settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test
|
207
|
+
|
208
|
+
def smooth_hull_lines(cluster_data):
|
209
|
+
hull = ConvexHull(cluster_data)
|
210
|
+
|
211
|
+
# Extract vertices of the hull
|
212
|
+
vertices = hull.points[hull.vertices]
|
213
|
+
|
214
|
+
# Close the loop
|
215
|
+
vertices = np.vstack([vertices, vertices[0, :]])
|
216
|
+
|
217
|
+
# Parameterize the vertices
|
218
|
+
tck, u = splprep(vertices.T, u=None, s=0.0)
|
219
|
+
|
220
|
+
# Evaluate spline at new parameter values
|
221
|
+
new_points = splev(np.linspace(0, 1, 100), tck)
|
222
|
+
|
223
|
+
return new_points[0], new_points[1]
|
224
|
+
|
225
|
+
def _gen_rgb_image(image, channels):
|
226
|
+
"""
|
227
|
+
Generate an RGB image from the specified channels of the input image.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
image (ndarray): The input image.
|
231
|
+
channels (list): List of channel indices to use for RGB.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
rgb_image (ndarray): The generated RGB image.
|
235
|
+
"""
|
236
|
+
rgb_image = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.float32)
|
237
|
+
for i, chan in enumerate(channels):
|
238
|
+
if chan < image.shape[2]:
|
239
|
+
rgb_image[:, :, i] = image[:, :, chan]
|
54
240
|
return rgb_image
|
55
241
|
|
56
242
|
def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
|
57
|
-
from concurrent.futures import ThreadPoolExecutor
|
58
|
-
import cv2
|
59
|
-
|
60
243
|
outlines = []
|
61
244
|
overlayed_image = rgb_image.copy()
|
62
245
|
|
@@ -66,11 +249,12 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
|
|
66
249
|
|
67
250
|
# Find and draw contours
|
68
251
|
for j in np.unique(mask):
|
69
|
-
|
252
|
+
if j == 0:
|
253
|
+
continue # Skip background
|
70
254
|
contours = find_contours(mask == j, 0.5)
|
71
255
|
# Convert contours for OpenCV format and draw directly to optimize
|
72
256
|
cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
|
73
|
-
cv2.drawContours(outline, cv_contours, -1, color=
|
257
|
+
cv2.drawContours(outline, cv_contours, -1, color=255, thickness=outline_thickness)
|
74
258
|
|
75
259
|
return dilation(outline, square(outline_thickness))
|
76
260
|
|
@@ -78,19 +262,15 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
|
|
78
262
|
with ThreadPoolExecutor() as executor:
|
79
263
|
outlines = list(executor.map(process_dim, mask_dims))
|
80
264
|
|
81
|
-
# Overlay outlines onto the RGB image
|
265
|
+
# Overlay outlines onto the RGB image
|
82
266
|
for i, outline in enumerate(outlines):
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
267
|
+
color = np.array(outline_colors[i % len(outline_colors)])
|
268
|
+
for j in np.unique(outline):
|
269
|
+
if j == 0:
|
270
|
+
continue # Skip background
|
87
271
|
mask = outline == j
|
88
272
|
overlayed_image[mask] = color # Direct assignment with broadcasting
|
89
273
|
|
90
|
-
# Remove mask_dims from image
|
91
|
-
channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
|
92
|
-
image = np.take(image, channels_to_keep, axis=-1)
|
93
|
-
|
94
274
|
return overlayed_image, outlines, image
|
95
275
|
|
96
276
|
def _convert_cq1_well_id(well_id):
|
@@ -350,43 +530,82 @@ def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pa
|
|
350
530
|
df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
|
351
531
|
df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
|
352
532
|
return df
|
353
|
-
|
354
|
-
def normalize_to_dtype(array,
|
533
|
+
|
534
|
+
def normalize_to_dtype(array, p1=2, p2=98):
|
355
535
|
"""
|
356
|
-
Normalize the
|
536
|
+
Normalize each image in the stack to its own percentiles.
|
357
537
|
|
358
538
|
Parameters:
|
359
539
|
- array: numpy array
|
360
|
-
The input
|
361
|
-
-
|
540
|
+
The input stack to be normalized.
|
541
|
+
- p1: int, optional
|
362
542
|
The lower percentile value for normalization. Default is 2.
|
363
|
-
-
|
543
|
+
- p2: int, optional
|
364
544
|
The upper percentile value for normalization. Default is 98.
|
365
|
-
- percentiles: list of tuples, optional
|
366
|
-
A list of tuples containing the percentile values for each image in the array.
|
367
|
-
If provided, the percentiles for each image will be used instead of q1 and q2.
|
368
545
|
|
369
546
|
Returns:
|
370
547
|
- new_stack: numpy array
|
371
|
-
The normalized
|
548
|
+
The normalized stack with the same shape as the input stack.
|
372
549
|
"""
|
373
550
|
nimg = array.shape[2]
|
374
551
|
new_stack = np.empty_like(array)
|
375
|
-
|
376
|
-
|
552
|
+
|
553
|
+
for i in range(nimg):
|
554
|
+
img = array[:, :, i]
|
377
555
|
non_zero_img = img[img > 0]
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
else:
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
556
|
+
|
557
|
+
if non_zero_img.size > 0:
|
558
|
+
img_min = np.percentile(non_zero_img, p1)
|
559
|
+
img_max = np.percentile(non_zero_img, p2)
|
560
|
+
else:
|
561
|
+
img_min = img.min()
|
562
|
+
img_max = img.max()
|
563
|
+
|
564
|
+
# Determine output range based on dtype
|
565
|
+
if np.issubdtype(array.dtype, np.integer):
|
566
|
+
out_range = (0, np.iinfo(array.dtype).max)
|
567
|
+
else:
|
568
|
+
out_range = (0.0, 1.0)
|
569
|
+
|
570
|
+
img = rescale_intensity(img, in_range=(img_min, img_max), out_range=out_range).astype(array.dtype)
|
571
|
+
new_stack[:, :, i] = img
|
572
|
+
|
573
|
+
return new_stack
|
574
|
+
|
575
|
+
def normalize_to_dtype(array, p1=2, p2=98):
|
576
|
+
"""
|
577
|
+
Normalize each image in the stack to its own percentiles.
|
578
|
+
|
579
|
+
Parameters:
|
580
|
+
- array: numpy array
|
581
|
+
The input stack to be normalized.
|
582
|
+
- p1: int, optional
|
583
|
+
The lower percentile value for normalization. Default is 2.
|
584
|
+
- p2: int, optional
|
585
|
+
The upper percentile value for normalization. Default is 98.
|
586
|
+
|
587
|
+
Returns:
|
588
|
+
- new_stack: numpy array
|
589
|
+
The normalized stack with the same shape as the input stack.
|
590
|
+
"""
|
591
|
+
nimg = array.shape[2]
|
592
|
+
new_stack = np.empty_like(array, dtype=np.float32)
|
593
|
+
|
594
|
+
for i in range(nimg):
|
595
|
+
img = array[:, :, i]
|
596
|
+
non_zero_img = img[img > 0]
|
597
|
+
|
598
|
+
if non_zero_img.size > 0:
|
599
|
+
img_min = np.percentile(non_zero_img, p1)
|
600
|
+
img_max = np.percentile(non_zero_img, p2)
|
601
|
+
else:
|
602
|
+
img_min = img.min()
|
603
|
+
img_max = img.max()
|
604
|
+
|
605
|
+
# Normalize to the range (0, 1) for visualization
|
606
|
+
img = rescale_intensity(img, in_range=(img_min, img_max), out_range=(0.0, 1.0))
|
607
|
+
new_stack[:, :, i] = img
|
608
|
+
|
390
609
|
return new_stack
|
391
610
|
|
392
611
|
def _list_endpoint_subdirectories(base_dir):
|
@@ -744,7 +963,7 @@ def _get_diam(mag, obj):
|
|
744
963
|
elif obj == 'nucleus':
|
745
964
|
diamiter = 60
|
746
965
|
elif obj == 'pathogen':
|
747
|
-
diamiter =
|
966
|
+
diamiter = 20
|
748
967
|
else:
|
749
968
|
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
750
969
|
|
@@ -764,7 +983,7 @@ def _get_diam(mag, obj):
|
|
764
983
|
if obj == 'nucleus':
|
765
984
|
diamiter = 90
|
766
985
|
if obj == 'pathogen':
|
767
|
-
diamiter =
|
986
|
+
diamiter = 60
|
768
987
|
else:
|
769
988
|
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
770
989
|
else:
|
@@ -800,8 +1019,9 @@ def _get_object_settings(object_type, settings):
|
|
800
1019
|
|
801
1020
|
elif object_type == 'pathogen':
|
802
1021
|
object_settings['model_name'] = 'cyto'
|
803
|
-
object_settings['filter_size'] =
|
1022
|
+
object_settings['filter_size'] = False
|
804
1023
|
object_settings['filter_intensity'] = False
|
1024
|
+
object_settings['resample'] = False
|
805
1025
|
object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
|
806
1026
|
object_settings['merge'] = settings['merge_pathogens']
|
807
1027
|
|
@@ -2751,15 +2971,37 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
|
|
2751
2971
|
print(f'After {object_type} maximum mean intensity filter: {len(df)}')
|
2752
2972
|
return df
|
2753
2973
|
|
2754
|
-
def
|
2974
|
+
def _get_regex(metadata_type, img_format, custom_regex=None):
|
2975
|
+
|
2976
|
+
if img_format == None:
|
2977
|
+
img_format == '.tif'
|
2978
|
+
if metadata_type == 'cellvoyager':
|
2979
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2980
|
+
elif metadata_type == 'cq1':
|
2981
|
+
regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2982
|
+
elif metadata_type == 'nikon':
|
2983
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2984
|
+
elif metadata_type == 'zeis':
|
2985
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2986
|
+
elif metadata_type == 'leica':
|
2987
|
+
regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
|
2988
|
+
elif metadata_type == 'custom':
|
2989
|
+
regex = f'({custom_regex}){img_format}'
|
2990
|
+
|
2991
|
+
print(f'regex mode:{metadata_type} regex:{regex}')
|
2992
|
+
return regex
|
2993
|
+
|
2994
|
+
def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True):
|
2995
|
+
|
2755
2996
|
if timelapse:
|
2756
2997
|
test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
|
2757
|
-
|
2758
|
-
test_images = 10 # Use 10 sets for non-timelapse scenarios
|
2759
|
-
|
2998
|
+
|
2760
2999
|
test_folder_path = os.path.join(src, 'test')
|
2761
3000
|
os.makedirs(test_folder_path, exist_ok=True)
|
2762
3001
|
regular_expression = re.compile(regex)
|
3002
|
+
|
3003
|
+
if os.path.exists(os.path.join(src, 'orig')):
|
3004
|
+
src = os.path.join(src, 'orig')
|
2763
3005
|
|
2764
3006
|
all_filenames = [filename for filename in os.listdir(src) if regular_expression.match(filename)]
|
2765
3007
|
print(f'Found {len(all_filenames)} files')
|
@@ -2771,25 +3013,20 @@ def _run_test_mode(src, regex, timelapse=False):
|
|
2771
3013
|
plate = match.group('plateID') if 'plateID' in match.groupdict() else os.path.basename(src)
|
2772
3014
|
well = match.group('wellID')
|
2773
3015
|
field = match.group('fieldID')
|
2774
|
-
|
2775
|
-
if timelapse:
|
2776
|
-
set_identifier = (plate, well, field)
|
2777
|
-
else:
|
2778
|
-
# For non-timelapse, you might want to distinguish sets more granularly
|
2779
|
-
# Here, assuming you're grouping by plate, well, and field for simplicity
|
2780
|
-
set_identifier = (plate, well, field)
|
3016
|
+
set_identifier = (plate, well, field)
|
2781
3017
|
images_by_set[set_identifier].append(filename)
|
2782
3018
|
|
2783
3019
|
# Prepare for random selection
|
2784
3020
|
set_identifiers = list(images_by_set.keys())
|
2785
|
-
|
3021
|
+
if random_test:
|
3022
|
+
random.seed(42)
|
2786
3023
|
random.shuffle(set_identifiers) # Randomize the order
|
2787
3024
|
|
2788
3025
|
# Select a subset based on the test_images count
|
2789
3026
|
selected_sets = set_identifiers[:test_images]
|
2790
3027
|
|
2791
3028
|
# Print information about the number of sets used
|
2792
|
-
print(f'Using {
|
3029
|
+
print(f'Using {len(selected_sets)} random image set(s) for test model')
|
2793
3030
|
|
2794
3031
|
# Copy files for selected sets to the test folder
|
2795
3032
|
for set_identifier in selected_sets:
|
@@ -2798,24 +3035,1034 @@ def _run_test_mode(src, regex, timelapse=False):
|
|
2798
3035
|
|
2799
3036
|
return test_folder_path
|
2800
3037
|
|
2801
|
-
def _choose_model(model_name, device, object_type='cell', restore_type=None):
|
3038
|
+
def _choose_model(model_name, device, object_type='cell', restore_type=None, object_settings={}):
|
3039
|
+
|
3040
|
+
if object_type == 'pathogen':
|
3041
|
+
if model_name == 'toxo_pv_lumen':
|
3042
|
+
diameter = object_settings['diameter']
|
3043
|
+
current_dir = os.path.dirname(__file__)
|
3044
|
+
model_path = os.path.join(current_dir, 'models', 'cp', 'toxo_pv_lumen.CP_model')
|
3045
|
+
print(model_path)
|
3046
|
+
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=model_path, diam_mean=diameter, device=device)
|
3047
|
+
#model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type='cyto', device=device)
|
3048
|
+
print(f'Using Toxoplasma PV lumen model to generate pathogen masks')
|
3049
|
+
return model
|
3050
|
+
|
2802
3051
|
restore_list = ['denoise', 'deblur', 'upsample', None]
|
2803
3052
|
if restore_type not in restore_list:
|
2804
3053
|
print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
|
2805
3054
|
restore_type = None
|
2806
3055
|
|
2807
3056
|
if restore_type == None:
|
2808
|
-
|
3057
|
+
if model_name in ['cyto', 'cyto2', 'cyto3', 'nuclei']:
|
3058
|
+
model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type=model_name, device=device)
|
3059
|
+
|
2809
3060
|
else:
|
2810
3061
|
if object_type == 'nucleus':
|
2811
3062
|
restore = f'{type}_nuclei'
|
2812
|
-
model = denoise.CellposeDenoiseModel(gpu=
|
3063
|
+
model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
|
2813
3064
|
else:
|
2814
3065
|
restore = f'{type}_cyto3'
|
2815
3066
|
if model_name =='cyto2':
|
2816
3067
|
chan2_restore = True
|
2817
3068
|
if model_name =='cyto':
|
2818
3069
|
chan2_restore = False
|
2819
|
-
model = denoise.CellposeDenoiseModel(gpu=
|
3070
|
+
model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
|
3071
|
+
|
3072
|
+
return model
|
3073
|
+
|
3074
|
+
class SelectChannels:
|
3075
|
+
def __init__(self, channels):
|
3076
|
+
self.channels = channels
|
3077
|
+
|
3078
|
+
def __call__(self, img):
|
3079
|
+
img = img.clone()
|
3080
|
+
if 1 not in self.channels:
|
3081
|
+
img[0, :, :] = 0 # Zero out the red channel
|
3082
|
+
if 2 not in self.channels:
|
3083
|
+
img[1, :, :] = 0 # Zero out the green channel
|
3084
|
+
if 3 not in self.channels:
|
3085
|
+
img[2, :, :] = 0 # Zero out the blue channel
|
3086
|
+
return img
|
3087
|
+
|
3088
|
+
def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
|
3089
|
+
|
3090
|
+
if normalize:
|
3091
|
+
transform = transforms.Compose([
|
3092
|
+
transforms.ToTensor(),
|
3093
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
3094
|
+
SelectChannels(channels),
|
3095
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
3096
|
+
else:
|
3097
|
+
transform = transforms.Compose([
|
3098
|
+
transforms.ToTensor(),
|
3099
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
3100
|
+
SelectChannels(channels)])
|
3101
|
+
|
3102
|
+
image = Image.open(image_path).convert('RGB')
|
3103
|
+
input_tensor = transform(image).unsqueeze(0)
|
3104
|
+
return image, input_tensor
|
3105
|
+
|
3106
|
+
|
3107
|
+
class SaliencyMapGenerator:
|
3108
|
+
def __init__(self, model):
|
3109
|
+
self.model = model
|
3110
|
+
|
3111
|
+
def compute_saliency_maps(self, X, y):
|
3112
|
+
self.model.eval()
|
3113
|
+
X.requires_grad_()
|
3114
|
+
|
3115
|
+
# Forward pass
|
3116
|
+
scores = self.model(X).squeeze()
|
3117
|
+
|
3118
|
+
# For binary classification, target scores can be the single output
|
3119
|
+
target_scores = scores * (2 * y - 1)
|
3120
|
+
|
3121
|
+
self.model.zero_grad()
|
3122
|
+
target_scores.backward(torch.ones_like(target_scores))
|
3123
|
+
|
3124
|
+
saliency = X.grad.abs()
|
3125
|
+
return saliency
|
3126
|
+
|
3127
|
+
def plot_saliency_maps(self, X, y, saliency, class_names):
|
3128
|
+
N = X.shape[0]
|
3129
|
+
for i in range(N):
|
3130
|
+
plt.subplot(2, N, i + 1)
|
3131
|
+
plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
|
3132
|
+
plt.axis('off')
|
3133
|
+
plt.title(class_names[y[i]])
|
3134
|
+
plt.subplot(2, N, N + i + 1)
|
3135
|
+
plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
|
3136
|
+
plt.axis('off')
|
3137
|
+
plt.gcf().set_size_inches(12, 5)
|
3138
|
+
plt.show()
|
3139
|
+
|
3140
|
+
def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
|
3141
|
+
preprocess = transforms.Compose([
|
3142
|
+
transforms.Resize((image_size, image_size)),
|
3143
|
+
transforms.ToTensor(),
|
3144
|
+
])
|
3145
|
+
|
3146
|
+
image = Image.open(image_path).convert('RGB')
|
3147
|
+
input_tensor = preprocess(image)
|
3148
|
+
if normalize:
|
3149
|
+
input_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_tensor)
|
3150
|
+
input_tensor = input_tensor.unsqueeze(0)
|
3151
|
+
|
3152
|
+
return image, input_tensor
|
3153
|
+
|
3154
|
+
def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names = ['nc', 'pc']):
|
3155
|
+
|
3156
|
+
def jitter(img, ox, oy):
|
3157
|
+
# Randomly jitter the image
|
3158
|
+
return torch.roll(torch.roll(img, ox, dims=2), oy, dims=3)
|
3159
|
+
|
3160
|
+
def blur_image(img, sigma=1):
|
3161
|
+
# Apply Gaussian blur to the image
|
3162
|
+
img_np = img.cpu().numpy()
|
3163
|
+
for i in range(img_np.shape[1]):
|
3164
|
+
img_np[:, i] = gaussian_filter(img_np[:, i], sigma=sigma)
|
3165
|
+
img.copy_(torch.tensor(img_np).to(img.device))
|
3166
|
+
|
3167
|
+
def deprocess(img_tensor):
|
3168
|
+
# Convert the tensor image to a numpy array for visualization
|
3169
|
+
img_tensor = img_tensor.clone()
|
3170
|
+
for c in range(3):
|
3171
|
+
img_tensor[:, c] = img_tensor[:, c] * SQUEEZENET_STD[c] + SQUEEZENET_MEAN[c]
|
3172
|
+
img_tensor = img_tensor.clamp(0, 1)
|
3173
|
+
return img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
|
3174
|
+
|
3175
|
+
# Assuming these are defined somewhere in your codebase
|
3176
|
+
SQUEEZENET_MEAN = [0.485, 0.456, 0.406]
|
3177
|
+
SQUEEZENET_STD = [0.229, 0.224, 0.225]
|
3178
|
+
|
3179
|
+
model = torch.load(model_path)
|
3180
|
+
|
3181
|
+
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
|
3182
|
+
len_chans = len(channels)
|
3183
|
+
model.type(dtype)
|
3184
|
+
|
3185
|
+
# Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
|
3186
|
+
img = torch.randn(1, len_chans, img_size, img_size).mul_(1.0).type(dtype).requires_grad_()
|
3187
|
+
|
3188
|
+
for t in range(num_iterations):
|
3189
|
+
# Randomly jitter the image a bit; this gives slightly nicer results
|
3190
|
+
ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
|
3191
|
+
img.data.copy_(jitter(img.data, ox, oy))
|
3192
|
+
|
3193
|
+
# Forward pass
|
3194
|
+
score = model(img)
|
3195
|
+
|
3196
|
+
if target_y == 0:
|
3197
|
+
target_score = -score
|
3198
|
+
else:
|
3199
|
+
target_score = score
|
3200
|
+
|
3201
|
+
# Add regularization
|
3202
|
+
target_score = target_score - l2_reg * torch.norm(img)
|
3203
|
+
|
3204
|
+
# Backward pass
|
3205
|
+
target_score.backward()
|
3206
|
+
|
3207
|
+
# Gradient ascent step
|
3208
|
+
with torch.no_grad():
|
3209
|
+
img += learning_rate * img.grad / torch.norm(img.grad)
|
3210
|
+
img.grad.zero_()
|
3211
|
+
|
3212
|
+
# Undo the random jitter
|
3213
|
+
img.data.copy_(jitter(img.data, -ox, -oy))
|
3214
|
+
|
3215
|
+
# As regularizer, clamp and periodically blur the image
|
3216
|
+
for c in range(3):
|
3217
|
+
lo = float(-SQUEEZENET_MEAN[c] / SQUEEZENET_STD[c])
|
3218
|
+
hi = float((1.0 - SQUEEZENET_MEAN[c]) / SQUEEZENET_STD[c])
|
3219
|
+
img.data[:, c].clamp_(min=lo, max=hi)
|
3220
|
+
if t % blur_every == 0:
|
3221
|
+
blur_image(img.data, sigma=0.5)
|
3222
|
+
|
3223
|
+
# Periodically show the image
|
3224
|
+
if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
|
3225
|
+
plt.imshow(deprocess(img.data.clone().cpu()))
|
3226
|
+
class_name = class_names[target_y]
|
3227
|
+
plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
|
3228
|
+
plt.gcf().set_size_inches(4, 4)
|
3229
|
+
plt.axis('off')
|
3230
|
+
plt.show()
|
3231
|
+
|
3232
|
+
return deprocess(img.data.cpu())
|
3233
|
+
|
3234
|
+
def get_submodules(model, prefix=''):
|
3235
|
+
submodules = []
|
3236
|
+
for name, module in model.named_children():
|
3237
|
+
full_name = prefix + ('.' if prefix else '') + name
|
3238
|
+
submodules.append(full_name)
|
3239
|
+
submodules.extend(get_submodules(module, full_name))
|
3240
|
+
return submodules
|
3241
|
+
|
3242
|
+
class GradCAM:
|
3243
|
+
def __init__(self, model, target_layers=None, use_cuda=True):
|
3244
|
+
self.model = model
|
3245
|
+
self.model.eval()
|
3246
|
+
self.target_layers = target_layers
|
3247
|
+
self.cuda = use_cuda
|
3248
|
+
if self.cuda:
|
3249
|
+
self.model = model.cuda()
|
3250
|
+
|
3251
|
+
def forward(self, input):
|
3252
|
+
return self.model(input)
|
3253
|
+
|
3254
|
+
def __call__(self, x, index=None):
|
3255
|
+
if self.cuda:
|
3256
|
+
x = x.cuda()
|
3257
|
+
|
3258
|
+
features = []
|
3259
|
+
def hook(module, input, output):
|
3260
|
+
features.append(output)
|
3261
|
+
|
3262
|
+
handles = []
|
3263
|
+
for name, module in self.model.named_modules():
|
3264
|
+
if name in self.target_layers:
|
3265
|
+
handles.append(module.register_forward_hook(hook))
|
3266
|
+
|
3267
|
+
output = self.forward(x)
|
3268
|
+
if index is None:
|
3269
|
+
index = np.argmax(output.data.cpu().numpy())
|
3270
|
+
|
3271
|
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
3272
|
+
one_hot[0][index] = 1
|
3273
|
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
3274
|
+
if self.cuda:
|
3275
|
+
one_hot = one_hot.cuda()
|
3276
|
+
|
3277
|
+
one_hot = torch.sum(one_hot * output)
|
3278
|
+
self.model.zero_grad()
|
3279
|
+
one_hot.backward(retain_graph=True)
|
3280
|
+
|
3281
|
+
grads_val = features[0].grad.cpu().data.numpy()
|
3282
|
+
target = features[0].cpu().data.numpy()[0, :]
|
3283
|
+
|
3284
|
+
weights = np.mean(grads_val, axis=(2, 3))[0, :]
|
3285
|
+
cam = np.zeros(target.shape[1:], dtype=np.float32)
|
3286
|
+
|
3287
|
+
for i, w in enumerate(weights):
|
3288
|
+
cam += w * target[i, :, :]
|
3289
|
+
|
3290
|
+
cam = np.maximum(cam, 0)
|
3291
|
+
cam = cv2.resize(cam, (x.size(2), x.size(3)))
|
3292
|
+
cam = cam - np.min(cam)
|
3293
|
+
cam = cam / np.max(cam)
|
3294
|
+
|
3295
|
+
for handle in handles:
|
3296
|
+
handle.remove()
|
3297
|
+
|
3298
|
+
return cam
|
3299
|
+
|
3300
|
+
def show_cam_on_image(img, mask):
|
3301
|
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
3302
|
+
heatmap = np.float32(heatmap) / 255
|
3303
|
+
cam = heatmap + np.float32(img)
|
3304
|
+
cam = cam / np.max(cam)
|
3305
|
+
return np.uint8(255 * cam)
|
3306
|
+
|
3307
|
+
def recommend_target_layers(model):
|
3308
|
+
target_layers = []
|
3309
|
+
for name, module in model.named_modules():
|
3310
|
+
if isinstance(module, torch.nn.Conv2d):
|
3311
|
+
target_layers.append(name)
|
3312
|
+
# Choose the last conv layer as the recommended target layer
|
3313
|
+
if target_layers:
|
3314
|
+
return [target_layers[-1]], target_layers
|
3315
|
+
else:
|
3316
|
+
raise ValueError("No convolutional layers found in the model.")
|
2820
3317
|
|
2821
|
-
|
3318
|
+
class IntegratedGradients:
|
3319
|
+
def __init__(self, model):
|
3320
|
+
self.model = model
|
3321
|
+
self.model.eval()
|
3322
|
+
|
3323
|
+
def generate_integrated_gradients(self, input_tensor, target_label_idx, baseline=None, num_steps=50):
|
3324
|
+
if baseline is None:
|
3325
|
+
baseline = torch.zeros_like(input_tensor)
|
3326
|
+
|
3327
|
+
assert baseline.shape == input_tensor.shape
|
3328
|
+
|
3329
|
+
# Scale input and compute gradients
|
3330
|
+
scaled_inputs = [(baseline + (float(i) / num_steps) * (input_tensor - baseline)).requires_grad_(True) for i in range(0, num_steps + 1)]
|
3331
|
+
grads = []
|
3332
|
+
for scaled_input in scaled_inputs:
|
3333
|
+
out = self.model(scaled_input)
|
3334
|
+
self.model.zero_grad()
|
3335
|
+
out[0, target_label_idx].backward(retain_graph=True)
|
3336
|
+
grads.append(scaled_input.grad.data.cpu().numpy())
|
3337
|
+
|
3338
|
+
avg_grads = np.mean(grads[:-1], axis=0)
|
3339
|
+
integrated_grads = (input_tensor.cpu().data.numpy() - baseline.cpu().data.numpy()) * avg_grads
|
3340
|
+
return integrated_grads
|
3341
|
+
|
3342
|
+
def get_db_paths(src):
|
3343
|
+
if isinstance(src, str):
|
3344
|
+
src = [src]
|
3345
|
+
db_paths = [os.path.join(source, 'measurements/measurements.db') for source in src]
|
3346
|
+
return db_paths
|
3347
|
+
|
3348
|
+
def get_sequencing_paths(src):
|
3349
|
+
if isinstance(src, str):
|
3350
|
+
src = [src]
|
3351
|
+
seq_paths = [os.path.join(source, 'sequencing/sequencing_data.csv') for source in src]
|
3352
|
+
return seq_paths
|
3353
|
+
|
3354
|
+
def load_image_paths(c, visualize):
|
3355
|
+
c.execute(f'SELECT * FROM png_list')
|
3356
|
+
data = c.fetchall()
|
3357
|
+
columns_info = c.execute(f'PRAGMA table_info(png_list)').fetchall()
|
3358
|
+
column_names = [col_info[1] for col_info in columns_info]
|
3359
|
+
image_paths_df = pd.DataFrame(data, columns=column_names)
|
3360
|
+
if visualize:
|
3361
|
+
object_visualize = visualize + '_png'
|
3362
|
+
image_paths_df = image_paths_df[image_paths_df['png_path'].str.contains(object_visualize)]
|
3363
|
+
image_paths_df = image_paths_df.set_index('prcfo')
|
3364
|
+
return image_paths_df
|
3365
|
+
|
3366
|
+
def merge_dataframes(df, image_paths_df, verbose):
|
3367
|
+
df.set_index('prcfo', inplace=True)
|
3368
|
+
df = image_paths_df.merge(df, left_index=True, right_index=True)
|
3369
|
+
if verbose:
|
3370
|
+
display(df)
|
3371
|
+
return df
|
3372
|
+
|
3373
|
+
def remove_highly_correlated_columns(df, threshold):
|
3374
|
+
corr_matrix = df.corr().abs()
|
3375
|
+
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
3376
|
+
to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > threshold)]
|
3377
|
+
return df.drop(to_drop, axis=1)
|
3378
|
+
|
3379
|
+
def filter_columns(df, filter_by):
|
3380
|
+
if filter_by != 'morphology':
|
3381
|
+
cols_to_include = [col for col in df.columns if filter_by in str(col)]
|
3382
|
+
else:
|
3383
|
+
cols_to_include = [col for col in df.columns if 'channel' not in str(col)]
|
3384
|
+
df = df[cols_to_include]
|
3385
|
+
return df
|
3386
|
+
|
3387
|
+
def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1, mode='fit', model=False):
|
3388
|
+
"""
|
3389
|
+
Perform dimensionality reduction and clustering on the given data.
|
3390
|
+
|
3391
|
+
Parameters:
|
3392
|
+
numeric_data (np.ndarray): Numeric data for embedding and clustering.
|
3393
|
+
n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
|
3394
|
+
min_dist (float): Minimum distance for UMAP.
|
3395
|
+
metric (str): Metric for UMAP and DBSCAN.
|
3396
|
+
eps (float): Epsilon for DBSCAN.
|
3397
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
3398
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3399
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3400
|
+
verbose (bool): Whether to print verbose output.
|
3401
|
+
embedding (np.ndarray, optional): Precomputed embedding. Default is None.
|
3402
|
+
return_model (bool): Whether to return the reducer model. Default is False.
|
3403
|
+
|
3404
|
+
Returns:
|
3405
|
+
tuple: embedding, labels (and optionally the reducer model)
|
3406
|
+
"""
|
3407
|
+
|
3408
|
+
if verbose:
|
3409
|
+
v = 1
|
3410
|
+
else:
|
3411
|
+
v = 0
|
3412
|
+
|
3413
|
+
if isinstance(n_neighbors, float):
|
3414
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3415
|
+
|
3416
|
+
if n_neighbors <= 2:
|
3417
|
+
n_neighbors = 2
|
3418
|
+
|
3419
|
+
if mode == 'fit':
|
3420
|
+
if reduction_method == 'umap':
|
3421
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors,
|
3422
|
+
n_components=2,
|
3423
|
+
metric=metric,
|
3424
|
+
n_epochs=None,
|
3425
|
+
learning_rate=1.0,
|
3426
|
+
init='spectral',
|
3427
|
+
min_dist=min_dist,
|
3428
|
+
spread=1.0,
|
3429
|
+
set_op_mix_ratio=1.0,
|
3430
|
+
local_connectivity=1,
|
3431
|
+
repulsion_strength=1.0,
|
3432
|
+
negative_sample_rate=5,
|
3433
|
+
transform_queue_size=4.0,
|
3434
|
+
a=None,
|
3435
|
+
b=None,
|
3436
|
+
random_state=42,
|
3437
|
+
metric_kwds=None,
|
3438
|
+
angular_rp_forest=False,
|
3439
|
+
target_n_neighbors=-1,
|
3440
|
+
target_metric='categorical',
|
3441
|
+
target_metric_kwds=None,
|
3442
|
+
target_weight=0.5,
|
3443
|
+
transform_seed=42,
|
3444
|
+
n_jobs=n_jobs,
|
3445
|
+
verbose=verbose)
|
3446
|
+
|
3447
|
+
elif reduction_method == 'tsne':
|
3448
|
+
reducer = TSNE(n_components=2,
|
3449
|
+
perplexity=n_neighbors,
|
3450
|
+
early_exaggeration=12.0,
|
3451
|
+
learning_rate=200.0,
|
3452
|
+
n_iter=1000,
|
3453
|
+
n_iter_without_progress=300,
|
3454
|
+
min_grad_norm=1e-7,
|
3455
|
+
metric=metric,
|
3456
|
+
init='random',
|
3457
|
+
verbose=v,
|
3458
|
+
random_state=42,
|
3459
|
+
method='barnes_hut',
|
3460
|
+
angle=0.5,
|
3461
|
+
n_jobs=n_jobs)
|
3462
|
+
|
3463
|
+
else:
|
3464
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
3465
|
+
|
3466
|
+
embedding = reducer.fit_transform(numeric_data)
|
3467
|
+
if verbose:
|
3468
|
+
print(f'Trained and fit reducer')
|
3469
|
+
|
3470
|
+
else:
|
3471
|
+
if not model is None:
|
3472
|
+
embedding = model.transform(numeric_data)
|
3473
|
+
reducer = model
|
3474
|
+
if verbose:
|
3475
|
+
print(f'Fit data to reducer')
|
3476
|
+
else:
|
3477
|
+
raise ValueError(f"Model is None. Please provide a model for transform.")
|
3478
|
+
|
3479
|
+
if clustering == 'dbscan':
|
3480
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
|
3481
|
+
elif clustering == 'kmeans':
|
3482
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
3483
|
+
|
3484
|
+
clustering_model.fit(embedding)
|
3485
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
3486
|
+
|
3487
|
+
if verbose:
|
3488
|
+
print(f'Embedding shape: {embedding.shape}')
|
3489
|
+
|
3490
|
+
return embedding, labels, reducer
|
3491
|
+
|
3492
|
+
def reduction_and_clustering_v1(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1):
|
3493
|
+
"""
|
3494
|
+
Perform dimensionality reduction and clustering on the given data.
|
3495
|
+
|
3496
|
+
Parameters:
|
3497
|
+
numeric_data (np.ndarray): Numeric data for embedding and clustering.
|
3498
|
+
n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
|
3499
|
+
min_dist (float): Minimum distance for UMAP.
|
3500
|
+
metric (str): Metric for UMAP and DBSCAN.
|
3501
|
+
eps (float): Epsilon for DBSCAN.
|
3502
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
3503
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3504
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3505
|
+
verbose (bool): Whether to print verbose output.
|
3506
|
+
embedding (np.ndarray, optional): Precomputed embedding. Default is None.
|
3507
|
+
|
3508
|
+
Returns:
|
3509
|
+
tuple: embedding, labels
|
3510
|
+
"""
|
3511
|
+
|
3512
|
+
if verbose:
|
3513
|
+
v=1
|
3514
|
+
else:
|
3515
|
+
v=0
|
3516
|
+
|
3517
|
+
if isinstance(n_neighbors, float):
|
3518
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3519
|
+
|
3520
|
+
if n_neighbors <= 2:
|
3521
|
+
n_neighbors = 2
|
3522
|
+
|
3523
|
+
if reduction_method == 'umap':
|
3524
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors,
|
3525
|
+
n_components=2,
|
3526
|
+
metric=metric,
|
3527
|
+
n_epochs=None,
|
3528
|
+
learning_rate=1.0,
|
3529
|
+
init='spectral',
|
3530
|
+
min_dist=min_dist,
|
3531
|
+
spread=1.0,
|
3532
|
+
set_op_mix_ratio=1.0,
|
3533
|
+
local_connectivity=1,
|
3534
|
+
repulsion_strength=1.0,
|
3535
|
+
negative_sample_rate=5,
|
3536
|
+
transform_queue_size=4.0,
|
3537
|
+
a=None,
|
3538
|
+
b=None,
|
3539
|
+
random_state=42,
|
3540
|
+
metric_kwds=None,
|
3541
|
+
angular_rp_forest=False,
|
3542
|
+
target_n_neighbors=-1,
|
3543
|
+
target_metric='categorical',
|
3544
|
+
target_metric_kwds=None,
|
3545
|
+
target_weight=0.5,
|
3546
|
+
transform_seed=42,
|
3547
|
+
n_jobs=n_jobs,
|
3548
|
+
verbose=verbose)
|
3549
|
+
|
3550
|
+
elif reduction_method == 'tsne':
|
3551
|
+
|
3552
|
+
#tsne_params.setdefault('n_components', 2)
|
3553
|
+
#reducer = TSNE(**tsne_params)
|
3554
|
+
|
3555
|
+
reducer = TSNE(n_components=2,
|
3556
|
+
perplexity=n_neighbors,
|
3557
|
+
early_exaggeration=12.0,
|
3558
|
+
learning_rate=200.0,
|
3559
|
+
n_iter=1000,
|
3560
|
+
n_iter_without_progress=300,
|
3561
|
+
min_grad_norm=1e-7,
|
3562
|
+
metric=metric,
|
3563
|
+
init='random',
|
3564
|
+
verbose=v,
|
3565
|
+
random_state=42,
|
3566
|
+
method='barnes_hut',
|
3567
|
+
angle=0.5,
|
3568
|
+
n_jobs=n_jobs)
|
3569
|
+
|
3570
|
+
else:
|
3571
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
3572
|
+
|
3573
|
+
if embedding is None:
|
3574
|
+
embedding = reducer.fit_transform(numeric_data)
|
3575
|
+
|
3576
|
+
if clustering == 'dbscan':
|
3577
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
|
3578
|
+
elif clustering == 'kmeans':
|
3579
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
3580
|
+
else:
|
3581
|
+
raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
|
3582
|
+
|
3583
|
+
clustering_model.fit(embedding)
|
3584
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
3585
|
+
|
3586
|
+
if verbose:
|
3587
|
+
print(f'Embedding shape: {embedding.shape}')
|
3588
|
+
|
3589
|
+
return embedding, labels
|
3590
|
+
|
3591
|
+
def remove_noise(embedding, labels):
|
3592
|
+
non_noise_indices = labels != -1
|
3593
|
+
embedding = embedding[non_noise_indices]
|
3594
|
+
labels = labels[non_noise_indices]
|
3595
|
+
return embedding, labels
|
3596
|
+
|
3597
|
+
def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, plot_by_cluster, plot_outlines, plot_points, plot_images, smooth_lines, black_background, figuresize, dot_size, remove_image_canvas, verbose):
|
3598
|
+
unique_labels = np.unique(labels)
|
3599
|
+
#num_clusters = len(unique_labels[unique_labels != 0])
|
3600
|
+
colors, label_to_color_index = assign_colors(unique_labels, colors)
|
3601
|
+
cluster_centers = [np.mean(embedding[labels == cluster_label], axis=0) for cluster_label in unique_labels]
|
3602
|
+
fig, ax = setup_plot(figuresize, black_background)
|
3603
|
+
plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize, dot_size, verbose)
|
3604
|
+
if not image_paths is None and plot_images:
|
3605
|
+
plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose)
|
3606
|
+
plt.show()
|
3607
|
+
return fig
|
3608
|
+
|
3609
|
+
def generate_colors(num_clusters, black_background):
|
3610
|
+
random_colors = np.random.rand(num_clusters + 1, 4)
|
3611
|
+
random_colors[:, 3] = 1
|
3612
|
+
specific_colors = [
|
3613
|
+
[155 / 255, 55 / 255, 155 / 255, 1],
|
3614
|
+
[55 / 255, 155 / 255, 155 / 255, 1],
|
3615
|
+
[55 / 255, 155 / 255, 255 / 255, 1],
|
3616
|
+
[255 / 255, 55 / 255, 155 / 255, 1]
|
3617
|
+
]
|
3618
|
+
random_colors = np.vstack((specific_colors, random_colors[len(specific_colors):]))
|
3619
|
+
if not black_background:
|
3620
|
+
random_colors = np.vstack(([0, 0, 0, 1], random_colors))
|
3621
|
+
return random_colors
|
3622
|
+
|
3623
|
+
def assign_colors(unique_labels, random_colors):
|
3624
|
+
normalized_colors = random_colors / 255
|
3625
|
+
colors_img = [tuple(color) for color in normalized_colors]
|
3626
|
+
colors = [tuple(color) for color in random_colors]
|
3627
|
+
label_to_color_index = {label: index for index, label in enumerate(unique_labels)}
|
3628
|
+
return colors, label_to_color_index
|
3629
|
+
|
3630
|
+
def setup_plot(figuresize, black_background):
|
3631
|
+
if black_background:
|
3632
|
+
plt.rcParams.update({'figure.facecolor': 'black', 'axes.facecolor': 'black', 'text.color': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'axes.labelcolor': 'white'})
|
3633
|
+
else:
|
3634
|
+
plt.rcParams.update({'figure.facecolor': 'white', 'axes.facecolor': 'white', 'text.color': 'black', 'xtick.color': 'black', 'ytick.color': 'black', 'axes.labelcolor': 'black'})
|
3635
|
+
fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
|
3636
|
+
return fig, ax
|
3637
|
+
|
3638
|
+
def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize=50, dot_size=50, verbose=False):
|
3639
|
+
unique_labels = np.unique(labels)
|
3640
|
+
for cluster_label, color, center in zip(unique_labels, colors, cluster_centers):
|
3641
|
+
cluster_data = embedding[labels == cluster_label]
|
3642
|
+
if smooth_lines:
|
3643
|
+
if cluster_data.shape[0] > 2:
|
3644
|
+
x_smooth, y_smooth = smooth_hull_lines(cluster_data)
|
3645
|
+
if plot_outlines:
|
3646
|
+
plt.plot(x_smooth, y_smooth, color=color, linewidth=2)
|
3647
|
+
else:
|
3648
|
+
if cluster_data.shape[0] > 2:
|
3649
|
+
hull = ConvexHull(cluster_data)
|
3650
|
+
for simplex in hull.simplices:
|
3651
|
+
if plot_outlines:
|
3652
|
+
plt.plot(hull.points[simplex, 0], hull.points[simplex, 1], color=color, linewidth=4)
|
3653
|
+
if plot_points:
|
3654
|
+
scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0.5, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
|
3655
|
+
else:
|
3656
|
+
scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
|
3657
|
+
ax.text(center[0], center[1], str(cluster_label), fontsize=12, ha='center', va='center')
|
3658
|
+
plt.legend(loc='best', fontsize=int(figuresize * 0.75))
|
3659
|
+
plt.xlabel('UMAP Dimension 1', fontsize=int(figuresize * 0.75))
|
3660
|
+
plt.ylabel('UMAP Dimension 2', fontsize=int(figuresize * 0.75))
|
3661
|
+
plt.tick_params(axis='both', which='major', labelsize=int(figuresize * 0.75))
|
3662
|
+
|
3663
|
+
def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose):
|
3664
|
+
if plot_by_cluster:
|
3665
|
+
cluster_indices = {label: np.where(labels == label)[0] for label in np.unique(labels) if label != -1}
|
3666
|
+
plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose)
|
3667
|
+
else:
|
3668
|
+
indices = random.sample(range(len(embedding)), image_nr)
|
3669
|
+
for i, index in enumerate(indices):
|
3670
|
+
x, y = embedding[index]
|
3671
|
+
img = Image.open(image_paths[index])
|
3672
|
+
plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
|
3673
|
+
|
3674
|
+
def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose):
|
3675
|
+
for cluster_label, color in zip(np.unique(labels), colors):
|
3676
|
+
if cluster_label == -1:
|
3677
|
+
continue
|
3678
|
+
indices = cluster_indices.get(cluster_label, [])
|
3679
|
+
if len(indices) > image_nr:
|
3680
|
+
indices = random.sample(list(indices), image_nr)
|
3681
|
+
for index in indices:
|
3682
|
+
x, y = embedding[index]
|
3683
|
+
img = Image.open(image_paths[index])
|
3684
|
+
plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
|
3685
|
+
|
3686
|
+
def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
|
3687
|
+
img = np.array(img)
|
3688
|
+
if remove_image_canvas:
|
3689
|
+
img = remove_canvas(img)
|
3690
|
+
imagebox = OffsetImage(img, zoom=img_zoom)
|
3691
|
+
ab = AnnotationBbox(imagebox, (x, y), frameon=False)
|
3692
|
+
ax.add_artist(ab)
|
3693
|
+
|
3694
|
+
def remove_canvas(img):
|
3695
|
+
if img.mode in ['L', 'I']:
|
3696
|
+
img_data = np.array(img)
|
3697
|
+
img_data = img_data / np.max(img_data)
|
3698
|
+
alpha_channel = (img_data > 0).astype(float)
|
3699
|
+
img_data_rgb = np.stack([img_data] * 3, axis=-1)
|
3700
|
+
img_data_with_alpha = np.dstack([img_data_rgb, alpha_channel])
|
3701
|
+
elif img.mode == 'RGB':
|
3702
|
+
img_data = np.array(img)
|
3703
|
+
img_data = img_data / 255.0
|
3704
|
+
alpha_channel = (np.sum(img_data, axis=-1) > 0).astype(float)
|
3705
|
+
img_data_with_alpha = np.dstack([img_data, alpha_channel])
|
3706
|
+
else:
|
3707
|
+
raise ValueError(f"Unsupported image mode: {img.mode}")
|
3708
|
+
return img_data_with_alpha
|
3709
|
+
|
3710
|
+
def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figuresize, black_background, verbose):
|
3711
|
+
unique_labels = np.unique(labels)
|
3712
|
+
num_clusters = len(unique_labels[unique_labels != -1])
|
3713
|
+
if num_clusters == 0:
|
3714
|
+
print("No clusters found.")
|
3715
|
+
return
|
3716
|
+
cluster_images = {label: [] for label in unique_labels if label != -1}
|
3717
|
+
cluster_indices = {label: np.where(labels == label)[0] for label in unique_labels if label != -1}
|
3718
|
+
for cluster_label, indices in cluster_indices.items():
|
3719
|
+
if cluster_label == -1:
|
3720
|
+
continue
|
3721
|
+
if len(indices) > image_nr:
|
3722
|
+
indices = random.sample(list(indices), image_nr)
|
3723
|
+
for index in indices:
|
3724
|
+
img_path = image_paths[index]
|
3725
|
+
img_array = Image.open(img_path)
|
3726
|
+
img = np.array(img_array)
|
3727
|
+
cluster_images[cluster_label].append(img)
|
3728
|
+
fig = plot_grid(cluster_images, colors, figuresize, black_background, verbose)
|
3729
|
+
return fig
|
3730
|
+
|
3731
|
+
def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
|
3732
|
+
num_clusters = len(cluster_images)
|
3733
|
+
max_figsize = 200 # Set a maximum figure size
|
3734
|
+
if figuresize * num_clusters > max_figsize:
|
3735
|
+
figuresize = max_figsize / num_clusters
|
3736
|
+
|
3737
|
+
grid_fig, grid_axes = plt.subplots(1, num_clusters, figsize=(figuresize * num_clusters, figuresize), gridspec_kw={'wspace': 0.2, 'hspace': 0})
|
3738
|
+
if num_clusters == 1:
|
3739
|
+
grid_axes = [grid_axes] # Ensure grid_axes is always iterable
|
3740
|
+
for cluster_label, axes in zip(cluster_images.keys(), grid_axes):
|
3741
|
+
images = cluster_images[cluster_label]
|
3742
|
+
num_images = len(images)
|
3743
|
+
grid_size = int(np.ceil(np.sqrt(num_images)))
|
3744
|
+
image_size = 0.9 / grid_size
|
3745
|
+
whitespace = (1 - grid_size * image_size) / (grid_size + 1)
|
3746
|
+
|
3747
|
+
if isinstance(cluster_label, str):
|
3748
|
+
idx = list(cluster_images.keys()).index(cluster_label)
|
3749
|
+
color = colors[idx]
|
3750
|
+
if verbose:
|
3751
|
+
print(f'Lable: {cluster_label} index: {idx}')
|
3752
|
+
else:
|
3753
|
+
color = colors[cluster_label]
|
3754
|
+
|
3755
|
+
axes.add_patch(plt.Rectangle((0, 0), 1, 1, transform=axes.transAxes, color=color[:3]))
|
3756
|
+
axes.axis('off')
|
3757
|
+
for i, img in enumerate(images):
|
3758
|
+
row = i // grid_size
|
3759
|
+
col = i % grid_size
|
3760
|
+
x_pos = (col + 1) * whitespace + col * image_size
|
3761
|
+
y_pos = 1 - ((row + 1) * whitespace + (row + 1) * image_size)
|
3762
|
+
ax_img = axes.inset_axes([x_pos, y_pos, image_size, image_size], transform=axes.transAxes)
|
3763
|
+
ax_img.imshow(img, cmap='gray', aspect='auto')
|
3764
|
+
ax_img.axis('off')
|
3765
|
+
ax_img.set_aspect('equal')
|
3766
|
+
ax_img.set_facecolor(color[:3])
|
3767
|
+
|
3768
|
+
# Add cluster labels beside the UMAP plot
|
3769
|
+
spacing_factor = 0.5 # Adjust this value to control the spacing between labels
|
3770
|
+
for i, (cluster_label, color) in enumerate(zip(cluster_images.keys(), colors)):
|
3771
|
+
label_y = 1 - (i + 1) * (spacing_factor / num_clusters) # Adjust y position for each label
|
3772
|
+
grid_fig.text(1.05, label_y, f'Cluster {cluster_label}', verticalalignment='center', fontsize=figuresize, color='black' if not black_background else 'white')
|
3773
|
+
grid_fig.patches.append(plt.Rectangle((1, label_y - 0.02), 0.03, 0.03, transform=grid_fig.transFigure, color=color[:3], clip_on=False))
|
3774
|
+
|
3775
|
+
plt.show()
|
3776
|
+
return grid_fig
|
3777
|
+
|
3778
|
+
def correct_paths(df, base_path):
|
3779
|
+
|
3780
|
+
if 'png_path' not in df.columns:
|
3781
|
+
print("No 'png_path' column found in the dataframe.")
|
3782
|
+
return df, None
|
3783
|
+
|
3784
|
+
image_paths = df['png_path'].to_list()
|
3785
|
+
|
3786
|
+
adjusted_image_paths = []
|
3787
|
+
for path in image_paths:
|
3788
|
+
if base_path not in path:
|
3789
|
+
parts = path.split('/data/')
|
3790
|
+
if len(parts) > 1:
|
3791
|
+
new_path = os.path.join(base_path, 'data', parts[1])
|
3792
|
+
adjusted_image_paths.append(new_path)
|
3793
|
+
else:
|
3794
|
+
adjusted_image_paths.append(path)
|
3795
|
+
else:
|
3796
|
+
adjusted_image_paths.append(path)
|
3797
|
+
|
3798
|
+
df['png_path'] = adjusted_image_paths
|
3799
|
+
image_paths = df['png_path'].to_list()
|
3800
|
+
return df, image_paths
|
3801
|
+
|
3802
|
+
def correct_paths_v1(df, base_path):
|
3803
|
+
if 'png_path' not in df.columns:
|
3804
|
+
print("No 'png_path' column found in the dataframe.")
|
3805
|
+
return df, None
|
3806
|
+
|
3807
|
+
image_paths = df['png_path'].to_list()
|
3808
|
+
|
3809
|
+
adjusted_image_paths = []
|
3810
|
+
for path in image_paths:
|
3811
|
+
if base_path not in path:
|
3812
|
+
print(f"Adjusting path: {path}")
|
3813
|
+
parts = path.split('data/')
|
3814
|
+
if len(parts) > 1:
|
3815
|
+
new_path = os.path.join(base_path, 'data', parts[1])
|
3816
|
+
adjusted_image_paths.append(new_path)
|
3817
|
+
else:
|
3818
|
+
adjusted_image_paths.append(path)
|
3819
|
+
else:
|
3820
|
+
adjusted_image_paths.append(path)
|
3821
|
+
|
3822
|
+
df['png_path'] = adjusted_image_paths
|
3823
|
+
image_paths = df['png_path'].to_list()
|
3824
|
+
return df, image_paths
|
3825
|
+
|
3826
|
+
def get_umap_image_settings(settings={}):
|
3827
|
+
settings.setdefault('src', 'path')
|
3828
|
+
settings.setdefault('row_limit', 1000)
|
3829
|
+
settings.setdefault('tables', ['cell', 'cytoplasm', 'nucleus', 'pathogen'])
|
3830
|
+
settings.setdefault('visualize', 'cell')
|
3831
|
+
settings.setdefault('image_nr', 16)
|
3832
|
+
settings.setdefault('dot_size', 50)
|
3833
|
+
settings.setdefault('n_neighbors', 1000)
|
3834
|
+
settings.setdefault('min_dist', 0.1)
|
3835
|
+
settings.setdefault('metric', 'euclidean')
|
3836
|
+
settings.setdefault('eps', 0.5)
|
3837
|
+
settings.setdefault('min_samples', 1000)
|
3838
|
+
settings.setdefault('filter_by', 'channel_0')
|
3839
|
+
settings.setdefault('img_zoom', 0.5)
|
3840
|
+
settings.setdefault('plot_by_cluster', True)
|
3841
|
+
settings.setdefault('plot_cluster_grids', True)
|
3842
|
+
settings.setdefault('remove_cluster_noise', True)
|
3843
|
+
settings.setdefault('remove_highly_correlated', True)
|
3844
|
+
settings.setdefault('log_data', False)
|
3845
|
+
settings.setdefault('figuresize', 60)
|
3846
|
+
settings.setdefault('black_background', True)
|
3847
|
+
settings.setdefault('remove_image_canvas', False)
|
3848
|
+
settings.setdefault('plot_outlines', True)
|
3849
|
+
settings.setdefault('plot_points', True)
|
3850
|
+
settings.setdefault('smooth_lines', True)
|
3851
|
+
settings.setdefault('clustering', 'dbscan')
|
3852
|
+
settings.setdefault('exclude', None)
|
3853
|
+
settings.setdefault('col_to_compare', 'col')
|
3854
|
+
settings.setdefault('pos', 'c1')
|
3855
|
+
settings.setdefault('neg', 'c2')
|
3856
|
+
settings.setdefault('embedding_by_controls', False)
|
3857
|
+
settings.setdefault('plot_images', True)
|
3858
|
+
settings.setdefault('reduction_method','umap')
|
3859
|
+
settings.setdefault('save_figure', False)
|
3860
|
+
settings.setdefault('n_jobs', -1)
|
3861
|
+
settings.setdefault('color_by', None)
|
3862
|
+
settings.setdefault('neg', 'c1')
|
3863
|
+
settings.setdefault('pos', 'c2')
|
3864
|
+
settings.setdefault('mix', 'c3')
|
3865
|
+
settings.setdefault('mix', 'c3')
|
3866
|
+
settings.setdefault('exclude_conditions', None)
|
3867
|
+
settings.setdefault('analyze_clusters', False)
|
3868
|
+
settings.setdefault('resnet_features', False)
|
3869
|
+
settings.setdefault('verbose',True)
|
3870
|
+
return settings
|
3871
|
+
|
3872
|
+
def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude):
|
3873
|
+
"""
|
3874
|
+
Preprocesses the given dataframe by applying filtering, removing highly correlated columns,
|
3875
|
+
applying log transformation, filling NaN values, and scaling the numeric data.
|
3876
|
+
|
3877
|
+
Args:
|
3878
|
+
df (pandas.DataFrame): The input dataframe.
|
3879
|
+
filter_by (str or None): The channel of interest to filter the dataframe by.
|
3880
|
+
remove_highly_correlated (bool or float): Whether to remove highly correlated columns.
|
3881
|
+
If a float is provided, it represents the correlation threshold.
|
3882
|
+
log_data (bool): Whether to apply log transformation to the numeric data.
|
3883
|
+
exclude (list or None): List of features to exclude from the filtering process.
|
3884
|
+
verbose (bool): Whether to print verbose output during preprocessing.
|
3885
|
+
|
3886
|
+
Returns:
|
3887
|
+
numpy.ndarray: The preprocessed numeric data.
|
3888
|
+
|
3889
|
+
Raises:
|
3890
|
+
ValueError: If no numeric columns are available after filtering.
|
3891
|
+
|
3892
|
+
"""
|
3893
|
+
# Apply filtering based on the `filter_by` parameter
|
3894
|
+
if filter_by is not None:
|
3895
|
+
df, _ = filter_dataframe_features(df, channel_of_interest=filter_by, exclude=exclude)
|
3896
|
+
|
3897
|
+
# Select numerical features
|
3898
|
+
numeric_data = df.select_dtypes(include=['number'])
|
3899
|
+
|
3900
|
+
# Check if numeric_data is empty
|
3901
|
+
if numeric_data.empty:
|
3902
|
+
raise ValueError("No numeric columns available after filtering. Please check the filter_by and exclude parameters.")
|
3903
|
+
|
3904
|
+
# Remove highly correlated columns
|
3905
|
+
if not remove_highly_correlated is False:
|
3906
|
+
if isinstance(remove_highly_correlated, float):
|
3907
|
+
numeric_data = remove_highly_correlated_columns(numeric_data, remove_highly_correlated)
|
3908
|
+
else:
|
3909
|
+
numeric_data = remove_highly_correlated_columns(numeric_data, 0.95)
|
3910
|
+
|
3911
|
+
# Apply log transformation
|
3912
|
+
if log_data:
|
3913
|
+
numeric_data = np.log(numeric_data + 1e-6)
|
3914
|
+
|
3915
|
+
# Fill NaN values with the column mean
|
3916
|
+
numeric_data = numeric_data.fillna(numeric_data.mean())
|
3917
|
+
|
3918
|
+
# Scale the numeric data
|
3919
|
+
scaler = StandardScaler(copy=True, with_mean=True, with_std=True)
|
3920
|
+
numeric_data = scaler.fit_transform(numeric_data)
|
3921
|
+
|
3922
|
+
return numeric_data
|
3923
|
+
|
3924
|
+
def filter_dataframe_features(df, channel_of_interest, exclude=None):
|
3925
|
+
"""
|
3926
|
+
Filter the dataframe `df` based on the specified `channel_of_interest` and `exclude` parameters.
|
3927
|
+
|
3928
|
+
Parameters:
|
3929
|
+
- df (pandas.DataFrame): The input dataframe to be filtered.
|
3930
|
+
- channel_of_interest (str, int, list, None): The channel(s) of interest to filter the dataframe.
|
3931
|
+
If None, no filtering is applied. If 'morphology', only morphology features are included.
|
3932
|
+
If an integer, only the specified channel is included. If a list, only the specified channels are included.
|
3933
|
+
If a string, only the specified channel is included.
|
3934
|
+
- exclude (str, list, None): The feature(s) to exclude from the filtered dataframe.
|
3935
|
+
If None, no features are excluded. If a string, the specified feature is excluded.
|
3936
|
+
If a list, the specified features are excluded.
|
3937
|
+
|
3938
|
+
Returns:
|
3939
|
+
- filtered_df (pandas.DataFrame): The filtered dataframe based on the specified parameters.
|
3940
|
+
- features (list): The list of selected features after filtering.
|
3941
|
+
|
3942
|
+
"""
|
3943
|
+
if channel_of_interest is None:
|
3944
|
+
feature_string = None
|
3945
|
+
elif channel_of_interest == 'morphology':
|
3946
|
+
feature_string = 'morphology'
|
3947
|
+
elif isinstance(channel_of_interest, list):
|
3948
|
+
feature_string = []
|
3949
|
+
for i in channel_of_interest:
|
3950
|
+
feature_string_tmp = f'channel_{i}'
|
3951
|
+
feature_string.append(feature_string_tmp)
|
3952
|
+
elif isinstance(channel_of_interest, int):
|
3953
|
+
feature_string = f'channel_{channel_of_interest}'
|
3954
|
+
elif isinstance(channel_of_interest, str):
|
3955
|
+
feature_string = channel_of_interest
|
3956
|
+
|
3957
|
+
# Remove columns with a single value
|
3958
|
+
df = df.loc[:, df.nunique() > 1]
|
3959
|
+
|
3960
|
+
# Select numerical features
|
3961
|
+
features = df.select_dtypes(include=[np.number]).columns.tolist()
|
3962
|
+
|
3963
|
+
if feature_string is not None:
|
3964
|
+
feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
|
3965
|
+
|
3966
|
+
# Remove feature_string from the list if it exists
|
3967
|
+
if isinstance(feature_string, str):
|
3968
|
+
if feature_string in feature_list:
|
3969
|
+
feature_list.remove(feature_string)
|
3970
|
+
elif isinstance(feature_string, list):
|
3971
|
+
feature_list = [feature for feature in feature_list if feature not in feature_string]
|
3972
|
+
|
3973
|
+
if feature_string != 'morphology':
|
3974
|
+
features = [feature for feature in features if feature_string in feature]
|
3975
|
+
|
3976
|
+
# Iterate through the list and remove columns from df
|
3977
|
+
for feature_ in feature_list:
|
3978
|
+
features = [feature for feature in features if feature_ not in feature]
|
3979
|
+
print(f'After removing {feature_} features: {len(features)}')
|
3980
|
+
|
3981
|
+
if isinstance(exclude, list):
|
3982
|
+
features = [feature for feature in features if feature not in exclude]
|
3983
|
+
elif isinstance(exclude, str):
|
3984
|
+
features.remove(exclude)
|
3985
|
+
|
3986
|
+
filtered_df = df[features]
|
3987
|
+
|
3988
|
+
return filtered_df, features
|
3989
|
+
|
3990
|
+
# Create a function to check if images overlap
|
3991
|
+
def check_overlap(current_position, other_positions, threshold):
|
3992
|
+
for other_position in other_positions:
|
3993
|
+
distance = np.linalg.norm(np.array(current_position) - np.array(other_position))
|
3994
|
+
if distance < threshold:
|
3995
|
+
return True
|
3996
|
+
return False
|
3997
|
+
|
3998
|
+
# Define a function to try random positions around a given point
|
3999
|
+
def find_non_overlapping_position(x, y, image_positions, threshold, max_attempts=100):
|
4000
|
+
offset_range = 10 # Adjust the range for random offsets
|
4001
|
+
attempts = 0
|
4002
|
+
while attempts < max_attempts:
|
4003
|
+
random_offset_x = random.uniform(-offset_range, offset_range)
|
4004
|
+
random_offset_y = random.uniform(-offset_range, offset_range)
|
4005
|
+
new_x = x + random_offset_x
|
4006
|
+
new_y = y + random_offset_y
|
4007
|
+
if not check_overlap((new_x, new_y), image_positions, threshold):
|
4008
|
+
return new_x, new_y
|
4009
|
+
attempts += 1
|
4010
|
+
return x, y # Return the original position if no suitable position found
|
4011
|
+
|
4012
|
+
def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method, verbose, reduction_param=None, embedding=None, n_jobs=-1):
|
4013
|
+
"""
|
4014
|
+
Perform dimensionality reduction and clustering on the given data.
|
4015
|
+
|
4016
|
+
Parameters:
|
4017
|
+
numeric_data (np.array): Numeric data to process.
|
4018
|
+
n_neighbors (int): Number of neighbors for UMAP or perplexity for tSNE.
|
4019
|
+
min_dist (float): Minimum distance for UMAP.
|
4020
|
+
metric (str): Metric for UMAP, tSNE, and DBSCAN.
|
4021
|
+
eps (float): Epsilon for DBSCAN clustering.
|
4022
|
+
min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
|
4023
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
4024
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
4025
|
+
verbose (bool): Whether to print verbose output.
|
4026
|
+
reduction_param (dict): Additional parameters for the reduction method.
|
4027
|
+
embedding (np.array): Precomputed embedding (optional).
|
4028
|
+
n_jobs (int): Number of parallel jobs to run.
|
4029
|
+
|
4030
|
+
Returns:
|
4031
|
+
embedding (np.array): Embedding of the data.
|
4032
|
+
labels (np.array): Cluster labels.
|
4033
|
+
"""
|
4034
|
+
|
4035
|
+
if isinstance(n_neighbors, float):
|
4036
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
4037
|
+
if n_neighbors <= 1:
|
4038
|
+
n_neighbors = 2
|
4039
|
+
print(f'n_neighbors cannota be less than 2. Setting n_neighbors to {n_neighbors}')
|
4040
|
+
|
4041
|
+
reduction_param = reduction_param or {}
|
4042
|
+
reduction_param = {k: v for k, v in reduction_param.items() if k not in ['perplexity', 'n_neighbors', 'min_dist', 'metric', 'method']}
|
4043
|
+
|
4044
|
+
if reduction_method == 'umap':
|
4045
|
+
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, n_jobs=n_jobs, **reduction_param)
|
4046
|
+
elif reduction_method == 'tsne':
|
4047
|
+
reducer = TSNE(n_components=2, perplexity=n_neighbors, metric=metric, n_jobs=n_jobs, **reduction_param)
|
4048
|
+
else:
|
4049
|
+
raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
|
4050
|
+
|
4051
|
+
if embedding is None:
|
4052
|
+
embedding = reducer.fit_transform(numeric_data)
|
4053
|
+
|
4054
|
+
if clustering == 'dbscan':
|
4055
|
+
clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
|
4056
|
+
elif clustering == 'kmeans':
|
4057
|
+
from sklearn.cluster import KMeans
|
4058
|
+
clustering_model = KMeans(n_clusters=min_samples, random_state=42)
|
4059
|
+
else:
|
4060
|
+
raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
|
4061
|
+
clustering_model.fit(embedding)
|
4062
|
+
labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
|
4063
|
+
if verbose:
|
4064
|
+
print(f'Embedding shape: {embedding.shape}')
|
4065
|
+
return embedding, labels
|
4066
|
+
|
4067
|
+
|
4068
|
+
|