spacr 0.0.36__py3-none-any.whl → 0.0.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/utils.py CHANGED
@@ -1,12 +1,18 @@
1
- import sys, os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
1
+ import sys, os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob
2
2
 
3
3
  import numpy as np
4
4
  from cellpose import models as cp_models
5
5
  from cellpose import denoise
6
+
6
7
  from skimage import morphology
7
8
  from skimage.measure import label, regionprops_table, regionprops
8
9
  import skimage.measure as measure
9
- from collections import defaultdict
10
+ from skimage.transform import resize as resizescikit
11
+ from skimage.morphology import dilation, square
12
+ from skimage.measure import find_contours
13
+ from skimage.segmentation import clear_border
14
+
15
+ from collections import defaultdict, OrderedDict
10
16
  from PIL import Image
11
17
  import pandas as pd
12
18
  from statsmodels.stats.outliers_influence import variance_inflation_factor
@@ -15,48 +21,225 @@ import statsmodels.formula.api as smf
15
21
  import statsmodels.api as sm
16
22
  from statsmodels.stats.multitest import multipletests
17
23
  from itertools import combinations
18
- from collections import OrderedDict
19
24
  from functools import reduce
20
- from IPython.display import display, clear_output
25
+ from IPython.display import display
26
+
21
27
  from multiprocessing import Pool, cpu_count
22
- from skimage.transform import resize as resizescikit
23
- from skimage.morphology import dilation, square
24
- from skimage.measure import find_contours
28
+ from concurrent.futures import ThreadPoolExecutor
29
+
25
30
  import torch.nn as nn
26
31
  import torch.nn.functional as F
27
- #from torchsummary import summary
28
32
  from torch.utils.checkpoint import checkpoint
29
33
  from torch.utils.data import Subset
30
34
  from torch.autograd import grad
31
- from torchvision import models
32
- from skimage.segmentation import clear_border
35
+
33
36
  import seaborn as sns
34
37
  import matplotlib.pyplot as plt
38
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
39
+
35
40
  import scipy.ndimage as ndi
36
41
  from scipy.spatial import distance
37
42
  from scipy.stats import fisher_exact
38
- from scipy.ndimage import binary_erosion, binary_dilation
43
+ from scipy.ndimage.filters import gaussian_filter
44
+ from scipy.spatial import ConvexHull
45
+ from scipy.interpolate import splprep, splev
46
+
47
+ from sklearn.preprocessing import StandardScaler
39
48
  from skimage.exposure import rescale_intensity
40
49
  from sklearn.metrics import auc, precision_recall_curve
41
50
  from sklearn.model_selection import train_test_split
42
51
  from sklearn.linear_model import Lasso, Ridge
43
52
  from sklearn.preprocessing import OneHotEncoder
44
53
  from sklearn.cluster import KMeans
54
+ from sklearn.preprocessing import StandardScaler
55
+ from sklearn.cluster import DBSCAN
56
+ from sklearn.cluster import KMeans
57
+ from sklearn.manifold import TSNE
58
+
59
+ import umap.umap_ as umap
60
+
61
+ from torchvision import models
45
62
  from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
63
+ import torchvision.transforms as transforms
46
64
 
47
65
  from .logger import log_function_call
48
66
 
49
- def _gen_rgb_image(image, cahnnels):
50
- rgb_image = np.take(image, cahnnels, axis=-1)
51
- rgb_image = rgb_image.astype(float)
52
- rgb_image -= rgb_image.min()
53
- rgb_image /= rgb_image.max()
67
+ def check_mask_folder(src,mask_fldr):
68
+
69
+ mask_folder = os.path.join(src,'norm_channel_stack',mask_fldr)
70
+ stack_folder = os.path.join(src,'stack')
71
+
72
+ if not os.path.exists(mask_folder):
73
+ return True
74
+
75
+ mask_count = sum(1 for file in os.listdir(mask_folder) if file.endswith('.npy'))
76
+ stack_count = sum(1 for file in os.listdir(stack_folder) if file.endswith('.npy'))
77
+
78
+ if mask_count == stack_count:
79
+ print(f'All masks have been generated for {mask_fldr}')
80
+ return False
81
+ else:
82
+ return True
83
+
84
+ def set_default_plot_merge_settings():
85
+ settings = {}
86
+ settings.setdefault('include_noninfected', True)
87
+ settings.setdefault('include_multiinfected', True)
88
+ settings.setdefault('include_multinucleated', True)
89
+ settings.setdefault('remove_background', False)
90
+ settings.setdefault('filter_min_max', None)
91
+ settings.setdefault('channel_dims', [0,1,2,3])
92
+ settings.setdefault('backgrounds', [100,100,100,100])
93
+ settings.setdefault('cell_mask_dim', 4)
94
+ settings.setdefault('nucleus_mask_dim', 5)
95
+ settings.setdefault('pathogen_mask_dim', 6)
96
+ settings.setdefault('outline_thickness', 3)
97
+ settings.setdefault('outline_color', 'gbr')
98
+ settings.setdefault('overlay_chans', [1,2,3])
99
+ settings.setdefault('overlay', True)
100
+ settings.setdefault('normalization_percentiles', [2,98])
101
+ settings.setdefault('normalize', True)
102
+ settings.setdefault('print_object_number', True)
103
+ settings.setdefault('nr', 1)
104
+ settings.setdefault('figuresize', 50)
105
+ settings.setdefault('cmap', 'inferno')
106
+ settings.setdefault('verbose', True)
107
+
108
+ return settings
109
+
110
+ def set_default_settings_preprocess_generate_masks(src, settings={}):
111
+ # Main settings
112
+ settings['src'] = src
113
+ settings.setdefault('preprocess', True)
114
+ settings.setdefault('masks', True)
115
+ settings.setdefault('save', True)
116
+ settings.setdefault('batch_size', 50)
117
+ settings.setdefault('test_mode', False)
118
+ settings.setdefault('test_images', 10)
119
+ settings.setdefault('magnification', 20)
120
+ settings.setdefault('custom_regex', None)
121
+ settings.setdefault('metadata_type', 'cellvoyager')
122
+ settings.setdefault('workers', os.cpu_count()-4)
123
+ settings.setdefault('randomize', True)
124
+ settings.setdefault('verbose', True)
125
+
126
+ settings.setdefault('remove_background_cell', False)
127
+ settings.setdefault('remove_background_nucleus', False)
128
+ settings.setdefault('remove_background_pathogen', False)
129
+
130
+ # Channel settings
131
+ settings.setdefault('cell_channel', None)
132
+ settings.setdefault('nucleus_channel', None)
133
+ settings.setdefault('pathogen_channel', None)
134
+ settings.setdefault('channels', [0,1,2,3])
135
+ settings.setdefault('pathogen_background', 100)
136
+ settings.setdefault('pathogen_Signal_to_noise', 10)
137
+ settings.setdefault('pathogen_CP_prob', 0)
138
+ settings.setdefault('cell_background', 100)
139
+ settings.setdefault('cell_Signal_to_noise', 10)
140
+ settings.setdefault('cell_CP_prob', 0)
141
+ settings.setdefault('nucleus_background', 100)
142
+ settings.setdefault('nucleus_Signal_to_noise', 10)
143
+ settings.setdefault('nucleus_CP_prob', 0)
144
+
145
+ settings.setdefault('nucleus_FT', 100)
146
+ settings.setdefault('cell_FT', 100)
147
+ settings.setdefault('pathogen_FT', 100)
148
+
149
+ # Plot settings
150
+ settings.setdefault('plot', False)
151
+ settings.setdefault('figuresize', 50)
152
+ settings.setdefault('cmap', 'inferno')
153
+ settings.setdefault('normalize', True)
154
+ settings.setdefault('normalize_plots', True)
155
+ settings.setdefault('examples_to_plot', 1)
156
+
157
+ # Analasys settings
158
+ settings.setdefault('pathogen_model', None)
159
+ settings.setdefault('merge_pathogens', False)
160
+ settings.setdefault('filter', False)
161
+ settings.setdefault('lower_percentile', 2)
162
+
163
+ # Timelapse settings
164
+ settings.setdefault('timelapse', False)
165
+ settings.setdefault('fps', 2)
166
+ settings.setdefault('timelapse_displacement', None)
167
+ settings.setdefault('timelapse_memory', 3)
168
+ settings.setdefault('timelapse_frame_limits', None)
169
+ settings.setdefault('timelapse_remove_transient', False)
170
+ settings.setdefault('timelapse_mode', 'trackpy')
171
+ settings.setdefault('timelapse_objects', 'cells')
172
+
173
+ # Misc settings
174
+ settings.setdefault('all_to_mip', False)
175
+ settings.setdefault('pick_slice', False)
176
+ settings.setdefault('skip_mode', '01')
177
+ settings.setdefault('upscale', False)
178
+ settings.setdefault('upscale_factor', 2.0)
179
+ settings.setdefault('adjust_cells', False)
180
+
181
+ return settings
182
+
183
+ def set_default_settings_preprocess_img_data(settings):
184
+
185
+ metadata_type = settings.setdefault('metadata_type', 'cellvoyager')
186
+ custom_regex = settings.setdefault('custom_regex', None)
187
+ nr = settings.setdefault('nr', 1)
188
+ plot = settings.setdefault('plot', True)
189
+ batch_size = settings.setdefault('batch_size', 50)
190
+ timelapse = settings.setdefault('timelapse', False)
191
+ lower_percentile = settings.setdefault('lower_percentile', 2)
192
+ randomize = settings.setdefault('randomize', True)
193
+ all_to_mip = settings.setdefault('all_to_mip', False)
194
+ pick_slice = settings.setdefault('pick_slice', False)
195
+ skip_mode = settings.setdefault('skip_mode', False)
196
+
197
+ cmap = settings.setdefault('cmap', 'inferno')
198
+ figuresize = settings.setdefault('figuresize', 50)
199
+ normalize = settings.setdefault('normalize', True)
200
+ save_dtype = settings.setdefault('save_dtype', 'uint16')
201
+
202
+ test_mode = settings.setdefault('test_mode', False)
203
+ test_images = settings.setdefault('test_images', 10)
204
+ random_test = settings.setdefault('random_test', True)
205
+
206
+ return settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test
207
+
208
+ def smooth_hull_lines(cluster_data):
209
+ hull = ConvexHull(cluster_data)
210
+
211
+ # Extract vertices of the hull
212
+ vertices = hull.points[hull.vertices]
213
+
214
+ # Close the loop
215
+ vertices = np.vstack([vertices, vertices[0, :]])
216
+
217
+ # Parameterize the vertices
218
+ tck, u = splprep(vertices.T, u=None, s=0.0)
219
+
220
+ # Evaluate spline at new parameter values
221
+ new_points = splev(np.linspace(0, 1, 100), tck)
222
+
223
+ return new_points[0], new_points[1]
224
+
225
+ def _gen_rgb_image(image, channels):
226
+ """
227
+ Generate an RGB image from the specified channels of the input image.
228
+
229
+ Args:
230
+ image (ndarray): The input image.
231
+ channels (list): List of channel indices to use for RGB.
232
+
233
+ Returns:
234
+ rgb_image (ndarray): The generated RGB image.
235
+ """
236
+ rgb_image = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.float32)
237
+ for i, chan in enumerate(channels):
238
+ if chan < image.shape[2]:
239
+ rgb_image[:, :, i] = image[:, :, chan]
54
240
  return rgb_image
