spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/utils.py CHANGED
@@ -1,10 +1,18 @@
1
- import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
1
+ import sys, os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob
2
2
 
3
3
  import numpy as np
4
+ from cellpose import models as cp_models
5
+ from cellpose import denoise
6
+
4
7
  from skimage import morphology
5
8
  from skimage.measure import label, regionprops_table, regionprops
6
9
  import skimage.measure as measure
7
- from collections import defaultdict
10
+ from skimage.transform import resize as resizescikit
11
+ from skimage.morphology import dilation, square
12
+ from skimage.measure import find_contours
13
+ from skimage.segmentation import clear_border
14
+
15
+ from collections import defaultdict, OrderedDict
8
16
  from PIL import Image
9
17
  import pandas as pd
10
18
  from statsmodels.stats.outliers_influence import variance_inflation_factor
@@ -13,37 +21,257 @@ import statsmodels.formula.api as smf
13
21
  import statsmodels.api as sm
14
22
  from statsmodels.stats.multitest import multipletests
15
23
  from itertools import combinations
16
- from collections import OrderedDict
17
24
  from functools import reduce
18
- from IPython.display import display, clear_output
25
+ from IPython.display import display
26
+
19
27
  from multiprocessing import Pool, cpu_count
20
- from skimage.transform import resize as resizescikit
28
+ from concurrent.futures import ThreadPoolExecutor
29
+
21
30
  import torch.nn as nn
22
31
  import torch.nn.functional as F
23
- #from torchsummary import summary
24
32
  from torch.utils.checkpoint import checkpoint
25
33
  from torch.utils.data import Subset
26
34
  from torch.autograd import grad
27
- from torchvision import models
28
- from skimage.segmentation import clear_border
35
+
29
36
  import seaborn as sns
30
37
  import matplotlib.pyplot as plt
38
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
39
+
31
40
  import scipy.ndimage as ndi
41
+ from scipy.spatial import distance
32
42
  from scipy.stats import fisher_exact
33
- from scipy.ndimage import binary_erosion, binary_dilation
43
+ from scipy.ndimage.filters import gaussian_filter
44
+ from scipy.spatial import ConvexHull
45
+ from scipy.interpolate import splprep, splev
46
+
47
+ from sklearn.preprocessing import StandardScaler
34
48
  from skimage.exposure import rescale_intensity
35
49
  from sklearn.metrics import auc, precision_recall_curve
36
50
  from sklearn.model_selection import train_test_split
37
51
  from sklearn.linear_model import Lasso, Ridge
38
52
  from sklearn.preprocessing import OneHotEncoder
53
+ from sklearn.cluster import KMeans
54
+ from sklearn.preprocessing import StandardScaler
55
+ from sklearn.cluster import DBSCAN
56
+ from sklearn.cluster import KMeans
57
+ from sklearn.manifold import TSNE
58
+
59
+ import umap.umap_ as umap
60
+
61
+ from torchvision import models
39
62
  from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
63
+ import torchvision.transforms as transforms
40
64
 
41
65
  from .logger import log_function_call
42
66
 
43
- #from .io import _read_and_join_tables, _save_figure
44
- #from .timelapse import _btrack_track_cells, _trackpy_track_cells
45
- #from .plot import _plot_images_on_grid, plot_masks, _plot_histograms_and_stats, plot_resize, _plot_plates, _reg_v_plot, plot_masks
46
- #from .core import identify_masks
67
+ def check_mask_folder(src,mask_fldr):
68
+
69
+ mask_folder = os.path.join(src,'norm_channel_stack',mask_fldr)
70
+ stack_folder = os.path.join(src,'stack')
71
+
72
+ if not os.path.exists(mask_folder):
73
+ return True
74
+
75
+ mask_count = sum(1 for file in os.listdir(mask_folder) if file.endswith('.npy'))
76
+ stack_count = sum(1 for file in os.listdir(stack_folder) if file.endswith('.npy'))
77
+
78
+ if mask_count == stack_count:
79
+ print(f'All masks have been generated for {mask_fldr}')
80
+ return False
81
+ else:
82
+ return True
83
+
84
+ def set_default_plot_merge_settings():
85
+ settings = {}
86
+ settings.setdefault('include_noninfected', True)
87
+ settings.setdefault('include_multiinfected', True)
88
+ settings.setdefault('include_multinucleated', True)
89
+ settings.setdefault('remove_background', False)
90
+ settings.setdefault('filter_min_max', None)
91
+ settings.setdefault('channel_dims', [0,1,2,3])
92
+ settings.setdefault('backgrounds', [100,100,100,100])
93
+ settings.setdefault('cell_mask_dim', 4)
94
+ settings.setdefault('nucleus_mask_dim', 5)
95
+ settings.setdefault('pathogen_mask_dim', 6)
96
+ settings.setdefault('outline_thickness', 3)
97
+ settings.setdefault('outline_color', 'gbr')
98
+ settings.setdefault('overlay_chans', [1,2,3])
99
+ settings.setdefault('overlay', True)
100
+ settings.setdefault('normalization_percentiles', [2,98])
101
+ settings.setdefault('normalize', True)
102
+ settings.setdefault('print_object_number', True)
103
+ settings.setdefault('nr', 1)
104
+ settings.setdefault('figuresize', 50)
105
+ settings.setdefault('cmap', 'inferno')
106
+ settings.setdefault('verbose', True)
107
+
108
+ return settings
109
+
110
+ def set_default_settings_preprocess_generate_masks(src, settings={}):
111
+ # Main settings
112
+ settings['src'] = src
113
+ settings.setdefault('preprocess', True)
114
+ settings.setdefault('masks', True)
115
+ settings.setdefault('save', True)
116
+ settings.setdefault('batch_size', 50)
117
+ settings.setdefault('test_mode', False)
118
+ settings.setdefault('test_images', 10)
119
+ settings.setdefault('magnification', 20)
120
+ settings.setdefault('custom_regex', None)
121
+ settings.setdefault('metadata_type', 'cellvoyager')
122
+ settings.setdefault('workers', os.cpu_count()-4)
123
+ settings.setdefault('randomize', True)
124
+ settings.setdefault('verbose', True)
125
+
126
+ settings.setdefault('remove_background_cell', False)
127
+ settings.setdefault('remove_background_nucleus', False)
128
+ settings.setdefault('remove_background_pathogen', False)
129
+
130
+ # Channel settings
131
+ settings.setdefault('cell_channel', None)
132
+ settings.setdefault('nucleus_channel', None)
133
+ settings.setdefault('pathogen_channel', None)
134
+ settings.setdefault('channels', [0,1,2,3])
135
+ settings.setdefault('pathogen_background', 100)
136
+ settings.setdefault('pathogen_Signal_to_noise', 10)
137
+ settings.setdefault('pathogen_CP_prob', 0)
138
+ settings.setdefault('cell_background', 100)
139
+ settings.setdefault('cell_Signal_to_noise', 10)
140
+ settings.setdefault('cell_CP_prob', 0)
141
+ settings.setdefault('nucleus_background', 100)
142
+ settings.setdefault('nucleus_Signal_to_noise', 10)
143
+ settings.setdefault('nucleus_CP_prob', 0)
144
+
145
+ settings.setdefault('nucleus_FT', 100)
146
+ settings.setdefault('cell_FT', 100)
147
+ settings.setdefault('pathogen_FT', 100)
148
+
149
+ # Plot settings
150
+ settings.setdefault('plot', False)
151
+ settings.setdefault('figuresize', 50)
152
+ settings.setdefault('cmap', 'inferno')
153
+ settings.setdefault('normalize', True)
154
+ settings.setdefault('normalize_plots', True)
155
+ settings.setdefault('examples_to_plot', 1)
156
+
157
+ # Analasys settings
158
+ settings.setdefault('pathogen_model', None)
159
+ settings.setdefault('merge_pathogens', False)
160
+ settings.setdefault('filter', False)
161
+ settings.setdefault('lower_percentile', 2)
162
+
163
+ # Timelapse settings
164
+ settings.setdefault('timelapse', False)
165
+ settings.setdefault('fps', 2)
166
+ settings.setdefault('timelapse_displacement', None)
167
+ settings.setdefault('timelapse_memory', 3)
168
+ settings.setdefault('timelapse_frame_limits', None)
169
+ settings.setdefault('timelapse_remove_transient', False)
170
+ settings.setdefault('timelapse_mode', 'trackpy')
171
+ settings.setdefault('timelapse_objects', 'cells')
172
+
173
+ # Misc settings
174
+ settings.setdefault('all_to_mip', False)
175
+ settings.setdefault('pick_slice', False)
176
+ settings.setdefault('skip_mode', '01')
177
+ settings.setdefault('upscale', False)
178
+ settings.setdefault('upscale_factor', 2.0)
179
+ settings.setdefault('adjust_cells', False)
180
+
181
+ return settings
182
+
183
+ def set_default_settings_preprocess_img_data(settings):
184
+
185
+ metadata_type = settings.setdefault('metadata_type', 'cellvoyager')
186
+ custom_regex = settings.setdefault('custom_regex', None)
187
+ nr = settings.setdefault('nr', 1)
188
+ plot = settings.setdefault('plot', True)
189
+ batch_size = settings.setdefault('batch_size', 50)
190
+ timelapse = settings.setdefault('timelapse', False)
191
+ lower_percentile = settings.setdefault('lower_percentile', 2)
192
+ randomize = settings.setdefault('randomize', True)
193
+ all_to_mip = settings.setdefault('all_to_mip', False)
194
+ pick_slice = settings.setdefault('pick_slice', False)
195
+ skip_mode = settings.setdefault('skip_mode', False)
196
+
197
+ cmap = settings.setdefault('cmap', 'inferno')
198
+ figuresize = settings.setdefault('figuresize', 50)
199
+ normalize = settings.setdefault('normalize', True)
200
+ save_dtype = settings.setdefault('save_dtype', 'uint16')
201
+
202
+ test_mode = settings.setdefault('test_mode', False)
203
+ test_images = settings.setdefault('test_images', 10)
204
+ random_test = settings.setdefault('random_test', True)
205
+
206
+ return settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test
207
+
208
+ def smooth_hull_lines(cluster_data):
209
+ hull = ConvexHull(cluster_data)
210
+
211
+ # Extract vertices of the hull
212
+ vertices = hull.points[hull.vertices]
213
+
214
+ # Close the loop
215
+ vertices = np.vstack([vertices, vertices[0, :]])
216
+
217
+ # Parameterize the vertices
218
+ tck, u = splprep(vertices.T, u=None, s=0.0)
219
+
220
+ # Evaluate spline at new parameter values
221
+ new_points = splev(np.linspace(0, 1, 100), tck)
222
+
223
+ return new_points[0], new_points[1]
224
+
225
+ def _gen_rgb_image(image, channels):
226
+ """
227
+ Generate an RGB image from the specified channels of the input image.
228
+
229
+ Args:
230
+ image (ndarray): The input image.
231
+ channels (list): List of channel indices to use for RGB.
232
+
233
+ Returns:
234
+ rgb_image (ndarray): The generated RGB image.
235
+ """
236
+ rgb_image = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.float32)
237
+ for i, chan in enumerate(channels):
238
+ if chan < image.shape[2]:
239
+ rgb_image[:, :, i] = image[:, :, chan]
240
+ return rgb_image
241
+
242
+ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
243
+ outlines = []
244
+ overlayed_image = rgb_image.copy()
245
+
246
+ def process_dim(mask_dim):
247
+ mask = np.take(image, mask_dim, axis=-1)
248
+ outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
249
+
250
+ # Find and draw contours
251
+ for j in np.unique(mask):
252
+ if j == 0:
253
+ continue # Skip background
254
+ contours = find_contours(mask == j, 0.5)
255
+ # Convert contours for OpenCV format and draw directly to optimize
256
+ cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
257
+ cv2.drawContours(outline, cv_contours, -1, color=255, thickness=outline_thickness)
258
+
259
+ return dilation(outline, square(outline_thickness))
260
+
261
+ # Parallel processing
262
+ with ThreadPoolExecutor() as executor:
263
+ outlines = list(executor.map(process_dim, mask_dims))
264
+
265
+ # Overlay outlines onto the RGB image
266
+ for i, outline in enumerate(outlines):
267
+ color = np.array(outline_colors[i % len(outline_colors)])
268
+ for j in np.unique(outline):
269
+ if j == 0:
270
+ continue # Skip background
271
+ mask = outline == j
272
+ overlayed_image[mask] = color # Direct assignment with broadcasting
273
+
274
+ return overlayed_image, outlines, image
47
275
 
