spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/utils.py CHANGED
@@ -1,6 +1,8 @@
1
- import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
1
+ import sys, os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
2
2
 
3
3
  import numpy as np
4
+ from cellpose import models as cp_models
5
+ from cellpose import denoise
4
6
  from skimage import morphology
5
7
  from skimage.measure import label, regionprops_table, regionprops
6
8
  import skimage.measure as measure
@@ -18,6 +20,8 @@ from functools import reduce
18
20
  from IPython.display import display, clear_output
19
21
  from multiprocessing import Pool, cpu_count
20
22
  from skimage.transform import resize as resizescikit
23
+ from skimage.morphology import dilation, square
24
+ from skimage.measure import find_contours
21
25
  import torch.nn as nn
22
26
  import torch.nn.functional as F
23
27
  #from torchsummary import summary
@@ -29,6 +33,7 @@ from skimage.segmentation import clear_border
29
33
  import seaborn as sns
30
34
  import matplotlib.pyplot as plt
31
35
  import scipy.ndimage as ndi
36
+ from scipy.spatial import distance
32
37
  from scipy.stats import fisher_exact
33
38
  from scipy.ndimage import binary_erosion, binary_dilation
34
39
  from skimage.exposure import rescale_intensity
@@ -36,6 +41,7 @@ from sklearn.metrics import auc, precision_recall_curve
36
41
  from sklearn.model_selection import train_test_split
37
42
  from sklearn.linear_model import Lasso, Ridge
38
43
  from sklearn.preprocessing import OneHotEncoder
44
+ from sklearn.cluster import KMeans
39
45
  from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
40
46
 
41
47
  from .logger import log_function_call
@@ -45,6 +51,54 @@ from .logger import log_function_call
45
51
  #from .plot import _plot_images_on_grid, plot_masks, _plot_histograms_and_stats, plot_resize, _plot_plates, _reg_v_plot, plot_masks
46
52
  #from .core import identify_masks
47
53
 
54
+
55
+ def _gen_rgb_image(image, cahnnels):
56
+ rgb_image = np.take(image, cahnnels, axis=-1)
57
+ rgb_image = rgb_image.astype(float)
58
+ rgb_image -= rgb_image.min()
59
+ rgb_image /= rgb_image.max()
60
+ return rgb_image
61
+
62
+ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
63
+ from concurrent.futures import ThreadPoolExecutor
64
+ import cv2
65
+
66
+ outlines = []
67
+ overlayed_image = rgb_image.copy()
68
+
69
+ def process_dim(mask_dim):
70
+ mask = np.take(image, mask_dim, axis=-1)
71
+ outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
72
+
73
+ # Find and draw contours
74
+ for j in np.unique(mask):
75
+ #for j in np.unique(mask)[1:]:
76
+ contours = find_contours(mask == j, 0.5)
77
+ # Convert contours for OpenCV format and draw directly to optimize
78
+ cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
79
+ cv2.drawContours(outline, cv_contours, -1, color=int(j), thickness=outline_thickness)
80
+
81
+ return dilation(outline, square(outline_thickness))
82
+
83
+ # Parallel processing
84
+ with ThreadPoolExecutor() as executor:
85
+ outlines = list(executor.map(process_dim, mask_dims))
86
+
87
+ # Overlay outlines onto the RGB image in a batch/vectorized manner if possible
88
+ for i, outline in enumerate(outlines):
89
+ # This part may need to be adapted to your specific use case and available functions
90
+ # The goal is to overlay each outline with its respective color more efficiently
91
+ color = outline_colors[i % len(outline_colors)]
92
+ for j in np.unique(outline)[1:]:
93
+ mask = outline == j
94
+ overlayed_image[mask] = color # Direct assignment with broadcasting
95
+
96
+ # Remove mask_dims from image
97
+ channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
98
+ image = np.take(image, channels_to_keep, axis=-1)
99
+
100
+ return overlayed_image, outlines, image
101
+
48
102
  def _convert_cq1_well_id(well_id):