55
241
 
56
242
  def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
57
- from concurrent.futures import ThreadPoolExecutor
58
- import cv2
59
-
60
243
  outlines = []
61
244
  overlayed_image = rgb_image.copy()
62
245
 
@@ -66,11 +249,12 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
66
249
 
67
250
  # Find and draw contours
68
251
  for j in np.unique(mask):
69
- #for j in np.unique(mask)[1:]:
252
+ if j == 0:
253
+ continue # Skip background
70
254
  contours = find_contours(mask == j, 0.5)
71
255
  # Convert contours for OpenCV format and draw directly to optimize
72
256
  cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
73
- cv2.drawContours(outline, cv_contours, -1, color=int(j), thickness=outline_thickness)
257
+ cv2.drawContours(outline, cv_contours, -1, color=255, thickness=outline_thickness)
74
258
 
75
259
  return dilation(outline, square(outline_thickness))
76
260
 
@@ -78,19 +262,15 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
78
262
  with ThreadPoolExecutor() as executor:
79
263
  outlines = list(executor.map(process_dim, mask_dims))
80
264
 
81
- # Overlay outlines onto the RGB image in a batch/vectorized manner if possible
265
+ # Overlay outlines onto the RGB image
82
266
  for i, outline in enumerate(outlines):
83
- # This part may need to be adapted to your specific use case and available functions
84
- # The goal is to overlay each outline with its respective color more efficiently
85
- color = outline_colors[i % len(outline_colors)]
86
- for j in np.unique(outline)[1:]:
267
+ color = np.array(outline_colors[i % len(outline_colors)])
268
+ for j in np.unique(outline):
269
+ if j == 0:
270
+ continue # Skip background
87
271
  mask = outline == j
88
272
  overlayed_image[mask] = color # Direct assignment with broadcasting
89
273
 
90
- # Remove mask_dims from image
91
- channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
92
- image = np.take(image, channels_to_keep, axis=-1)
93
-
94
274
  return overlayed_image, outlines, image
95
275
 
96
276
  def _convert_cq1_well_id(well_id):
@@ -350,43 +530,82 @@ def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pa
350
530
  df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
351
531
  df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
352
532
  return df
353
-
354
- def normalize_to_dtype(array, q1=2,q2=98, percentiles=None):
533
+
534
+ def normalize_to_dtype(array, p1=2, p2=98):
355
535
  """
356
- Normalize the input array to a specified data type.
536
+ Normalize each image in the stack to its own percentiles.
357
537
 
358
538
  Parameters:
359
539
  - array: numpy array
360
- The input array to be normalized.
361
- - q1: int, optional
540
+ The input stack to be normalized.
541
+ - p1: int, optional
362
542
  The lower percentile value for normalization. Default is 2.
363
- - q2: int, optional
543
+ - p2: int, optional
364
544
  The upper percentile value for normalization. Default is 98.
365
- - percentiles: list of tuples, optional
366
- A list of tuples containing the percentile values for each image in the array.
367
- If provided, the percentiles for each image will be used instead of q1 and q2.
368
545
 
369
546
  Returns:
370
547
  - new_stack: numpy array
371
- The normalized array with the same shape as the input array.
548
+ The normalized stack with the same shape as the input stack.
372
549
  """
373
550
  nimg = array.shape[2]
374
551
  new_stack = np.empty_like(array)
375
- for i,v in enumerate(range(nimg)):
376
- img = np.squeeze(array[:, :, v])
552
+
553
+ for i in range(nimg):
554
+ img = array[:, :, i]
377
555
  non_zero_img = img[img > 0]
378
- if non_zero_img.size > 0: # check if there are non-zero values
379
- img_min = np.percentile(non_zero_img, q1) # change percentile from 0.02 to 2
380
- img_max = np.percentile(non_zero_img, q2) # change percentile from 0.98 to 98
381
- img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
382
- else: # if there are no non-zero values, just use the image as it is
383
- if percentiles==None:
384
- img_min, img_max = img.min(), img.max()
385
- else:
386
- img_min, img_max = percentiles[i]
387
- img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
388
- img = np.expand_dims(img, axis=2)
389
- new_stack[:, :, v] = img[:, :, 0]
556
+
557
+ if non_zero_img.size > 0:
558
+ img_min = np.percentile(non_zero_img, p1)
559
+ img_max = np.percentile(non_zero_img, p2)
560
+ else:
561
+ img_min = img.min()
562
+ img_max = img.max()
563
+
564
+ # Determine output range based on dtype
565
+ if np.issubdtype(array.dtype, np.integer):
566
+ out_range = (0, np.iinfo(array.dtype).max)
567
+ else:
568
+ out_range = (0.0, 1.0)
569
+
570
+ img = rescale_intensity(img, in_range=(img_min, img_max), out_range=out_range).astype(array.dtype)
571
+ new_stack[:, :, i] = img
572
+
573
+ return new_stack
574
+
575
+ def normalize_to_dtype(array, p1=2, p2=98):
576
+ """
577
+ Normalize each image in the stack to its own percentiles.
578
+
579
+ Parameters:
580
+ - array: numpy array
581
+ The input stack to be normalized.
582
+ - p1: int, optional
583
+ The lower percentile value for normalization. Default is 2.
584
+ - p2: int, optional
585
+ The upper percentile value for normalization. Default is 98.
586
+
587
+ Returns:
588
+ - new_stack: numpy array
589
+ The normalized stack with the same shape as the input stack.
590
+ """
591
+ nimg = array.shape[2]
592
+ new_stack = np.empty_like(array, dtype=np.float32)
593
+
594
+ for i in range(nimg):
595
+ img = array[:, :, i]
596
+ non_zero_img = img[img > 0]
597
+
598
+ if non_zero_img.size > 0:
599
+ img_min = np.percentile(non_zero_img, p1)
600
+ img_max = np.percentile(non_zero_img, p2)
601
+ else:
602
+ img_min = img.min()
603
+ img_max = img.max()
604
+
605
+ # Normalize to the range (0, 1) for visualization
606
+ img = rescale_intensity(img, in_range=(img_min, img_max), out_range=(0.0, 1.0))
607
+ new_stack[:, :, i] = img
608
+
390
609
  return new_stack