48
276
  def _convert_cq1_well_id(well_id):
49
277
  """
@@ -114,8 +342,8 @@ def _extract_filename_metadata(filenames, src, images_by_key, regular_expression
114
342
  if metadata_type =='cq1':
115
343
  orig_wellID = wellID
116
344
  wellID = _convert_cq1_well_id(wellID)
117
- clear_output(wait=True)
118
- print(f'\033[KConverted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
345
+ #clear_output(wait=True)
346
+ print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
119
347
 
120
348
  if pick_slice:
121
349
  try:
@@ -302,43 +530,82 @@ def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pa
302
530
  df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
303
531
  df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
304
532
  return df
305
-
306
- def normalize_to_dtype(array, q1=2,q2=98, percentiles=None):
533
+
534
+ def normalize_to_dtype(array, p1=2, p2=98):
307
535
  """
308
- Normalize the input array to a specified data type.
536
+ Normalize each image in the stack to its own percentiles.
309
537
 
310
538
  Parameters:
311
539
  - array: numpy array
312
- The input array to be normalized.
313
- - q1: int, optional
540
+ The input stack to be normalized.
541
+ - p1: int, optional
314
542
  The lower percentile value for normalization. Default is 2.
315
- - q2: int, optional
543
+ - p2: int, optional
316
544
  The upper percentile value for normalization. Default is 98.
317
- - percentiles: list of tuples, optional
318
- A list of tuples containing the percentile values for each image in the array.
319
- If provided, the percentiles for each image will be used instead of q1 and q2.
320
545
 
321
546
  Returns:
322
547
  - new_stack: numpy array
323
- The normalized array with the same shape as the input array.
548
+ The normalized stack with the same shape as the input stack.
324
549
  """
325
550
  nimg = array.shape[2]
326
551
  new_stack = np.empty_like(array)
327
- for i,v in enumerate(range(nimg)):
328
- img = np.squeeze(array[:, :, v])
552
+
553
+ for i in range(nimg):
554
+ img = array[:, :, i]
329
555
  non_zero_img = img[img > 0]
330
- if non_zero_img.size > 0: # check if there are non-zero values
331
- img_min = np.percentile(non_zero_img, q1) # change percentile from 0.02 to 2
332
- img_max = np.percentile(non_zero_img, q2) # change percentile from 0.98 to 98
333
- img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
334
- else: # if there are no non-zero values, just use the image as it is
335
- if percentiles==None:
336
- img_min, img_max = img.min(), img.max()
337
- else:
338
- img_min, img_max = percentiles[i]
339
- img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
340
- img = np.expand_dims(img, axis=2)
341
- new_stack[:, :, v] = img[:, :, 0]
556
+
557
+ if non_zero_img.size > 0:
558
+ img_min = np.percentile(non_zero_img, p1)
559
+ img_max = np.percentile(non_zero_img, p2)
560
+ else:
561
+ img_min = img.min()
562
+ img_max = img.max()
563
+
564
+ # Determine output range based on dtype
565
+ if np.issubdtype(array.dtype, np.integer):
566
+ out_range = (0, np.iinfo(array.dtype).max)
567
+ else:
568
+ out_range = (0.0, 1.0)
569
+
570
+ img = rescale_intensity(img, in_range=(img_min, img_max), out_range=out_range).astype(array.dtype)
571
+ new_stack[:, :, i] = img
572
+
573
+ return new_stack
574
+
575
+ def normalize_to_dtype(array, p1=2, p2=98):
576
+ """
577
+ Normalize each image in the stack to its own percentiles.
578
+
579
+ Parameters:
580
+ - array: numpy array
581
+ The input stack to be normalized.
582
+ - p1: int, optional
583
+ The lower percentile value for normalization. Default is 2.
584
+ - p2: int, optional
585
+ The upper percentile value for normalization. Default is 98.
586
+
587
+ Returns:
588
+ - new_stack: numpy array
589
+ The normalized stack with the same shape as the input stack.
590
+ """
591
+ nimg = array.shape[2]
592
+ new_stack = np.empty_like(array, dtype=np.float32)
593
+
594
+ for i in range(nimg):
595
+ img = array[:, :, i]
596
+ non_zero_img = img[img > 0]
597
+
598
+ if non_zero_img.size > 0:
599
+ img_min = np.percentile(non_zero_img, p1)
600
+ img_max = np.percentile(non_zero_img, p2)
601
+ else:
602
+ img_min = img.min()
603
+ img_max = img.max()
604
+
605
+ # Normalize to the range (0, 1) for visualization
606
+ img = rescale_intensity(img, in_range=(img_min, img_max), out_range=(0.0, 1.0))
607
+ new_stack[:, :, i] = img
608
+
342
609
  return new_stack
343
610
 
344
611
  def _list_endpoint_subdirectories(base_dir):
@@ -673,9 +940,6 @@ def _crop_center(img, cell_mask, new_width, new_height, normalize=(2,98)):
673
940
  img = img[start_y:end_y, start_x:end_x, :]
674
941
  return img
675
942
 
676
-
677
-
678
-
679
943
  def _masks_to_masks_stack(masks):
680
944
  """
681
945
  Convert a list of masks into a stack of masks.
@@ -692,53 +956,50 @@ def _masks_to_masks_stack(masks):
692
956
  return mask_stack
693
957
 
694
958
  def _get_diam(mag, obj):
695
- if obj == 'cell':
696
- if mag == 20:
697
- scale = 6
698
- if mag == 40:
699
- scale = 4.5
700
- if mag == 60:
701
- scale = 3
702
- elif obj == 'nucleus':
703
- if mag == 20:
704
- scale = 3
705
- if mag == 40:
706
- scale = 2
707
- if mag == 60:
708
- scale = 1.5
709
- elif obj == 'pathogen':
710
- if mag == 20:
711
- scale = 1.5
712
- if mag == 40:
713
- scale = 1
714
- if mag == 60:
715
- scale = 1.25
716
- elif obj == 'pathogen_nucleus':
717
- if mag == 20:
718
- scale = 0.25
719
- if mag == 40:
720
- scale = 0.2
721
- if mag == 60:
722
- scale = 0.2
959
+
960
+ if mag == 20:
961
+ if obj == 'cell':
962
+ diamiter = 120
963
+ elif obj == 'nucleus':
964
+ diamiter = 60
965
+ elif obj == 'pathogen':
966
+ diamiter = 20
967
+ else:
968
+ raise ValueError("Invalid magnification: Use 20, 40 or 60")
969
+
970
+ elif mag == 40:
971
+ if obj == 'cell':
972
+ diamiter = 160
973
+ elif obj == 'nucleus':
974
+ diamiter = 80
975
+ elif obj == 'pathogen':
976
+ diamiter = 40
977
+ else:
978
+ raise ValueError("Invalid magnification: Use 20, 40 or 60")
979
+
980
+ elif mag == 60:
981
+ if obj == 'cell':
982
+ diamiter = 200
983
+ if obj == 'nucleus':
984
+ diamiter = 90
985
+ if obj == 'pathogen':
986
+ diamiter = 60
987
+ else:
988
+ raise ValueError("Invalid magnification: Use 20, 40 or 60")
723
989
  else:
724
- raise ValueError("Invalid object type")
725
- diamiter = mag*scale
990
+ raise ValueError("Invalid magnification: Use 20, 40 or 60")
991
+
726
992
  return diamiter
727
993
 
728
994
  def _get_object_settings(object_type, settings):
729
-
730
995
  object_settings = {}
731
- object_settings['refine_masks'] = False
732
- object_settings['filter_size'] = False
733
- object_settings['filter_dimm'] = False
734
- print(object_type)
996
+
735
997
  object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
736
- object_settings['remove_border_objects'] = False
737
- object_settings['minimum_size'] = (object_settings['diameter']**2)/10
738
- object_settings['maximum_size'] = object_settings['minimum_size']*50
998
+ object_settings['minimum_size'] = (object_settings['diameter']**2)/4
999
+ object_settings['maximum_size'] = (object_settings['diameter']**2)*10
739
1000
  object_settings['merge'] = False
740
- object_settings['net_avg'] = True
741
1001
  object_settings['resample'] = True
1002
+ object_settings['remove_border_objects'] = False
742
1003
  object_settings['model_name'] = 'cyto'
743
1004
 
744
1005
  if object_type == 'cell':
@@ -746,20 +1007,29 @@ def _get_object_settings(object_type, settings):
746
1007
  object_settings['model_name'] = 'cyto'
747
1008
  else:
748
1009
  object_settings['model_name'] = 'cyto2'
749
-
1010
+ object_settings['filter_size'] = False
1011
+ object_settings['filter_intensity'] = False
1012
+ object_settings['restore_type'] = settings.get('cell_restore_type', None)
1013
+
750
1014
  elif object_type == 'nucleus':
751
1015
  object_settings['model_name'] = 'nuclei'
1016
+ object_settings['filter_size'] = False
1017
+ object_settings['filter_intensity'] = False
1018
+ object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
752
1019
 
753
1020
  elif object_type == 'pathogen':
754
- object_settings['model_name'] = 'cyto3'
755
-
756
- elif object_type == 'pathogen_nucleus':
757
- object_settings['filter_size'] = True
758
1021
  object_settings['model_name'] = 'cyto'
1022
+ object_settings['filter_size'] = False
1023
+ object_settings['filter_intensity'] = False
1024
+ object_settings['resample'] = False
1025
+ object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
1026
+ object_settings['merge'] = settings['merge_pathogens']
759
1027
 
