spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/utils.py CHANGED
@@ -1,12 +1,18 @@
1
- import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
1
+ import sys, os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob
2
2
 
3
3
  import numpy as np
4
4
  from cellpose import models as cp_models
5
5
  from cellpose import denoise
6
+
6
7
  from skimage import morphology
7
8
  from skimage.measure import label, regionprops_table, regionprops
8
9
  import skimage.measure as measure
9
- from collections import defaultdict
10
+ from skimage.transform import resize as resizescikit
11
+ from skimage.morphology import dilation, square
12
+ from skimage.measure import find_contours
13
+ from skimage.segmentation import clear_border
14
+
15
+ from collections import defaultdict, OrderedDict
10
16
  from PIL import Image
11
17
  import pandas as pd
12
18
  from statsmodels.stats.outliers_influence import variance_inflation_factor
@@ -15,54 +21,225 @@ import statsmodels.formula.api as smf
15
21
  import statsmodels.api as sm
16
22
  from statsmodels.stats.multitest import multipletests
17
23
  from itertools import combinations
18
- from collections import OrderedDict
19
24
  from functools import reduce
20
- from IPython.display import display, clear_output
25
+ from IPython.display import display
26
+
21
27
  from multiprocessing import Pool, cpu_count
22
- from skimage.transform import resize as resizescikit
23
- from skimage.morphology import dilation, square
24
- from skimage.measure import find_contours
28
+ from concurrent.futures import ThreadPoolExecutor
29
+
25
30
  import torch.nn as nn
26
31
  import torch.nn.functional as F
27
- #from torchsummary import summary
28
32
  from torch.utils.checkpoint import checkpoint
29
33
  from torch.utils.data import Subset
30
34
  from torch.autograd import grad
31
- from torchvision import models
32
- from skimage.segmentation import clear_border
35
+
33
36
  import seaborn as sns
34
37
  import matplotlib.pyplot as plt
38
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
39
+
35
40
  import scipy.ndimage as ndi
36
41
  from scipy.spatial import distance
37
42
  from scipy.stats import fisher_exact
38
- from scipy.ndimage import binary_erosion, binary_dilation
43
+ from scipy.ndimage.filters import gaussian_filter
44
+ from scipy.spatial import ConvexHull
45
+ from scipy.interpolate import splprep, splev
46
+
47
+ from sklearn.preprocessing import StandardScaler
39
48
  from skimage.exposure import rescale_intensity
40
49
  from sklearn.metrics import auc, precision_recall_curve
41
50
  from sklearn.model_selection import train_test_split
42
51
  from sklearn.linear_model import Lasso, Ridge
43
52
  from sklearn.preprocessing import OneHotEncoder
44
53
  from sklearn.cluster import KMeans
54
+ from sklearn.preprocessing import StandardScaler
55
+ from sklearn.cluster import DBSCAN
56
+ from sklearn.cluster import KMeans
57
+ from sklearn.manifold import TSNE
58
+
59
+ import umap.umap_ as umap
60
+
61
+ from torchvision import models
45
62
  from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
63
+ import torchvision.transforms as transforms
46
64
 
47
65
  from .logger import log_function_call
48
66
 
49
- #from .io import _read_and_join_tables, _save_figure
50
- #from .timelapse import _btrack_track_cells, _trackpy_track_cells
51
- #from .plot import _plot_images_on_grid, plot_masks, _plot_histograms_and_stats, plot_resize, _plot_plates, _reg_v_plot, plot_masks
52
- #from .core import identify_masks
67
+ def check_mask_folder(src,mask_fldr):
68
+
69
+ mask_folder = os.path.join(src,'norm_channel_stack',mask_fldr)
70
+ stack_folder = os.path.join(src,'stack')
71
+
72
+ if not os.path.exists(mask_folder):
73
+ return True
74
+
75
+ mask_count = sum(1 for file in os.listdir(mask_folder) if file.endswith('.npy'))
76
+ stack_count = sum(1 for file in os.listdir(stack_folder) if file.endswith('.npy'))
77
+
78
+ if mask_count == stack_count:
79
+ print(f'All masks have been generated for {mask_fldr}')
80
+ return False
81
+ else:
82
+ return True
83
+
84
+ def set_default_plot_merge_settings():
85
+ settings = {}
86
+ settings.setdefault('include_noninfected', True)
87
+ settings.setdefault('include_multiinfected', True)
88
+ settings.setdefault('include_multinucleated', True)
89
+ settings.setdefault('remove_background', False)
90
+ settings.setdefault('filter_min_max', None)
91
+ settings.setdefault('channel_dims', [0,1,2,3])
92
+ settings.setdefault('backgrounds', [100,100,100,100])
93
+ settings.setdefault('cell_mask_dim', 4)
94
+ settings.setdefault('nucleus_mask_dim', 5)
95
+ settings.setdefault('pathogen_mask_dim', 6)
96
+ settings.setdefault('outline_thickness', 3)
97
+ settings.setdefault('outline_color', 'gbr')
98
+ settings.setdefault('overlay_chans', [1,2,3])
99
+ settings.setdefault('overlay', True)
100
+ settings.setdefault('normalization_percentiles', [2,98])
101
+ settings.setdefault('normalize', True)
102
+ settings.setdefault('print_object_number', True)
103
+ settings.setdefault('nr', 1)
104
+ settings.setdefault('figuresize', 50)
105
+ settings.setdefault('cmap', 'inferno')
106
+ settings.setdefault('verbose', True)
107
+
108
+ return settings
109
+
110
+ def set_default_settings_preprocess_generate_masks(src, settings={}):
111
+ # Main settings
112
+ settings['src'] = src
113
+ settings.setdefault('preprocess', True)
114
+ settings.setdefault('masks', True)
115
+ settings.setdefault('save', True)
116
+ settings.setdefault('batch_size', 50)
117
+ settings.setdefault('test_mode', False)
118
+ settings.setdefault('test_images', 10)
119
+ settings.setdefault('magnification', 20)
120
+ settings.setdefault('custom_regex', None)
121
+ settings.setdefault('metadata_type', 'cellvoyager')
122
+ settings.setdefault('workers', os.cpu_count()-4)
123
+ settings.setdefault('randomize', True)
124
+ settings.setdefault('verbose', True)
125
+
126
+ settings.setdefault('remove_background_cell', False)
127
+ settings.setdefault('remove_background_nucleus', False)
128
+ settings.setdefault('remove_background_pathogen', False)
129
+
130
+ # Channel settings
131
+ settings.setdefault('cell_channel', None)
132
+ settings.setdefault('nucleus_channel', None)
133
+ settings.setdefault('pathogen_channel', None)
134
+ settings.setdefault('channels', [0,1,2,3])
135
+ settings.setdefault('pathogen_background', 100)
136
+ settings.setdefault('pathogen_Signal_to_noise', 10)
137
+ settings.setdefault('pathogen_CP_prob', 0)
138
+ settings.setdefault('cell_background', 100)
139
+ settings.setdefault('cell_Signal_to_noise', 10)
140
+ settings.setdefault('cell_CP_prob', 0)
141
+ settings.setdefault('nucleus_background', 100)
142
+ settings.setdefault('nucleus_Signal_to_noise', 10)
143
+ settings.setdefault('nucleus_CP_prob', 0)
144
+
145
+ settings.setdefault('nucleus_FT', 100)
146
+ settings.setdefault('cell_FT', 100)
147
+ settings.setdefault('pathogen_FT', 100)
148
+
149
+ # Plot settings
150
+ settings.setdefault('plot', False)
151
+ settings.setdefault('figuresize', 50)
152
+ settings.setdefault('cmap', 'inferno')
153
+ settings.setdefault('normalize', True)
154
+ settings.setdefault('normalize_plots', True)
155
+ settings.setdefault('examples_to_plot', 1)
156
+
157
+ # Analasys settings
158
+ settings.setdefault('pathogen_model', None)
159
+ settings.setdefault('merge_pathogens', False)
160
+ settings.setdefault('filter', False)
161
+ settings.setdefault('lower_percentile', 2)
162
+
163
+ # Timelapse settings
164
+ settings.setdefault('timelapse', False)
165
+ settings.setdefault('fps', 2)
166
+ settings.setdefault('timelapse_displacement', None)
167
+ settings.setdefault('timelapse_memory', 3)
168
+ settings.setdefault('timelapse_frame_limits', None)
169
+ settings.setdefault('timelapse_remove_transient', False)
170
+ settings.setdefault('timelapse_mode', 'trackpy')
171
+ settings.setdefault('timelapse_objects', 'cells')
172
+
173
+ # Misc settings
174
+ settings.setdefault('all_to_mip', False)
175
+ settings.setdefault('pick_slice', False)
176
+ settings.setdefault('skip_mode', '01')
177
+ settings.setdefault('upscale', False)
178
+ settings.setdefault('upscale_factor', 2.0)
179
+ settings.setdefault('adjust_cells', False)
180
+
181
+ return settings
182
+
183
+ def set_default_settings_preprocess_img_data(settings):
184
+
185
+ metadata_type = settings.setdefault('metadata_type', 'cellvoyager')
186
+ custom_regex = settings.setdefault('custom_regex', None)
187
+ nr = settings.setdefault('nr', 1)
188
+ plot = settings.setdefault('plot', True)
189
+ batch_size = settings.setdefault('batch_size', 50)
190
+ timelapse = settings.setdefault('timelapse', False)
191
+ lower_percentile = settings.setdefault('lower_percentile', 2)
192
+ randomize = settings.setdefault('randomize', True)
193
+ all_to_mip = settings.setdefault('all_to_mip', False)
194
+ pick_slice = settings.setdefault('pick_slice', False)
195
+ skip_mode = settings.setdefault('skip_mode', False)
196
+
197
+ cmap = settings.setdefault('cmap', 'inferno')
198
+ figuresize = settings.setdefault('figuresize', 50)
199
+ normalize = settings.setdefault('normalize', True)
200
+ save_dtype = settings.setdefault('save_dtype', 'uint16')
201
+
202
+ test_mode = settings.setdefault('test_mode', False)
203
+ test_images = settings.setdefault('test_images', 10)
204
+ random_test = settings.setdefault('random_test', True)
205
+
206
+ return settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test
207
+
208
+ def smooth_hull_lines(cluster_data):
209
+ hull = ConvexHull(cluster_data)
210
+
211
+ # Extract vertices of the hull
212
+ vertices = hull.points[hull.vertices]
213
+
214
+ # Close the loop
215
+ vertices = np.vstack([vertices, vertices[0, :]])
216
+
217
+ # Parameterize the vertices
218
+ tck, u = splprep(vertices.T, u=None, s=0.0)
219
+
220
+ # Evaluate spline at new parameter values
221
+ new_points = splev(np.linspace(0, 1, 100), tck)
222
+
223
+ return new_points[0], new_points[1]
224
+
225
+ def _gen_rgb_image(image, channels):
226
+ """
227
+ Generate an RGB image from the specified channels of the input image.
53
228
 
229
+ Args:
230
+ image (ndarray): The input image.
231
+ channels (list): List of channel indices to use for RGB.
54
232
 
55
- def _gen_rgb_image(image, cahnnels):
56
- rgb_image = np.take(image, cahnnels, axis=-1)
57
- rgb_image = rgb_image.astype(float)
58
- rgb_image -= rgb_image.min()
59
- rgb_image /= rgb_image.max()
233
+ Returns:
234
+ rgb_image (ndarray): The generated RGB image.
235
+ """
236
+ rgb_image = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.float32)
237
+ for i, chan in enumerate(channels):
238
+ if chan < image.shape[2]:
239
+ rgb_image[:, :, i] = image[:, :, chan]
60
240
  return rgb_image
