spacr 0.0.20__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/measure.py CHANGED
@@ -12,6 +12,8 @@ from scipy.ndimage import binary_dilation
12
12
  from skimage.segmentation import find_boundaries
13
13
  from skimage.feature import graycomatrix, graycoprops
14
14
  from mahotas.features import zernike_moments
15
+ from skimage import morphology, measure, filters
16
+ from skimage.util import img_as_bool
15
17
 
16
18
  from .logger import log_function_call
17
19
 
@@ -92,6 +94,69 @@ def _calculate_zernike(mask, df, degree=8):
92
94
  else:
93
95
  return df
94
96
 
97
+ def _analyze_cytoskeleton(array, mask, channel):
98
+ """
99
+ Analyzes and extracts skeleton properties from labeled objects in a masked image based on microtubule staining intensities.
100
+
101
+ Parameters:
102
+ image : numpy array
103
+ Intensity image where the microtubules are stained.
104
+ mask : numpy array
105
+ Mask where objects are labeled for analysis. Each label corresponds to a unique object.
106
+
107
+ Returns:
108
+ DataFrame
109
+ A pandas DataFrame containing the measured properties of each object's skeleton.
110
+ """
111
+
112
+ image = array[:, :, channel]
113
+
114
+ properties_list = []
115
+
116
+ # Process each object in the mask based on its label
117
+ for label in np.unique(mask):
118
+ if label == 0:
119
+ continue # Skip background
120
+
121
+ # Isolate the object using the label
122
+ object_region = mask == label
123
+ region_intensity = np.where(object_region, image, 0) # Use np.where for more efficient masking
124
+
125
+ # Ensure there are non-zero values to process
126
+ if np.any(region_intensity):
127
+ # Calculate adaptive offset based on intensity percentiles within the object
128
+ valid_pixels = region_intensity[region_intensity > 0]
129
+ if len(valid_pixels) > 1: # Ensure there are enough pixels to compute percentiles
130
+ offset = np.percentile(valid_pixels, 90) - np.percentile(valid_pixels, 50)
131
+ block_size = 35 # Adjust this based on your object sizes and detail needs
132
+ local_thresh = filters.threshold_local(region_intensity, block_size=block_size, offset=offset)
133
+ cytoskeleton = region_intensity > local_thresh
134
+
135
+ # Skeletonize the thresholded cytoskeleton
136
+ skeleton = morphology.skeletonize(img_as_bool(cytoskeleton))
137
+
138
+ # Measure properties of the skeleton
139
+ skeleton_props = measure.regionprops(measure.label(skeleton), intensity_image=image)
140
+ skeleton_length = sum(prop.area for prop in skeleton_props) # Sum of lengths of all skeleton segments
141
+ branch_data = morphology.skeleton_branch_analysis(skeleton)
142
+
143
+ # Store properties
144
+ properties = {
145
+ "object_label": label,
146
+ "skeleton_length": skeleton_length,
147
+ "skeleton_branch_points": len(branch_data['branch_points'])
148
+ }
149
+ properties_list.append(properties)
150
+ else:
151
+ # Handle cases with insufficient pixels
152
+ properties_list.append({
153
+ "object_label": label,
154
+ "skeleton_length": 0,
155
+ "skeleton_branch_points": 0
156
+ })
157
+
158
+ return pd.DataFrame(properties_list)
159
+
95
160
  @log_function_call
96
161
  def _morphological_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, settings, zernike=True, degree=8):
97
162
  """
@@ -526,6 +591,7 @@ def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_ma
526
591
 
527
592
  @log_function_call
528
593
  def _measure_crop_core(index, time_ls, file, settings):
594
+
529
595
  """
530
596
  Measure and crop the images based on specified settings.
531
597
 
@@ -624,9 +690,8 @@ def _measure_crop_core(index, time_ls, file, settings):
624
690
  if settings['cytoplasm_min_size'] is not None and settings['cytoplasm_min_size'] != 0:
625
691
  cytoplasm_mask = _filter_object(cytoplasm_mask, settings['cytoplasm_min_size'])
626
692
 
627
- if settings['cell_mask_dim'] is not None and settings['pathogen_mask_dim'] is not None:
628
- if settings['include_uninfected'] == False:
629
- cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask = _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, include_uninfected=False)
693
+ if settings['cell_mask_dim'] is not None:
694
+ cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask = _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, include_uninfected=settings['include_uninfected'])
630
695
 
631
696
  # Update data with the new masks
632
697
  if settings['cell_mask_dim'] is not None:
@@ -645,6 +710,10 @@ def _measure_crop_core(index, time_ls, file, settings):
645
710
 
646
711
  cell_df, nucleus_df, pathogen_df, cytoplasm_df = _morphological_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, settings)
647
712
 
713
+ #if settings['skeleton']:
714
+ #skeleton_df = _analyze_cytoskeleton(image=channel_arrays, mask=cell_mask, channel=1)
715
+ #merge skeleton_df with cell_df here
716
+
648
717
  cell_intensity_df, nucleus_intensity_df, pathogen_intensity_df, cytoplasm_intensity_df = _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[1, 2, 3, 4, 5], periphery=True, outside=True)
649
718
  if settings['cell_mask_dim'] is not None:
650
719
  cell_merged_df = _merge_and_save_to_database(cell_df, cell_intensity_df, 'cell', source_folder, file_name, settings['experiment'], settings['timelapse'])
@@ -658,7 +727,6 @@ def _measure_crop_core(index, time_ls, file, settings):
658
727
  if settings['cytoplasm']:
659
728
  cytoplasm_merged_df = _merge_and_save_to_database(cytoplasm_df, cytoplasm_intensity_df, 'cytoplasm', source_folder, file_name, settings['experiment'], settings['timelapse'])
660
729
 
661
-
662
730
  if settings['save_png'] or settings['save_arrays'] or settings['plot']:
663
731
 
664
732
  if isinstance(settings['dialate_pngs'], bool):
@@ -731,7 +799,7 @@ def _measure_crop_core(index, time_ls, file, settings):
731
799
  png_channels = data[:, :, settings['png_dims']].astype(data_type)
732
800
 
733
801
  if settings['normalize_by'] == 'fov':
734
- percentiles_list = _get_percentiles(png_channels, settings['normalize_percentiles'][0],q2=settings['normalize_percentiles'][1])
802
+ percentiles_list = _get_percentiles(png_channels, settings['normalize'][0],q2=settings['normalize'][1])
735
803
 
736
804
  png_channels = _crop_center(png_channels, region, new_width=width, new_height=height)
737
805
 
@@ -788,6 +856,7 @@ def _measure_crop_core(index, time_ls, file, settings):
788
856
  conn.commit()
789
857
  except sqlite3.OperationalError as e:
790
858
  print(f"SQLite error: {e}", flush=True)
859
+ traceback.print_exc()
791
860
 
792
861
  if settings['plot']:
793
862
  _plot_cropped_arrays(png_channels)
@@ -819,14 +888,13 @@ def _measure_crop_core(index, time_ls, file, settings):
819
888
  return average_time, cells
820
889
 
821
890
  @log_function_call