760
1028
  else:
761
1029
  print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
762
- print(f'using settings: {object_settings}')
1030
+
1031
+ if settings['verbose']:
1032
+ print(object_settings)
763
1033
 
764
1034
  return object_settings
765
1035
 
@@ -786,6 +1056,7 @@ def _pivot_counts_table(db_path):
786
1056
  return df
787
1057
 
788
1058
  def _pivot_dataframe(df):
1059
+
789
1060
  """
790
1061
  Pivot the DataFrame.
791
1062
 
@@ -812,61 +1083,32 @@ def _pivot_counts_table(db_path):
812
1083
  pivoted_df.to_sql('pivoted_counts', conn, if_exists='replace', index=False)
813
1084
  conn.close()
814
1085
 
815
- def _get_cellpose_channels_v1(mask_channels, nucleus_chann_dim, pathogen_chann_dim, cell_chann_dim):
816
- cellpose_channels = {}
817
- if nucleus_chann_dim in mask_channels:
818
- cellpose_channels['nucleus'] = [0, mask_channels.index(nucleus_chann_dim)]
819
- if pathogen_chann_dim in mask_channels:
820
- cellpose_channels['pathogen'] = [0, mask_channels.index(pathogen_chann_dim)]
821
- if cell_chann_dim in mask_channels:
822
- cellpose_channels['cell'] = [0, mask_channels.index(cell_chann_dim)]
823
- return cellpose_channels
1086
+ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel):
824
1087
 
825
- def _get_cellpose_channels_v1(cell_channel, nucleus_channel, pathogen_channel):
826
- # Initialize a dictionary to hold the new indices for the specified channels
827
- cellpose_channels = {}
1088
+ cell_mask_path = os.path.join(src, 'norm_channel_stack', 'cell_mask_stack')
1089
+ nucleus_mask_path = os.path.join(src, 'norm_channel_stack', 'nucleus_mask_stack')
1090
+ pathogen_mask_path = os.path.join(src, 'norm_channel_stack', 'pathogen_mask_stack')
828
1091
 
829
- # Initialize a list to keep track of the channels in their new order
830
- new_channel_order = []
831
-
832
- # Add each channel to the new order list if it is not None
833
- if cell_channel is not None:
834
- new_channel_order.append(('cell', cell_channel))
835
- if nucleus_channel is not None:
836
- new_channel_order.append(('nucleus', nucleus_channel))
837
- if pathogen_channel is not None:
838
- new_channel_order.append(('pathogen', pathogen_channel))
839
-
840
- # Sort the list based on the original channel indices to maintain the original order
841
- new_channel_order.sort(key=lambda x: x[1])
842
- print(new_channel_order)
843
- # Assign new indices based on the sorted order
844
- for new_index, (channel_name, _) in enumerate(new_channel_order):
845
- cellpose_channels[channel_name] = [new_index, 0]
846
-
847
- if cell_channel is not None and nucleus_channel is not None:
848
- cellpose_channels['cell'][1] = cellpose_channels['nucleus'][0]
849
-
850
- return cellpose_channels
851
1092
 
852
- def _get_cellpose_channels(nucleus_channel, pathogen_channel, cell_channel):
1093
+ if os.path.exists(cell_mask_path) or os.path.exists(nucleus_mask_path) or os.path.exists(pathogen_mask_path):
1094
+ if nucleus_channel is None or nucleus_channel is None or nucleus_channel is None:
1095
+ print('Warning: Cellpose masks already exist. Unexpected behaviour when setting any object dimention to None when the object masks have been created.')
1096
+
853
1097
  cellpose_channels = {}
854
1098
  if not nucleus_channel is None:
855
1099
  cellpose_channels['nucleus'] = [0,0]
856
1100
 
857
1101
  if not pathogen_channel is None:
858
1102
  if not nucleus_channel is None:
859
- cellpose_channels['pathogen'] = [0,1]
1103
+ if not pathogen_channel is None:
1104
+ cellpose_channels['pathogen'] = [0,2]
1105
+ else:
1106
+ cellpose_channels['pathogen'] = [0,1]
860
1107
  else:
861
1108
  cellpose_channels['pathogen'] = [0,0]
862
1109
 
863
1110
  if not cell_channel is None:
864
1111
  if not nucleus_channel is None:
865
- if not pathogen_channel is None:
866
- cellpose_channels['cell'] = [0,2]
867
- else:
868
- cellpose_channels['cell'] = [0,1]
869
- elif not pathogen_channel is None:
870
1112
  cellpose_channels['cell'] = [0,1]
871
1113
  else:
872
1114
  cellpose_channels['cell'] = [0,0]
@@ -1027,9 +1269,6 @@ def _group_by_well(df):
1027
1269
  # Apply mean function to numeric columns and first to non-numeric
1028
1270
  df_grouped = df.groupby(['plate', 'row', 'col']).agg({**{col: np.mean for col in numeric_cols}, **{col: 'first' for col in non_numeric_cols}})
1029
1271
  return df_grouped
1030
-
1031
-
1032
-
1033
1272
 
1034
1273
  ###################################################
1035
1274
  # Classify
@@ -1044,7 +1283,7 @@ class Cache:
1044
1283
  cache (OrderedDict): The cache data structure.
1045
1284
  """
1046
1285
 
1047
- def _init__(self, max_size):
1286
+ def __init__(self, max_size):
1048
1287
  self.cache = OrderedDict()
1049
1288
  self.max_size = max_size
1050
1289
 
@@ -1075,7 +1314,7 @@ class ScaledDotProductAttention(nn.Module):
1075
1314
 
1076
1315
  """
1077
1316
 
1078
- def _init__(self, d_k):
1317
+ def __init__(self, d_k):
1079
1318
  super(ScaledDotProductAttention, self).__init__()
1080
1319
  self.d_k = d_k
1081
1320
 
@@ -1106,7 +1345,7 @@ class SelfAttention(nn.Module):
1106
1345
  d_k (int): Dimensionality of the key and query vectors.
1107
1346
  """
1108
1347
 
1109
- def _init__(self, in_channels, d_k):
1348
+ def __init__(self, in_channels, d_k):
1110
1349
  super(SelfAttention, self).__init__()
1111
1350
  self.W_q = nn.Linear(in_channels, d_k)
1112
1351
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1130,7 +1369,7 @@ class SelfAttention(nn.Module):
1130
1369
  return output
1131
1370
 
1132
1371
  class ScaledDotProductAttention(nn.Module):
1133
- def _init__(self, d_k):
1372
+ def __init__(self, d_k):
1134
1373
  """
1135
1374
  Initializes the ScaledDotProductAttention module.
1136
1375
 
@@ -1167,7 +1406,7 @@ class SelfAttention(nn.Module):
1167
1406
  in_channels (int): Number of input channels.
1168
1407
  d_k (int): Dimensionality of the key and query vectors.
1169
1408
  """
1170
- def _init__(self, in_channels, d_k):
1409
+ def __init__(self, in_channels, d_k):
1171
1410
  super(SelfAttention, self).__init__()
1172
1411
  self.W_q = nn.Linear(in_channels, d_k)
1173
1412
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1198,7 +1437,7 @@ class EarlyFusion(nn.Module):
1198
1437
  Args:
1199
1438
  in_channels (int): Number of input channels.
1200
1439
  """
1201
- def _init__(self, in_channels):
1440
+ def __init__(self, in_channels):
1202
1441
  super(EarlyFusion, self).__init__()
1203
1442
  self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
1204
1443
 
@@ -1217,7 +1456,7 @@ class EarlyFusion(nn.Module):
1217
1456
 
1218
1457
  # Spatial Attention Mechanism
1219
1458
  class SpatialAttention(nn.Module):
1220
- def _init__(self, kernel_size=7):
1459
+ def __init__(self, kernel_size=7):
1221
1460
  """
1222
1461
  Initializes the SpatialAttention module.
1223
1462
 
@@ -1262,7 +1501,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1262
1501
  forward: Forward method for the module.
1263
1502
  """
1264
1503
 
1265
- def _init__(self, in_channels, out_channels):
1504
+ def __init__(self, in_channels, out_channels):
1266
1505
  super(MultiScaleBlockWithAttention, self).__init__()
1267
1506
  self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
1268
1507
  self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
@@ -1295,7 +1534,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1295
1534
 
1296
1535
  # Final Classifier
1297
1536
  class CustomCellClassifier(nn.Module):
1298
- def _init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1537
+ def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1299
1538
  super(CustomCellClassifier, self).__init__()
1300
1539
  self.early_fusion = EarlyFusion(in_channels=3)
1301
1540
 
@@ -1324,7 +1563,7 @@ class CustomCellClassifier(nn.Module):
1324
1563
 
1325
1564
  #CNN and Transformer class, pick any Torch model.
1326
1565
  class TorchModel(nn.Module):
1327
- def _init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1566
+ def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1328
1567
  super(TorchModel, self).__init__()
1329
1568
  self.model_name = model_name
1330
1569
  self.use_checkpoint = use_checkpoint
@@ -1398,7 +1637,7 @@ class TorchModel(nn.Module):
1398
1637
  return logits
1399
1638
 
1400
1639
  class FocalLossWithLogits(nn.Module):
1401
- def _init__(self, alpha=1, gamma=2):
1640
+ def __init__(self, alpha=1, gamma=2):
1402
1641
  super(FocalLossWithLogits, self).__init__()
1403
1642
  self.alpha = alpha
1404
1643
  self.gamma = gamma
@@ -1410,7 +1649,7 @@ class FocalLossWithLogits(nn.Module):
1410
1649
  return focal_loss.mean()
1411
1650
 
1412
1651
  class ResNet(nn.Module):
1413
- def _init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1652
+ def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1414
1653
  super(ResNet, self).__init__()
1415
1654
 
1416
1655
  resnet_map = {
@@ -1763,25 +2002,24 @@ def annotate_predictions(csv_loc):
1763
2002
  df['cond'] = df.apply(assign_condition, axis=1)
1764
2003
  return df
1765
2004
 
1766
- def init_globals(counter_, lock_):
2005
+ def initiate_counter(counter_, lock_):
1767
2006
  global counter, lock
1768
2007
  counter = counter_
1769
2008
  lock = lock_
1770
2009
 
1771
- def add_images_to_tar(args):
1772
- global counter, lock, total_images
1773
- paths_chunk, tar_path = args
2010
+ def add_images_to_tar(paths_chunk, tar_path, total_images):
1774
2011
  with tarfile.open(tar_path, 'w') as tar:
1775
- for img_path in paths_chunk:
2012
+ for i, img_path in enumerate(paths_chunk):
1776
2013
  arcname = os.path.basename(img_path)
1777
2014
  try:
1778
2015
  tar.add(img_path, arcname=arcname)
1779
2016
  with lock:
1780
2017
  counter.value += 1
1781
- print(f"\rProcessed: {counter.value}/{total_images}", end='', flush=True)
2018
+ if counter.value % 100 == 0: # Print every 100 updates
2019
+ progress = (counter.value / total_images) * 100
2020
+ print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
1782
2021
  except FileNotFoundError:
1783
2022
  print(f"File not found: {img_path}")
1784
- return tar_path
1785
2023
 
1786
2024
  def generate_fraction_map(df, gene_column, min_frequency=0.0):
1787
2025
  df['fraction'] = df['count']/df['well_read_sum']
@@ -2230,8 +2468,8 @@ def dice_coefficient(mask1, mask2):
2230
2468
  def extract_boundaries(mask, dilation_radius=1):
2231
2469
  binary_mask = (mask > 0).astype(np.uint8)
2232
2470
  struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
2233
- dilated = binary_dilation(binary_mask, footprint=struct_elem)
2234
- eroded = binary_erosion(binary_mask, footprint=struct_elem)
2471
+ dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
2472
+ eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
2235
2473
  boundary = dilated ^ eroded
2236
2474
  return boundary
2237
2475
 
@@ -2612,24 +2850,21 @@ def _filter_object(mask, min_value):
2612
2850
  mask[np.isin(mask, to_remove)] = 0
2613
2851
  return mask
2614
2852
 
2615
- def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize):
2853
+ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize):
2854
+
2616
2855
  """