49
103
  """
50
104
  Converts a well ID to the CQ1 well format.
@@ -114,7 +168,7 @@ def _extract_filename_metadata(filenames, src, images_by_key, regular_expression
114
168
  if metadata_type =='cq1':
115
169
  orig_wellID = wellID
116
170
  wellID = _convert_cq1_well_id(wellID)
117
- clear_output(wait=True)
171
+ #clear_output(wait=True)
118
172
  print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
119
173
 
120
174
  if pick_slice:
@@ -673,9 +727,6 @@ def _crop_center(img, cell_mask, new_width, new_height, normalize=(2,98)):
673
727
  img = img[start_y:end_y, start_x:end_x, :]
674
728
  return img
675
729
 
676
-
677
-
678
-
679
730
  def _masks_to_masks_stack(masks):
680
731
  """
681
732
  Convert a list of masks into a stack of masks.
@@ -692,53 +743,50 @@ def _masks_to_masks_stack(masks):
692
743
  return mask_stack
693
744
 
694
745
  def _get_diam(mag, obj):
695
- if obj == 'cell':
696
- if mag == 20:
697
- scale = 6
698
- if mag == 40:
699
- scale = 4.5
700
- if mag == 60:
701
- scale = 3
702
- elif obj == 'nucleus':
703
- if mag == 20:
704
- scale = 3
705
- if mag == 40:
706
- scale = 2
707
- if mag == 60:
708
- scale = 1.5
709
- elif obj == 'pathogen':
710
- if mag == 20:
711
- scale = 1.5
712
- if mag == 40:
713
- scale = 1
714
- if mag == 60:
715
- scale = 1.25
716
- elif obj == 'pathogen_nucleus':
717
- if mag == 20:
718
- scale = 0.25
719
- if mag == 40:
720
- scale = 0.2
721
- if mag == 60:
722
- scale = 0.2
746
+
747
+ if mag == 20:
748
+ if obj == 'cell':
749
+ diamiter = 120
750
+ elif obj == 'nucleus':
751
+ diamiter = 60
752
+ elif obj == 'pathogen':
753
+ diamiter = 30
754
+ else:
755
+ raise ValueError("Invalid magnification: Use 20, 40 or 60")
756
+
757
+ elif mag == 40:
758
+ if obj == 'cell':
759
+ diamiter = 160
760
+ elif obj == 'nucleus':
761
+ diamiter = 80
762
+ elif obj == 'pathogen':
763
+ diamiter = 40
764
+ else:
765
+ raise ValueError("Invalid magnification: Use 20, 40 or 60")
766
+
767
+ elif mag == 60:
768
+ if obj == 'cell':
769
+ diamiter = 200
770
+ if obj == 'nucleus':
771
+ diamiter = 90
772
+ if obj == 'pathogen':
773
+ diamiter = 75
774
+ else:
775
+ raise ValueError("Invalid magnification: Use 20, 40 or 60")
723
776
  else:
724
- raise ValueError("Invalid object type")
725
- diamiter = mag*scale
777
+ raise ValueError("Invalid magnification: Use 20, 40 or 60")
778
+
726
779
  return diamiter
727
780
 
728
781
  def _get_object_settings(object_type, settings):
729
-
730
782
  object_settings = {}
731
- object_settings['refine_masks'] = False
732
- object_settings['filter_size'] = False
733
- object_settings['filter_dimm'] = False
734
- print(object_type)
783
+
735
784
  object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
736
- object_settings['remove_border_objects'] = False
737
- object_settings['minimum_size'] = (object_settings['diameter']**2)/10
738
- object_settings['maximum_size'] = object_settings['minimum_size']*50
785
+ object_settings['minimum_size'] = (object_settings['diameter']**2)/4
786
+ object_settings['maximum_size'] = (object_settings['diameter']**2)*10
739
787
  object_settings['merge'] = False