822
- def measure_crop(settings, annotation_settings, advanced_settings):
891
+ def measure_crop(settings):
892
+
823
893
  """
824
894
  Measure the crop of an image based on the provided settings.
825
895
 
826
896
  Args:
827
897
  settings (dict): The settings for measuring the crop.
828
- annotation_settings (dict): The annotation settings.
829
- advanced_settings (dict): The advanced settings.
830
898
 
831
899
  Returns:
832
900
  None
@@ -845,19 +913,6 @@ def measure_crop(settings, annotation_settings, advanced_settings):
845
913
  from .plot import _save_scimg_plot
846
914
  from .utils import _list_endpoint_subdirectories, _generate_representative_images
847
915
 
848
- settings = {**settings, **annotation_settings, **advanced_settings}
849
-
850
- dirname = os.path.dirname(settings['input_folder'])
851
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
852
- settings_csv = os.path.join(dirname,'settings','measure_crop_settings.csv')
853
- os.makedirs(os.path.join(dirname,'settings'), exist_ok=True)
854
- settings_df.to_csv(settings_csv, index=False)
855
-
856
- if settings['timelapse_objects'] == 'nucleus':
857
- if not settings['cell_mask_dim'] is None:
858
- tlo = settings['timelapse_objects']
859
- print(f'timelapse object:{tlo}, cells will be relabeled to nucleus labels to track cells.')
860
-
861
916
  #general settings
862
917
  settings['merge_edge_pathogen_cells'] = True
863
918
  settings['radial_dist'] = True
@@ -866,6 +921,26 @@ def measure_crop(settings, annotation_settings, advanced_settings):
866
921
  settings['homogeneity'] = True
867
922
  settings['homogeneity_distances'] = [8,16,32]
868
923
  settings['save_arrays'] = False
924
+
925
+ settings['dialate_pngs'] = False
926
+ settings['dialate_png_ratios'] = [0.2]
927
+ settings['timelapse'] = False
928
+ settings['representative_images'] = False
929
+ settings['timelapse_objects'] = 'cell'
930
+ settings['max_workers'] = os.cpu_count()-2
931
+ settings['experiment'] = 'test'
932
+ settings['cells'] = 'HeLa'
933
+ settings['cell_loc'] = None
934
+ settings['pathogens'] = ['ME49Dku80WT', 'ME49Dku80dgra8:GRA8', 'ME49Dku80dgra8', 'ME49Dku80TKO']
935
+ settings['pathogen_loc'] = [['c1', 'c2', 'c3', 'c4', 'c5', 'c6'], ['c7', 'c8', 'c9', 'c10', 'c11', 'c12'], ['c13', 'c14', 'c15', 'c16', 'c17', 'c18'], ['c19', 'c20', 'c21', 'c22', 'c23', 'c24']]
936
+ settings['treatments'] = ['BR1', 'BR2', 'BR3']
937
+ settings['treatment_loc'] = [['c1', 'c2', 'c7', 'c8', 'c13', 'c14', 'c19', 'c20'], ['c3', 'c4', 'c9', 'c10', 'c15', 'c16', 'c21', 'c22'], ['c5', 'c6', 'c11', 'c12', 'c17', 'c18', 'c23', 'c24']]
938
+ settings['channel_of_interest'] = 2
939
+ settings['compartments'] = ['pathogen', 'cytoplasm']
940
+ settings['measurement'] = 'mean_intensity'
941
+ settings['nr_imgs'] = 32
942
+ settings['um_per_pixel'] = 0.1
943
+ settings['center_crop'] = True
869
944
 
870
945
  if settings['cell_mask_dim'] is None:
871
946
  settings['include_uninfected'] = True
@@ -878,7 +953,18 @@ def measure_crop(settings, annotation_settings, advanced_settings):
878
953
  else:
879
954
  settings['cytoplasm'] = False
880
955
 
881
- settings['center_crop'] = True
956
+ #settings = {**settings, **annotation_settings, **advanced_settings}
957
+
958
+ dirname = os.path.dirname(settings['input_folder'])
959
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
960
+ settings_csv = os.path.join(dirname,'settings','measure_crop_settings.csv')
961
+ os.makedirs(os.path.join(dirname,'settings'), exist_ok=True)
962
+ settings_df.to_csv(settings_csv, index=False)
963
+
964
+ if settings['timelapse_objects'] == 'nucleus':
965
+ if not settings['cell_mask_dim'] is None:
966
+ tlo = settings['timelapse_objects']
967
+ print(f'timelapse object:{tlo}, cells will be relabeled to nucleus labels to track cells.')
882
968
 
883
969
  int_setting_keys = ['cell_mask_dim', 'nucleus_mask_dim', 'pathogen_mask_dim', 'cell_min_size', 'nucleus_min_size', 'pathogen_min_size', 'cytoplasm_min_size']
884
970
 
@@ -922,49 +1008,49 @@ def measure_crop(settings, annotation_settings, advanced_settings):
922
1008
  time_left = (((files_to_process-files_processed)*average_time)/max_workers)/60
923
1009
  print(f'Progress: {files_processed}/{files_to_process} Time/img {average_time:.3f}sec, Time Remaining {time_left:.3f} min.', end='\r', flush=True)
924
1010
  result.get()
925
-
926
- if settings['save_png']:
927
- img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
928
- sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
929
-
930
- for i, well_src in enumerate(sc_img_fldrs):
931
- if len(os.listdir(well_src)) < 16:
932
- nr_imgs = len(os.listdir(well_src))
933
- standardize = False
934
- else:
935
- nr_imgs = 16
936
- standardize = True
937
- try:
938
- all_folders = len(sc_img_fldrs)
939
- _save_scimg_plot(src=well_src, nr_imgs=nr_imgs, channel_indices=settings['png_dims'], um_per_pixel=0.1, scale_bar_length_um=10, standardize=standardize, fontsize=12, show_filename=True, channel_names=['red','green','blue'], dpi=300, plot=False, i=i, all_folders=all_folders)
940
-
941
- except Exception as e:
942
- print(f"Unable to generate figure for folder {well_src}: {e}", end='\r', flush=True)
943
- #traceback.print_exc()
1011
+
1012
+ if settings['representative_images']:
1013
+ if settings['save_png']:
1014
+ img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
1015
+ sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
1016
+
1017
+ for i, well_src in enumerate(sc_img_fldrs):
1018
+ if len(os.listdir(well_src)) < 16:
1019
+ nr_imgs = len(os.listdir(well_src))
1020
+ standardize = False
1021
+ else:
1022
+ nr_imgs = 16
1023
+ standardize = True
1024
+ try:
1025
+ all_folders = len(sc_img_fldrs)
1026
+ _save_scimg_plot(src=well_src, nr_imgs=nr_imgs, channel_indices=settings['png_dims'], um_per_pixel=0.1, scale_bar_length_um=10, standardize=standardize, fontsize=12, show_filename=True, channel_names=['red','green','blue'], dpi=300, plot=False, i=i, all_folders=all_folders)
1027
+
1028
+ except Exception as e:
1029
+ print(f"Unable to generate figure for folder {well_src}: {e}", end='\r', flush=True)
1030
+ #traceback.print_exc()
944
1031
 
945
1032
  if settings['save_measurements']:
946
- if settings['representative_images']:
947
- db_path = os.path.join(os.path.dirname(settings['input_folder']), 'measurements', 'measurements.db')
948
- channel_indices = settings['png_dims']
949
- channel_indices = [min(value, 2) for value in channel_indices]
950
- _generate_representative_images(db_path,
951
- cells=settings['cells'],
952
- cell_loc=settings['cell_loc'],
953
- pathogens=settings['pathogens'],
954
- pathogen_loc=settings['pathogen_loc'],
955
- treatments=settings['treatments'],
956
- treatment_loc=settings['treatment_loc'],
957
- channel_of_interest=settings['channel_of_interest'],
958
- compartments = settings['compartments'],
959
- measurement = settings['measurement'],
960
- nr_imgs=settings['nr_imgs'],
961
- channel_indices=channel_indices,
962
- um_per_pixel=settings['um_per_pixel'],
963
- scale_bar_length_um=10,
964
- plot=False,
965
- fontsize=12,
966
- show_filename=True,
967
- channel_names=None)
1033
+ db_path = os.path.join(os.path.dirname(settings['input_folder']), 'measurements', 'measurements.db')
1034
+ channel_indices = settings['png_dims']
1035
+ channel_indices = [min(value, 2) for value in channel_indices]
1036
+ _generate_representative_images(db_path,
1037
+ cells=settings['cells'],
1038
+ cell_loc=settings['cell_loc'],
1039
+ pathogens=settings['pathogens'],
1040
+ pathogen_loc=settings['pathogen_loc'],
1041
+ treatments=settings['treatments'],
1042
+ treatment_loc=settings['treatment_loc'],
1043
+ channel_of_interest=settings['channel_of_interest'],
1044
+ compartments = settings['compartments'],
1045
+ measurement = settings['measurement'],
1046
+ nr_imgs=settings['nr_imgs'],
1047
+ channel_indices=channel_indices,
1048
+ um_per_pixel=settings['um_per_pixel'],
1049
+ scale_bar_length_um=10,
1050
+ plot=False,
1051
+ fontsize=12,
1052
+ show_filename=True,
1053
+ channel_names=None)
968
1054
 
969
1055
  if settings['timelapse']:
970
1056
  if settings['timelapse_objects'] == 'nucleus':
@@ -973,7 +1059,7 @@ def measure_crop(settings, annotation_settings, advanced_settings):
973
1059
  object_types = ['nucleus','pathogen','cell']
974
1060
  _timelapse_masks_to_gif(folder_path, mask_channels, object_types)
975
1061
 
976
- if settings['save_png']:
1062
+ #if settings['save_png']:
977
1063
  img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
978
1064
  sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
979
1065
  _scmovie(sc_img_fldrs)
spacr/plot.py CHANGED
@@ -8,7 +8,7 @@ import scipy.ndimage as ndi
8
8
  import seaborn as sns
9
9
  import scipy.stats as stats
10
10
  import statsmodels.api as sm
11
-
11
+ import imageio.v2 as imageio
12
12
  from IPython.display import display
13
13
  from skimage.segmentation import find_boundaries
14
14
  from skimage.measure import find_contours
@@ -195,6 +195,88 @@ def _get_colours_merged(outline_color):
195
195
  outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
196
196
  return outline_colors
197
197
 
198
+ def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, threshold=1000, extensions=['.npy', '.tif', '.tiff', '.png']):
199
+ """
200
+ Plot images and arrays from the given folders.
201
+
202
+ Args:
203
+ folders (list): A list of folder paths containing the images and arrays.
204
+ lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
205
+ upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
206
+ threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
207
+ extensions (list, optional): A list of file extensions to consider. Defaults to ['.npy', '.tif', '.tiff', '.png'].
208
+ """
209
+
210
+ def normalize_image(image, lower=1, upper=99):
211
+ p2, p98 = np.percentile(image, (lower, upper))
212
+ return np.clip((image - p2) / (p98 - p2), 0, 1)
213
+
214
+ def find_files(folders, extensions=['.npy', '.tif', '.tiff', '.png']):
215
+ file_dict = {}
216
+
217
+ for folder in folders:
218
+ for root, _, files in os.walk(folder):
219
+ for file in files:
220
+ if any(file.endswith(ext) for ext in extensions):
221
+ file_name_wo_ext = os.path.splitext(file)[0]
222
+ file_path = os.path.join(root, file)
223
+ if file_name_wo_ext not in file_dict:
224
+ file_dict[file_name_wo_ext] = {}
225
+ file_dict[file_name_wo_ext][folder] = file_path
226
+
227
+ # Filter out files that don't have paths in all folders
228
+ filtered_dict = {k: v for k, v in file_dict.items() if len(v) == len(folders)}
229
+ return filtered_dict
230
+
231
+ def plot_from_file_dict(file_dict, threshold=1000, lower_percentile=1, upper_percentile=99):
232
+ """
233
+ Plot images and arrays from the given file dictionary.
234
+
235
+ Args:
236
+ file_dict (dict): A dictionary containing the file paths for each image or array.
237
+ threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
238
+ lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
239
+ upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
240
+ """
241
+
242
+ for filename, folder_paths in file_dict.items():
243
+ num_files = len(folder_paths)
244
+ fig, axes = plt.subplots(1, num_files, figsize=(15, 5))
245
+ #fig.suptitle(filename)
246
+
247
+ # Ensure axes is always a list
248
+ if num_files == 1:
249
+ axes = [axes]
250
+
251
+ for i, (folder, path) in enumerate(folder_paths.items()):
252
+ if path.endswith('.npy'):
253
+ data = np.load(path)
254
+ elif path.endswith('.tif') or path.endswith('.tiff'):
255
+ data = imageio.imread(path)
256
+ else:
257
+ continue
258
+
259
+ ax = axes[i]
260
+ unique_values = np.unique(data)
261
+ if len(unique_values) > threshold:
262
+ # Normalize image to percentiles
263
+ data = normalize_image(data, lower_percentile, upper_percentile)
264
+ ax.imshow(data, cmap='gray')
265
+ else:
266
+ # Display as mask with random colormap
267
+ cmap = random_cmap(num_objects=len(unique_values))
268
+ ax.imshow(data, cmap=cmap)
269
+
270
+ ax.set_title(f"{os.path.basename(folder)}: {os.path.basename(path)}")
271
+ ax.axis('off')
272
+ plt.tight_layout
273
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
274
+ plt.show()
275
+
276
+ file_dict = find_files(folders, extensions)
277
+ plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile)
278
+ return
279
+
198
280
  def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
199
281
  """