391
610
 
392
611
  def _list_endpoint_subdirectories(base_dir):
@@ -744,7 +963,7 @@ def _get_diam(mag, obj):
744
963
  elif obj == 'nucleus':
745
964
  diamiter = 60
746
965
  elif obj == 'pathogen':
747
- diamiter = 30
966
+ diamiter = 20
748
967
  else:
749
968
  raise ValueError("Invalid magnification: Use 20, 40 or 60")
750
969
 
@@ -764,7 +983,7 @@ def _get_diam(mag, obj):
764
983
  if obj == 'nucleus':
765
984
  diamiter = 90
766
985
  if obj == 'pathogen':
767
- diamiter = 75
986
+ diamiter = 60
768
987
  else:
769
988
  raise ValueError("Invalid magnification: Use 20, 40 or 60")
770
989
  else:
@@ -800,8 +1019,9 @@ def _get_object_settings(object_type, settings):
800
1019
 
801
1020
  elif object_type == 'pathogen':
802
1021
  object_settings['model_name'] = 'cyto'
803
- object_settings['filter_size'] = True
1022
+ object_settings['filter_size'] = False
804
1023
  object_settings['filter_intensity'] = False
1024
+ object_settings['resample'] = False
805
1025
  object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
806
1026
  object_settings['merge'] = settings['merge_pathogens']
807
1027
 
@@ -2751,15 +2971,37 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
2751
2971
  print(f'After {object_type} maximum mean intensity filter: {len(df)}')
2752
2972
  return df
2753
2973
 
2754
- def _run_test_mode(src, regex, timelapse=False):
2974
+ def _get_regex(metadata_type, img_format, custom_regex=None):
2975
+
2976
+ if img_format == None:
2977
+ img_format == '.tif'
2978
+ if metadata_type == 'cellvoyager':
2979
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2980
+ elif metadata_type == 'cq1':
2981
+ regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2982
+ elif metadata_type == 'nikon':
2983
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2984
+ elif metadata_type == 'zeis':
2985
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2986
+ elif metadata_type == 'leica':
2987
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2988
+ elif metadata_type == 'custom':
2989
+ regex = f'({custom_regex}){img_format}'
2990
+
2991
+ print(f'regex mode:{metadata_type} regex:{regex}')
2992
+ return regex
2993
+
2994
+ def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True):
2995
+
2755
2996
  if timelapse:
2756
2997
  test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
2757
- else:
2758
- test_images = 10 # Use 10 sets for non-timelapse scenarios
2759
-
2998
+
2760
2999
  test_folder_path = os.path.join(src, 'test')
2761
3000
  os.makedirs(test_folder_path, exist_ok=True)
2762
3001
  regular_expression = re.compile(regex)
3002
+
3003
+ if os.path.exists(os.path.join(src, 'orig')):
3004
+ src = os.path.join(src, 'orig')
2763
3005
 
2764
3006
  all_filenames = [filename for filename in os.listdir(src) if regular_expression.match(filename)]
2765
3007
  print(f'Found {len(all_filenames)} files')
@@ -2771,25 +3013,20 @@ def _run_test_mode(src, regex, timelapse=False):
2771
3013
  plate = match.group('plateID') if 'plateID' in match.groupdict() else os.path.basename(src)
2772
3014
  well = match.group('wellID')
2773
3015
  field = match.group('fieldID')
2774
- # For timelapse experiments, group images by plate, well, and field only
2775
- if timelapse:
2776
- set_identifier = (plate, well, field)
2777
- else:
2778
- # For non-timelapse, you might want to distinguish sets more granularly
2779
- # Here, assuming you're grouping by plate, well, and field for simplicity
2780
- set_identifier = (plate, well, field)
3016
+ set_identifier = (plate, well, field)
2781
3017
  images_by_set[set_identifier].append(filename)
2782
3018
 
2783
3019
  # Prepare for random selection
2784
3020
  set_identifiers = list(images_by_set.keys())
2785
- random.seed(42)
3021
+ if random_test:
3022
+ random.seed(42)
2786
3023
  random.shuffle(set_identifiers) # Randomize the order
2787
3024
 
2788
3025
  # Select a subset based on the test_images count
2789
3026
  selected_sets = set_identifiers[:test_images]
2790
3027
 
2791
3028
  # Print information about the number of sets used
2792
- print(f'Using {test_images} random image set(s) for test model')
3029
+ print(f'Using {len(selected_sets)} random image set(s) for test model')
2793
3030
 
2794
3031
  # Copy files for selected sets to the test folder
2795
3032
  for set_identifier in selected_sets:
@@ -2798,24 +3035,1034 @@ def _run_test_mode(src, regex, timelapse=False):
2798
3035
 
2799
3036
  return test_folder_path
2800
3037
 
2801
- def _choose_model(model_name, device, object_type='cell', restore_type=None):
3038
+ def _choose_model(model_name, device, object_type='cell', restore_type=None, object_settings={}):
3039
+
3040
+ if object_type == 'pathogen':
3041
+ if model_name == 'toxo_pv_lumen':
3042
+ diameter = object_settings['diameter']
3043
+ current_dir = os.path.dirname(__file__)
3044
+ model_path = os.path.join(current_dir, 'models', 'cp', 'toxo_pv_lumen.CP_model')
3045
+ print(model_path)
3046
+ model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=model_path, diam_mean=diameter, device=device)
3047
+ #model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type='cyto', device=device)
3048
+ print(f'Using Toxoplasma PV lumen model to generate pathogen masks')
3049
+ return model
3050
+
2802
3051
  restore_list = ['denoise', 'deblur', 'upsample', None]
2803
3052
  if restore_type not in restore_list:
2804
3053
  print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
2805
3054
  restore_type = None
2806
3055
 
2807
3056
  if restore_type == None:
2808
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
3057
+ if model_name in ['cyto', 'cyto2', 'cyto3', 'nuclei']:
3058
+ model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type=model_name, device=device)
3059
+
2809
3060
  else:
2810
3061
  if object_type == 'nucleus':
2811
3062
  restore = f'{type}_nuclei'