740
- object_settings['net_avg'] = True
741
788
  object_settings['resample'] = True
789
+ object_settings['remove_border_objects'] = False
742
790
  object_settings['model_name'] = 'cyto'
743
791
 
744
792
  if object_type == 'cell':
@@ -746,20 +794,28 @@ def _get_object_settings(object_type, settings):
746
794
  object_settings['model_name'] = 'cyto'
747
795
  else:
748
796
  object_settings['model_name'] = 'cyto2'
749
-
797
+ object_settings['filter_size'] = False
798
+ object_settings['filter_intensity'] = False
799
+ object_settings['restore_type'] = settings.get('cell_restore_type', None)
800
+
750
801
  elif object_type == 'nucleus':
751
802
  object_settings['model_name'] = 'nuclei'
803
+ object_settings['filter_size'] = False
804
+ object_settings['filter_intensity'] = False
805
+ object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
752
806
 
753
807
  elif object_type == 'pathogen':
754
- object_settings['model_name'] = 'cyto3'
755
-
756
- elif object_type == 'pathogen_nucleus':
757
- object_settings['filter_size'] = True
758
808
  object_settings['model_name'] = 'cyto'
809
+ object_settings['filter_size'] = True
810
+ object_settings['filter_intensity'] = False
811
+ object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
812
+ object_settings['merge'] = settings['merge_pathogens']
759
813
 
760
814
  else:
761
815
  print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
762
- print(f'using settings: {object_settings}')
816
+
817
+ if settings['verbose']:
818
+ print(object_settings)
763
819
 
764
820
  return object_settings
765
821
 
@@ -786,6 +842,7 @@ def _pivot_counts_table(db_path):
786
842
  return df
787
843
 
788
844
  def _pivot_dataframe(df):