200
282
  Filters objects in a plot based on various criteria.
@@ -955,7 +1037,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
955
1037
  if idx < n_images:
956
1038
  canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
957
1039
  plt.figure(figsize=(50, 50))
958
- plt._imshow(canvas)
1040
+ plt.imshow(canvas)
959
1041
  plt.axis("off")
960
1042
  for i, label in enumerate(labels):
961
1043
  row = i // n_col
@@ -1043,7 +1125,7 @@ def _reg_v_plot(df, grouping, variable, plate_number):
1043
1125
  plt.axhline(y=-np.log10(0.05), color='gray', linestyle='--') # line for p=0.05
1044
1126
  plt.show()
1045
1127
 
1046
- def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1128
+ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
1047
1129
  df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
1048
1130
  df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
1049
1131
 
@@ -1056,15 +1138,21 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1056
1138
 
1057
1139
  df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
1058
1140
  df['col'] = pd.Categorical(df['col'], categories=col_order, ordered=True)
1059
-
1141
+ df['count'] = df.groupby(['row', 'col'])['row'].transform('count')
1142
+
1143
+ if min_count > 0:
1144
+ df = df[df['count'] >= min_count]
1145
+
1060
1146
  # Explicitly set observed=True to avoid FutureWarning
1061
- grouped = df.groupby(['row', 'col'], observed=True)
1147
+ grouped = df.groupby(['row', 'col'], observed=True)
1148
+
1062
1149
 