2617
2856
  Filter the masks based on various criteria such as size, border objects, merging, and intensity.
2618
2857
 
2619
2858
  Args:
2620
2859
  masks (list): List of masks.
2621
2860
  flows (list): List of flows.
2622
- refine_masks (bool): Flag indicating whether to refine masks.
2623
2861
  filter_size (bool): Flag indicating whether to filter based on size.
2862
+ filter_intensity (bool): Flag indicating whether to filter based on intensity.
2624
2863
  minimum_size (int): Minimum size of objects to keep.
2625
2864
  maximum_size (int): Maximum size of objects to keep.
2626
2865
  remove_border_objects (bool): Flag indicating whether to remove border objects.
2627
2866
  merge (bool): Flag indicating whether to merge adjacent objects.
2628
- filter_dimm (bool): Flag indicating whether to filter based on intensity.
2629
2867
  batch (ndarray): Batch of images.
2630
- moving_avg_q1 (float): Moving average of the first quartile of object intensities.
2631
- moving_avg_q3 (float): Moving average of the third quartile of object intensities.
2632
- moving_count (int): Count of moving averages.
2633
2868
  plot (bool): Flag indicating whether to plot the masks.
2634
2869
  figuresize (tuple): Size of the figure.
2635
2870
 
@@ -2641,51 +2876,66 @@ def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remo
2641
2876
 
2642
2877
  mask_stack = []
2643
2878
  for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2879
+
2644
2880
  if plot and idx == 0:
2645
2881
  num_objects = mask_object_count(mask)
2646
2882
  print(f'Number of objects before filtration: {num_objects}')
2647
2883
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2648
2884
 
2649
- if filter_size:
2650
- props = measure.regionprops_table(mask, properties=['label', 'area']) # Measure properties of labeled image regions.
2651
- valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)] # Select labels of valid size.
2652
- masks[idx] = np.isin(mask, valid_labels) * mask # Keep only valid objects.
2885
+ if merge:
2886
+ mask = merge_touching_objects(mask, threshold=0.66)
2653
2887
  if plot and idx == 0:
2654
2888
  num_objects = mask_object_count(mask)
2655
- print(f'Number of objects after size filtration >{minimum_size} and <{maximum_size} : {num_objects}')
2889
+ print(f'Number of objects after merging adjacent objects, : {num_objects}')
2656
2890
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2657
- if remove_border_objects:
2658
- mask = clear_border(mask)
2891
+
2892
+ if filter_size:
2893
+ props = measure.regionprops_table(mask, properties=['label', 'area'])
2894
+ valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
2895
+ mask = np.isin(mask, valid_labels) * mask
2659
2896
  if plot and idx == 0:
2660
2897
  num_objects = mask_object_count(mask)
2661
- print(f'Number of objects after removing border objects, : {num_objects}')
2898
+ print(f'Number of objects after size filtration >{minimum_size} and <{maximum_size} : {num_objects}')
2662
2899
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2663
- if merge:
2664
- mask = merge_touching_objects(mask, threshold=0.25)
2900
+
2901
+ if filter_intensity:
2902
+ intensity_image = image[:, :, 1]
2903
+ props = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'mean_intensity'])
2904
+ mean_intensities = np.array(props['mean_intensity']).reshape(-1, 1)
2905
+
2906
+ if mean_intensities.shape[0] >= 2:
2907
+ kmeans = KMeans(n_clusters=2, random_state=0).fit(mean_intensities)
2908
+ centroids = kmeans.cluster_centers_
2909
+
2910
+ # Calculate the Euclidean distance between the two centroids
2911
+ dist_between_centroids = distance.euclidean(centroids[0], centroids[1])
2912
+
2913
+ # Set a threshold for the minimum distance to consider clusters distinct
2914
+ distance_threshold = 0.25
2915
+
2916
+ if dist_between_centroids > distance_threshold:
2917
+ high_intensity_cluster = np.argmax(centroids)
2918
+ valid_labels = np.array(props['label'])[kmeans.labels_ == high_intensity_cluster]
2919
+ mask = np.isin(mask, valid_labels) * mask
2920
+
2665
2921
  if plot and idx == 0:
2666
2922
  num_objects = mask_object_count(mask)
2667
- print(f'Number of objects after merging adjacent objects, : {num_objects}')
2923
+ props_after = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'mean_intensity'])
2924
+ mean_intensities_after = np.mean(np.array(props_after['mean_intensity']))
2925
+ average_intensity_before = np.mean(mean_intensities)
2926
+ print(f'Number of objects after potential intensity clustering: {num_objects}. Mean intensity before:{average_intensity_before:.4f}. After:{mean_intensities_after:.4f}.')
2668
2927
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2669
- if filter_dimm:
2670
- unique_labels = np.unique(mask)
2671
- if len(unique_labels) == 1 and unique_labels[0] == 0:
2672
- continue
2673
- object_intensities = [np.mean(batch[idx, :, :, 1][mask == label]) for label in unique_labels if label != 0]
2674
- object_q1s = [np.percentile(intensities, 25) for intensities in object_intensities if intensities.size > 0]
2675
- object_q3s = [np.percentile(intensities, 75) for intensities in object_intensities if intensities.size > 0]
2676
- if object_q1s:
2677
- object_q1_mean = np.mean(object_q1s)
2678
- object_q3_mean = np.mean(object_q3s)
2679
- moving_avg_q1 = (moving_avg_q1 * moving_count + object_q1_mean) / (moving_count + 1)
2680
- moving_avg_q3 = (moving_avg_q3 * moving_count + object_q3_mean) / (moving_count + 1)
2681
- moving_count += 1
2682
- mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q1, mode='low')
2683
- mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q3, mode='high')
2928
+
2929
+
2930
+ if remove_border_objects:
2931
+ mask = clear_border(mask)
2684
2932
  if plot and idx == 0:
2685
2933
  num_objects = mask_object_count(mask)
2686
- print(f'Objects after intensity filtration > {moving_avg_q1} and <{moving_avg_q3}: {num_objects}')
2934
+ print(f'Number of objects after removing border objects, : {num_objects}')
2687
2935
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2936
+
2688
2937
  mask_stack.append(mask)
2938
+
2689
2939
  return mask_stack
2690
2940
 
2691
2941
  def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mask_chan):
@@ -2721,6 +2971,1098 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
2721
2971
  print(f'After {object_type} maximum mean intensity filter: {len(df)}')
2722
2972
  return df
2723
2973
 