845
+
789
846
  """
790
847
  Pivot the DataFrame.
791
848
 
@@ -812,61 +869,32 @@ def _pivot_counts_table(db_path):
812
869
  pivoted_df.to_sql('pivoted_counts', conn, if_exists='replace', index=False)
813
870
  conn.close()
814
871
 
815
- def _get_cellpose_channels_v1(mask_channels, nucleus_chann_dim, pathogen_chann_dim, cell_chann_dim):
816
- cellpose_channels = {}
817
- if nucleus_chann_dim in mask_channels:
818
- cellpose_channels['nucleus'] = [0, mask_channels.index(nucleus_chann_dim)]
819
- if pathogen_chann_dim in mask_channels:
820
- cellpose_channels['pathogen'] = [0, mask_channels.index(pathogen_chann_dim)]
821
- if cell_chann_dim in mask_channels:
822
- cellpose_channels['cell'] = [0, mask_channels.index(cell_chann_dim)]
823
- return cellpose_channels
872
+ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel):
824
873
 
825
- def _get_cellpose_channels_v1(cell_channel, nucleus_channel, pathogen_channel):
826
- # Initialize a dictionary to hold the new indices for the specified channels
827
- cellpose_channels = {}
874
+ cell_mask_path = os.path.join(src, 'norm_channel_stack', 'cell_mask_stack')
875
+ nucleus_mask_path = os.path.join(src, 'norm_channel_stack', 'nucleus_mask_stack')
876
+ pathogen_mask_path = os.path.join(src, 'norm_channel_stack', 'pathogen_mask_stack')
828
877
 
829
- # Initialize a list to keep track of the channels in their new order
830
- new_channel_order = []
831
-
832
- # Add each channel to the new order list if it is not None
833
- if cell_channel is not None:
834
- new_channel_order.append(('cell', cell_channel))
835
- if nucleus_channel is not None:
836
- new_channel_order.append(('nucleus', nucleus_channel))
837
- if pathogen_channel is not None:
838
- new_channel_order.append(('pathogen', pathogen_channel))
839
-
840
- # Sort the list based on the original channel indices to maintain the original order
841
- new_channel_order.sort(key=lambda x: x[1])
842
- print(new_channel_order)
843
- # Assign new indices based on the sorted order
844
- for new_index, (channel_name, _) in enumerate(new_channel_order):
845
- cellpose_channels[channel_name] = [new_index, 0]
846
-
847
- if cell_channel is not None and nucleus_channel is not None:
848
- cellpose_channels['cell'][1] = cellpose_channels['nucleus'][0]
849
-
850
- return cellpose_channels
851
878
 
852
- def _get_cellpose_channels(nucleus_channel, pathogen_channel, cell_channel):
879
+ if os.path.exists(cell_mask_path) or os.path.exists(nucleus_mask_path) or os.path.exists(pathogen_mask_path):
880
+ if nucleus_channel is None or nucleus_channel is None or nucleus_channel is None:
881
+ print('Warning: Cellpose masks already exist. Unexpected behaviour when setting any object dimention to None when the object masks have been created.')
882
+
853
883
  cellpose_channels = {}
854
884
  if not nucleus_channel is None:
855
885
  cellpose_channels['nucleus'] = [0,0]
856
886
 
857
887
  if not pathogen_channel is None:
858
888
  if not nucleus_channel is None:
859
- cellpose_channels['pathogen'] = [0,1]
889
+ if not pathogen_channel is None:
890
+ cellpose_channels['pathogen'] = [0,2]
891
+ else:
892
+ cellpose_channels['pathogen'] = [0,1]
860
893
  else:
861
894
  cellpose_channels['pathogen'] = [0,0]
862
895
 
863
896
  if not cell_channel is None:
864
897
  if not nucleus_channel is None:
865
- if not pathogen_channel is None:
866
- cellpose_channels['cell'] = [0,2]
867
- else:
868
- cellpose_channels['cell'] = [0,1]
869
- elif not pathogen_channel is None:
870
898
  cellpose_channels['cell'] = [0,1]
871
899
  else:
872
900
  cellpose_channels['cell'] = [0,0]
@@ -1027,9 +1055,6 @@ def _group_by_well(df):
1027
1055
  # Apply mean function to numeric columns and first to non-numeric
1028
1056
  df_grouped = df.groupby(['plate', 'row', 'col']).agg({**{col: np.mean for col in numeric_cols}, **{col: 'first' for col in non_numeric_cols}})
1029
1057
  return df_grouped
1030
-
1031
-
1032
-
1033
1058
 
1034
1059
  ###################################################
1035
1060
  # Classify
@@ -1044,7 +1069,7 @@ class Cache:
1044
1069
  cache (OrderedDict): The cache data structure.
1045
1070
  """
1046
1071
 
1047
- def _init__(self, max_size):
1072
+ def __init__(self, max_size):
1048
1073
  self.cache = OrderedDict()
1049
1074
  self.max_size = max_size
1050
1075
 
@@ -1075,7 +1100,7 @@ class ScaledDotProductAttention(nn.Module):
1075
1100
 
1076
1101
  """
1077
1102
 
1078
- def _init__(self, d_k):
1103
+ def __init__(self, d_k):
1079
1104
  super(ScaledDotProductAttention, self).__init__()
1080
1105
  self.d_k = d_k
1081
1106
 
@@ -1106,7 +1131,7 @@ class SelfAttention(nn.Module):
1106
1131
  d_k (int): Dimensionality of the key and query vectors.
1107
1132
  """
1108
1133
 
1109
- def _init__(self, in_channels, d_k):
1134
+ def __init__(self, in_channels, d_k):
1110
1135
  super(SelfAttention, self).__init__()
1111
1136
  self.W_q = nn.Linear(in_channels, d_k)
1112
1137
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1130,7 +1155,7 @@ class SelfAttention(nn.Module):
1130
1155
  return output
1131
1156
 