1063
1150
  if grouping == 'mean':
1064
1151
  plate = grouped[variable].mean().reset_index()
1065
1152
  elif grouping == 'sum':
1066
1153
  plate = grouped[variable].sum().reset_index()
1067
1154
  elif grouping == 'count':
1155
+ variable = 'count'
1068
1156
  plate = grouped[variable].count().reset_index()
1069
1157
  else:
1070
1158
  raise ValueError(f"Unsupported grouping: {grouping}")
@@ -1074,21 +1162,24 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1074
1162
  if min_max == 'all':
1075
1163
  min_max = [plate_map.min().min(), plate_map.max().max()]
1076
1164
  elif min_max == 'allq':
1077
- min_max = np.quantile(plate_map.values, [0.2, 0.98])
1078
- elif min_max == 'plate':
1079
- min_max = [plate_map.min().min(), plate_map.max().max()]
1165
+ min_max = np.quantile(plate_map.values, [0.02, 0.98])
1166
+ elif isinstance(min_max, (list, tuple)) and len(min_max) == 2:
1167
+ if isinstance(min_max[0], (float)) and isinstance(min_max[1], (float)):
1168
+ min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
1169
+ if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
1170
+ min_max = [min_max[0], min_max[1]]
1080
1171
 
1081
1172
  return plate_map, min_max