61
241
 
62
242
  def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
63
- from concurrent.futures import ThreadPoolExecutor
64
- import cv2
65
-
66
243
  outlines = []
67
244
  overlayed_image = rgb_image.copy()
68
245
 
@@ -71,11 +248,13 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
71
248
  outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
72
249
 
73
250
  # Find and draw contours
74
- for j in np.unique(mask)[1:]:
251
+ for j in np.unique(mask):
252
+ if j == 0:
253
+ continue # Skip background
75
254
  contours = find_contours(mask == j, 0.5)
76
255
  # Convert contours for OpenCV format and draw directly to optimize
77
256
  cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
78
- cv2.drawContours(outline, cv_contours, -1, color=int(j), thickness=outline_thickness)
257
+ cv2.drawContours(outline, cv_contours, -1, color=255, thickness=outline_thickness)
79
258
 
80
259
  return dilation(outline, square(outline_thickness))
81
260
 
@@ -83,19 +262,15 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
83
262
  with ThreadPoolExecutor() as executor:
84
263
  outlines = list(executor.map(process_dim, mask_dims))
85
264
 
86
- # Overlay outlines onto the RGB image in a batch/vectorized manner if possible
265
+ # Overlay outlines onto the RGB image
87
266
  for i, outline in enumerate(outlines):
88
- # This part may need to be adapted to your specific use case and available functions
89
- # The goal is to overlay each outline with its respective color more efficiently
90
- color = outline_colors[i % len(outline_colors)]
91
- for j in np.unique(outline)[1:]:
267
+ color = np.array(outline_colors[i % len(outline_colors)])
268
+ for j in np.unique(outline):
269
+ if j == 0:
270
+ continue # Skip background
92
271
  mask = outline == j
93
272
  overlayed_image[mask] = color # Direct assignment with broadcasting
94
273
 
95
- # Remove mask_dims from image
96
- channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
97
- image = np.take(image, channels_to_keep, axis=-1)
98
-
99
274
  return overlayed_image, outlines, image
100
275
 
101
276
  def _convert_cq1_well_id(well_id):
@@ -355,43 +530,82 @@ def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pa
355
530
  df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
356
531
  df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
357
532
  return df
358
-
359
- def normalize_to_dtype(array, q1=2,q2=98, percentiles=None):
533
+
534
+ def normalize_to_dtype(array, p1=2, p2=98):
360
535
  """
361
- Normalize the input array to a specified data type.
536
+ Normalize each image in the stack to its own percentiles.
362
537
 
363
538
  Parameters:
364
539
  - array: numpy array
365
- The input array to be normalized.
366
- - q1: int, optional
540
+ The input stack to be normalized.
541
+ - p1: int, optional
367
542
  The lower percentile value for normalization. Default is 2.
368
- - q2: int, optional
543
+ - p2: int, optional
369
544
  The upper percentile value for normalization. Default is 98.
370
- - percentiles: list of tuples, optional
371
- A list of tuples containing the percentile values for each image in the array.
372
- If provided, the percentiles for each image will be used instead of q1 and q2.
373
545
 
374
546
  Returns:
375
547
  - new_stack: numpy array
376
- The normalized array with the same shape as the input array.
548
+ The normalized stack with the same shape as the input stack.
377
549
  """
378
550
  nimg = array.shape[2]
379
551
  new_stack = np.empty_like(array)
380
- for i,v in enumerate(range(nimg)):
381
- img = np.squeeze(array[:, :, v])
552
+
553
+ for i in range(nimg):
554
+ img = array[:, :, i]
382
555
  non_zero_img = img[img > 0]
383
- if non_zero_img.size > 0: # check if there are non-zero values
384
- img_min = np.percentile(non_zero_img, q1) # change percentile from 0.02 to 2
385
- img_max = np.percentile(non_zero_img, q2) # change percentile from 0.98 to 98
386
- img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
387
- else: # if there are no non-zero values, just use the image as it is
388
- if percentiles==None:
389
- img_min, img_max = img.min(), img.max()
390
- else:
391
- img_min, img_max = percentiles[i]
392
- img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
393
- img = np.expand_dims(img, axis=2)
394
- new_stack[:, :, v] = img[:, :, 0]
556
+
557
+ if non_zero_img.size > 0:
558
+ img_min = np.percentile(non_zero_img, p1)
559
+ img_max = np.percentile(non_zero_img, p2)
560
+ else:
561
+ img_min = img.min()
562
+ img_max = img.max()
563
+
564
+ # Determine output range based on dtype
565
+ if np.issubdtype(array.dtype, np.integer):
566
+ out_range = (0, np.iinfo(array.dtype).max)
567
+ else:
568
+ out_range = (0.0, 1.0)
569
+
570
+ img = rescale_intensity(img, in_range=(img_min, img_max), out_range=out_range).astype(array.dtype)
571
+ new_stack[:, :, i] = img
572
+
573
+ return new_stack
574
+
575
+ def normalize_to_dtype(array, p1=2, p2=98):
576
+ """
577
+ Normalize each image in the stack to its own percentiles.
578
+
579
+ Parameters:
580
+ - array: numpy array
581
+ The input stack to be normalized.
582
+ - p1: int, optional
583
+ The lower percentile value for normalization. Default is 2.
584
+ - p2: int, optional
585
+ The upper percentile value for normalization. Default is 98.
586
+
587
+ Returns:
588
+ - new_stack: numpy array
589
+ The normalized stack with the same shape as the input stack.
590
+ """
591
+ nimg = array.shape[2]
592
+ new_stack = np.empty_like(array, dtype=np.float32)
593
+
594
+ for i in range(nimg):
595
+ img = array[:, :, i]
596
+ non_zero_img = img[img > 0]
597
+
598
+ if non_zero_img.size > 0:
599
+ img_min = np.percentile(non_zero_img, p1)
600
+ img_max = np.percentile(non_zero_img, p2)
601
+ else:
602
+ img_min = img.min()
603
+ img_max = img.max()
604
+
605
+ # Normalize to the range (0, 1) for visualization
606
+ img = rescale_intensity(img, in_range=(img_min, img_max), out_range=(0.0, 1.0))
607
+ new_stack[:, :, i] = img
608
+
395
609
  return new_stack
396
610
 
397
611
  def _list_endpoint_subdirectories(base_dir):
@@ -749,7 +963,7 @@ def _get_diam(mag, obj):
749
963
  elif obj == 'nucleus':
750
964
  diamiter = 60
751
965
  elif obj == 'pathogen':
752
- diamiter = 30
966
+ diamiter = 20
753
967
  else:
754
968
  raise ValueError("Invalid magnification: Use 20, 40 or 60")
755
969
 
@@ -769,7 +983,7 @@ def _get_diam(mag, obj):
769
983
  if obj == 'nucleus':
770
984
  diamiter = 90
771
985
  if obj == 'pathogen':
772
- diamiter = 75
986
+ diamiter = 60
773
987
  else:
774
988
  raise ValueError("Invalid magnification: Use 20, 40 or 60")
775
989
  else:
@@ -781,8 +995,8 @@ def _get_object_settings(object_type, settings):
781
995
  object_settings = {}
782
996
 
783
997
  object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
784
- object_settings['minimum_size'] = (object_settings['diameter']**2)/5
785
- object_settings['maximum_size'] = (object_settings['diameter']**2)*3
998
+ object_settings['minimum_size'] = (object_settings['diameter']**2)/4
999
+ object_settings['maximum_size'] = (object_settings['diameter']**2)*10
786
1000
  object_settings['merge'] = False
787
1001
  object_settings['resample'] = True
788
1002
  object_settings['remove_border_objects'] = False
@@ -793,21 +1007,23 @@ def _get_object_settings(object_type, settings):
793
1007
  object_settings['model_name'] = 'cyto'
794
1008
  else:
795
1009
  object_settings['model_name'] = 'cyto2'
796
- object_settings['filter_size'] = True
797
- object_settings['filter_intensity'] = True
1010
+ object_settings['filter_size'] = False
1011
+ object_settings['filter_intensity'] = False
798
1012
  object_settings['restore_type'] = settings.get('cell_restore_type', None)