1132
1157
  class ScaledDotProductAttention(nn.Module):
1133
- def _init__(self, d_k):
1158
+ def __init__(self, d_k):
1134
1159
  """
1135
1160
  Initializes the ScaledDotProductAttention module.
1136
1161
 
@@ -1167,7 +1192,7 @@ class SelfAttention(nn.Module):
1167
1192
  in_channels (int): Number of input channels.
1168
1193
  d_k (int): Dimensionality of the key and query vectors.
1169
1194
  """
1170
- def _init__(self, in_channels, d_k):
1195
+ def __init__(self, in_channels, d_k):
1171
1196
  super(SelfAttention, self).__init__()
1172
1197
  self.W_q = nn.Linear(in_channels, d_k)
1173
1198
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1198,7 +1223,7 @@ class EarlyFusion(nn.Module):
1198
1223
  Args:
1199
1224
  in_channels (int): Number of input channels.
1200
1225
  """
1201
- def _init__(self, in_channels):
1226
+ def __init__(self, in_channels):
1202
1227
  super(EarlyFusion, self).__init__()
1203
1228
  self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
1204
1229
 
@@ -1217,7 +1242,7 @@ class EarlyFusion(nn.Module):
1217
1242
 
1218
1243
  # Spatial Attention Mechanism
1219
1244
  class SpatialAttention(nn.Module):
1220
- def _init__(self, kernel_size=7):
1245
+ def __init__(self, kernel_size=7):
1221
1246
  """
1222
1247
  Initializes the SpatialAttention module.
1223
1248
 
@@ -1262,7 +1287,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1262
1287
  forward: Forward method for the module.
1263
1288
  """
1264
1289
 
1265
- def _init__(self, in_channels, out_channels):
1290
+ def __init__(self, in_channels, out_channels):
1266
1291
  super(MultiScaleBlockWithAttention, self).__init__()
1267
1292
  self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
1268
1293
  self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
@@ -1295,7 +1320,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1295
1320
 
1296
1321
  # Final Classifier
1297
1322
  class CustomCellClassifier(nn.Module):
1298
- def _init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1323
+ def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1299
1324
  super(CustomCellClassifier, self).__init__()
1300
1325
  self.early_fusion = EarlyFusion(in_channels=3)
1301
1326
 
@@ -1324,7 +1349,7 @@ class CustomCellClassifier(nn.Module):
1324
1349
 
1325
1350
  #CNN and Transformer class, pick any Torch model.
1326
1351
  class TorchModel(nn.Module):
1327
- def _init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1352
+ def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1328
1353
  super(TorchModel, self).__init__()
1329
1354
  self.model_name = model_name
1330
1355
  self.use_checkpoint = use_checkpoint
@@ -1398,7 +1423,7 @@ class TorchModel(nn.Module):
1398
1423
  return logits
1399
1424
 
1400
1425
  class FocalLossWithLogits(nn.Module):
1401
- def _init__(self, alpha=1, gamma=2):
1426
+ def __init__(self, alpha=1, gamma=2):
1402
1427
  super(FocalLossWithLogits, self).__init__()
1403
1428
  self.alpha = alpha
1404
1429
  self.gamma = gamma
@@ -1410,7 +1435,7 @@ class FocalLossWithLogits(nn.Module):
1410
1435
  return focal_loss.mean()
1411
1436
 
1412
1437
  class ResNet(nn.Module):
1413
- def _init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1438
+ def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1414
1439
  super(ResNet, self).__init__()
1415
1440
 
1416
1441
  resnet_map = {
@@ -1763,25 +1788,24 @@ def annotate_predictions(csv_loc):
1763
1788
  df['cond'] = df.apply(assign_condition, axis=1)
1764
1789
  return df
1765
1790
 
1766
- def init_globals(counter_, lock_):
1791
+ def initiate_counter(counter_, lock_):
1767
1792
  global counter, lock
1768
1793
  counter = counter_
1769
1794
  lock = lock_
1770
1795
 
1771
- def add_images_to_tar(args):
1772
- global counter, lock, total_images
1773
- paths_chunk, tar_path = args
1796
+ def add_images_to_tar(paths_chunk, tar_path, total_images):
1774
1797
  with tarfile.open(tar_path, 'w') as tar:
1775
- for img_path in paths_chunk:
1798
+ for i, img_path in enumerate(paths_chunk):
1776
1799
  arcname = os.path.basename(img_path)
1777
1800
  try:
1778
1801
  tar.add(img_path, arcname=arcname)
1779
1802
  with lock:
1780
1803
  counter.value += 1
1781
- print(f"\rProcessed: {counter.value}/{total_images}", end='', flush=True)
1804
+ if counter.value % 100 == 0: # Print every 100 updates
1805
+ progress = (counter.value / total_images) * 100
1806
+ print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
1782
1807
  except FileNotFoundError:
1783
1808
  print(f"File not found: {img_path}")
1784
- return tar_path
1785
1809
 
1786
1810
  def generate_fraction_map(df, gene_column, min_frequency=0.0):
1787
1811
  df['fraction'] = df['count']/df['well_read_sum']
@@ -2230,8 +2254,8 @@ def dice_coefficient(mask1, mask2):
2230
2254
  def extract_boundaries(mask, dilation_radius=1):
2231
2255
  binary_mask = (mask > 0).astype(np.uint8)
2232
2256
  struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
2233
- dilated = binary_dilation(binary_mask, footprint=struct_elem)
2234
- eroded = binary_erosion(binary_mask, footprint=struct_elem)
2257
+ dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
2258
+ eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
2235
2259
  boundary = dilated ^ eroded
2236
2260
  return boundary
2237
2261
 
@@ -2612,24 +2636,21 @@ def _filter_object(mask, min_value):
2612
2636
  mask[np.isin(mask, to_remove)] = 0
2613
2637
  return mask
2614
2638
 
2615
- def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize):
2639
+ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize):
2640
+
2616
2641
  """