1082
1173
 
1083
- def _plot_plates(df, variable, grouping, min_max, cmap):
1174
+ def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
1084
1175
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1085
1176
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1086
1177
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
1087
1178
  ax = ax.flatten()
1088
1179
 
1089
1180
  for index, plate in enumerate(plates):
1090
- plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max)
1091
- sns.heatmap(plate_map, cmap=cmap, vmin=0, vmax=2, ax=ax[index])
1181
+ plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max, min_count)
1182
+ sns.heatmap(plate_map, cmap=cmap, vmin=min_max_values[0], vmax=min_max_values[1], ax=ax[index])
1092
1183
  ax[index].set_title(plate)
1093
1184
 
1094
1185
  for i in range(len(plates), n_rows * n_cols):
@@ -1096,7 +1187,26 @@ def _plot_plates(df, variable, grouping, min_max, cmap):
1096
1187
 
1097
1188
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1098
1189
  plt.show()
1099
- return
1190
+ return fig
1191
+
1192
+ #def plate_heatmap(src, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, verbose=False):
1193
+ # db_loc = [src+'/measurements/measurements.db']
1194
+ # tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1195
+ # include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
1196
+ # df, _ = spacr.io._read_and_merge_data(db_loc,
1197
+ # tables,
1198
+ # verbose=verbose,
1199
+ # include_multinucleated=include_multinucleated,
1200
+ # include_multiinfected=include_multiinfected,
1201
+ # include_noninfected=include_noninfected)
1202
+ #
1203
+ # df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_outside_75_percentile']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1204
+ #
1205
+ # spacr.plot._plot_plates(df, variable, grouping, min_max, cmap, min_count)
1206
+ # #display(df)
1207
+ # #for col in df.columns:
1208
+ # # print(col)
1209
+ # return
1100
1210
 