2812
- model = denoise.CellposeDenoiseModel(gpu=True, model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
3063
+ model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
2813
3064
  else:
2814
3065
  restore = f'{type}_cyto3'
2815
3066
  if model_name =='cyto2':
2816
3067
  chan2_restore = True
2817
3068
  if model_name =='cyto':
2818
3069
  chan2_restore = False
2819
- model = denoise.CellposeDenoiseModel(gpu=True, model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
3070
+ model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
3071
+
3072
+ return model
3073
+
3074
+ class SelectChannels:
3075
+ def __init__(self, channels):
3076
+ self.channels = channels
3077
+
3078
+ def __call__(self, img):
3079
+ img = img.clone()
3080
+ if 1 not in self.channels:
3081
+ img[0, :, :] = 0 # Zero out the red channel
3082
+ if 2 not in self.channels:
3083
+ img[1, :, :] = 0 # Zero out the green channel
3084
+ if 3 not in self.channels:
3085
+ img[2, :, :] = 0 # Zero out the blue channel
3086
+ return img
3087
+
3088
+ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
3089
+
3090
+ if normalize:
3091
+ transform = transforms.Compose([
3092
+ transforms.ToTensor(),
3093
+ transforms.CenterCrop(size=(image_size, image_size)),
3094
+ SelectChannels(channels),
3095
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
3096
+ else:
3097
+ transform = transforms.Compose([
3098
+ transforms.ToTensor(),
3099
+ transforms.CenterCrop(size=(image_size, image_size)),
3100
+ SelectChannels(channels)])
3101
+
3102
+ image = Image.open(image_path).convert('RGB')
3103
+ input_tensor = transform(image).unsqueeze(0)
3104
+ return image, input_tensor
3105
+
3106
+
3107
+ class SaliencyMapGenerator:
3108
+ def __init__(self, model):
3109
+ self.model = model
3110
+
3111
+ def compute_saliency_maps(self, X, y):
3112
+ self.model.eval()
3113
+ X.requires_grad_()
3114
+
3115
+ # Forward pass
3116
+ scores = self.model(X).squeeze()
3117
+
3118
+ # For binary classification, target scores can be the single output
3119
+ target_scores = scores * (2 * y - 1)
3120
+
3121
+ self.model.zero_grad()
3122
+ target_scores.backward(torch.ones_like(target_scores))
3123
+
3124
+ saliency = X.grad.abs()
3125
+ return saliency
3126
+
3127
+ def plot_saliency_maps(self, X, y, saliency, class_names):
3128
+ N = X.shape[0]
3129
+ for i in range(N):
3130
+ plt.subplot(2, N, i + 1)
3131
+ plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
3132
+ plt.axis('off')
3133
+ plt.title(class_names[y[i]])
3134
+ plt.subplot(2, N, N + i + 1)
3135
+ plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
3136
+ plt.axis('off')
3137
+ plt.gcf().set_size_inches(12, 5)
3138
+ plt.show()
3139
+
3140
+ def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
3141
+ preprocess = transforms.Compose([
3142
+ transforms.Resize((image_size, image_size)),
3143
+ transforms.ToTensor(),
3144
+ ])
3145
+
3146
+ image = Image.open(image_path).convert('RGB')
3147
+ input_tensor = preprocess(image)
3148
+ if normalize:
3149
+ input_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_tensor)
3150
+ input_tensor = input_tensor.unsqueeze(0)
3151
+
3152
+ return image, input_tensor
3153
+
3154
+ def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names = ['nc', 'pc']):
3155
+
3156
+ def jitter(img, ox, oy):
3157
+ # Randomly jitter the image
3158
+ return torch.roll(torch.roll(img, ox, dims=2), oy, dims=3)
3159
+
3160
+ def blur_image(img, sigma=1):
3161
+ # Apply Gaussian blur to the image
3162
+ img_np = img.cpu().numpy()
3163
+ for i in range(img_np.shape[1]):
3164
+ img_np[:, i] = gaussian_filter(img_np[:, i], sigma=sigma)
3165
+ img.copy_(torch.tensor(img_np).to(img.device))
3166
+
3167
+ def deprocess(img_tensor):
3168
+ # Convert the tensor image to a numpy array for visualization
3169
+ img_tensor = img_tensor.clone()
3170
+ for c in range(3):
3171
+ img_tensor[:, c] = img_tensor[:, c] * SQUEEZENET_STD[c] + SQUEEZENET_MEAN[c]
3172
+ img_tensor = img_tensor.clamp(0, 1)
3173
+ return img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
3174
+
3175
+ # Assuming these are defined somewhere in your codebase
3176
+ SQUEEZENET_MEAN = [0.485, 0.456, 0.406]
3177
+ SQUEEZENET_STD = [0.229, 0.224, 0.225]
3178
+
3179
+ model = torch.load(model_path)
3180
+
3181
+ dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
3182
+ len_chans = len(channels)
3183
+ model.type(dtype)
3184
+
3185
+ # Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
3186
+ img = torch.randn(1, len_chans, img_size, img_size).mul_(1.0).type(dtype).requires_grad_()
3187
+
3188
+ for t in range(num_iterations):
3189
+ # Randomly jitter the image a bit; this gives slightly nicer results
3190
+ ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
3191
+ img.data.copy_(jitter(img.data, ox, oy))
3192
+
3193
+ # Forward pass
3194
+ score = model(img)
3195
+
3196
+ if target_y == 0:
3197
+ target_score = -score
3198
+ else:
3199
+ target_score = score
3200
+
3201
+ # Add regularization
3202
+ target_score = target_score - l2_reg * torch.norm(img)
3203
+
3204
+ # Backward pass
3205
+ target_score.backward()
3206
+
3207
+ # Gradient ascent step
3208
+ with torch.no_grad():
3209
+ img += learning_rate * img.grad / torch.norm(img.grad)
3210
+ img.grad.zero_()
3211
+
3212
+ # Undo the random jitter
3213
+ img.data.copy_(jitter(img.data, -ox, -oy))
3214
+
3215
+ # As regularizer, clamp and periodically blur the image
3216
+ for c in range(3):
3217
+ lo = float(-SQUEEZENET_MEAN[c] / SQUEEZENET_STD[c])
3218
+ hi = float((1.0 - SQUEEZENET_MEAN[c]) / SQUEEZENET_STD[c])
3219
+ img.data[:, c].clamp_(min=lo, max=hi)
3220
+ if t % blur_every == 0:
3221
+ blur_image(img.data, sigma=0.5)
3222
+
3223
+ # Periodically show the image
3224
+ if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
3225
+ plt.imshow(deprocess(img.data.clone().cpu()))
3226
+ class_name = class_names[target_y]
3227
+ plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
3228
+ plt.gcf().set_size_inches(4, 4)
3229
+ plt.axis('off')
3230
+ plt.show()
3231
+
3232
+ return deprocess(img.data.cpu())
3233
+
3234
+ def get_submodules(model, prefix=''):
3235
+ submodules = []
3236
+ for name, module in model.named_children():
3237
+ full_name = prefix + ('.' if prefix else '') + name
3238
+ submodules.append(full_name)
3239
+ submodules.extend(get_submodules(module, full_name))
3240
+ return submodules
3241
+
3242
+ class GradCAM:
3243
+ def __init__(self, model, target_layers=None, use_cuda=True):
3244
+ self.model = model
3245
+ self.model.eval()
3246
+ self.target_layers = target_layers
3247
+ self.cuda = use_cuda
3248
+ if self.cuda:
3249
+ self.model = model.cuda()
3250
+
3251
+ def forward(self, input):
3252
+ return self.model(input)
3253
+
3254
+ def __call__(self, x, index=None):
3255
+ if self.cuda:
3256
+ x = x.cuda()
3257
+
3258
+ features = []
3259
+ def hook(module, input, output):
3260
+ features.append(output)
3261
+
3262
+ handles = []
3263
+ for name, module in self.model.named_modules():
3264
+ if name in self.target_layers:
3265
+ handles.append(module.register_forward_hook(hook))
3266
+
3267
+ output = self.forward(x)
3268
+ if index is None:
3269
+ index = np.argmax(output.data.cpu().numpy())
3270
+
3271
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
3272
+ one_hot[0][index] = 1
3273
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
3274
+ if self.cuda:
3275
+ one_hot = one_hot.cuda()
3276
+
3277
+ one_hot = torch.sum(one_hot * output)
3278
+ self.model.zero_grad()
3279
+ one_hot.backward(retain_graph=True)
3280
+
3281
+ grads_val = features[0].grad.cpu().data.numpy()
3282
+ target = features[0].cpu().data.numpy()[0, :]
3283
+
3284
+ weights = np.mean(grads_val, axis=(2, 3))[0, :]
3285
+ cam = np.zeros(target.shape[1:], dtype=np.float32)
3286
+
3287
+ for i, w in enumerate(weights):
3288
+ cam += w * target[i, :, :]
3289
+
3290
+ cam = np.maximum(cam, 0)
3291
+ cam = cv2.resize(cam, (x.size(2), x.size(3)))
3292
+ cam = cam - np.min(cam)
3293
+ cam = cam / np.max(cam)
3294
+
3295
+ for handle in handles:
3296
+ handle.remove()
3297
+
3298
+ return cam
3299
+
3300
+ def show_cam_on_image(img, mask):
3301
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
3302
+ heatmap = np.float32(heatmap) / 255
3303
+ cam = heatmap + np.float32(img)
3304
+ cam = cam / np.max(cam)
3305
+ return np.uint8(255 * cam)
3306
+
3307
+ def recommend_target_layers(model):
3308
+ target_layers = []
3309
+ for name, module in model.named_modules():
3310
+ if isinstance(module, torch.nn.Conv2d):
3311
+ target_layers.append(name)
3312
+ # Choose the last conv layer as the recommended target layer
3313
+ if target_layers:
3314
+ return [target_layers[-1]], target_layers
3315
+ else:
3316
+ raise ValueError("No convolutional layers found in the model.")
2820
3317
 
2821
- return model
3318
+ class IntegratedGradients:
3319
+ def __init__(self, model):
3320
+ self.model = model
3321
+ self.model.eval()
3322
+
3323
+ def generate_integrated_gradients(self, input_tensor, target_label_idx, baseline=None, num_steps=50):
3324
+ if baseline is None:
3325
+ baseline = torch.zeros_like(input_tensor)
3326
+
3327
+ assert baseline.shape == input_tensor.shape
3328
+
3329
+ # Scale input and compute gradients
3330
+ scaled_inputs = [(baseline + (float(i) / num_steps) * (input_tensor - baseline)).requires_grad_(True) for i in range(0, num_steps + 1)]
3331
+ grads = []
3332
+ for scaled_input in scaled_inputs:
3333
+ out = self.model(scaled_input)
3334
+ self.model.zero_grad()
3335
+ out[0, target_label_idx].backward(retain_graph=True)
3336
+ grads.append(scaled_input.grad.data.cpu().numpy())
3337
+
3338
+ avg_grads = np.mean(grads[:-1], axis=0)
3339
+ integrated_grads = (input_tensor.cpu().data.numpy() - baseline.cpu().data.numpy()) * avg_grads
3340
+ return integrated_grads
3341
+
3342
+ def get_db_paths(src):
3343
+ if isinstance(src, str):
3344
+ src = [src]
3345
+ db_paths = [os.path.join(source, 'measurements/measurements.db') for source in src]
3346
+ return db_paths
3347
+
3348
+ def get_sequencing_paths(src):
3349
+ if isinstance(src, str):
3350
+ src = [src]
3351
+ seq_paths = [os.path.join(source, 'sequencing/sequencing_data.csv') for source in src]
3352
+ return seq_paths
3353
+
3354
+ def load_image_paths(c, visualize):
3355
+ c.execute(f'SELECT * FROM png_list')
3356
+ data = c.fetchall()
3357
+ columns_info = c.execute(f'PRAGMA table_info(png_list)').fetchall()
3358
+ column_names = [col_info[1] for col_info in columns_info]
3359
+ image_paths_df = pd.DataFrame(data, columns=column_names)
3360
+ if visualize:
3361
+ object_visualize = visualize + '_png'
3362
+ image_paths_df = image_paths_df[image_paths_df['png_path'].str.contains(object_visualize)]
3363
+ image_paths_df = image_paths_df.set_index('prcfo')
3364
+ return image_paths_df
3365
+
3366
+ def merge_dataframes(df, image_paths_df, verbose):
3367
+ df.set_index('prcfo', inplace=True)
3368
+ df = image_paths_df.merge(df, left_index=True, right_index=True)
3369
+ if verbose:
3370
+ display(df)
3371
+ return df
3372
+
3373
+ def remove_highly_correlated_columns(df, threshold):
3374
+ corr_matrix = df.corr().abs()
3375
+ upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
3376
+ to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > threshold)]
3377
+ return df.drop(to_drop, axis=1)
3378
+
3379
+ def filter_columns(df, filter_by):
3380
+ if filter_by != 'morphology':
3381
+ cols_to_include = [col for col in df.columns if filter_by in str(col)]
3382
+ else:
3383
+ cols_to_include = [col for col in df.columns if 'channel' not in str(col)]
3384
+ df = df[cols_to_include]
3385
+ return df
3386
+
3387
+ def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1, mode='fit', model=False):
3388
+ """
3389
+ Perform dimensionality reduction and clustering on the given data.
3390
+
3391
+ Parameters:
3392
+ numeric_data (np.ndarray): Numeric data for embedding and clustering.
3393
+ n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
3394
+ min_dist (float): Minimum distance for UMAP.
3395
+ metric (str): Metric for UMAP and DBSCAN.
3396
+ eps (float): Epsilon for DBSCAN.
3397
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
3398
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3399
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3400
+ verbose (bool): Whether to print verbose output.
3401
+ embedding (np.ndarray, optional): Precomputed embedding. Default is None.
3402
+ return_model (bool): Whether to return the reducer model. Default is False.
3403
+
3404
+ Returns:
3405
+ tuple: embedding, labels (and optionally the reducer model)
3406
+ """
3407
+
3408
+ if verbose:
3409
+ v = 1
3410
+ else:
3411
+ v = 0
3412
+
3413
+ if isinstance(n_neighbors, float):
3414
+ n_neighbors = int(n_neighbors * len(numeric_data))
3415
+
3416
+ if n_neighbors <= 2:
3417
+ n_neighbors = 2
3418
+
3419
+ if mode == 'fit':
3420
+ if reduction_method == 'umap':
3421
+ reducer = umap.UMAP(n_neighbors=n_neighbors,
3422
+ n_components=2,
3423
+ metric=metric,
3424
+ n_epochs=None,
3425
+ learning_rate=1.0,
3426
+ init='spectral',
3427
+ min_dist=min_dist,
3428
+ spread=1.0,
3429
+ set_op_mix_ratio=1.0,
3430
+ local_connectivity=1,
3431
+ repulsion_strength=1.0,
3432
+ negative_sample_rate=5,
3433
+ transform_queue_size=4.0,
3434
+ a=None,
3435
+ b=None,
3436
+ random_state=42,
3437
+ metric_kwds=None,
3438
+ angular_rp_forest=False,
3439
+ target_n_neighbors=-1,
3440
+ target_metric='categorical',
3441
+ target_metric_kwds=None,
3442
+ target_weight=0.5,
3443
+ transform_seed=42,
3444
+ n_jobs=n_jobs,
3445
+ verbose=verbose)
3446
+
3447
+ elif reduction_method == 'tsne':
3448
+ reducer = TSNE(n_components=2,
3449
+ perplexity=n_neighbors,
3450
+ early_exaggeration=12.0,
3451
+ learning_rate=200.0,
3452
+ n_iter=1000,
3453
+ n_iter_without_progress=300,
3454
+ min_grad_norm=1e-7,
3455
+ metric=metric,
3456
+ init='random',
3457
+ verbose=v,
3458
+ random_state=42,
3459
+ method='barnes_hut',
3460
+ angle=0.5,
3461
+ n_jobs=n_jobs)
3462
+
3463
+ else:
3464
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
3465
+
3466
+ embedding = reducer.fit_transform(numeric_data)
3467
+ if verbose:
3468
+ print(f'Trained and fit reducer')
3469
+
3470
+ else:
3471
+ if not model is None:
3472
+ embedding = model.transform(numeric_data)
3473
+ reducer = model
3474
+ if verbose:
3475
+ print(f'Fit data to reducer')
3476
+ else:
3477
+ raise ValueError(f"Model is None. Please provide a model for transform.")
3478
+
3479
+ if clustering == 'dbscan':
3480
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
3481
+ elif clustering == 'kmeans':
3482
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
3483
+
3484
+ clustering_model.fit(embedding)
3485
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
3486
+
3487
+ if verbose:
3488
+ print(f'Embedding shape: {embedding.shape}')
3489
+
3490
+ return embedding, labels, reducer
3491
+
3492
+ def reduction_and_clustering_v1(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1):
3493
+ """
3494
+ Perform dimensionality reduction and clustering on the given data.
3495
+
3496
+ Parameters:
3497
+ numeric_data (np.ndarray): Numeric data for embedding and clustering.
3498
+ n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
3499
+ min_dist (float): Minimum distance for UMAP.
3500
+ metric (str): Metric for UMAP and DBSCAN.
3501
+ eps (float): Epsilon for DBSCAN.
3502
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
3503
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3504
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3505
+ verbose (bool): Whether to print verbose output.
3506
+ embedding (np.ndarray, optional): Precomputed embedding. Default is None.
3507
+
3508
+ Returns:
3509
+ tuple: embedding, labels
3510
+ """
3511
+
3512
+ if verbose:
3513
+ v=1
3514
+ else:
3515
+ v=0
3516
+
3517
+ if isinstance(n_neighbors, float):
3518
+ n_neighbors = int(n_neighbors * len(numeric_data))
3519
+
3520
+ if n_neighbors <= 2:
3521
+ n_neighbors = 2
3522
+
3523
+ if reduction_method == 'umap':
3524
+ reducer = umap.UMAP(n_neighbors=n_neighbors,
3525
+ n_components=2,
3526
+ metric=metric,
3527
+ n_epochs=None,
3528
+ learning_rate=1.0,
3529
+ init='spectral',
3530
+ min_dist=min_dist,
3531
+ spread=1.0,
3532
+ set_op_mix_ratio=1.0,
3533
+ local_connectivity=1,
3534
+ repulsion_strength=1.0,
3535
+ negative_sample_rate=5,
3536
+ transform_queue_size=4.0,
3537
+ a=None,
3538
+ b=None,
3539
+ random_state=42,
3540
+ metric_kwds=None,
3541
+ angular_rp_forest=False,
3542
+ target_n_neighbors=-1,
3543
+ target_metric='categorical',
3544
+ target_metric_kwds=None,
3545
+ target_weight=0.5,
3546
+ transform_seed=42,
3547
+ n_jobs=n_jobs,
3548
+ verbose=verbose)
3549
+
3550
+ elif reduction_method == 'tsne':
3551
+
3552
+ #tsne_params.setdefault('n_components', 2)
3553
+ #reducer = TSNE(**tsne_params)
3554
+
3555
+ reducer = TSNE(n_components=2,
3556
+ perplexity=n_neighbors,
3557
+ early_exaggeration=12.0,
3558
+ learning_rate=200.0,
3559
+ n_iter=1000,
3560
+ n_iter_without_progress=300,
3561
+ min_grad_norm=1e-7,
3562
+ metric=metric,
3563
+ init='random',
3564
+ verbose=v,
3565
+ random_state=42,
3566
+ method='barnes_hut',
3567
+ angle=0.5,
3568
+ n_jobs=n_jobs)
3569
+
3570
+ else:
3571
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
3572
+
3573
+ if embedding is None:
3574
+ embedding = reducer.fit_transform(numeric_data)
3575
+
3576
+ if clustering == 'dbscan':
3577
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
3578
+ elif clustering == 'kmeans':
3579
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
3580
+ else:
3581
+ raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
3582
+
3583
+ clustering_model.fit(embedding)
3584
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
3585
+
3586
+ if verbose:
3587
+ print(f'Embedding shape: {embedding.shape}')
3588
+
3589
+ return embedding, labels
3590
+
3591
+ def remove_noise(embedding, labels):
3592
+ non_noise_indices = labels != -1
3593
+ embedding = embedding[non_noise_indices]
3594
+ labels = labels[non_noise_indices]
3595
+ return embedding, labels
3596
+
3597
+ def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, plot_by_cluster, plot_outlines, plot_points, plot_images, smooth_lines, black_background, figuresize, dot_size, remove_image_canvas, verbose):
3598
+ unique_labels = np.unique(labels)
3599
+ #num_clusters = len(unique_labels[unique_labels != 0])
3600
+ colors, label_to_color_index = assign_colors(unique_labels, colors)
3601
+ cluster_centers = [np.mean(embedding[labels == cluster_label], axis=0) for cluster_label in unique_labels]
3602
+ fig, ax = setup_plot(figuresize, black_background)
3603
+ plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize, dot_size, verbose)
3604
+ if not image_paths is None and plot_images:
3605
+ plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose)
3606
+ plt.show()
3607
+ return fig
3608
+
3609
+ def generate_colors(num_clusters, black_background):
3610
+ random_colors = np.random.rand(num_clusters + 1, 4)
3611
+ random_colors[:, 3] = 1
3612
+ specific_colors = [
3613
+ [155 / 255, 55 / 255, 155 / 255, 1],
3614
+ [55 / 255, 155 / 255, 155 / 255, 1],
3615
+ [55 / 255, 155 / 255, 255 / 255, 1],
3616
+ [255 / 255, 55 / 255, 155 / 255, 1]
3617
+ ]
3618
+ random_colors = np.vstack((specific_colors, random_colors[len(specific_colors):]))
3619
+ if not black_background:
3620
+ random_colors = np.vstack(([0, 0, 0, 1], random_colors))
3621
+ return random_colors
3622
+
3623
+ def assign_colors(unique_labels, random_colors):
3624
+ normalized_colors = random_colors / 255
3625
+ colors_img = [tuple(color) for color in normalized_colors]
3626
+ colors = [tuple(color) for color in random_colors]
3627
+ label_to_color_index = {label: index for index, label in enumerate(unique_labels)}
3628
+ return colors, label_to_color_index
3629
+
3630
+ def setup_plot(figuresize, black_background):
3631
+ if black_background:
3632
+ plt.rcParams.update({'figure.facecolor': 'black', 'axes.facecolor': 'black', 'text.color': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'axes.labelcolor': 'white'})
3633
+ else:
3634
+ plt.rcParams.update({'figure.facecolor': 'white', 'axes.facecolor': 'white', 'text.color': 'black', 'xtick.color': 'black', 'ytick.color': 'black', 'axes.labelcolor': 'black'})
3635
+ fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
3636
+ return fig, ax
3637
+
3638
+ def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize=50, dot_size=50, verbose=False):
3639
+ unique_labels = np.unique(labels)
3640
+ for cluster_label, color, center in zip(unique_labels, colors, cluster_centers):
3641
+ cluster_data = embedding[labels == cluster_label]
3642
+ if smooth_lines:
3643
+ if cluster_data.shape[0] > 2:
3644
+ x_smooth, y_smooth = smooth_hull_lines(cluster_data)
3645
+ if plot_outlines:
3646
+ plt.plot(x_smooth, y_smooth, color=color, linewidth=2)
3647
+ else:
3648
+ if cluster_data.shape[0] > 2:
3649
+ hull = ConvexHull(cluster_data)
3650
+ for simplex in hull.simplices:
3651
+ if plot_outlines:
3652
+ plt.plot(hull.points[simplex, 0], hull.points[simplex, 1], color=color, linewidth=4)
3653
+ if plot_points:
3654
+ scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0.5, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
3655
+ else:
3656
+ scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
3657
+ ax.text(center[0], center[1], str(cluster_label), fontsize=12, ha='center', va='center')
3658
+ plt.legend(loc='best', fontsize=int(figuresize * 0.75))
3659
+ plt.xlabel('UMAP Dimension 1', fontsize=int(figuresize * 0.75))
3660
+ plt.ylabel('UMAP Dimension 2', fontsize=int(figuresize * 0.75))
3661
+ plt.tick_params(axis='both', which='major', labelsize=int(figuresize * 0.75))
3662
+
3663
+ def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose):
3664
+ if plot_by_cluster:
3665
+ cluster_indices = {label: np.where(labels == label)[0] for label in np.unique(labels) if label != -1}
3666
+ plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose)
3667
+ else:
3668
+ indices = random.sample(range(len(embedding)), image_nr)
3669
+ for i, index in enumerate(indices):
3670
+ x, y = embedding[index]
3671
+ img = Image.open(image_paths[index])
3672
+ plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
3673
+
3674
+ def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose):
3675
+ for cluster_label, color in zip(np.unique(labels), colors):
3676
+ if cluster_label == -1:
3677
+ continue
3678
+ indices = cluster_indices.get(cluster_label, [])
3679
+ if len(indices) > image_nr:
3680
+ indices = random.sample(list(indices), image_nr)
3681
+ for index in indices:
3682
+ x, y = embedding[index]
3683
+ img = Image.open(image_paths[index])
3684
+ plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
3685
+
3686
+ def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
3687
+ img = np.array(img)
3688
+ if remove_image_canvas:
3689
+ img = remove_canvas(img)
3690
+ imagebox = OffsetImage(img, zoom=img_zoom)
3691
+ ab = AnnotationBbox(imagebox, (x, y), frameon=False)
3692
+ ax.add_artist(ab)
3693
+
3694
+ def remove_canvas(img):
3695
+ if img.mode in ['L', 'I']:
3696
+ img_data = np.array(img)
3697
+ img_data = img_data / np.max(img_data)
3698
+ alpha_channel = (img_data > 0).astype(float)
3699
+ img_data_rgb = np.stack([img_data] * 3, axis=-1)
3700
+ img_data_with_alpha = np.dstack([img_data_rgb, alpha_channel])
3701
+ elif img.mode == 'RGB':
3702
+ img_data = np.array(img)
3703
+ img_data = img_data / 255.0
3704
+ alpha_channel = (np.sum(img_data, axis=-1) > 0).astype(float)
3705
+ img_data_with_alpha = np.dstack([img_data, alpha_channel])
3706
+ else:
3707
+ raise ValueError(f"Unsupported image mode: {img.mode}")
3708
+ return img_data_with_alpha
3709
+
3710
+ def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figuresize, black_background, verbose):
3711
+ unique_labels = np.unique(labels)
3712
+ num_clusters = len(unique_labels[unique_labels != -1])
3713
+ if num_clusters == 0:
3714
+ print("No clusters found.")
3715
+ return
3716
+ cluster_images = {label: [] for label in unique_labels if label != -1}
3717
+ cluster_indices = {label: np.where(labels == label)[0] for label in unique_labels if label != -1}
3718
+ for cluster_label, indices in cluster_indices.items():
3719
+ if cluster_label == -1:
3720
+ continue
3721
+ if len(indices) > image_nr:
3722
+ indices = random.sample(list(indices), image_nr)
3723
+ for index in indices:
3724
+ img_path = image_paths[index]
3725
+ img_array = Image.open(img_path)
3726
+ img = np.array(img_array)
3727
+ cluster_images[cluster_label].append(img)
3728
+ fig = plot_grid(cluster_images, colors, figuresize, black_background, verbose)
3729
+ return fig
3730
+
3731
+ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
3732
+ num_clusters = len(cluster_images)
3733
+ max_figsize = 200 # Set a maximum figure size
3734
+ if figuresize * num_clusters > max_figsize:
3735
+ figuresize = max_figsize / num_clusters
3736
+
3737
+ grid_fig, grid_axes = plt.subplots(1, num_clusters, figsize=(figuresize * num_clusters, figuresize), gridspec_kw={'wspace': 0.2, 'hspace': 0})
3738
+ if num_clusters == 1:
3739
+ grid_axes = [grid_axes] # Ensure grid_axes is always iterable
3740
+ for cluster_label, axes in zip(cluster_images.keys(), grid_axes):
3741
+ images = cluster_images[cluster_label]
3742
+ num_images = len(images)
3743
+ grid_size = int(np.ceil(np.sqrt(num_images)))
3744
+ image_size = 0.9 / grid_size
3745
+ whitespace = (1 - grid_size * image_size) / (grid_size + 1)
3746
+
3747
+ if isinstance(cluster_label, str):
3748
+ idx = list(cluster_images.keys()).index(cluster_label)
3749
+ color = colors[idx]
3750
+ if verbose:
3751
+ print(f'Lable: {cluster_label} index: {idx}')
3752
+ else:
3753
+ color = colors[cluster_label]
3754
+
3755
+ axes.add_patch(plt.Rectangle((0, 0), 1, 1, transform=axes.transAxes, color=color[:3]))
3756
+ axes.axis('off')
3757
+ for i, img in enumerate(images):
3758
+ row = i // grid_size
3759
+ col = i % grid_size
3760
+ x_pos = (col + 1) * whitespace + col * image_size
3761
+ y_pos = 1 - ((row + 1) * whitespace + (row + 1) * image_size)
3762
+ ax_img = axes.inset_axes([x_pos, y_pos, image_size, image_size], transform=axes.transAxes)
3763
+ ax_img.imshow(img, cmap='gray', aspect='auto')
3764
+ ax_img.axis('off')
3765
+ ax_img.set_aspect('equal')
3766
+ ax_img.set_facecolor(color[:3])
3767
+
3768
+ # Add cluster labels beside the UMAP plot
3769
+ spacing_factor = 0.5 # Adjust this value to control the spacing between labels
3770
+ for i, (cluster_label, color) in enumerate(zip(cluster_images.keys(), colors)):
3771
+ label_y = 1 - (i + 1) * (spacing_factor / num_clusters) # Adjust y position for each label
3772
+ grid_fig.text(1.05, label_y, f'Cluster {cluster_label}', verticalalignment='center', fontsize=figuresize, color='black' if not black_background else 'white')
3773
+ grid_fig.patches.append(plt.Rectangle((1, label_y - 0.02), 0.03, 0.03, transform=grid_fig.transFigure, color=color[:3], clip_on=False))
3774
+
3775
+ plt.show()
3776
+ return grid_fig
3777
+
3778
+ def correct_paths(df, base_path):
3779
+
3780
+ if 'png_path' not in df.columns:
3781
+ print("No 'png_path' column found in the dataframe.")
3782
+ return df, None
3783
+
3784
+ image_paths = df['png_path'].to_list()
3785
+
3786
+ adjusted_image_paths = []
3787
+ for path in image_paths:
3788
+ if base_path not in path:
3789
+ parts = path.split('/data/')
3790
+ if len(parts) > 1:
3791
+ new_path = os.path.join(base_path, 'data', parts[1])
3792
+ adjusted_image_paths.append(new_path)
3793
+ else:
3794
+ adjusted_image_paths.append(path)
3795
+ else:
3796
+ adjusted_image_paths.append(path)
3797
+
3798
+ df['png_path'] = adjusted_image_paths
3799
+ image_paths = df['png_path'].to_list()
3800
+ return df, image_paths
3801
+
3802
+ def correct_paths_v1(df, base_path):
3803
+ if 'png_path' not in df.columns:
3804
+ print("No 'png_path' column found in the dataframe.")
3805
+ return df, None
3806
+
3807
+ image_paths = df['png_path'].to_list()
3808
+
3809
+ adjusted_image_paths = []
3810
+ for path in image_paths:
3811
+ if base_path not in path:
3812
+ print(f"Adjusting path: {path}")
3813
+ parts = path.split('data/')
3814
+ if len(parts) > 1:
3815
+ new_path = os.path.join(base_path, 'data', parts[1])
3816
+ adjusted_image_paths.append(new_path)
3817
+ else:
3818
+ adjusted_image_paths.append(path)
3819
+ else:
3820
+ adjusted_image_paths.append(path)
3821
+
3822
+ df['png_path'] = adjusted_image_paths
3823
+ image_paths = df['png_path'].to_list()
3824
+ return df, image_paths
3825
+
3826
+ def get_umap_image_settings(settings={}):
3827
+ settings.setdefault('src', 'path')
3828
+ settings.setdefault('row_limit', 1000)
3829
+ settings.setdefault('tables', ['cell', 'cytoplasm', 'nucleus', 'pathogen'])
3830
+ settings.setdefault('visualize', 'cell')
3831
+ settings.setdefault('image_nr', 16)
3832
+ settings.setdefault('dot_size', 50)
3833
+ settings.setdefault('n_neighbors', 1000)
3834
+ settings.setdefault('min_dist', 0.1)
3835
+ settings.setdefault('metric', 'euclidean')
3836
+ settings.setdefault('eps', 0.5)
3837
+ settings.setdefault('min_samples', 1000)
3838
+ settings.setdefault('filter_by', 'channel_0')
3839
+ settings.setdefault('img_zoom', 0.5)
3840
+ settings.setdefault('plot_by_cluster', True)
3841
+ settings.setdefault('plot_cluster_grids', True)
3842
+ settings.setdefault('remove_cluster_noise', True)
3843
+ settings.setdefault('remove_highly_correlated', True)
3844
+ settings.setdefault('log_data', False)
3845
+ settings.setdefault('figuresize', 60)
3846
+ settings.setdefault('black_background', True)
3847
+ settings.setdefault('remove_image_canvas', False)
3848
+ settings.setdefault('plot_outlines', True)
3849
+ settings.setdefault('plot_points', True)
3850
+ settings.setdefault('smooth_lines', True)
3851
+ settings.setdefault('clustering', 'dbscan')
3852
+ settings.setdefault('exclude', None)
3853
+ settings.setdefault('col_to_compare', 'col')
3854
+ settings.setdefault('pos', 'c1')
3855
+ settings.setdefault('neg', 'c2')
3856
+ settings.setdefault('embedding_by_controls', False)
3857
+ settings.setdefault('plot_images', True)
3858
+ settings.setdefault('reduction_method','umap')
3859
+ settings.setdefault('save_figure', False)
3860
+ settings.setdefault('n_jobs', -1)
3861
+ settings.setdefault('color_by', None)
3862
+ settings.setdefault('neg', 'c1')
3863
+ settings.setdefault('pos', 'c2')
3864
+ settings.setdefault('mix', 'c3')
3865
+ settings.setdefault('mix', 'c3')
3866
+ settings.setdefault('exclude_conditions', None)
3867
+ settings.setdefault('analyze_clusters', False)
3868
+ settings.setdefault('resnet_features', False)
3869
+ settings.setdefault('verbose',True)
3870
+ return settings
3871
+
3872
+ def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude):
3873
+ """
3874
+ Preprocesses the given dataframe by applying filtering, removing highly correlated columns,
3875
+ applying log transformation, filling NaN values, and scaling the numeric data.
3876
+
3877
+ Args:
3878
+ df (pandas.DataFrame): The input dataframe.
3879
+ filter_by (str or None): The channel of interest to filter the dataframe by.
3880
+ remove_highly_correlated (bool or float): Whether to remove highly correlated columns.
3881
+ If a float is provided, it represents the correlation threshold.
3882
+ log_data (bool): Whether to apply log transformation to the numeric data.
3883
+ exclude (list or None): List of features to exclude from the filtering process.
3884
+ verbose (bool): Whether to print verbose output during preprocessing.
3885
+
3886
+ Returns:
3887
+ numpy.ndarray: The preprocessed numeric data.
3888
+
3889
+ Raises:
3890
+ ValueError: If no numeric columns are available after filtering.
3891
+
3892
+ """
3893
+ # Apply filtering based on the `filter_by` parameter
3894
+ if filter_by is not None:
3895
+ df, _ = filter_dataframe_features(df, channel_of_interest=filter_by, exclude=exclude)
3896
+
3897
+ # Select numerical features
3898
+ numeric_data = df.select_dtypes(include=['number'])
3899
+
3900
+ # Check if numeric_data is empty
3901
+ if numeric_data.empty:
3902
+ raise ValueError("No numeric columns available after filtering. Please check the filter_by and exclude parameters.")
3903
+
3904
+ # Remove highly correlated columns
3905
+ if not remove_highly_correlated is False:
3906
+ if isinstance(remove_highly_correlated, float):
3907
+ numeric_data = remove_highly_correlated_columns(numeric_data, remove_highly_correlated)
3908
+ else:
3909
+ numeric_data = remove_highly_correlated_columns(numeric_data, 0.95)
3910
+
3911
+ # Apply log transformation
3912
+ if log_data:
3913
+ numeric_data = np.log(numeric_data + 1e-6)
3914
+
3915
+ # Fill NaN values with the column mean
3916
+ numeric_data = numeric_data.fillna(numeric_data.mean())
3917
+
3918
+ # Scale the numeric data
3919
+ scaler = StandardScaler(copy=True, with_mean=True, with_std=True)
3920
+ numeric_data = scaler.fit_transform(numeric_data)
3921
+
3922
+ return numeric_data
3923
+
3924
+ def filter_dataframe_features(df, channel_of_interest, exclude=None):
3925
+ """
3926
+ Filter the dataframe `df` based on the specified `channel_of_interest` and `exclude` parameters.
3927
+
3928
+ Parameters:
3929
+ - df (pandas.DataFrame): The input dataframe to be filtered.
3930
+ - channel_of_interest (str, int, list, None): The channel(s) of interest to filter the dataframe.
3931
+ If None, no filtering is applied. If 'morphology', only morphology features are included.
3932
+ If an integer, only the specified channel is included. If a list, only the specified channels are included.
3933
+ If a string, only the specified channel is included.
3934
+ - exclude (str, list, None): The feature(s) to exclude from the filtered dataframe.
3935
+ If None, no features are excluded. If a string, the specified feature is excluded.
3936
+ If a list, the specified features are excluded.
3937
+
3938
+ Returns:
3939
+ - filtered_df (pandas.DataFrame): The filtered dataframe based on the specified parameters.
3940
+ - features (list): The list of selected features after filtering.
3941
+
3942
+ """
3943
+ if channel_of_interest is None:
3944
+ feature_string = None
3945
+ elif channel_of_interest == 'morphology':
3946
+ feature_string = 'morphology'
3947
+ elif isinstance(channel_of_interest, list):
3948
+ feature_string = []
3949
+ for i in channel_of_interest:
3950
+ feature_string_tmp = f'channel_{i}'
3951
+ feature_string.append(feature_string_tmp)
3952
+ elif isinstance(channel_of_interest, int):
3953
+ feature_string = f'channel_{channel_of_interest}'
3954
+ elif isinstance(channel_of_interest, str):
3955
+ feature_string = channel_of_interest
3956
+
3957
+ # Remove columns with a single value
3958
+ df = df.loc[:, df.nunique() > 1]
3959
+
3960
+ # Select numerical features
3961
+ features = df.select_dtypes(include=[np.number]).columns.tolist()
3962
+
3963
+ if feature_string is not None:
3964
+ feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
3965
+
3966
+ # Remove feature_string from the list if it exists
3967
+ if isinstance(feature_string, str):
3968
+ if feature_string in feature_list:
3969
+ feature_list.remove(feature_string)
3970
+ elif isinstance(feature_string, list):
3971
+ feature_list = [feature for feature in feature_list if feature not in feature_string]
3972
+
3973
+ if feature_string != 'morphology':
3974
+ features = [feature for feature in features if feature_string in feature]
3975
+
3976
+ # Iterate through the list and remove columns from df
3977
+ for feature_ in feature_list:
3978
+ features = [feature for feature in features if feature_ not in feature]
3979
+ print(f'After removing {feature_} features: {len(features)}')
3980
+
3981
+ if isinstance(exclude, list):
3982
+ features = [feature for feature in features if feature not in exclude]
3983
+ elif isinstance(exclude, str):
3984
+ features.remove(exclude)
3985
+
3986
+ filtered_df = df[features]
3987
+
3988
+ return filtered_df, features
3989
+
3990
+ # Create a function to check if images overlap
3991
+ def check_overlap(current_position, other_positions, threshold):
3992
+ for other_position in other_positions:
3993
+ distance = np.linalg.norm(np.array(current_position) - np.array(other_position))
3994
+ if distance < threshold:
3995
+ return True
3996
+ return False
3997
+
3998
+ # Define a function to try random positions around a given point
3999
+ def find_non_overlapping_position(x, y, image_positions, threshold, max_attempts=100):
4000
+ offset_range = 10 # Adjust the range for random offsets
4001
+ attempts = 0
4002
+ while attempts < max_attempts:
4003
+ random_offset_x = random.uniform(-offset_range, offset_range)
4004
+ random_offset_y = random.uniform(-offset_range, offset_range)
4005
+ new_x = x + random_offset_x
4006
+ new_y = y + random_offset_y
4007
+ if not check_overlap((new_x, new_y), image_positions, threshold):
4008
+ return new_x, new_y
4009
+ attempts += 1
4010
+ return x, y # Return the original position if no suitable position found
4011
+
4012
+ def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method, verbose, reduction_param=None, embedding=None, n_jobs=-1):
4013
+ """
4014
+ Perform dimensionality reduction and clustering on the given data.
4015
+
4016
+ Parameters:
4017
+ numeric_data (np.array): Numeric data to process.
4018
+ n_neighbors (int): Number of neighbors for UMAP or perplexity for tSNE.
4019
+ min_dist (float): Minimum distance for UMAP.
4020
+ metric (str): Metric for UMAP, tSNE, and DBSCAN.
4021
+ eps (float): Epsilon for DBSCAN clustering.
4022
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
4023
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
4024
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
4025
+ verbose (bool): Whether to print verbose output.
4026
+ reduction_param (dict): Additional parameters for the reduction method.
4027
+ embedding (np.array): Precomputed embedding (optional).
4028
+ n_jobs (int): Number of parallel jobs to run.
4029
+
4030
+ Returns:
4031
+ embedding (np.array): Embedding of the data.
4032
+ labels (np.array): Cluster labels.
4033
+ """
4034
+
4035
+ if isinstance(n_neighbors, float):
4036
+ n_neighbors = int(n_neighbors * len(numeric_data))
4037
+ if n_neighbors <= 1:
4038
+ n_neighbors = 2
4039
+ print(f'n_neighbors cannota be less than 2. Setting n_neighbors to {n_neighbors}')
4040
+
4041
+ reduction_param = reduction_param or {}
4042
+ reduction_param = {k: v for k, v in reduction_param.items() if k not in ['perplexity', 'n_neighbors', 'min_dist', 'metric', 'method']}
4043
+
4044
+ if reduction_method == 'umap':
4045
+ reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, n_jobs=n_jobs, **reduction_param)
4046
+ elif reduction_method == 'tsne':
4047
+ reducer = TSNE(n_components=2, perplexity=n_neighbors, metric=metric, n_jobs=n_jobs, **reduction_param)
4048
+ else:
4049
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
4050
+
4051
+ if embedding is None:
4052
+ embedding = reducer.fit_transform(numeric_data)
4053
+
4054
+ if clustering == 'dbscan':
4055
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
4056
+ elif clustering == 'kmeans':
4057
+ from sklearn.cluster import KMeans
4058
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
4059
+ else:
4060
+ raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
4061
+ clustering_model.fit(embedding)
4062
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
4063
+ if verbose:
4064
+ print(f'Embedding shape: {embedding.shape}')
4065
+ return embedding, labels
4066
+
4067
+
4068
+