2724
- ###################################################
2725
- # Classify
2726
- ###################################################
2974
+ def _get_regex(metadata_type, img_format, custom_regex=None):
2975
+
2976
+ if img_format == None:
2977
+ img_format == '.tif'
2978
+ if metadata_type == 'cellvoyager':
2979
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2980
+ elif metadata_type == 'cq1':
2981
+ regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2982
+ elif metadata_type == 'nikon':
2983
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2984
+ elif metadata_type == 'zeis':
2985
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2986
+ elif metadata_type == 'leica':
2987
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2988
+ elif metadata_type == 'custom':
2989
+ regex = f'({custom_regex}){img_format}'
2990
+
2991
+ print(f'regex mode:{metadata_type} regex:{regex}')
2992
+ return regex
2993
+
2994
+ def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True):
2995
+
2996
+ if timelapse:
2997
+ test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
2998
+
2999
+ test_folder_path = os.path.join(src, 'test')
3000
+ os.makedirs(test_folder_path, exist_ok=True)
3001
+ regular_expression = re.compile(regex)
3002
+
3003
+ if os.path.exists(os.path.join(src, 'orig')):
3004
+ src = os.path.join(src, 'orig')
3005
+
3006
+ all_filenames = [filename for filename in os.listdir(src) if regular_expression.match(filename)]
3007
+ print(f'Found {len(all_filenames)} files')
3008
+ images_by_set = defaultdict(list)
3009
+
3010
+ for filename in all_filenames:
3011
+ match = regular_expression.match(filename)
3012
+ if match:
3013
+ plate = match.group('plateID') if 'plateID' in match.groupdict() else os.path.basename(src)
3014
+ well = match.group('wellID')
3015
+ field = match.group('fieldID')
3016
+ set_identifier = (plate, well, field)
3017
+ images_by_set[set_identifier].append(filename)
3018
+
3019
+ # Prepare for random selection
3020
+ set_identifiers = list(images_by_set.keys())
3021
+ if random_test:
3022
+ random.seed(42)
3023
+ random.shuffle(set_identifiers) # Randomize the order
3024
+
3025
+ # Select a subset based on the test_images count
3026
+ selected_sets = set_identifiers[:test_images]
3027
+
3028
+ # Print information about the number of sets used
3029
+ print(f'Using {len(selected_sets)} random image set(s) for test model')
3030
+
3031
+ # Copy files for selected sets to the test folder
3032
+ for set_identifier in selected_sets:
3033
+ for filename in images_by_set[set_identifier]:
3034
+ shutil.copy(os.path.join(src, filename), test_folder_path)
3035
+
3036
+ return test_folder_path
3037
+
3038
+ def _choose_model(model_name, device, object_type='cell', restore_type=None, object_settings={}):
3039
+
3040
+ if object_type == 'pathogen':
3041
+ if model_name == 'toxo_pv_lumen':
3042
+ diameter = object_settings['diameter']
3043
+ current_dir = os.path.dirname(__file__)
3044
+ model_path = os.path.join(current_dir, 'models', 'cp', 'toxo_pv_lumen.CP_model')
3045
+ print(model_path)
3046
+ model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=model_path, diam_mean=diameter, device=device)
3047
+ #model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type='cyto', device=device)
3048
+ print(f'Using Toxoplasma PV lumen model to generate pathogen masks')
3049
+ return model
3050
+
3051
+ restore_list = ['denoise', 'deblur', 'upsample', None]
3052
+ if restore_type not in restore_list:
3053
+ print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
3054
+ restore_type = None
3055
+
3056
+ if restore_type == None:
3057
+ if model_name in ['cyto', 'cyto2', 'cyto3', 'nuclei']:
3058
+ model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type=model_name, device=device)
3059
+
3060
+ else:
3061
+ if object_type == 'nucleus':
3062
+ restore = f'{type}_nuclei'
3063
+ model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
3064
+ else:
3065
+ restore = f'{type}_cyto3'
3066
+ if model_name =='cyto2':
3067
+ chan2_restore = True
3068
+ if model_name =='cyto':
3069
+ chan2_restore = False
3070
+ model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
3071
+
3072
+ return model
3073
+
3074
+ class SelectChannels:
3075
+ def __init__(self, channels):
3076
+ self.channels = channels
3077
+
3078
+ def __call__(self, img):
3079
+ img = img.clone()
3080
+ if 1 not in self.channels:
3081
+ img[0, :, :] = 0 # Zero out the red channel
3082
+ if 2 not in self.channels:
3083
+ img[1, :, :] = 0 # Zero out the green channel
3084
+ if 3 not in self.channels:
3085
+ img[2, :, :] = 0 # Zero out the blue channel
3086
+ return img
3087
+
3088
+ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
3089
+
3090
+ if normalize:
3091
+ transform = transforms.Compose([
3092
+ transforms.ToTensor(),
3093
+ transforms.CenterCrop(size=(image_size, image_size)),
3094
+ SelectChannels(channels),
3095
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
3096
+ else:
3097
+ transform = transforms.Compose([
3098
+ transforms.ToTensor(),
3099
+ transforms.CenterCrop(size=(image_size, image_size)),
3100
+ SelectChannels(channels)])
3101
+
3102
+ image = Image.open(image_path).convert('RGB')
3103
+ input_tensor = transform(image).unsqueeze(0)
3104
+ return image, input_tensor
3105
+
3106
+
3107
+ class SaliencyMapGenerator:
3108
+ def __init__(self, model):
3109
+ self.model = model
3110
+
3111
+ def compute_saliency_maps(self, X, y):
3112
+ self.model.eval()
3113
+ X.requires_grad_()
3114
+
3115
+ # Forward pass
3116
+ scores = self.model(X).squeeze()
3117
+
3118
+ # For binary classification, target scores can be the single output
3119
+ target_scores = scores * (2 * y - 1)
3120
+
3121
+ self.model.zero_grad()
3122
+ target_scores.backward(torch.ones_like(target_scores))
3123
+
3124
+ saliency = X.grad.abs()
3125
+ return saliency
3126
+
3127
+ def plot_saliency_maps(self, X, y, saliency, class_names):
3128
+ N = X.shape[0]
3129
+ for i in range(N):
3130
+ plt.subplot(2, N, i + 1)
3131
+ plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
3132
+ plt.axis('off')
3133
+ plt.title(class_names[y[i]])
3134
+ plt.subplot(2, N, N + i + 1)
3135
+ plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
3136
+ plt.axis('off')
3137
+ plt.gcf().set_size_inches(12, 5)
3138
+ plt.show()
3139
+
3140
+ def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
3141
+ preprocess = transforms.Compose([
3142
+ transforms.Resize((image_size, image_size)),
3143
+ transforms.ToTensor(),
3144
+ ])
3145
+
3146
+ image = Image.open(image_path).convert('RGB')
3147
+ input_tensor = preprocess(image)
3148
+ if normalize:
3149
+ input_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_tensor)
3150
+ input_tensor = input_tensor.unsqueeze(0)
3151
+
3152
+ return image, input_tensor
3153
+
3154
+ def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names = ['nc', 'pc']):
3155
+
3156
+ def jitter(img, ox, oy):
3157
+ # Randomly jitter the image
3158
+ return torch.roll(torch.roll(img, ox, dims=2), oy, dims=3)
3159
+
3160
+ def blur_image(img, sigma=1):
3161
+ # Apply Gaussian blur to the image
3162
+ img_np = img.cpu().numpy()
3163
+ for i in range(img_np.shape[1]):
3164
+ img_np[:, i] = gaussian_filter(img_np[:, i], sigma=sigma)
3165
+ img.copy_(torch.tensor(img_np).to(img.device))
3166
+
3167
+ def deprocess(img_tensor):
3168
+ # Convert the tensor image to a numpy array for visualization
3169
+ img_tensor = img_tensor.clone()
3170
+ for c in range(3):
3171
+ img_tensor[:, c] = img_tensor[:, c] * SQUEEZENET_STD[c] + SQUEEZENET_MEAN[c]
3172
+ img_tensor = img_tensor.clamp(0, 1)
3173
+ return img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
3174
+
3175
+ # Assuming these are defined somewhere in your codebase
3176
+ SQUEEZENET_MEAN = [0.485, 0.456, 0.406]
3177
+ SQUEEZENET_STD = [0.229, 0.224, 0.225]
3178
+
3179
+ model = torch.load(model_path)
3180
+
3181
+ dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
3182
+ len_chans = len(channels)
3183
+ model.type(dtype)
3184
+
3185
+ # Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
3186
+ img = torch.randn(1, len_chans, img_size, img_size).mul_(1.0).type(dtype).requires_grad_()
3187
+
3188
+ for t in range(num_iterations):
3189
+ # Randomly jitter the image a bit; this gives slightly nicer results
3190
+ ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
3191
+ img.data.copy_(jitter(img.data, ox, oy))
3192
+
3193
+ # Forward pass
3194
+ score = model(img)
3195
+
3196
+ if target_y == 0:
3197
+ target_score = -score
3198
+ else:
3199
+ target_score = score
3200
+
3201
+ # Add regularization
3202
+ target_score = target_score - l2_reg * torch.norm(img)
3203
+
3204
+ # Backward pass
3205
+ target_score.backward()
3206
+
3207
+ # Gradient ascent step
3208
+ with torch.no_grad():
3209
+ img += learning_rate * img.grad / torch.norm(img.grad)
3210
+ img.grad.zero_()
3211
+
3212
+ # Undo the random jitter
3213
+ img.data.copy_(jitter(img.data, -ox, -oy))
3214
+
3215
+ # As regularizer, clamp and periodically blur the image
3216
+ for c in range(3):
3217
+ lo = float(-SQUEEZENET_MEAN[c] / SQUEEZENET_STD[c])
3218
+ hi = float((1.0 - SQUEEZENET_MEAN[c]) / SQUEEZENET_STD[c])
3219
+ img.data[:, c].clamp_(min=lo, max=hi)
3220
+ if t % blur_every == 0:
3221
+ blur_image(img.data, sigma=0.5)
3222
+
3223
+ # Periodically show the image
3224
+ if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
3225
+ plt.imshow(deprocess(img.data.clone().cpu()))
3226
+ class_name = class_names[target_y]
3227
+ plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
3228
+ plt.gcf().set_size_inches(4, 4)
3229
+ plt.axis('off')
3230
+ plt.show()
3231
+
3232
+ return deprocess(img.data.cpu())
3233
+
3234
+ def get_submodules(model, prefix=''):
3235
+ submodules = []
3236
+ for name, module in model.named_children():
3237
+ full_name = prefix + ('.' if prefix else '') + name
3238
+ submodules.append(full_name)
3239
+ submodules.extend(get_submodules(module, full_name))
3240
+ return submodules
3241
+
3242
+ class GradCAM:
3243
+ def __init__(self, model, target_layers=None, use_cuda=True):
3244
+ self.model = model
3245
+ self.model.eval()
3246
+ self.target_layers = target_layers
3247
+ self.cuda = use_cuda
3248
+ if self.cuda:
3249
+ self.model = model.cuda()
3250
+
3251
+ def forward(self, input):
3252
+ return self.model(input)
3253
+
3254
+ def __call__(self, x, index=None):
3255
+ if self.cuda:
3256
+ x = x.cuda()
3257
+
3258
+ features = []
3259
+ def hook(module, input, output):
3260
+ features.append(output)
3261
+
3262
+ handles = []
3263
+ for name, module in self.model.named_modules():
3264
+ if name in self.target_layers:
3265
+ handles.append(module.register_forward_hook(hook))
3266
+
3267
+ output = self.forward(x)
3268
+ if index is None:
3269
+ index = np.argmax(output.data.cpu().numpy())
3270
+
3271
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
3272
+ one_hot[0][index] = 1
3273
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
3274
+ if self.cuda:
3275
+ one_hot = one_hot.cuda()
3276
+
3277
+ one_hot = torch.sum(one_hot * output)
3278
+ self.model.zero_grad()
3279
+ one_hot.backward(retain_graph=True)
3280
+
3281
+ grads_val = features[0].grad.cpu().data.numpy()
3282
+ target = features[0].cpu().data.numpy()[0, :]
3283
+
3284
+ weights = np.mean(grads_val, axis=(2, 3))[0, :]
3285
+ cam = np.zeros(target.shape[1:], dtype=np.float32)
3286
+
3287
+ for i, w in enumerate(weights):
3288
+ cam += w * target[i, :, :]
3289
+
3290
+ cam = np.maximum(cam, 0)
3291
+ cam = cv2.resize(cam, (x.size(2), x.size(3)))
3292
+ cam = cam - np.min(cam)
3293
+ cam = cam / np.max(cam)
3294
+
3295
+ for handle in handles:
3296
+ handle.remove()
3297
+
3298
+ return cam
3299
+
3300
+ def show_cam_on_image(img, mask):
3301
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
3302
+ heatmap = np.float32(heatmap) / 255
3303
+ cam = heatmap + np.float32(img)
3304
+ cam = cam / np.max(cam)
3305
+ return np.uint8(255 * cam)
3306
+
3307
+ def recommend_target_layers(model):
3308
+ target_layers = []
3309
+ for name, module in model.named_modules():
3310
+ if isinstance(module, torch.nn.Conv2d):
3311
+ target_layers.append(name)
3312
+ # Choose the last conv layer as the recommended target layer
3313
+ if target_layers:
3314
+ return [target_layers[-1]], target_layers
3315
+ else:
3316
+ raise ValueError("No convolutional layers found in the model.")
3317
+
3318
+ class IntegratedGradients:
3319
+ def __init__(self, model):
3320
+ self.model = model
3321
+ self.model.eval()
3322
+
3323
+ def generate_integrated_gradients(self, input_tensor, target_label_idx, baseline=None, num_steps=50):
3324
+ if baseline is None:
3325
+ baseline = torch.zeros_like(input_tensor)
3326
+
3327
+ assert baseline.shape == input_tensor.shape
3328
+
3329
+ # Scale input and compute gradients
3330
+ scaled_inputs = [(baseline + (float(i) / num_steps) * (input_tensor - baseline)).requires_grad_(True) for i in range(0, num_steps + 1)]
3331
+ grads = []
3332
+ for scaled_input in scaled_inputs:
3333
+ out = self.model(scaled_input)
3334
+ self.model.zero_grad()
3335
+ out[0, target_label_idx].backward(retain_graph=True)
3336
+ grads.append(scaled_input.grad.data.cpu().numpy())
3337
+
3338
+ avg_grads = np.mean(grads[:-1], axis=0)
3339
+ integrated_grads = (input_tensor.cpu().data.numpy() - baseline.cpu().data.numpy()) * avg_grads
3340
+ return integrated_grads
3341
+
3342
+ def get_db_paths(src):
3343
+ if isinstance(src, str):
3344
+ src = [src]
3345
+ db_paths = [os.path.join(source, 'measurements/measurements.db') for source in src]
3346
+ return db_paths
3347
+
3348
+ def get_sequencing_paths(src):
3349
+ if isinstance(src, str):
3350
+ src = [src]
3351
+ seq_paths = [os.path.join(source, 'sequencing/sequencing_data.csv') for source in src]
3352
+ return seq_paths
3353
+
3354
+ def load_image_paths(c, visualize):
3355
+ c.execute(f'SELECT * FROM png_list')
3356
+ data = c.fetchall()
3357
+ columns_info = c.execute(f'PRAGMA table_info(png_list)').fetchall()
3358
+ column_names = [col_info[1] for col_info in columns_info]
3359
+ image_paths_df = pd.DataFrame(data, columns=column_names)
3360
+ if visualize:
3361
+ object_visualize = visualize + '_png'
3362
+ image_paths_df = image_paths_df[image_paths_df['png_path'].str.contains(object_visualize)]
3363
+ image_paths_df = image_paths_df.set_index('prcfo')
3364
+ return image_paths_df
3365
+
3366
+ def merge_dataframes(df, image_paths_df, verbose):
3367
+ df.set_index('prcfo', inplace=True)
3368
+ df = image_paths_df.merge(df, left_index=True, right_index=True)
3369
+ if verbose:
3370
+ display(df)
3371
+ return df
3372
+
3373
+ def remove_highly_correlated_columns(df, threshold):
3374
+ corr_matrix = df.corr().abs()
3375
+ upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
3376
+ to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > threshold)]
3377
+ return df.drop(to_drop, axis=1)
3378
+
3379
+ def filter_columns(df, filter_by):
3380
+ if filter_by != 'morphology':
3381
+ cols_to_include = [col for col in df.columns if filter_by in str(col)]
3382
+ else:
3383
+ cols_to_include = [col for col in df.columns if 'channel' not in str(col)]
3384
+ df = df[cols_to_include]
3385
+ return df
3386
+
3387
+ def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1, mode='fit', model=False):
3388
+ """
3389
+ Perform dimensionality reduction and clustering on the given data.
3390
+
3391
+ Parameters:
3392
+ numeric_data (np.ndarray): Numeric data for embedding and clustering.
3393
+ n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
3394
+ min_dist (float): Minimum distance for UMAP.
3395
+ metric (str): Metric for UMAP and DBSCAN.
3396
+ eps (float): Epsilon for DBSCAN.
3397
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
3398
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3399
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3400
+ verbose (bool): Whether to print verbose output.
3401
+ embedding (np.ndarray, optional): Precomputed embedding. Default is None.
3402
+ return_model (bool): Whether to return the reducer model. Default is False.
3403
+
3404
+ Returns:
3405
+ tuple: embedding, labels (and optionally the reducer model)
3406
+ """
3407
+
3408
+ if verbose:
3409
+ v = 1
3410
+ else:
3411
+ v = 0
3412
+
3413
+ if isinstance(n_neighbors, float):
3414
+ n_neighbors = int(n_neighbors * len(numeric_data))
3415
+
3416
+ if n_neighbors <= 2:
3417
+ n_neighbors = 2
3418
+
3419
+ if mode == 'fit':
3420
+ if reduction_method == 'umap':
3421
+ reducer = umap.UMAP(n_neighbors=n_neighbors,
3422
+ n_components=2,
3423
+ metric=metric,
3424
+ n_epochs=None,
3425
+ learning_rate=1.0,
3426
+ init='spectral',
3427
+ min_dist=min_dist,
3428
+ spread=1.0,
3429
+ set_op_mix_ratio=1.0,
3430
+ local_connectivity=1,
3431
+ repulsion_strength=1.0,
3432
+ negative_sample_rate=5,
3433
+ transform_queue_size=4.0,
3434
+ a=None,
3435
+ b=None,
3436
+ random_state=42,
3437
+ metric_kwds=None,
3438
+ angular_rp_forest=False,
3439
+ target_n_neighbors=-1,
3440
+ target_metric='categorical',
3441
+ target_metric_kwds=None,
3442
+ target_weight=0.5,
3443
+ transform_seed=42,
3444
+ n_jobs=n_jobs,
3445
+ verbose=verbose)
3446
+
3447
+ elif reduction_method == 'tsne':
3448
+ reducer = TSNE(n_components=2,
3449
+ perplexity=n_neighbors,
3450
+ early_exaggeration=12.0,
3451
+ learning_rate=200.0,
3452
+ n_iter=1000,
3453
+ n_iter_without_progress=300,
3454
+ min_grad_norm=1e-7,
3455
+ metric=metric,
3456
+ init='random',
3457
+ verbose=v,
3458
+ random_state=42,
3459
+ method='barnes_hut',
3460
+ angle=0.5,
3461
+ n_jobs=n_jobs)
3462
+
3463
+ else:
3464
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
3465
+
3466
+ embedding = reducer.fit_transform(numeric_data)
3467
+ if verbose:
3468
+ print(f'Trained and fit reducer')
3469
+
3470
+ else:
3471
+ if not model is None:
3472
+ embedding = model.transform(numeric_data)
3473
+ reducer = model
3474
+ if verbose:
3475
+ print(f'Fit data to reducer')
3476
+ else:
3477
+ raise ValueError(f"Model is None. Please provide a model for transform.")
3478
+
3479
+ if clustering == 'dbscan':
3480
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
3481
+ elif clustering == 'kmeans':
3482
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
3483
+
3484
+ clustering_model.fit(embedding)
3485
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
3486
+
3487
+ if verbose:
3488
+ print(f'Embedding shape: {embedding.shape}')
3489
+
3490
+ return embedding, labels, reducer
3491
+
3492
+ def reduction_and_clustering_v1(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1):
3493
+ """
3494
+ Perform dimensionality reduction and clustering on the given data.
3495
+
3496
+ Parameters:
3497
+ numeric_data (np.ndarray): Numeric data for embedding and clustering.
3498
+ n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
3499
+ min_dist (float): Minimum distance for UMAP.
3500
+ metric (str): Metric for UMAP and DBSCAN.
3501
+ eps (float): Epsilon for DBSCAN.
3502
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
3503
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3504
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3505
+ verbose (bool): Whether to print verbose output.
3506
+ embedding (np.ndarray, optional): Precomputed embedding. Default is None.
3507
+
3508
+ Returns:
3509
+ tuple: embedding, labels
3510
+ """
3511
+
3512
+ if verbose:
3513
+ v=1
3514
+ else:
3515
+ v=0
3516
+
3517
+ if isinstance(n_neighbors, float):
3518
+ n_neighbors = int(n_neighbors * len(numeric_data))
3519
+
3520
+ if n_neighbors <= 2:
3521
+ n_neighbors = 2
3522
+
3523
+ if reduction_method == 'umap':
3524
+ reducer = umap.UMAP(n_neighbors=n_neighbors,
3525
+ n_components=2,
3526
+ metric=metric,
3527
+ n_epochs=None,
3528
+ learning_rate=1.0,
3529
+ init='spectral',
3530
+ min_dist=min_dist,
3531
+ spread=1.0,
3532
+ set_op_mix_ratio=1.0,
3533
+ local_connectivity=1,
3534
+ repulsion_strength=1.0,
3535
+ negative_sample_rate=5,
3536
+ transform_queue_size=4.0,
3537
+ a=None,
3538
+ b=None,
3539
+ random_state=42,
3540
+ metric_kwds=None,
3541
+ angular_rp_forest=False,
3542
+ target_n_neighbors=-1,
3543
+ target_metric='categorical',
3544
+ target_metric_kwds=None,
3545
+ target_weight=0.5,
3546
+ transform_seed=42,
3547
+ n_jobs=n_jobs,
3548
+ verbose=verbose)
3549
+
3550
+ elif reduction_method == 'tsne':
3551
+
3552
+ #tsne_params.setdefault('n_components', 2)
3553
+ #reducer = TSNE(**tsne_params)
3554
+
3555
+ reducer = TSNE(n_components=2,
3556
+ perplexity=n_neighbors,
3557
+ early_exaggeration=12.0,
3558
+ learning_rate=200.0,
3559
+ n_iter=1000,
3560
+ n_iter_without_progress=300,
3561
+ min_grad_norm=1e-7,
3562
+ metric=metric,
3563
+ init='random',
3564
+ verbose=v,
3565
+ random_state=42,
3566
+ method='barnes_hut',
3567
+ angle=0.5,
3568
+ n_jobs=n_jobs)
3569
+
3570
+ else:
3571
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
3572
+
3573
+ if embedding is None:
3574
+ embedding = reducer.fit_transform(numeric_data)
3575
+
3576
+ if clustering == 'dbscan':
3577
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
3578
+ elif clustering == 'kmeans':
3579
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
3580
+ else:
3581
+ raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
3582
+
3583
+ clustering_model.fit(embedding)
3584
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
3585
+
3586
+ if verbose:
3587
+ print(f'Embedding shape: {embedding.shape}')
3588
+
3589
+ return embedding, labels
3590
+
3591
+ def remove_noise(embedding, labels):
3592
+ non_noise_indices = labels != -1
3593
+ embedding = embedding[non_noise_indices]
3594
+ labels = labels[non_noise_indices]
3595
+ return embedding, labels
3596
+
3597
+ def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, plot_by_cluster, plot_outlines, plot_points, plot_images, smooth_lines, black_background, figuresize, dot_size, remove_image_canvas, verbose):
3598
+ unique_labels = np.unique(labels)
3599
+ #num_clusters = len(unique_labels[unique_labels != 0])
3600
+ colors, label_to_color_index = assign_colors(unique_labels, colors)
3601
+ cluster_centers = [np.mean(embedding[labels == cluster_label], axis=0) for cluster_label in unique_labels]
3602
+ fig, ax = setup_plot(figuresize, black_background)
3603
+ plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize, dot_size, verbose)
3604
+ if not image_paths is None and plot_images:
3605
+ plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose)
3606
+ plt.show()
3607
+ return fig
3608
+
3609
+ def generate_colors(num_clusters, black_background):
3610
+ random_colors = np.random.rand(num_clusters + 1, 4)
3611
+ random_colors[:, 3] = 1
3612
+ specific_colors = [
3613
+ [155 / 255, 55 / 255, 155 / 255, 1],
3614
+ [55 / 255, 155 / 255, 155 / 255, 1],
3615
+ [55 / 255, 155 / 255, 255 / 255, 1],
3616
+ [255 / 255, 55 / 255, 155 / 255, 1]
3617
+ ]
3618
+ random_colors = np.vstack((specific_colors, random_colors[len(specific_colors):]))
3619
+ if not black_background:
3620
+ random_colors = np.vstack(([0, 0, 0, 1], random_colors))
3621
+ return random_colors
3622
+
3623
+ def assign_colors(unique_labels, random_colors):
3624
+ normalized_colors = random_colors / 255
3625
+ colors_img = [tuple(color) for color in normalized_colors]
3626
+ colors = [tuple(color) for color in random_colors]
3627
+ label_to_color_index = {label: index for index, label in enumerate(unique_labels)}
3628
+ return colors, label_to_color_index
3629
+
3630
+ def setup_plot(figuresize, black_background):
3631
+ if black_background:
3632
+ plt.rcParams.update({'figure.facecolor': 'black', 'axes.facecolor': 'black', 'text.color': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'axes.labelcolor': 'white'})
3633
+ else:
3634
+ plt.rcParams.update({'figure.facecolor': 'white', 'axes.facecolor': 'white', 'text.color': 'black', 'xtick.color': 'black', 'ytick.color': 'black', 'axes.labelcolor': 'black'})
3635
+ fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
3636
+ return fig, ax
3637
+
3638
+ def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize=50, dot_size=50, verbose=False):
3639
+ unique_labels = np.unique(labels)
3640
+ for cluster_label, color, center in zip(unique_labels, colors, cluster_centers):
3641
+ cluster_data = embedding[labels == cluster_label]
3642
+ if smooth_lines:
3643
+ if cluster_data.shape[0] > 2:
3644
+ x_smooth, y_smooth = smooth_hull_lines(cluster_data)
3645
+ if plot_outlines:
3646
+ plt.plot(x_smooth, y_smooth, color=color, linewidth=2)
3647
+ else:
3648
+ if cluster_data.shape[0] > 2:
3649
+ hull = ConvexHull(cluster_data)
3650
+ for simplex in hull.simplices:
3651
+ if plot_outlines:
3652
+ plt.plot(hull.points[simplex, 0], hull.points[simplex, 1], color=color, linewidth=4)
3653
+ if plot_points:
3654
+ scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0.5, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
3655
+ else:
3656
+ scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
3657
+ ax.text(center[0], center[1], str(cluster_label), fontsize=12, ha='center', va='center')
3658
+ plt.legend(loc='best', fontsize=int(figuresize * 0.75))
3659
+ plt.xlabel('UMAP Dimension 1', fontsize=int(figuresize * 0.75))
3660
+ plt.ylabel('UMAP Dimension 2', fontsize=int(figuresize * 0.75))
3661
+ plt.tick_params(axis='both', which='major', labelsize=int(figuresize * 0.75))
3662
+
3663
+ def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose):
3664
+ if plot_by_cluster:
3665
+ cluster_indices = {label: np.where(labels == label)[0] for label in np.unique(labels) if label != -1}
3666
+ plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose)
3667
+ else:
3668
+ indices = random.sample(range(len(embedding)), image_nr)
3669
+ for i, index in enumerate(indices):
3670
+ x, y = embedding[index]
3671
+ img = Image.open(image_paths[index])
3672
+ plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
3673
+
3674
+ def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose):
3675
+ for cluster_label, color in zip(np.unique(labels), colors):
3676
+ if cluster_label == -1:
3677
+ continue
3678
+ indices = cluster_indices.get(cluster_label, [])
3679
+ if len(indices) > image_nr:
3680
+ indices = random.sample(list(indices), image_nr)
3681
+ for index in indices:
3682
+ x, y = embedding[index]
3683
+ img = Image.open(image_paths[index])
3684
+ plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
3685
+
3686
+ def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
3687
+ img = np.array(img)
3688
+ if remove_image_canvas:
3689
+ img = remove_canvas(img)
3690
+ imagebox = OffsetImage(img, zoom=img_zoom)
3691
+ ab = AnnotationBbox(imagebox, (x, y), frameon=False)
3692
+ ax.add_artist(ab)
3693
+
3694
+ def remove_canvas(img):
3695
+ if img.mode in ['L', 'I']:
3696
+ img_data = np.array(img)
3697
+ img_data = img_data / np.max(img_data)
3698
+ alpha_channel = (img_data > 0).astype(float)
3699
+ img_data_rgb = np.stack([img_data] * 3, axis=-1)
3700
+ img_data_with_alpha = np.dstack([img_data_rgb, alpha_channel])
3701
+ elif img.mode == 'RGB':
3702
+ img_data = np.array(img)
3703
+ img_data = img_data / 255.0
3704
+ alpha_channel = (np.sum(img_data, axis=-1) > 0).astype(float)
3705
+ img_data_with_alpha = np.dstack([img_data, alpha_channel])
3706
+ else:
3707
+ raise ValueError(f"Unsupported image mode: {img.mode}")
3708
+ return img_data_with_alpha
3709
+
3710
+ def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figuresize, black_background, verbose):
3711
+ unique_labels = np.unique(labels)
3712
+ num_clusters = len(unique_labels[unique_labels != -1])
3713
+ if num_clusters == 0:
3714
+ print("No clusters found.")
3715
+ return
3716
+ cluster_images = {label: [] for label in unique_labels if label != -1}
3717
+ cluster_indices = {label: np.where(labels == label)[0] for label in unique_labels if label != -1}
3718
+ for cluster_label, indices in cluster_indices.items():
3719
+ if cluster_label == -1:
3720
+ continue
3721
+ if len(indices) > image_nr:
3722
+ indices = random.sample(list(indices), image_nr)
3723
+ for index in indices:
3724
+ img_path = image_paths[index]
3725
+ img_array = Image.open(img_path)
3726
+ img = np.array(img_array)
3727
+ cluster_images[cluster_label].append(img)
3728
+ fig = plot_grid(cluster_images, colors, figuresize, black_background, verbose)
3729
+ return fig
3730
+
3731
+ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
3732
+ num_clusters = len(cluster_images)
3733
+ max_figsize = 200 # Set a maximum figure size
3734
+ if figuresize * num_clusters > max_figsize:
3735
+ figuresize = max_figsize / num_clusters
3736
+
3737
+ grid_fig, grid_axes = plt.subplots(1, num_clusters, figsize=(figuresize * num_clusters, figuresize), gridspec_kw={'wspace': 0.2, 'hspace': 0})
3738
+ if num_clusters == 1:
3739
+ grid_axes = [grid_axes] # Ensure grid_axes is always iterable
3740
+ for cluster_label, axes in zip(cluster_images.keys(), grid_axes):
3741
+ images = cluster_images[cluster_label]
3742
+ num_images = len(images)
3743
+ grid_size = int(np.ceil(np.sqrt(num_images)))
3744
+ image_size = 0.9 / grid_size
3745
+ whitespace = (1 - grid_size * image_size) / (grid_size + 1)
3746
+
3747
+ if isinstance(cluster_label, str):
3748
+ idx = list(cluster_images.keys()).index(cluster_label)
3749
+ color = colors[idx]
3750
+ if verbose:
3751
+ print(f'Lable: {cluster_label} index: {idx}')
3752
+ else:
3753
+ color = colors[cluster_label]
3754
+
3755
+ axes.add_patch(plt.Rectangle((0, 0), 1, 1, transform=axes.transAxes, color=color[:3]))
3756
+ axes.axis('off')
3757
+ for i, img in enumerate(images):
3758
+ row = i // grid_size
3759
+ col = i % grid_size
3760
+ x_pos = (col + 1) * whitespace + col * image_size
3761
+ y_pos = 1 - ((row + 1) * whitespace + (row + 1) * image_size)
3762
+ ax_img = axes.inset_axes([x_pos, y_pos, image_size, image_size], transform=axes.transAxes)
3763
+ ax_img.imshow(img, cmap='gray', aspect='auto')
3764
+ ax_img.axis('off')
3765
+ ax_img.set_aspect('equal')
3766
+ ax_img.set_facecolor(color[:3])
3767
+
3768
+ # Add cluster labels beside the UMAP plot
3769
+ spacing_factor = 0.5 # Adjust this value to control the spacing between labels
3770
+ for i, (cluster_label, color) in enumerate(zip(cluster_images.keys(), colors)):
3771
+ label_y = 1 - (i + 1) * (spacing_factor / num_clusters) # Adjust y position for each label
3772
+ grid_fig.text(1.05, label_y, f'Cluster {cluster_label}', verticalalignment='center', fontsize=figuresize, color='black' if not black_background else 'white')
3773
+ grid_fig.patches.append(plt.Rectangle((1, label_y - 0.02), 0.03, 0.03, transform=grid_fig.transFigure, color=color[:3], clip_on=False))
3774
+
3775
+ plt.show()
3776
+ return grid_fig
3777
+
3778
+ def correct_paths(df, base_path):
3779
+
3780
+ if 'png_path' not in df.columns:
3781
+ print("No 'png_path' column found in the dataframe.")
3782
+ return df, None
3783
+
3784
+ image_paths = df['png_path'].to_list()
3785
+
3786
+ adjusted_image_paths = []
3787
+ for path in image_paths:
3788
+ if base_path not in path:
3789
+ parts = path.split('/data/')
3790
+ if len(parts) > 1:
3791
+ new_path = os.path.join(base_path, 'data', parts[1])
3792
+ adjusted_image_paths.append(new_path)
3793
+ else:
3794
+ adjusted_image_paths.append(path)
3795
+ else:
3796
+ adjusted_image_paths.append(path)
3797
+
3798
+ df['png_path'] = adjusted_image_paths
3799
+ image_paths = df['png_path'].to_list()
3800
+ return df, image_paths
3801
+
3802
+ def correct_paths_v1(df, base_path):
3803
+ if 'png_path' not in df.columns:
3804
+ print("No 'png_path' column found in the dataframe.")
3805
+ return df, None
3806
+
3807
+ image_paths = df['png_path'].to_list()
3808
+
3809
+ adjusted_image_paths = []
3810
+ for path in image_paths:
3811
+ if base_path not in path:
3812
+ print(f"Adjusting path: {path}")
3813
+ parts = path.split('data/')
3814
+ if len(parts) > 1:
3815
+ new_path = os.path.join(base_path, 'data', parts[1])
3816
+ adjusted_image_paths.append(new_path)
3817
+ else:
3818
+ adjusted_image_paths.append(path)
3819
+ else:
3820
+ adjusted_image_paths.append(path)
3821
+
3822
+ df['png_path'] = adjusted_image_paths
3823
+ image_paths = df['png_path'].to_list()
3824
+ return df, image_paths
3825
+
3826
+ def get_umap_image_settings(settings={}):
3827
+ settings.setdefault('src', 'path')
3828
+ settings.setdefault('row_limit', 1000)
3829
+ settings.setdefault('tables', ['cell', 'cytoplasm', 'nucleus', 'pathogen'])
3830
+ settings.setdefault('visualize', 'cell')
3831
+ settings.setdefault('image_nr', 16)
3832
+ settings.setdefault('dot_size', 50)
3833
+ settings.setdefault('n_neighbors', 1000)
3834
+ settings.setdefault('min_dist', 0.1)
3835
+ settings.setdefault('metric', 'euclidean')
3836
+ settings.setdefault('eps', 0.5)
3837
+ settings.setdefault('min_samples', 1000)
3838
+ settings.setdefault('filter_by', 'channel_0')
3839
+ settings.setdefault('img_zoom', 0.5)
3840
+ settings.setdefault('plot_by_cluster', True)
3841
+ settings.setdefault('plot_cluster_grids', True)
3842
+ settings.setdefault('remove_cluster_noise', True)
3843
+ settings.setdefault('remove_highly_correlated', True)
3844
+ settings.setdefault('log_data', False)
3845
+ settings.setdefault('figuresize', 60)
3846
+ settings.setdefault('black_background', True)
3847
+ settings.setdefault('remove_image_canvas', False)
3848
+ settings.setdefault('plot_outlines', True)
3849
+ settings.setdefault('plot_points', True)
3850
+ settings.setdefault('smooth_lines', True)
3851
+ settings.setdefault('clustering', 'dbscan')
3852
+ settings.setdefault('exclude', None)
3853
+ settings.setdefault('col_to_compare', 'col')
3854
+ settings.setdefault('pos', 'c1')
3855
+ settings.setdefault('neg', 'c2')
3856
+ settings.setdefault('embedding_by_controls', False)
3857
+ settings.setdefault('plot_images', True)
3858
+ settings.setdefault('reduction_method','umap')
3859
+ settings.setdefault('save_figure', False)
3860
+ settings.setdefault('n_jobs', -1)
3861
+ settings.setdefault('color_by', None)
3862
+ settings.setdefault('neg', 'c1')
3863
+ settings.setdefault('pos', 'c2')
3864
+ settings.setdefault('mix', 'c3')
3865
+ settings.setdefault('mix', 'c3')
3866
+ settings.setdefault('exclude_conditions', None)
3867
+ settings.setdefault('analyze_clusters', False)
3868
+ settings.setdefault('resnet_features', False)
3869
+ settings.setdefault('verbose',True)
3870
+ return settings
3871
+
3872
+ def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude):
3873
+ """
3874
+ Preprocesses the given dataframe by applying filtering, removing highly correlated columns,
3875
+ applying log transformation, filling NaN values, and scaling the numeric data.
3876
+
3877
+ Args:
3878
+ df (pandas.DataFrame): The input dataframe.
3879
+ filter_by (str or None): The channel of interest to filter the dataframe by.
3880
+ remove_highly_correlated (bool or float): Whether to remove highly correlated columns.
3881
+ If a float is provided, it represents the correlation threshold.
3882
+ log_data (bool): Whether to apply log transformation to the numeric data.
3883
+ exclude (list or None): List of features to exclude from the filtering process.
3884
+ verbose (bool): Whether to print verbose output during preprocessing.
3885
+
3886
+ Returns:
3887
+ numpy.ndarray: The preprocessed numeric data.
3888
+
3889
+ Raises:
3890
+ ValueError: If no numeric columns are available after filtering.
3891
+
3892
+ """
3893
+ # Apply filtering based on the `filter_by` parameter
3894
+ if filter_by is not None:
3895
+ df, _ = filter_dataframe_features(df, channel_of_interest=filter_by, exclude=exclude)
3896
+
3897
+ # Select numerical features
3898
+ numeric_data = df.select_dtypes(include=['number'])
3899
+
3900
+ # Check if numeric_data is empty
3901
+ if numeric_data.empty:
3902
+ raise ValueError("No numeric columns available after filtering. Please check the filter_by and exclude parameters.")
3903
+
3904
+ # Remove highly correlated columns
3905
+ if not remove_highly_correlated is False:
3906
+ if isinstance(remove_highly_correlated, float):
3907
+ numeric_data = remove_highly_correlated_columns(numeric_data, remove_highly_correlated)
3908
+ else:
3909
+ numeric_data = remove_highly_correlated_columns(numeric_data, 0.95)
3910
+
3911
+ # Apply log transformation
3912
+ if log_data:
3913
+ numeric_data = np.log(numeric_data + 1e-6)
3914
+
3915
+ # Fill NaN values with the column mean
3916
+ numeric_data = numeric_data.fillna(numeric_data.mean())
3917
+
3918
+ # Scale the numeric data
3919
+ scaler = StandardScaler(copy=True, with_mean=True, with_std=True)
3920
+ numeric_data = scaler.fit_transform(numeric_data)
3921
+
3922
+ return numeric_data
3923
+
3924
+ def filter_dataframe_features(df, channel_of_interest, exclude=None):
3925
+ """
3926
+ Filter the dataframe `df` based on the specified `channel_of_interest` and `exclude` parameters.
3927
+
3928
+ Parameters:
3929
+ - df (pandas.DataFrame): The input dataframe to be filtered.
3930
+ - channel_of_interest (str, int, list, None): The channel(s) of interest to filter the dataframe.
3931
+ If None, no filtering is applied. If 'morphology', only morphology features are included.
3932
+ If an integer, only the specified channel is included. If a list, only the specified channels are included.
3933
+ If a string, only the specified channel is included.
3934
+ - exclude (str, list, None): The feature(s) to exclude from the filtered dataframe.
3935
+ If None, no features are excluded. If a string, the specified feature is excluded.
3936
+ If a list, the specified features are excluded.
3937
+
3938
+ Returns:
3939
+ - filtered_df (pandas.DataFrame): The filtered dataframe based on the specified parameters.
3940
+ - features (list): The list of selected features after filtering.
3941
+
3942
+ """
3943
+ if channel_of_interest is None:
3944
+ feature_string = None
3945
+ elif channel_of_interest == 'morphology':
3946
+ feature_string = 'morphology'
3947
+ elif isinstance(channel_of_interest, list):
3948
+ feature_string = []
3949
+ for i in channel_of_interest:
3950
+ feature_string_tmp = f'channel_{i}'
3951
+ feature_string.append(feature_string_tmp)
3952
+ elif isinstance(channel_of_interest, int):
3953
+ feature_string = f'channel_{channel_of_interest}'
3954
+ elif isinstance(channel_of_interest, str):
3955
+ feature_string = channel_of_interest
3956
+
3957
+ # Remove columns with a single value
3958
+ df = df.loc[:, df.nunique() > 1]
3959
+
3960
+ # Select numerical features
3961
+ features = df.select_dtypes(include=[np.number]).columns.tolist()
3962
+
3963
+ if feature_string is not None:
3964
+ feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
3965
+
3966
+ # Remove feature_string from the list if it exists
3967
+ if isinstance(feature_string, str):
3968
+ if feature_string in feature_list:
3969
+ feature_list.remove(feature_string)
3970
+ elif isinstance(feature_string, list):
3971
+ feature_list = [feature for feature in feature_list if feature not in feature_string]
3972
+
3973
+ if feature_string != 'morphology':
3974
+ features = [feature for feature in features if feature_string in feature]
3975
+
3976
+ # Iterate through the list and remove columns from df
3977
+ for feature_ in feature_list:
3978
+ features = [feature for feature in features if feature_ not in feature]
3979
+ print(f'After removing {feature_} features: {len(features)}')
3980
+
3981
+ if isinstance(exclude, list):
3982
+ features = [feature for feature in features if feature not in exclude]
3983
+ elif isinstance(exclude, str):
3984
+ features.remove(exclude)
3985
+
3986
+ filtered_df = df[features]
3987
+
3988
+ return filtered_df, features
3989
+
3990
+ # Create a function to check if images overlap
3991
+ def check_overlap(current_position, other_positions, threshold):
3992
+ for other_position in other_positions:
3993
+ distance = np.linalg.norm(np.array(current_position) - np.array(other_position))
3994
+ if distance < threshold:
3995
+ return True
3996
+ return False
3997
+
3998
+ # Define a function to try random positions around a given point
3999
+ def find_non_overlapping_position(x, y, image_positions, threshold, max_attempts=100):
4000
+ offset_range = 10 # Adjust the range for random offsets
4001
+ attempts = 0
4002
+ while attempts < max_attempts:
4003
+ random_offset_x = random.uniform(-offset_range, offset_range)
4004
+ random_offset_y = random.uniform(-offset_range, offset_range)
4005
+ new_x = x + random_offset_x
4006
+ new_y = y + random_offset_y
4007
+ if not check_overlap((new_x, new_y), image_positions, threshold):
4008
+ return new_x, new_y
4009
+ attempts += 1
4010
+ return x, y # Return the original position if no suitable position found
4011
+
4012
+ def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method, verbose, reduction_param=None, embedding=None, n_jobs=-1):
4013
+ """
4014
+ Perform dimensionality reduction and clustering on the given data.
4015
+
4016
+ Parameters:
4017
+ numeric_data (np.array): Numeric data to process.
4018
+ n_neighbors (int): Number of neighbors for UMAP or perplexity for tSNE.
4019
+ min_dist (float): Minimum distance for UMAP.
4020
+ metric (str): Metric for UMAP, tSNE, and DBSCAN.
4021
+ eps (float): Epsilon for DBSCAN clustering.
4022
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
4023
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
4024
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
4025
+ verbose (bool): Whether to print verbose output.
4026
+ reduction_param (dict): Additional parameters for the reduction method.
4027
+ embedding (np.array): Precomputed embedding (optional).
4028
+ n_jobs (int): Number of parallel jobs to run.
4029
+
4030
+ Returns:
4031
+ embedding (np.array): Embedding of the data.
4032
+ labels (np.array): Cluster labels.
4033
+ """
4034
+
4035
+ if isinstance(n_neighbors, float):
4036
+ n_neighbors = int(n_neighbors * len(numeric_data))
4037
+ if n_neighbors <= 1:
4038
+ n_neighbors = 2
4039
+ print(f'n_neighbors cannota be less than 2. Setting n_neighbors to {n_neighbors}')
4040
+
4041
+ reduction_param = reduction_param or {}
4042
+ reduction_param = {k: v for k, v in reduction_param.items() if k not in ['perplexity', 'n_neighbors', 'min_dist', 'metric', 'method']}
4043
+
4044
+ if reduction_method == 'umap':
4045
+ reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, n_jobs=n_jobs, **reduction_param)
4046
+ elif reduction_method == 'tsne':
4047
+ reducer = TSNE(n_components=2, perplexity=n_neighbors, metric=metric, n_jobs=n_jobs, **reduction_param)
4048
+ else:
4049
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
4050
+
4051
+ if embedding is None:
4052
+ embedding = reducer.fit_transform(numeric_data)
4053
+
4054
+ if clustering == 'dbscan':
4055
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
4056
+ elif clustering == 'kmeans':
4057
+ from sklearn.cluster import KMeans
4058
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
4059
+ else:
4060
+ raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
4061
+ clustering_model.fit(embedding)
4062
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
4063
+ if verbose:
4064
+ print(f'Embedding shape: {embedding.shape}')
4065
+ return embedding, labels
4066
+
4067
+
4068
+