799
1013
 
800
1014
  elif object_type == 'nucleus':
801
1015
  object_settings['model_name'] = 'nuclei'
802
- object_settings['filter_size'] = True
803
- object_settings['filter_intensity'] = True
1016
+ object_settings['filter_size'] = False
1017
+ object_settings['filter_intensity'] = False
804
1018
  object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
805
1019
 
806
1020
  elif object_type == 'pathogen':
807
1021
  object_settings['model_name'] = 'cyto'
808
- object_settings['filter_size'] = True
809
- object_settings['filter_intensity'] = True
1022
+ object_settings['filter_size'] = False
1023
+ object_settings['filter_intensity'] = False
1024
+ object_settings['resample'] = False
810
1025
  object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
1026
+ object_settings['merge'] = settings['merge_pathogens']
811
1027
 
812
1028
  else:
813
1029
  print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
@@ -884,17 +1100,15 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
884
1100
 
885
1101
  if not pathogen_channel is None:
886
1102
  if not nucleus_channel is None:
887
- cellpose_channels['pathogen'] = [0,1]
1103
+ if not pathogen_channel is None:
1104
+ cellpose_channels['pathogen'] = [0,2]
1105
+ else:
1106
+ cellpose_channels['pathogen'] = [0,1]
888
1107
  else:
889
1108
  cellpose_channels['pathogen'] = [0,0]
890
1109
 
891
1110
  if not cell_channel is None:
892
1111
  if not nucleus_channel is None:
893
- if not pathogen_channel is None:
894
- cellpose_channels['cell'] = [0,2]
895
- else:
896
- cellpose_channels['cell'] = [0,1]
897
- elif not pathogen_channel is None:
898
1112
  cellpose_channels['cell'] = [0,1]
899
1113
  else:
900
1114
  cellpose_channels['cell'] = [0,0]
@@ -1069,7 +1283,7 @@ class Cache:
1069
1283
  cache (OrderedDict): The cache data structure.
1070
1284
  """
1071
1285
 
1072
- def _init__(self, max_size):
1286
+ def __init__(self, max_size):
1073
1287
  self.cache = OrderedDict()
1074
1288
  self.max_size = max_size
1075
1289
 
@@ -1100,7 +1314,7 @@ class ScaledDotProductAttention(nn.Module):
1100
1314
 
1101
1315
  """
1102
1316
 
1103
- def _init__(self, d_k):
1317
+ def __init__(self, d_k):
1104
1318
  super(ScaledDotProductAttention, self).__init__()
1105
1319
  self.d_k = d_k
1106
1320
 
@@ -1131,7 +1345,7 @@ class SelfAttention(nn.Module):
1131
1345
  d_k (int): Dimensionality of the key and query vectors.
1132
1346
  """
1133
1347
 
1134
- def _init__(self, in_channels, d_k):
1348
+ def __init__(self, in_channels, d_k):
1135
1349
  super(SelfAttention, self).__init__()
1136
1350
  self.W_q = nn.Linear(in_channels, d_k)
1137
1351
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1155,7 +1369,7 @@ class SelfAttention(nn.Module):
1155
1369
  return output
1156
1370
 
1157
1371
  class ScaledDotProductAttention(nn.Module):
1158
- def _init__(self, d_k):
1372
+ def __init__(self, d_k):
1159
1373
  """
1160
1374
  Initializes the ScaledDotProductAttention module.
1161
1375
 
@@ -1192,7 +1406,7 @@ class SelfAttention(nn.Module):
1192
1406
  in_channels (int): Number of input channels.
1193
1407
  d_k (int): Dimensionality of the key and query vectors.
1194
1408
  """
1195
- def _init__(self, in_channels, d_k):
1409
+ def __init__(self, in_channels, d_k):
1196
1410
  super(SelfAttention, self).__init__()
1197
1411
  self.W_q = nn.Linear(in_channels, d_k)
1198
1412
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1223,7 +1437,7 @@ class EarlyFusion(nn.Module):
1223
1437
  Args:
1224
1438
  in_channels (int): Number of input channels.
1225
1439
  """
1226
- def _init__(self, in_channels):
1440
+ def __init__(self, in_channels):
1227
1441
  super(EarlyFusion, self).__init__()
1228
1442
  self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
1229
1443
 
@@ -1242,7 +1456,7 @@ class EarlyFusion(nn.Module):
1242
1456
 
1243
1457
  # Spatial Attention Mechanism
1244
1458
  class SpatialAttention(nn.Module):
1245
- def _init__(self, kernel_size=7):
1459
+ def __init__(self, kernel_size=7):
1246
1460
  """
1247
1461
  Initializes the SpatialAttention module.
1248
1462
 
@@ -1287,7 +1501,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1287
1501
  forward: Forward method for the module.
1288
1502
  """
1289
1503
 
1290
- def _init__(self, in_channels, out_channels):
1504
+ def __init__(self, in_channels, out_channels):
1291
1505
  super(MultiScaleBlockWithAttention, self).__init__()
1292
1506
  self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
1293
1507
  self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
@@ -1320,7 +1534,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1320
1534
 
1321
1535
  # Final Classifier
1322
1536
  class CustomCellClassifier(nn.Module):
1323
- def _init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1537
+ def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1324
1538
  super(CustomCellClassifier, self).__init__()
1325
1539
  self.early_fusion = EarlyFusion(in_channels=3)
1326
1540
 
@@ -1349,7 +1563,7 @@ class CustomCellClassifier(nn.Module):
1349
1563
 
1350
1564
  #CNN and Transformer class, pick any Torch model.
1351
1565
  class TorchModel(nn.Module):
1352
- def _init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1566
+ def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1353
1567
  super(TorchModel, self).__init__()
1354
1568
  self.model_name = model_name
1355
1569
  self.use_checkpoint = use_checkpoint
@@ -1423,7 +1637,7 @@ class TorchModel(nn.Module):
1423
1637
  return logits
1424
1638
 
1425
1639
  class FocalLossWithLogits(nn.Module):
1426
- def _init__(self, alpha=1, gamma=2):
1640
+ def __init__(self, alpha=1, gamma=2):
1427
1641
  super(FocalLossWithLogits, self).__init__()
1428
1642
  self.alpha = alpha
1429
1643
  self.gamma = gamma
@@ -1435,7 +1649,7 @@ class FocalLossWithLogits(nn.Module):
1435
1649
  return focal_loss.mean()
1436
1650
 
1437
1651
  class ResNet(nn.Module):
1438
- def _init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1652
+ def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1439
1653
  super(ResNet, self).__init__()
1440
1654
 
1441
1655
  resnet_map = {
@@ -1788,25 +2002,24 @@ def annotate_predictions(csv_loc):
1788
2002
  df['cond'] = df.apply(assign_condition, axis=1)
1789
2003
  return df
1790
2004
 
1791
- def init_globals(counter_, lock_):
2005
+ def initiate_counter(counter_, lock_):
1792
2006
  global counter, lock
1793
2007
  counter = counter_
1794
2008
  lock = lock_
1795
2009
 
1796
- def add_images_to_tar(args):
1797
- global counter, lock, total_images
1798
- paths_chunk, tar_path = args
2010
+ def add_images_to_tar(paths_chunk, tar_path, total_images):
1799
2011
  with tarfile.open(tar_path, 'w') as tar:
1800
- for img_path in paths_chunk:
2012
+ for i, img_path in enumerate(paths_chunk):
1801
2013
  arcname = os.path.basename(img_path)
1802
2014
  try:
1803
2015
  tar.add(img_path, arcname=arcname)
1804
2016
  with lock:
1805
2017
  counter.value += 1
1806
- print(f"\rProcessed: {counter.value}/{total_images}", end='', flush=True)
2018
+ if counter.value % 100 == 0: # Print every 100 updates
2019
+ progress = (counter.value / total_images) * 100
2020
+ print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
1807
2021
  except FileNotFoundError:
1808
2022
  print(f"File not found: {img_path}")
1809
- return tar_path
1810
2023
 
1811
2024
  def generate_fraction_map(df, gene_column, min_frequency=0.0):
1812
2025
  df['fraction'] = df['count']/df['well_read_sum']
@@ -2255,8 +2468,8 @@ def dice_coefficient(mask1, mask2):
2255
2468
  def extract_boundaries(mask, dilation_radius=1):
2256
2469
  binary_mask = (mask > 0).astype(np.uint8)
2257
2470
  struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
2258
- dilated = binary_dilation(binary_mask, footprint=struct_elem)
2259
- eroded = binary_erosion(binary_mask, footprint=struct_elem)
2471
+ dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
2472
+ eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
2260
2473
  boundary = dilated ^ eroded
2261
2474
  return boundary
2262
2475
 
@@ -2669,6 +2882,13 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
2669
2882
  print(f'Number of objects before filtration: {num_objects}')
2670
2883
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2671
2884
 
2885
+ if merge:
2886
+ mask = merge_touching_objects(mask, threshold=0.66)
2887
+ if plot and idx == 0:
2888
+ num_objects = mask_object_count(mask)
2889
+ print(f'Number of objects after merging adjacent objects, : {num_objects}')
2890
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2891
+
2672
2892
  if filter_size:
2673
2893
  props = measure.regionprops_table(mask, properties=['label', 'area'])
2674
2894
  valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
@@ -2714,13 +2934,6 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
2714
2934
  print(f'Number of objects after removing border objects, : {num_objects}')
2715
2935
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2716
2936
 
2717
- if merge:
2718
- mask = merge_touching_objects(mask, threshold=0.25)
2719
- if plot and idx == 0:
2720
- num_objects = mask_object_count(mask)
2721
- print(f'Number of objects after merging adjacent objects, : {num_objects}')
2722
- plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2723
-
2724
2937
  mask_stack.append(mask)
2725
2938
 
2726
2939
  return mask_stack
@@ -2758,15 +2971,37 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
2758
2971
  print(f'After {object_type} maximum mean intensity filter: {len(df)}')
2759
2972
  return df
2760
2973
 
2761
- def _run_test_mode(src, regex, timelapse=False):
2974
+ def _get_regex(metadata_type, img_format, custom_regex=None):
2975
+
2976
+ if img_format == None:
2977
+ img_format == '.tif'
2978
+ if metadata_type == 'cellvoyager':
2979
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2980
+ elif metadata_type == 'cq1':
2981
+ regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2982
+ elif metadata_type == 'nikon':
2983
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2984
+ elif metadata_type == 'zeis':
2985
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2986
+ elif metadata_type == 'leica':
2987
+ regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
2988
+ elif metadata_type == 'custom':
2989
+ regex = f'({custom_regex}){img_format}'
2990
+
2991
+ print(f'regex mode:{metadata_type} regex:{regex}')
2992
+ return regex
2993
+
2994
+ def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True):
2995
+
2762
2996
  if timelapse:
2763
2997
  test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
2764
- else:
2765
- test_images = 10 # Use 10 sets for non-timelapse scenarios
2766
-
2998
+
2767
2999
  test_folder_path = os.path.join(src, 'test')
2768
3000
  os.makedirs(test_folder_path, exist_ok=True)
2769
3001
  regular_expression = re.compile(regex)
3002
+
3003
+ if os.path.exists(os.path.join(src, 'orig')):
3004
+ src = os.path.join(src, 'orig')
2770
3005
 
2771
3006
  all_filenames = [filename for filename in os.listdir(src) if regular_expression.match(filename)]
2772
3007
  print(f'Found {len(all_filenames)} files')
@@ -2778,24 +3013,20 @@ def _run_test_mode(src, regex, timelapse=False):
2778
3013
  plate = match.group('plateID') if 'plateID' in match.groupdict() else os.path.basename(src)
2779
3014
  well = match.group('wellID')
2780
3015
  field = match.group('fieldID')
2781
- # For timelapse experiments, group images by plate, well, and field only
2782
- if timelapse:
2783
- set_identifier = (plate, well, field)
2784
- else:
2785
- # For non-timelapse, you might want to distinguish sets more granularly
2786
- # Here, assuming you're grouping by plate, well, and field for simplicity
2787
- set_identifier = (plate, well, field)
3016
+ set_identifier = (plate, well, field)
2788
3017
  images_by_set[set_identifier].append(filename)
2789
3018
 
2790
3019
  # Prepare for random selection
2791
3020
  set_identifiers = list(images_by_set.keys())
3021
+ if random_test:
3022
+ random.seed(42)
2792
3023
  random.shuffle(set_identifiers) # Randomize the order
2793
3024
 
2794
3025
  # Select a subset based on the test_images count
2795
3026
  selected_sets = set_identifiers[:test_images]
2796
3027
 
2797
3028
  # Print information about the number of sets used
2798
- print(f'Using {test_images} random image set(s) for test model')
3029
+ print(f'Using {len(selected_sets)} random image set(s) for test model')
2799
3030
 
2800
3031
  # Copy files for selected sets to the test folder
2801
3032
  for set_identifier in selected_sets:
@@ -2804,24 +3035,1034 @@ def _run_test_mode(src, regex, timelapse=False):
2804
3035
 
2805
3036
  return test_folder_path
2806
3037
 
2807
- def _choose_model(model_name, device, object_type='cell', restore_type=None):
3038
+ def _choose_model(model_name, device, object_type='cell', restore_type=None, object_settings={}):
3039
+
3040
+ if object_type == 'pathogen':
3041
+ if model_name == 'toxo_pv_lumen':
3042
+ diameter = object_settings['diameter']
3043
+ current_dir = os.path.dirname(__file__)
3044
+ model_path = os.path.join(current_dir, 'models', 'cp', 'toxo_pv_lumen.CP_model')
3045
+ print(model_path)
3046
+ model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=model_path, diam_mean=diameter, device=device)
3047
+ #model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type='cyto', device=device)
3048
+ print(f'Using Toxoplasma PV lumen model to generate pathogen masks')
3049
+ return model
3050
+
2808
3051
  restore_list = ['denoise', 'deblur', 'upsample', None]
2809
3052
  if restore_type not in restore_list:
2810
3053
  print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
2811
3054
  restore_type = None
2812
3055
 
2813
3056
  if restore_type == None:
2814
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
3057
+ if model_name in ['cyto', 'cyto2', 'cyto3', 'nuclei']:
3058
+ model = cp_models.Cellpose(gpu=torch.cuda.is_available(), model_type=model_name, device=device)
3059
+
2815
3060
  else:
2816
3061
  if object_type == 'nucleus':
2817
3062
  restore = f'{type}_nuclei'