2617
2642
  Filter the masks based on various criteria such as size, border objects, merging, and intensity.
2618
2643
 
2619
2644
  Args:
2620
2645
  masks (list): List of masks.
2621
2646
  flows (list): List of flows.
2622
- refine_masks (bool): Flag indicating whether to refine masks.
2623
2647
  filter_size (bool): Flag indicating whether to filter based on size.
2648
+ filter_intensity (bool): Flag indicating whether to filter based on intensity.
2624
2649
  minimum_size (int): Minimum size of objects to keep.
2625
2650
  maximum_size (int): Maximum size of objects to keep.
2626
2651
  remove_border_objects (bool): Flag indicating whether to remove border objects.
2627
2652
  merge (bool): Flag indicating whether to merge adjacent objects.
2628
- filter_dimm (bool): Flag indicating whether to filter based on intensity.
2629
2653
  batch (ndarray): Batch of images.
2630
- moving_avg_q1 (float): Moving average of the first quartile of object intensities.
2631
- moving_avg_q3 (float): Moving average of the third quartile of object intensities.
2632
- moving_count (int): Count of moving averages.
2633
2654
  plot (bool): Flag indicating whether to plot the masks.
2634
2655
  figuresize (tuple): Size of the figure.
2635
2656
 
@@ -2641,51 +2662,66 @@ def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remo
2641
2662
 
2642
2663
  mask_stack = []
2643
2664
  for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2665
+
2644
2666
  if plot and idx == 0:
2645
2667
  num_objects = mask_object_count(mask)
2646
2668
  print(f'Number of objects before filtration: {num_objects}')
2647
2669
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2648
2670
 