1101
1211
  #from finetune cellpose
1102
1212
  #def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
@@ -1236,6 +1346,35 @@ def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
1236
1346
  ax.axis('off')
1237
1347
  plt.suptitle(title)
1238
1348
  plt.show()
1349
+
1350
+ def visualize_cellpose_masks(masks, titles=None, comparison_title="Masks Comparison"):
1351
+ """
1352
+ Visualize multiple masks with optional titles.
1353
+
1354
+ Parameters:
1355
+ masks (list of np.ndarray): A list of masks to visualize.
1356
+ titles (list of str, optional): A list of titles for the masks. If None, default titles will be used.
1357
+ comparison_title (str): Title for the entire figure.
1358
+ """
1359
+ if titles is None:
1360
+ titles = [f'Mask {i+1}' for i in range(len(masks))]
1361
+
1362
+ # Ensure the length of titles matches the number of masks
1363
+ assert len(titles) == len(masks), "Number of titles and masks must match"
1364
+
1365
+ num_masks = len(masks)
1366
+ fig, axs = plt.subplots(1, num_masks, figsize=(10 * num_masks, 10)) # Adjusting figure size dynamically
1367
+
1368
+ for ax, mask, title in zip(axs, masks, titles):
1369
+ cmap = generate_mask_random_cmap(mask)
1370
+ # Normalize and display the mask
1371
+ norm = plt.Normalize(vmin=0, vmax=mask.max())
1372
+ ax.imshow(mask, cmap=cmap, norm=norm)
1373
+ ax.set_title(title)
1374
+ ax.axis('off')
1375
+
1376
+ plt.suptitle(comparison_title)
1377
+ plt.show()
1239
1378
 
1240
1379
  def plot_comparison_results(comparison_results):
1241
1380
  df = pd.DataFrame(comparison_results)