2818
- model = denoise.CellposeDenoiseModel(gpu=True, model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
3063
+ model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
2819
3064
  else:
2820
3065
  restore = f'{type}_cyto3'
2821
3066
  if model_name =='cyto2':
2822
3067
  chan2_restore = True
2823
3068
  if model_name =='cyto':
2824
3069
  chan2_restore = False
2825
- model = denoise.CellposeDenoiseModel(gpu=True, model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
3070
+ model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
3071
+
3072
+ return model
3073
+
3074
+ class SelectChannels:
3075
+ def __init__(self, channels):
3076
+ self.channels = channels
3077
+
3078
+ def __call__(self, img):
3079
+ img = img.clone()
3080
+ if 1 not in self.channels:
3081
+ img[0, :, :] = 0 # Zero out the red channel
3082
+ if 2 not in self.channels:
3083
+ img[1, :, :] = 0 # Zero out the green channel
3084
+ if 3 not in self.channels:
3085
+ img[2, :, :] = 0 # Zero out the blue channel
3086
+ return img
3087
+
3088
+ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
3089
+
3090
+ if normalize:
3091
+ transform = transforms.Compose([
3092
+ transforms.ToTensor(),
3093
+ transforms.CenterCrop(size=(image_size, image_size)),
3094
+ SelectChannels(channels),
3095
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
3096
+ else:
3097
+ transform = transforms.Compose([
3098
+ transforms.ToTensor(),
3099
+ transforms.CenterCrop(size=(image_size, image_size)),
3100
+ SelectChannels(channels)])
3101
+
3102
+ image = Image.open(image_path).convert('RGB')
3103
+ input_tensor = transform(image).unsqueeze(0)
3104
+ return image, input_tensor
3105
+
3106
+
3107
+ class SaliencyMapGenerator:
3108
+ def __init__(self, model):
3109
+ self.model = model
3110
+
3111
+ def compute_saliency_maps(self, X, y):
3112
+ self.model.eval()
3113
+ X.requires_grad_()
3114
+
3115
+ # Forward pass
3116
+ scores = self.model(X).squeeze()
3117
+
3118
+ # For binary classification, target scores can be the single output
3119
+ target_scores = scores * (2 * y - 1)
3120
+
3121
+ self.model.zero_grad()
3122
+ target_scores.backward(torch.ones_like(target_scores))
3123
+
3124
+ saliency = X.grad.abs()
3125
+ return saliency
3126
+
3127
+ def plot_saliency_maps(self, X, y, saliency, class_names):
3128
+ N = X.shape[0]
3129
+ for i in range(N):
3130
+ plt.subplot(2, N, i + 1)
3131
+ plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
3132
+ plt.axis('off')
3133
+ plt.title(class_names[y[i]])
3134
+ plt.subplot(2, N, N + i + 1)
3135
+ plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
3136
+ plt.axis('off')
3137
+ plt.gcf().set_size_inches(12, 5)
3138
+ plt.show()
3139
+
3140
+ def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
3141
+ preprocess = transforms.Compose([
3142
+ transforms.Resize((image_size, image_size)),
3143
+ transforms.ToTensor(),
3144
+ ])
3145
+
3146
+ image = Image.open(image_path).convert('RGB')
3147
+ input_tensor = preprocess(image)
3148
+ if normalize:
3149
+ input_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_tensor)
3150
+ input_tensor = input_tensor.unsqueeze(0)
3151
+
3152
+ return image, input_tensor
3153
+
3154
+ def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names = ['nc', 'pc']):
3155
+
3156
+ def jitter(img, ox, oy):
3157
+ # Randomly jitter the image
3158
+ return torch.roll(torch.roll(img, ox, dims=2), oy, dims=3)
3159
+
3160
+ def blur_image(img, sigma=1):
3161
+ # Apply Gaussian blur to the image
3162
+ img_np = img.cpu().numpy()
3163
+ for i in range(img_np.shape[1]):
3164
+ img_np[:, i] = gaussian_filter(img_np[:, i], sigma=sigma)
3165
+ img.copy_(torch.tensor(img_np).to(img.device))
3166
+
3167
+ def deprocess(img_tensor):
3168
+ # Convert the tensor image to a numpy array for visualization
3169
+ img_tensor = img_tensor.clone()
3170
+ for c in range(3):
3171
+ img_tensor[:, c] = img_tensor[:, c] * SQUEEZENET_STD[c] + SQUEEZENET_MEAN[c]
3172
+ img_tensor = img_tensor.clamp(0, 1)
3173
+ return img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
3174
+
3175
+ # Assuming these are defined somewhere in your codebase
3176
+ SQUEEZENET_MEAN = [0.485, 0.456, 0.406]
3177
+ SQUEEZENET_STD = [0.229, 0.224, 0.225]
3178
+
3179
+ model = torch.load(model_path)
3180
+
3181
+ dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
3182
+ len_chans = len(channels)
3183
+ model.type(dtype)
3184
+
3185
+ # Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
3186
+ img = torch.randn(1, len_chans, img_size, img_size).mul_(1.0).type(dtype).requires_grad_()
3187
+
3188
+ for t in range(num_iterations):
3189
+ # Randomly jitter the image a bit; this gives slightly nicer results
3190
+ ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
3191
+ img.data.copy_(jitter(img.data, ox, oy))
3192
+
3193
+ # Forward pass
3194
+ score = model(img)
3195
+
3196
+ if target_y == 0:
3197
+ target_score = -score
3198
+ else:
3199
+ target_score = score
3200
+
3201
+ # Add regularization
3202
+ target_score = target_score - l2_reg * torch.norm(img)
3203
+
3204
+ # Backward pass
3205
+ target_score.backward()
3206
+
3207
+ # Gradient ascent step
3208
+ with torch.no_grad():
3209
+ img += learning_rate * img.grad / torch.norm(img.grad)
3210
+ img.grad.zero_()
3211
+
3212
+ # Undo the random jitter
3213
+ img.data.copy_(jitter(img.data, -ox, -oy))
3214
+
3215
+ # As regularizer, clamp and periodically blur the image
3216
+ for c in range(3):
3217
+ lo = float(-SQUEEZENET_MEAN[c] / SQUEEZENET_STD[c])
3218
+ hi = float((1.0 - SQUEEZENET_MEAN[c]) / SQUEEZENET_STD[c])
3219
+ img.data[:, c].clamp_(min=lo, max=hi)
3220
+ if t % blur_every == 0:
3221
+ blur_image(img.data, sigma=0.5)
3222
+
3223
+ # Periodically show the image
3224
+ if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
3225
+ plt.imshow(deprocess(img.data.clone().cpu()))
3226
+ class_name = class_names[target_y]
3227
+ plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
3228
+ plt.gcf().set_size_inches(4, 4)
3229
+ plt.axis('off')
3230
+ plt.show()
3231
+
3232
+ return deprocess(img.data.cpu())
3233
+
3234
+ def get_submodules(model, prefix=''):
3235
+ submodules = []
3236
+ for name, module in model.named_children():
3237
+ full_name = prefix + ('.' if prefix else '') + name
3238
+ submodules.append(full_name)
3239
+ submodules.extend(get_submodules(module, full_name))
3240
+ return submodules
3241
+
3242
+ class GradCAM:
3243
+ def __init__(self, model, target_layers=None, use_cuda=True):
3244
+ self.model = model
3245
+ self.model.eval()
3246
+ self.target_layers = target_layers
3247
+ self.cuda = use_cuda
3248
+ if self.cuda:
3249
+ self.model = model.cuda()
3250
+
3251
+ def forward(self, input):
3252
+ return self.model(input)
3253
+
3254
+ def __call__(self, x, index=None):
3255
+ if self.cuda:
3256
+ x = x.cuda()
3257
+
3258
+ features = []
3259
+ def hook(module, input, output):
3260
+ features.append(output)
3261
+
3262
+ handles = []
3263
+ for name, module in self.model.named_modules():
3264
+ if name in self.target_layers:
3265
+ handles.append(module.register_forward_hook(hook))
3266
+
3267
+ output = self.forward(x)
3268
+ if index is None:
3269
+ index = np.argmax(output.data.cpu().numpy())
3270
+
3271
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
3272
+ one_hot[0][index] = 1
3273
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
3274
+ if self.cuda:
3275
+ one_hot = one_hot.cuda()
3276
+
3277
+ one_hot = torch.sum(one_hot * output)
3278
+ self.model.zero_grad()
3279
+ one_hot.backward(retain_graph=True)
3280
+
3281
+ grads_val = features[0].grad.cpu().data.numpy()
3282
+ target = features[0].cpu().data.numpy()[0, :]
3283
+
3284
+ weights = np.mean(grads_val, axis=(2, 3))[0, :]
3285
+ cam = np.zeros(target.shape[1:], dtype=np.float32)
3286
+
3287
+ for i, w in enumerate(weights):
3288
+ cam += w * target[i, :, :]
3289
+
3290
+ cam = np.maximum(cam, 0)
3291
+ cam = cv2.resize(cam, (x.size(2), x.size(3)))
3292
+ cam = cam - np.min(cam)
3293
+ cam = cam / np.max(cam)
3294
+
3295
+ for handle in handles:
3296
+ handle.remove()
3297
+
3298
+ return cam
3299
+
3300
+ def show_cam_on_image(img, mask):
3301
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
3302
+ heatmap = np.float32(heatmap) / 255
3303
+ cam = heatmap + np.float32(img)
3304
+ cam = cam / np.max(cam)
3305
+ return np.uint8(255 * cam)
3306
+
3307
+ def recommend_target_layers(model):
3308
+ target_layers = []
3309
+ for name, module in model.named_modules():
3310
+ if isinstance(module, torch.nn.Conv2d):
3311
+ target_layers.append(name)
3312
+ # Choose the last conv layer as the recommended target layer
3313
+ if target_layers:
3314
+ return [target_layers[-1]], target_layers
3315
+ else:
3316
+ raise ValueError("No convolutional layers found in the model.")
3317
+
3318
+ class IntegratedGradients:
3319
+ def __init__(self, model):
3320
+ self.model = model
3321
+ self.model.eval()
3322
+
3323
+ def generate_integrated_gradients(self, input_tensor, target_label_idx, baseline=None, num_steps=50):
3324
+ if baseline is None:
3325
+ baseline = torch.zeros_like(input_tensor)
3326
+
3327
+ assert baseline.shape == input_tensor.shape
3328
+
3329
+ # Scale input and compute gradients
3330
+ scaled_inputs = [(baseline + (float(i) / num_steps) * (input_tensor - baseline)).requires_grad_(True) for i in range(0, num_steps + 1)]
3331
+ grads = []
3332
+ for scaled_input in scaled_inputs:
3333
+ out = self.model(scaled_input)
3334
+ self.model.zero_grad()
3335
+ out[0, target_label_idx].backward(retain_graph=True)
3336
+ grads.append(scaled_input.grad.data.cpu().numpy())
3337
+
3338
+ avg_grads = np.mean(grads[:-1], axis=0)
3339
+ integrated_grads = (input_tensor.cpu().data.numpy() - baseline.cpu().data.numpy()) * avg_grads
3340
+ return integrated_grads
3341
+
3342
+ def get_db_paths(src):
3343
+ if isinstance(src, str):
3344
+ src = [src]
3345
+ db_paths = [os.path.join(source, 'measurements/measurements.db') for source in src]
3346
+ return db_paths
3347
+
3348
+ def get_sequencing_paths(src):
3349
+ if isinstance(src, str):
3350
+ src = [src]
3351
+ seq_paths = [os.path.join(source, 'sequencing/sequencing_data.csv') for source in src]
3352
+ return seq_paths
3353
+
3354
+ def load_image_paths(c, visualize):
3355
+ c.execute(f'SELECT * FROM png_list')
3356
+ data = c.fetchall()
3357
+ columns_info = c.execute(f'PRAGMA table_info(png_list)').fetchall()
3358
+ column_names = [col_info[1] for col_info in columns_info]
3359
+ image_paths_df = pd.DataFrame(data, columns=column_names)
3360
+ if visualize:
3361
+ object_visualize = visualize + '_png'
3362
+ image_paths_df = image_paths_df[image_paths_df['png_path'].str.contains(object_visualize)]
3363
+ image_paths_df = image_paths_df.set_index('prcfo')
3364
+ return image_paths_df
3365
+
3366
+ def merge_dataframes(df, image_paths_df, verbose):
3367
+ df.set_index('prcfo', inplace=True)
3368
+ df = image_paths_df.merge(df, left_index=True, right_index=True)
3369
+ if verbose:
3370
+ display(df)
3371
+ return df
3372
+
3373
+ def remove_highly_correlated_columns(df, threshold):
3374
+ corr_matrix = df.corr().abs()
3375
+ upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
3376
+ to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > threshold)]
3377
+ return df.drop(to_drop, axis=1)
3378
+
3379
+ def filter_columns(df, filter_by):
3380
+ if filter_by != 'morphology':
3381
+ cols_to_include = [col for col in df.columns if filter_by in str(col)]
3382
+ else:
3383
+ cols_to_include = [col for col in df.columns if 'channel' not in str(col)]
3384
+ df = df[cols_to_include]
3385
+ return df
3386
+
3387
+ def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1, mode='fit', model=False):
3388
+ """
3389
+ Perform dimensionality reduction and clustering on the given data.
2826
3390
 
2827
- return model
3391
+ Parameters:
3392
+ numeric_data (np.ndarray): Numeric data for embedding and clustering.
3393
+ n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
3394
+ min_dist (float): Minimum distance for UMAP.
3395
+ metric (str): Metric for UMAP and DBSCAN.
3396
+ eps (float): Epsilon for DBSCAN.
3397
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
3398
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3399
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3400
+ verbose (bool): Whether to print verbose output.
3401
+ embedding (np.ndarray, optional): Precomputed embedding. Default is None.
3402
+ return_model (bool): Whether to return the reducer model. Default is False.
3403
+
3404
+ Returns:
3405
+ tuple: embedding, labels (and optionally the reducer model)
3406
+ """
3407
+
3408
+ if verbose:
3409
+ v = 1
3410
+ else:
3411
+ v = 0
3412
+
3413
+ if isinstance(n_neighbors, float):
3414
+ n_neighbors = int(n_neighbors * len(numeric_data))
3415
+
3416
+ if n_neighbors <= 2:
3417
+ n_neighbors = 2
3418
+
3419
+ if mode == 'fit':
3420
+ if reduction_method == 'umap':
3421
+ reducer = umap.UMAP(n_neighbors=n_neighbors,
3422
+ n_components=2,
3423
+ metric=metric,
3424
+ n_epochs=None,
3425
+ learning_rate=1.0,
3426
+ init='spectral',
3427
+ min_dist=min_dist,
3428
+ spread=1.0,
3429
+ set_op_mix_ratio=1.0,
3430
+ local_connectivity=1,
3431
+ repulsion_strength=1.0,
3432
+ negative_sample_rate=5,
3433
+ transform_queue_size=4.0,
3434
+ a=None,
3435
+ b=None,
3436
+ random_state=42,
3437
+ metric_kwds=None,
3438
+ angular_rp_forest=False,
3439
+ target_n_neighbors=-1,
3440
+ target_metric='categorical',
3441
+ target_metric_kwds=None,
3442
+ target_weight=0.5,
3443
+ transform_seed=42,
3444
+ n_jobs=n_jobs,
3445
+ verbose=verbose)
3446
+
3447
+ elif reduction_method == 'tsne':
3448
+ reducer = TSNE(n_components=2,
3449
+ perplexity=n_neighbors,
3450
+ early_exaggeration=12.0,
3451
+ learning_rate=200.0,
3452
+ n_iter=1000,
3453
+ n_iter_without_progress=300,
3454
+ min_grad_norm=1e-7,
3455
+ metric=metric,
3456
+ init='random',
3457
+ verbose=v,
3458
+ random_state=42,
3459
+ method='barnes_hut',
3460
+ angle=0.5,
3461
+ n_jobs=n_jobs)
3462
+
3463
+ else:
3464
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
3465
+
3466
+ embedding = reducer.fit_transform(numeric_data)
3467
+ if verbose:
3468
+ print(f'Trained and fit reducer')
3469
+
3470
+ else:
3471
+ if not model is None:
3472
+ embedding = model.transform(numeric_data)
3473
+ reducer = model
3474
+ if verbose:
3475
+ print(f'Fit data to reducer')
3476
+ else:
3477
+ raise ValueError(f"Model is None. Please provide a model for transform.")
3478
+
3479
+ if clustering == 'dbscan':
3480
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
3481
+ elif clustering == 'kmeans':
3482
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
3483
+
3484
+ clustering_model.fit(embedding)
3485
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
3486
+
3487
+ if verbose:
3488
+ print(f'Embedding shape: {embedding.shape}')
3489
+
3490
+ return embedding, labels, reducer
3491
+
3492
+ def reduction_and_clustering_v1(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method='umap', verbose=False, embedding=None, n_jobs=-1):
3493
+ """
3494
+ Perform dimensionality reduction and clustering on the given data.
3495
+
3496
+ Parameters:
3497
+ numeric_data (np.ndarray): Numeric data for embedding and clustering.
3498
+ n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
3499
+ min_dist (float): Minimum distance for UMAP.
3500
+ metric (str): Metric for UMAP and DBSCAN.
3501
+ eps (float): Epsilon for DBSCAN.
3502
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
3503
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3504
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3505
+ verbose (bool): Whether to print verbose output.
3506
+ embedding (np.ndarray, optional): Precomputed embedding. Default is None.
3507
+
3508
+ Returns:
3509
+ tuple: embedding, labels
3510
+ """
3511
+
3512
+ if verbose:
3513
+ v=1
3514
+ else:
3515
+ v=0
3516
+
3517
+ if isinstance(n_neighbors, float):
3518
+ n_neighbors = int(n_neighbors * len(numeric_data))
3519
+
3520
+ if n_neighbors <= 2:
3521
+ n_neighbors = 2
3522
+
3523
+ if reduction_method == 'umap':
3524
+ reducer = umap.UMAP(n_neighbors=n_neighbors,
3525
+ n_components=2,
3526
+ metric=metric,
3527
+ n_epochs=None,
3528
+ learning_rate=1.0,
3529
+ init='spectral',
3530
+ min_dist=min_dist,
3531
+ spread=1.0,
3532
+ set_op_mix_ratio=1.0,
3533
+ local_connectivity=1,
3534
+ repulsion_strength=1.0,
3535
+ negative_sample_rate=5,
3536
+ transform_queue_size=4.0,
3537
+ a=None,
3538
+ b=None,
3539
+ random_state=42,
3540
+ metric_kwds=None,
3541
+ angular_rp_forest=False,
3542
+ target_n_neighbors=-1,
3543
+ target_metric='categorical',
3544
+ target_metric_kwds=None,
3545
+ target_weight=0.5,
3546
+ transform_seed=42,
3547
+ n_jobs=n_jobs,
3548
+ verbose=verbose)
3549
+
3550
+ elif reduction_method == 'tsne':
3551
+
3552
+ #tsne_params.setdefault('n_components', 2)
3553
+ #reducer = TSNE(**tsne_params)
3554
+
3555
+ reducer = TSNE(n_components=2,
3556
+ perplexity=n_neighbors,
3557
+ early_exaggeration=12.0,
3558
+ learning_rate=200.0,
3559
+ n_iter=1000,
3560
+ n_iter_without_progress=300,
3561
+ min_grad_norm=1e-7,
3562
+ metric=metric,
3563
+ init='random',
3564
+ verbose=v,
3565
+ random_state=42,
3566
+ method='barnes_hut',
3567
+ angle=0.5,
3568
+ n_jobs=n_jobs)
3569
+
3570
+ else:
3571
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
3572
+
3573
+ if embedding is None:
3574
+ embedding = reducer.fit_transform(numeric_data)
3575
+
3576
+ if clustering == 'dbscan':
3577
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
3578
+ elif clustering == 'kmeans':
3579
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
3580
+ else:
3581
+ raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
3582
+
3583
+ clustering_model.fit(embedding)
3584
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
3585
+
3586
+ if verbose:
3587
+ print(f'Embedding shape: {embedding.shape}')
3588
+
3589
+ return embedding, labels
3590
+
3591
+ def remove_noise(embedding, labels):
3592
+ non_noise_indices = labels != -1
3593
+ embedding = embedding[non_noise_indices]
3594
+ labels = labels[non_noise_indices]
3595
+ return embedding, labels
3596
+
3597
+ def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, plot_by_cluster, plot_outlines, plot_points, plot_images, smooth_lines, black_background, figuresize, dot_size, remove_image_canvas, verbose):
3598
+ unique_labels = np.unique(labels)
3599
+ #num_clusters = len(unique_labels[unique_labels != 0])
3600
+ colors, label_to_color_index = assign_colors(unique_labels, colors)
3601
+ cluster_centers = [np.mean(embedding[labels == cluster_label], axis=0) for cluster_label in unique_labels]
3602
+ fig, ax = setup_plot(figuresize, black_background)
3603
+ plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize, dot_size, verbose)
3604
+ if not image_paths is None and plot_images:
3605
+ plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose)
3606
+ plt.show()
3607
+ return fig
3608
+
3609
+ def generate_colors(num_clusters, black_background):
3610
+ random_colors = np.random.rand(num_clusters + 1, 4)
3611
+ random_colors[:, 3] = 1
3612
+ specific_colors = [
3613
+ [155 / 255, 55 / 255, 155 / 255, 1],
3614
+ [55 / 255, 155 / 255, 155 / 255, 1],
3615
+ [55 / 255, 155 / 255, 255 / 255, 1],
3616
+ [255 / 255, 55 / 255, 155 / 255, 1]
3617
+ ]
3618
+ random_colors = np.vstack((specific_colors, random_colors[len(specific_colors):]))
3619
+ if not black_background:
3620
+ random_colors = np.vstack(([0, 0, 0, 1], random_colors))
3621
+ return random_colors
3622
+
3623
+ def assign_colors(unique_labels, random_colors):
3624
+ normalized_colors = random_colors / 255
3625
+ colors_img = [tuple(color) for color in normalized_colors]
3626
+ colors = [tuple(color) for color in random_colors]
3627
+ label_to_color_index = {label: index for index, label in enumerate(unique_labels)}
3628
+ return colors, label_to_color_index
3629
+
3630
+ def setup_plot(figuresize, black_background):
3631
+ if black_background:
3632
+ plt.rcParams.update({'figure.facecolor': 'black', 'axes.facecolor': 'black', 'text.color': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'axes.labelcolor': 'white'})
3633
+ else:
3634
+ plt.rcParams.update({'figure.facecolor': 'white', 'axes.facecolor': 'white', 'text.color': 'black', 'xtick.color': 'black', 'ytick.color': 'black', 'axes.labelcolor': 'black'})
3635
+ fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
3636
+ return fig, ax
3637
+
3638
+ def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize=50, dot_size=50, verbose=False):
3639
+ unique_labels = np.unique(labels)
3640
+ for cluster_label, color, center in zip(unique_labels, colors, cluster_centers):
3641
+ cluster_data = embedding[labels == cluster_label]
3642
+ if smooth_lines:
3643
+ if cluster_data.shape[0] > 2:
3644
+ x_smooth, y_smooth = smooth_hull_lines(cluster_data)
3645
+ if plot_outlines:
3646
+ plt.plot(x_smooth, y_smooth, color=color, linewidth=2)
3647
+ else:
3648
+ if cluster_data.shape[0] > 2:
3649
+ hull = ConvexHull(cluster_data)
3650
+ for simplex in hull.simplices:
3651
+ if plot_outlines:
3652
+ plt.plot(hull.points[simplex, 0], hull.points[simplex, 1], color=color, linewidth=4)
3653
+ if plot_points:
3654
+ scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0.5, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
3655
+ else:
3656
+ scatter = ax.scatter(cluster_data[:, 0], cluster_data[:, 1], s=dot_size, c=[color], alpha=0, label=f'Cluster {cluster_label if cluster_label != -1 else "Noise"}')
3657
+ ax.text(center[0], center[1], str(cluster_label), fontsize=12, ha='center', va='center')
3658
+ plt.legend(loc='best', fontsize=int(figuresize * 0.75))
3659
+ plt.xlabel('UMAP Dimension 1', fontsize=int(figuresize * 0.75))
3660
+ plt.ylabel('UMAP Dimension 2', fontsize=int(figuresize * 0.75))
3661
+ plt.tick_params(axis='both', which='major', labelsize=int(figuresize * 0.75))
3662
+
3663
+ def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose):
3664
+ if plot_by_cluster:
3665
+ cluster_indices = {label: np.where(labels == label)[0] for label in np.unique(labels) if label != -1}
3666
+ plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose)
3667
+ else:
3668
+ indices = random.sample(range(len(embedding)), image_nr)
3669
+ for i, index in enumerate(indices):
3670
+ x, y = embedding[index]
3671
+ img = Image.open(image_paths[index])
3672
+ plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
3673
+
3674
+ def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose):
3675
+ for cluster_label, color in zip(np.unique(labels), colors):
3676
+ if cluster_label == -1:
3677
+ continue
3678
+ indices = cluster_indices.get(cluster_label, [])
3679
+ if len(indices) > image_nr:
3680
+ indices = random.sample(list(indices), image_nr)
3681
+ for index in indices:
3682
+ x, y = embedding[index]
3683
+ img = Image.open(image_paths[index])
3684
+ plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
3685
+
3686
+ def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
3687
+ img = np.array(img)
3688
+ if remove_image_canvas:
3689
+ img = remove_canvas(img)
3690
+ imagebox = OffsetImage(img, zoom=img_zoom)
3691
+ ab = AnnotationBbox(imagebox, (x, y), frameon=False)
3692
+ ax.add_artist(ab)
3693
+
3694
+ def remove_canvas(img):
3695
+ if img.mode in ['L', 'I']:
3696
+ img_data = np.array(img)
3697
+ img_data = img_data / np.max(img_data)
3698
+ alpha_channel = (img_data > 0).astype(float)
3699
+ img_data_rgb = np.stack([img_data] * 3, axis=-1)
3700
+ img_data_with_alpha = np.dstack([img_data_rgb, alpha_channel])
3701
+ elif img.mode == 'RGB':
3702
+ img_data = np.array(img)
3703
+ img_data = img_data / 255.0
3704
+ alpha_channel = (np.sum(img_data, axis=-1) > 0).astype(float)
3705
+ img_data_with_alpha = np.dstack([img_data, alpha_channel])
3706
+ else:
3707
+ raise ValueError(f"Unsupported image mode: {img.mode}")
3708
+ return img_data_with_alpha
3709
+
3710
+ def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figuresize, black_background, verbose):
3711
+ unique_labels = np.unique(labels)
3712
+ num_clusters = len(unique_labels[unique_labels != -1])
3713
+ if num_clusters == 0:
3714
+ print("No clusters found.")
3715
+ return
3716
+ cluster_images = {label: [] for label in unique_labels if label != -1}
3717
+ cluster_indices = {label: np.where(labels == label)[0] for label in unique_labels if label != -1}
3718
+ for cluster_label, indices in cluster_indices.items():
3719
+ if cluster_label == -1:
3720
+ continue
3721
+ if len(indices) > image_nr:
3722
+ indices = random.sample(list(indices), image_nr)
3723
+ for index in indices:
3724
+ img_path = image_paths[index]
3725
+ img_array = Image.open(img_path)
3726
+ img = np.array(img_array)
3727
+ cluster_images[cluster_label].append(img)
3728
+ fig = plot_grid(cluster_images, colors, figuresize, black_background, verbose)
3729
+ return fig
3730
+
3731
+ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
3732
+ num_clusters = len(cluster_images)
3733
+ max_figsize = 200 # Set a maximum figure size
3734
+ if figuresize * num_clusters > max_figsize:
3735
+ figuresize = max_figsize / num_clusters
3736
+
3737
+ grid_fig, grid_axes = plt.subplots(1, num_clusters, figsize=(figuresize * num_clusters, figuresize), gridspec_kw={'wspace': 0.2, 'hspace': 0})
3738
+ if num_clusters == 1:
3739
+ grid_axes = [grid_axes] # Ensure grid_axes is always iterable
3740
+ for cluster_label, axes in zip(cluster_images.keys(), grid_axes):
3741
+ images = cluster_images[cluster_label]
3742
+ num_images = len(images)
3743
+ grid_size = int(np.ceil(np.sqrt(num_images)))
3744
+ image_size = 0.9 / grid_size
3745
+ whitespace = (1 - grid_size * image_size) / (grid_size + 1)
3746
+
3747
+ if isinstance(cluster_label, str):
3748
+ idx = list(cluster_images.keys()).index(cluster_label)
3749
+ color = colors[idx]
3750
+ if verbose:
3751
+ print(f'Lable: {cluster_label} index: {idx}')
3752
+ else:
3753
+ color = colors[cluster_label]
3754
+
3755
+ axes.add_patch(plt.Rectangle((0, 0), 1, 1, transform=axes.transAxes, color=color[:3]))
3756
+ axes.axis('off')
3757
+ for i, img in enumerate(images):
3758
+ row = i // grid_size
3759
+ col = i % grid_size
3760
+ x_pos = (col + 1) * whitespace + col * image_size
3761
+ y_pos = 1 - ((row + 1) * whitespace + (row + 1) * image_size)
3762
+ ax_img = axes.inset_axes([x_pos, y_pos, image_size, image_size], transform=axes.transAxes)
3763
+ ax_img.imshow(img, cmap='gray', aspect='auto')
3764
+ ax_img.axis('off')
3765
+ ax_img.set_aspect('equal')
3766
+ ax_img.set_facecolor(color[:3])
3767
+
3768
+ # Add cluster labels beside the UMAP plot
3769
+ spacing_factor = 0.5 # Adjust this value to control the spacing between labels
3770
+ for i, (cluster_label, color) in enumerate(zip(cluster_images.keys(), colors)):
3771
+ label_y = 1 - (i + 1) * (spacing_factor / num_clusters) # Adjust y position for each label
3772
+ grid_fig.text(1.05, label_y, f'Cluster {cluster_label}', verticalalignment='center', fontsize=figuresize, color='black' if not black_background else 'white')
3773
+ grid_fig.patches.append(plt.Rectangle((1, label_y - 0.02), 0.03, 0.03, transform=grid_fig.transFigure, color=color[:3], clip_on=False))
3774
+
3775
+ plt.show()
3776
+ return grid_fig
3777
+
3778
+ def correct_paths(df, base_path):
3779
+
3780
+ if 'png_path' not in df.columns:
3781
+ print("No 'png_path' column found in the dataframe.")
3782
+ return df, None
3783
+
3784
+ image_paths = df['png_path'].to_list()
3785
+
3786
+ adjusted_image_paths = []
3787
+ for path in image_paths:
3788
+ if base_path not in path:
3789
+ parts = path.split('/data/')
3790
+ if len(parts) > 1:
3791
+ new_path = os.path.join(base_path, 'data', parts[1])
3792
+ adjusted_image_paths.append(new_path)
3793
+ else:
3794
+ adjusted_image_paths.append(path)
3795
+ else:
3796
+ adjusted_image_paths.append(path)
3797
+
3798
+ df['png_path'] = adjusted_image_paths
3799
+ image_paths = df['png_path'].to_list()
3800
+ return df, image_paths
3801
+
3802
+ def correct_paths_v1(df, base_path):
3803
+ if 'png_path' not in df.columns:
3804
+ print("No 'png_path' column found in the dataframe.")
3805
+ return df, None
3806
+
3807
+ image_paths = df['png_path'].to_list()
3808
+
3809
+ adjusted_image_paths = []
3810
+ for path in image_paths:
3811
+ if base_path not in path:
3812
+ print(f"Adjusting path: {path}")
3813
+ parts = path.split('data/')
3814
+ if len(parts) > 1:
3815
+ new_path = os.path.join(base_path, 'data', parts[1])
3816
+ adjusted_image_paths.append(new_path)
3817
+ else:
3818
+ adjusted_image_paths.append(path)
3819
+ else:
3820
+ adjusted_image_paths.append(path)
3821
+
3822
+ df['png_path'] = adjusted_image_paths
3823
+ image_paths = df['png_path'].to_list()
3824
+ return df, image_paths
3825
+
3826
+ def get_umap_image_settings(settings={}):
3827
+ settings.setdefault('src', 'path')
3828
+ settings.setdefault('row_limit', 1000)
3829
+ settings.setdefault('tables', ['cell', 'cytoplasm', 'nucleus', 'pathogen'])
3830
+ settings.setdefault('visualize', 'cell')
3831
+ settings.setdefault('image_nr', 16)
3832
+ settings.setdefault('dot_size', 50)
3833
+ settings.setdefault('n_neighbors', 1000)
3834
+ settings.setdefault('min_dist', 0.1)
3835
+ settings.setdefault('metric', 'euclidean')
3836
+ settings.setdefault('eps', 0.5)
3837
+ settings.setdefault('min_samples', 1000)
3838
+ settings.setdefault('filter_by', 'channel_0')
3839
+ settings.setdefault('img_zoom', 0.5)
3840
+ settings.setdefault('plot_by_cluster', True)
3841
+ settings.setdefault('plot_cluster_grids', True)
3842
+ settings.setdefault('remove_cluster_noise', True)
3843
+ settings.setdefault('remove_highly_correlated', True)
3844
+ settings.setdefault('log_data', False)
3845
+ settings.setdefault('figuresize', 60)
3846
+ settings.setdefault('black_background', True)
3847
+ settings.setdefault('remove_image_canvas', False)
3848
+ settings.setdefault('plot_outlines', True)
3849
+ settings.setdefault('plot_points', True)
3850
+ settings.setdefault('smooth_lines', True)
3851
+ settings.setdefault('clustering', 'dbscan')
3852
+ settings.setdefault('exclude', None)
3853
+ settings.setdefault('col_to_compare', 'col')
3854
+ settings.setdefault('pos', 'c1')
3855
+ settings.setdefault('neg', 'c2')
3856
+ settings.setdefault('embedding_by_controls', False)
3857
+ settings.setdefault('plot_images', True)
3858
+ settings.setdefault('reduction_method','umap')
3859
+ settings.setdefault('save_figure', False)
3860
+ settings.setdefault('n_jobs', -1)
3861
+ settings.setdefault('color_by', None)
3862
+ settings.setdefault('neg', 'c1')
3863
+ settings.setdefault('pos', 'c2')
3864
+ settings.setdefault('mix', 'c3')
3865
+ settings.setdefault('mix', 'c3')
3866
+ settings.setdefault('exclude_conditions', None)
3867
+ settings.setdefault('analyze_clusters', False)
3868
+ settings.setdefault('resnet_features', False)
3869
+ settings.setdefault('verbose',True)
3870
+ return settings
3871
+
3872
+ def preprocess_data(df, filter_by, remove_highly_correlated, log_data, exclude):
3873
+ """
3874
+ Preprocesses the given dataframe by applying filtering, removing highly correlated columns,
3875
+ applying log transformation, filling NaN values, and scaling the numeric data.
3876
+
3877
+ Args:
3878
+ df (pandas.DataFrame): The input dataframe.
3879
+ filter_by (str or None): The channel of interest to filter the dataframe by.
3880
+ remove_highly_correlated (bool or float): Whether to remove highly correlated columns.
3881
+ If a float is provided, it represents the correlation threshold.
3882
+ log_data (bool): Whether to apply log transformation to the numeric data.
3883
+ exclude (list or None): List of features to exclude from the filtering process.
3884
+ verbose (bool): Whether to print verbose output during preprocessing.
3885
+
3886
+ Returns:
3887
+ numpy.ndarray: The preprocessed numeric data.
3888
+
3889
+ Raises:
3890
+ ValueError: If no numeric columns are available after filtering.
3891
+
3892
+ """
3893
+ # Apply filtering based on the `filter_by` parameter
3894
+ if filter_by is not None:
3895
+ df, _ = filter_dataframe_features(df, channel_of_interest=filter_by, exclude=exclude)
3896
+
3897
+ # Select numerical features
3898
+ numeric_data = df.select_dtypes(include=['number'])
3899
+
3900
+ # Check if numeric_data is empty
3901
+ if numeric_data.empty:
3902
+ raise ValueError("No numeric columns available after filtering. Please check the filter_by and exclude parameters.")
3903
+
3904
+ # Remove highly correlated columns
3905
+ if not remove_highly_correlated is False:
3906
+ if isinstance(remove_highly_correlated, float):
3907
+ numeric_data = remove_highly_correlated_columns(numeric_data, remove_highly_correlated)
3908
+ else:
3909
+ numeric_data = remove_highly_correlated_columns(numeric_data, 0.95)
3910
+
3911
+ # Apply log transformation
3912
+ if log_data:
3913
+ numeric_data = np.log(numeric_data + 1e-6)
3914
+
3915
+ # Fill NaN values with the column mean
3916
+ numeric_data = numeric_data.fillna(numeric_data.mean())
3917
+
3918
+ # Scale the numeric data
3919
+ scaler = StandardScaler(copy=True, with_mean=True, with_std=True)
3920
+ numeric_data = scaler.fit_transform(numeric_data)
3921
+
3922
+ return numeric_data
3923
+
3924
+ def filter_dataframe_features(df, channel_of_interest, exclude=None):
3925
+ """
3926
+ Filter the dataframe `df` based on the specified `channel_of_interest` and `exclude` parameters.
3927
+
3928
+ Parameters:
3929
+ - df (pandas.DataFrame): The input dataframe to be filtered.
3930
+ - channel_of_interest (str, int, list, None): The channel(s) of interest to filter the dataframe.
3931
+ If None, no filtering is applied. If 'morphology', only morphology features are included.
3932
+ If an integer, only the specified channel is included. If a list, only the specified channels are included.
3933
+ If a string, only the specified channel is included.
3934
+ - exclude (str, list, None): The feature(s) to exclude from the filtered dataframe.
3935
+ If None, no features are excluded. If a string, the specified feature is excluded.
3936
+ If a list, the specified features are excluded.
3937
+
3938
+ Returns:
3939
+ - filtered_df (pandas.DataFrame): The filtered dataframe based on the specified parameters.
3940
+ - features (list): The list of selected features after filtering.
3941
+
3942
+ """
3943
+ if channel_of_interest is None:
3944
+ feature_string = None
3945
+ elif channel_of_interest == 'morphology':
3946
+ feature_string = 'morphology'
3947
+ elif isinstance(channel_of_interest, list):
3948
+ feature_string = []
3949
+ for i in channel_of_interest:
3950
+ feature_string_tmp = f'channel_{i}'
3951
+ feature_string.append(feature_string_tmp)
3952
+ elif isinstance(channel_of_interest, int):
3953
+ feature_string = f'channel_{channel_of_interest}'
3954
+ elif isinstance(channel_of_interest, str):
3955
+ feature_string = channel_of_interest
3956
+
3957
+ # Remove columns with a single value
3958
+ df = df.loc[:, df.nunique() > 1]
3959
+
3960
+ # Select numerical features
3961
+ features = df.select_dtypes(include=[np.number]).columns.tolist()
3962
+
3963
+ if feature_string is not None:
3964
+ feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
3965
+
3966
+ # Remove feature_string from the list if it exists
3967
+ if isinstance(feature_string, str):
3968
+ if feature_string in feature_list:
3969
+ feature_list.remove(feature_string)
3970
+ elif isinstance(feature_string, list):
3971
+ feature_list = [feature for feature in feature_list if feature not in feature_string]
3972
+
3973
+ if feature_string != 'morphology':
3974
+ features = [feature for feature in features if feature_string in feature]
3975
+
3976
+ # Iterate through the list and remove columns from df
3977
+ for feature_ in feature_list:
3978
+ features = [feature for feature in features if feature_ not in feature]
3979
+ print(f'After removing {feature_} features: {len(features)}')
3980
+
3981
+ if isinstance(exclude, list):
3982
+ features = [feature for feature in features if feature not in exclude]
3983
+ elif isinstance(exclude, str):
3984
+ features.remove(exclude)
3985
+
3986
+ filtered_df = df[features]
3987
+
3988
+ return filtered_df, features
3989
+
3990
+ # Create a function to check if images overlap
3991
+ def check_overlap(current_position, other_positions, threshold):
3992
+ for other_position in other_positions:
3993
+ distance = np.linalg.norm(np.array(current_position) - np.array(other_position))
3994
+ if distance < threshold:
3995
+ return True
3996
+ return False
3997
+
3998
+ # Define a function to try random positions around a given point
3999
+ def find_non_overlapping_position(x, y, image_positions, threshold, max_attempts=100):
4000
+ offset_range = 10 # Adjust the range for random offsets
4001
+ attempts = 0
4002
+ while attempts < max_attempts:
4003
+ random_offset_x = random.uniform(-offset_range, offset_range)
4004
+ random_offset_y = random.uniform(-offset_range, offset_range)
4005
+ new_x = x + random_offset_x
4006
+ new_y = y + random_offset_y
4007
+ if not check_overlap((new_x, new_y), image_positions, threshold):
4008
+ return new_x, new_y
4009
+ attempts += 1
4010
+ return x, y # Return the original position if no suitable position found
4011
+
4012
+ def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, min_samples, clustering, reduction_method, verbose, reduction_param=None, embedding=None, n_jobs=-1):
4013
+ """
4014
+ Perform dimensionality reduction and clustering on the given data.
4015
+
4016
+ Parameters:
4017
+ numeric_data (np.array): Numeric data to process.
4018
+ n_neighbors (int): Number of neighbors for UMAP or perplexity for tSNE.
4019
+ min_dist (float): Minimum distance for UMAP.
4020
+ metric (str): Metric for UMAP, tSNE, and DBSCAN.
4021
+ eps (float): Epsilon for DBSCAN clustering.
4022
+ min_samples (int): Minimum samples for DBSCAN or number of clusters for KMeans.
4023
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
4024
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
4025
+ verbose (bool): Whether to print verbose output.
4026
+ reduction_param (dict): Additional parameters for the reduction method.
4027
+ embedding (np.array): Precomputed embedding (optional).
4028
+ n_jobs (int): Number of parallel jobs to run.
4029
+
4030
+ Returns:
4031
+ embedding (np.array): Embedding of the data.
4032
+ labels (np.array): Cluster labels.
4033
+ """
4034
+
4035
+ if isinstance(n_neighbors, float):
4036
+ n_neighbors = int(n_neighbors * len(numeric_data))
4037
+ if n_neighbors <= 1:
4038
+ n_neighbors = 2
4039
+ print(f'n_neighbors cannota be less than 2. Setting n_neighbors to {n_neighbors}')
4040
+
4041
+ reduction_param = reduction_param or {}
4042
+ reduction_param = {k: v for k, v in reduction_param.items() if k not in ['perplexity', 'n_neighbors', 'min_dist', 'metric', 'method']}
4043
+
4044
+ if reduction_method == 'umap':
4045
+ reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, n_jobs=n_jobs, **reduction_param)
4046
+ elif reduction_method == 'tsne':
4047
+ reducer = TSNE(n_components=2, perplexity=n_neighbors, metric=metric, n_jobs=n_jobs, **reduction_param)
4048
+ else:
4049
+ raise ValueError(f"Unsupported reduction method: {reduction_method}. Supported methods are 'umap' and 'tsne'")
4050
+
4051
+ if embedding is None:
4052
+ embedding = reducer.fit_transform(numeric_data)
4053
+
4054
+ if clustering == 'dbscan':
4055
+ clustering_model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
4056
+ elif clustering == 'kmeans':
4057
+ from sklearn.cluster import KMeans
4058
+ clustering_model = KMeans(n_clusters=min_samples, random_state=42)
4059
+ else:
4060
+ raise ValueError(f"Unsupported clustering method: {clustering}. Supported methods are 'dbscan' and 'kmeans'")
4061
+ clustering_model.fit(embedding)
4062
+ labels = clustering_model.labels_ if clustering == 'dbscan' else clustering_model.predict(embedding)
4063
+ if verbose:
4064
+ print(f'Embedding shape: {embedding.shape}')
4065
+ return embedding, labels
4066
+
4067
+
4068
+