2649
- if filter_size:
2650
- props = measure.regionprops_table(mask, properties=['label', 'area']) # Measure properties of labeled image regions.
2651
- valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)] # Select labels of valid size.
2652
- masks[idx] = np.isin(mask, valid_labels) * mask # Keep only valid objects.
2671
+ if merge:
2672
+ mask = merge_touching_objects(mask, threshold=0.66)
2653
2673
  if plot and idx == 0:
2654
2674
  num_objects = mask_object_count(mask)
2655
- print(f'Number of objects after size filtration >{minimum_size} and <{maximum_size} : {num_objects}')
2675
+ print(f'Number of objects after merging adjacent objects, : {num_objects}')
2656
2676
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2657
- if remove_border_objects:
2658
- mask = clear_border(mask)
2677
+
2678
+ if filter_size:
2679
+ props = measure.regionprops_table(mask, properties=['label', 'area'])
2680
+ valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
2681
+ mask = np.isin(mask, valid_labels) * mask
2659
2682
  if plot and idx == 0:
2660
2683
  num_objects = mask_object_count(mask)
2661
- print(f'Number of objects after removing border objects, : {num_objects}')
2684
+ print(f'Number of objects after size filtration >{minimum_size} and <{maximum_size} : {num_objects}')
2662
2685
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2663
- if merge:
2664
- mask = merge_touching_objects(mask, threshold=0.25)
2686
+
2687
+ if filter_intensity:
2688
+ intensity_image = image[:, :, 1]
2689
+ props = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'mean_intensity'])
2690
+ mean_intensities = np.array(props['mean_intensity']).reshape(-1, 1)
2691
+
2692
+ if mean_intensities.shape[0] >= 2:
2693
+ kmeans = KMeans(n_clusters=2, random_state=0).fit(mean_intensities)
2694
+ centroids = kmeans.cluster_centers_
2695
+
2696
+ # Calculate the Euclidean distance between the two centroids
2697
+ dist_between_centroids = distance.euclidean(centroids[0], centroids[1])
2698
+
2699
+ # Set a threshold for the minimum distance to consider clusters distinct
2700
+ distance_threshold = 0.25
2701
+
2702
+ if dist_between_centroids > distance_threshold:
2703
+ high_intensity_cluster = np.argmax(centroids)
2704
+ valid_labels = np.array(props['label'])[kmeans.labels_ == high_intensity_cluster]
2705
+ mask = np.isin(mask, valid_labels) * mask
2706
+
2665
2707
  if plot and idx == 0:
2666
2708
  num_objects = mask_object_count(mask)
2667
- print(f'Number of objects after merging adjacent objects, : {num_objects}')
2709
+ props_after = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'mean_intensity'])
2710
+ mean_intensities_after = np.mean(np.array(props_after['mean_intensity']))
2711
+ average_intensity_before = np.mean(mean_intensities)
2712
+ print(f'Number of objects after potential intensity clustering: {num_objects}. Mean intensity before:{average_intensity_before:.4f}. After:{mean_intensities_after:.4f}.')
2668
2713
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2669
- if filter_dimm:
2670
- unique_labels = np.unique(mask)
2671
- if len(unique_labels) == 1 and unique_labels[0] == 0:
2672
- continue
2673
- object_intensities = [np.mean(batch[idx, :, :, 1][mask == label]) for label in unique_labels if label != 0]
2674
- object_q1s = [np.percentile(intensities, 25) for intensities in object_intensities if intensities.size > 0]
2675
- object_q3s = [np.percentile(intensities, 75) for intensities in object_intensities if intensities.size > 0]
2676
- if object_q1s:
2677
- object_q1_mean = np.mean(object_q1s)
2678
- object_q3_mean = np.mean(object_q3s)
2679
- moving_avg_q1 = (moving_avg_q1 * moving_count + object_q1_mean) / (moving_count + 1)
2680
- moving_avg_q3 = (moving_avg_q3 * moving_count + object_q3_mean) / (moving_count + 1)
2681
- moving_count += 1
2682
- mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q1, mode='low')
2683
- mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q3, mode='high')
2714
+
2715
+
2716
+ if remove_border_objects:
2717
+ mask = clear_border(mask)
2684
2718
  if plot and idx == 0:
2685
2719
  num_objects = mask_object_count(mask)
2686
- print(f'Objects after intensity filtration > {moving_avg_q1} and <{moving_avg_q3}: {num_objects}')
2720
+ print(f'Number of objects after removing border objects, : {num_objects}')
2687
2721
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2722
+
2688
2723
  mask_stack.append(mask)
2724
+
2689
2725
  return mask_stack
2690
2726
 
2691
2727
  def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mask_chan):
@@ -2721,6 +2757,71 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
2721
2757
  print(f'After {object_type} maximum mean intensity filter: {len(df)}')
2722
2758
  return df
2723
2759
 
2724
- ###################################################
2725
- # Classify
2726
- ###################################################
2760
+ def _run_test_mode(src, regex, timelapse=False):
2761
+ if timelapse:
2762
+ test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
2763
+ else:
2764
+ test_images = 10 # Use 10 sets for non-timelapse scenarios
2765
+
2766
+ test_folder_path = os.path.join(src, 'test')
2767
+ os.makedirs(test_folder_path, exist_ok=True)
2768
+ regular_expression = re.compile(regex)
2769
+
2770
+ all_filenames = [filename for filename in os.listdir(src) if regular_expression.match(filename)]
2771
+ print(f'Found {len(all_filenames)} files')
2772
+ images_by_set = defaultdict(list)
2773
+
2774
+ for filename in all_filenames:
2775
+ match = regular_expression.match(filename)
2776
+ if match:
2777
+ plate = match.group('plateID') if 'plateID' in match.groupdict() else os.path.basename(src)
2778
+ well = match.group('wellID')
2779
+ field = match.group('fieldID')
2780
+ # For timelapse experiments, group images by plate, well, and field only
2781
+ if timelapse:
2782
+ set_identifier = (plate, well, field)
2783
+ else:
2784
+ # For non-timelapse, you might want to distinguish sets more granularly
2785
+ # Here, assuming you're grouping by plate, well, and field for simplicity
2786
+ set_identifier = (plate, well, field)
2787
+ images_by_set[set_identifier].append(filename)
2788
+
2789
+ # Prepare for random selection
2790
+ set_identifiers = list(images_by_set.keys())
2791
+ random.seed(42)
2792
+ random.shuffle(set_identifiers) # Randomize the order
2793
+
2794
+ # Select a subset based on the test_images count
2795
+ selected_sets = set_identifiers[:test_images]
2796
+
2797
+ # Print information about the number of sets used
2798
+ print(f'Using {test_images} random image set(s) for test model')
2799
+
2800
+ # Copy files for selected sets to the test folder
2801
+ for set_identifier in selected_sets:
2802
+ for filename in images_by_set[set_identifier]:
2803
+ shutil.copy(os.path.join(src, filename), test_folder_path)
2804
+
2805
+ return test_folder_path
2806
+
2807
+ def _choose_model(model_name, device, object_type='cell', restore_type=None):
2808
+ restore_list = ['denoise', 'deblur', 'upsample', None]
2809
+ if restore_type not in restore_list:
2810
+ print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
2811
+ restore_type = None
2812
+
2813
+ if restore_type == None:
2814
+ model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
2815
+ else:
2816
+ if object_type == 'nucleus':
2817
+ restore = f'{type}_nuclei'
2818
+ model = denoise.CellposeDenoiseModel(gpu=True, model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
2819
+ else:
2820
+ restore = f'{type}_cyto3'
2821
+ if model_name =='cyto2':
2822
+ chan2_restore = True
2823
+ if model_name =='cyto':
2824
+ chan2_restore = False
2825
+ model = denoise.CellposeDenoiseModel(gpu=True, model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
2826
+
2827
+ return model