spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -8,7 +8,7 @@ import scipy.ndimage as ndi
8
8
  import seaborn as sns
9
9
  import scipy.stats as stats
10
10
  import statsmodels.api as sm
11
-
11
+ import imageio.v2 as imageio
12
12
  from IPython.display import display
13
13
  from skimage.segmentation import find_boundaries
14
14
  from skimage.measure import find_contours
@@ -195,6 +195,88 @@ def _get_colours_merged(outline_color):
195
195
  outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
196
196
  return outline_colors
197
197
 
198
+ def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, threshold=1000, extensions=['.npy', '.tif', '.tiff', '.png']):
199
+ """
200
+ Plot images and arrays from the given folders.
201
+
202
+ Args:
203
+ folders (list): A list of folder paths containing the images and arrays.
204
+ lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
205
+ upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
206
+ threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
207
+ extensions (list, optional): A list of file extensions to consider. Defaults to ['.npy', '.tif', '.tiff', '.png'].
208
+ """
209
+
210
+ def normalize_image(image, lower=1, upper=99):
211
+ p2, p98 = np.percentile(image, (lower, upper))
212
+ return np.clip((image - p2) / (p98 - p2), 0, 1)
213
+
214
+ def find_files(folders, extensions=['.npy', '.tif', '.tiff', '.png']):
215
+ file_dict = {}
216
+
217
+ for folder in folders:
218
+ for root, _, files in os.walk(folder):
219
+ for file in files:
220
+ if any(file.endswith(ext) for ext in extensions):
221
+ file_name_wo_ext = os.path.splitext(file)[0]
222
+ file_path = os.path.join(root, file)
223
+ if file_name_wo_ext not in file_dict:
224
+ file_dict[file_name_wo_ext] = {}
225
+ file_dict[file_name_wo_ext][folder] = file_path
226
+
227
+ # Filter out files that don't have paths in all folders
228
+ filtered_dict = {k: v for k, v in file_dict.items() if len(v) == len(folders)}
229
+ return filtered_dict
230
+
231
+ def plot_from_file_dict(file_dict, threshold=1000, lower_percentile=1, upper_percentile=99):
232
+ """
233
+ Plot images and arrays from the given file dictionary.
234
+
235
+ Args:
236
+ file_dict (dict): A dictionary containing the file paths for each image or array.
237
+ threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
238
+ lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
239
+ upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
240
+ """
241
+
242
+ for filename, folder_paths in file_dict.items():
243
+ num_files = len(folder_paths)
244
+ fig, axes = plt.subplots(1, num_files, figsize=(15, 5))
245
+ #fig.suptitle(filename)
246
+
247
+ # Ensure axes is always a list
248
+ if num_files == 1:
249
+ axes = [axes]
250
+
251
+ for i, (folder, path) in enumerate(folder_paths.items()):
252
+ if path.endswith('.npy'):
253
+ data = np.load(path)
254
+ elif path.endswith('.tif') or path.endswith('.tiff'):
255
+ data = imageio.imread(path)
256
+ else:
257
+ continue
258
+
259
+ ax = axes[i]
260
+ unique_values = np.unique(data)
261
+ if len(unique_values) > threshold:
262
+ # Normalize image to percentiles
263
+ data = normalize_image(data, lower_percentile, upper_percentile)
264
+ ax.imshow(data, cmap='gray')
265
+ else:
266
+ # Display as mask with random colormap
267
+ cmap = random_cmap(num_objects=len(unique_values))
268
+ ax.imshow(data, cmap=cmap)
269
+
270
+ ax.set_title(f"{os.path.basename(folder)}: {os.path.basename(path)}")
271
+ ax.axis('off')
272
+ plt.tight_layout
273
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
274
+ plt.show()
275
+
276
+ file_dict = find_files(folders, extensions)
277
+ plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile)
278
+ return
279
+
198
280
  def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
199
281
  """
200
282
  Filters objects in a plot based on various criteria.
@@ -220,10 +302,11 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
220
302
  if not filter_min_max is None:
221
303
  min_max = filter_min_max[i]
222
304
  else:
223
- min_max = [0, 100000]
305
+ min_max = [0, 100000000]
224
306
 
225
307
  mask = np.take(stack, mask_dim, axis=2)
226
308
  props = measure.regionprops_table(mask, properties=['label', 'area'])
309
+ #props = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'area', 'mean_intensity'])
227
310
  avg_size_before = np.mean(props['area'])
228
311
  total_count_before = len(props['label'])
229
312
 
@@ -264,7 +347,55 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
264
347
 
265
348
  return stack
266
349
 
267
- def _normalize_and_outline(image, remove_background, backgrounds, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
350
+ def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
351
+ """
352
+ Plot randomly selected arrays from a given directory.
353
+
354
+ Parameters:
355
+ - src (str): The directory path containing the arrays.
356
+ - figuresize (int): The size of the figure (default: 50).
357
+ - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
358
+ - nr (int): The number of arrays to plot (default: 1).
359
+ - normalize (bool): Whether to normalize the arrays (default: True).
360
+ - q1 (int): The lower percentile for normalization (default: 1).
361
+ - q2 (int): The upper percentile for normalization (default: 99).
362
+
363
+ Returns:
364
+ None
365
+ """
366
+ from .utils import normalize_to_dtype
367
+
368
+ mask_cmap = random_cmap()
369
+ paths = []
370
+ for file in os.listdir(src):
371
+ if file.endswith('.npy'):
372
+ path = os.path.join(src, file)
373
+ paths.append(path)
374
+ paths = random.sample(paths, nr)
375
+ for path in paths:
376
+ print(f'Image path:{path}')
377
+ img = np.load(path)
378
+ if normalize:
379
+ img = normalize_to_dtype(array=img, q1=q1, q2=q2)
380
+ dim = img.shape
381
+ if len(img.shape)>2:
382
+ array_nr = img.shape[2]
383
+ fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
384
+ for channel in range(array_nr):
385
+ i = np.take(img, [channel], axis=2)
386
+ axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
387
+ axs[channel].set_title('Channel '+str(channel),size=24)
388
+ axs[channel].axis('off')
389
+ else:
390
+ fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
391
+ ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
392
+ ax.set_title('Channel 0',size=24)
393
+ ax.axis('off')
394
+ fig.tight_layout()
395
+ plt.show()
396
+ return
397
+
398
+ def _normalize_and_outline(image, remove_background, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
268
399
  """
269
400
  Normalize and outline an image.
270
401
 
@@ -283,43 +414,34 @@ def _normalize_and_outline(image, remove_background, backgrounds, normalize, nor
283
414
  Returns:
284
415
  tuple: A tuple containing the overlayed image, the original image, and a list of outlines.
285
416
  """
286
- from .utils import normalize_to_dtype
287
-
288
- outlines = []
417
+ from .utils import normalize_to_dtype, _outline_and_overlay, _gen_rgb_image
418
+
289
419
  if remove_background:
290
- for chan_index, channel in enumerate(range(image.shape[-1])):
291
- single_channel = image[:, :, channel] # Extract the specific channel
292
- background = backgrounds[chan_index]
293
- single_channel[single_channel < background] = 0
294
- image[:, :, channel] = single_channel
420
+ backgrounds = np.percentile(image, 1, axis=(0, 1))
421
+ backgrounds = backgrounds[:, np.newaxis, np.newaxis]
422
+ mask = np.zeros_like(image, dtype=bool)
423
+ for chan_index in range(image.shape[-1]):
424
+ if chan_index not in mask_dims:
425
+ mask[:, :, chan_index] = image[:, :, chan_index] < backgrounds[chan_index]
426
+ image[mask] = 0
427
+
295
428
  if normalize:
296
429
  image = normalize_to_dtype(array=image, q1=normalization_percentiles[0], q2=normalization_percentiles[1])
297
- rgb_image = np.take(image, overlay_chans, axis=-1)
298
- rgb_image = rgb_image.astype(float)
299
- rgb_image -= rgb_image.min()
300
- rgb_image /= rgb_image.max()
430
+
431
+ rgb_image = _gen_rgb_image(image, cahnnels=overlay_chans)
432
+
301
433
  if overlay:
302
- overlayed_image = rgb_image.copy()
303
- for i, mask_dim in enumerate(mask_dims):
304
- mask = np.take(image, mask_dim, axis=2)
305
- outline = np.zeros_like(mask)
306
- # Find the contours of the objects in the mask
307
- for j in np.unique(mask)[1:]:
308
- contours = find_contours(mask == j, 0.5)
309
- for contour in contours:
310
- contour = contour.astype(int)
311
- outline[contour[:, 0], contour[:, 1]] = j
312
- # Make the outline thicker
313
- outline = dilation(outline, square(outline_thickness))
314
- outlines.append(outline)
315
- # Overlay the outlines onto the RGB image
316
- for j in np.unique(outline)[1:]:
317
- overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
434
+ overlayed_image, outlines, image = _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness)
435
+
318
436
  return overlayed_image, image, outlines
319
437
  else:
438
+ # Remove mask_dims from image
439
+ channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
440
+ image = np.take(image, channels_to_keep, axis=-1)
320
441
  return [], image, []
321
442
 
322
443
  def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_image, outlines, cmap, outline_colors, print_object_number):
444
+
323
445
  """
324
446
  Plot the merged plot with overlay, image channels, and masks.
325
447
 
@@ -338,6 +460,7 @@ def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_im
338
460
  Returns:
339
461
  fig (Figure): The generated matplotlib figure.
340
462
  """
463
+
341
464
  if overlay:
342
465
  fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims) + 1, figsize=(4 * figuresize, figuresize))
343
466
  ax[0].imshow(overlayed_image) #_imshow
@@ -378,60 +501,12 @@ def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_im
378
501
  plt.show()
379
502
  return fig
380
503
 
381
- def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
382
- """
383
- Plot randomly selected arrays from a given directory.
384
-
385
- Parameters:
386
- - src (str): The directory path containing the arrays.
387
- - figuresize (int): The size of the figure (default: 50).
388
- - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
389
- - nr (int): The number of arrays to plot (default: 1).
390
- - normalize (bool): Whether to normalize the arrays (default: True).
391
- - q1 (int): The lower percentile for normalization (default: 1).
392
- - q2 (int): The upper percentile for normalization (default: 99).
393
-
394
- Returns:
395
- None
396
- """
397
- from .utils import normalize_to_dtype
398
-
399
- mask_cmap = random_cmap()
400
- paths = []
401
- for file in os.listdir(src):
402
- if file.endswith('.npy'):
403
- path = os.path.join(src, file)
404
- paths.append(path)
405
- paths = random.sample(paths, nr)
406
- for path in paths:
407
- print(f'Image path:{path}')
408
- img = np.load(path)
409
- if normalize:
410
- img = normalize_to_dtype(array=img, q1=q1, q2=q2)
411
- dim = img.shape
412
- if len(img.shape)>2:
413
- array_nr = img.shape[2]
414
- fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
415
- for channel in range(array_nr):
416
- i = np.take(img, [channel], axis=2)
417
- axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
418
- axs[channel].set_title('Channel '+str(channel),size=24)
419
- axs[channel].axis('off')
420
- else:
421
- fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
422
- ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
423
- ax.set_title('Channel 0',size=24)
424
- ax.axis('off')
425
- fig.tight_layout()
426
- plt.show()
427
- return
428
-
429
504
  def plot_merged(src, settings):
430
505
  """
431
506
  Plot the merged images after applying various filters and modifications.
432
507
 
433
508
  Args:
434
- src (ndarray): The source images.
509
+ src (path): Path to folder with images.
435
510
  settings (dict): The settings for the plot.
436
511
 
437
512
  Returns:
@@ -463,15 +538,27 @@ def plot_merged(src, settings):
463
538
  if settings['include_multiinfected'] is not True or settings['include_multinucleated'] is not True or settings['filter_min_max'] is not None:
464
539
  stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['include_multinucleated'], settings['include_multiinfected'])
465
540
 
466
- #channels_to_keep = np.ones(stack.shape[-1], dtype=bool)
467
- #channels_to_keep[mask_dims] = False
468
- #channel_image = stack[:, :, channels_to_keep]
469
-
470
- overlayed_image, image, outlines = _normalize_and_outline(stack, settings['remove_background'], settings['backgrounds'], settings['normalize'], settings['normalization_percentiles'], settings['overlay'], settings['overlay_chans'], mask_dims, outline_colors, settings['outline_thickness'])
471
-
541
+ overlayed_image, image, outlines = _normalize_and_outline(image=stack,
542
+ remove_background=settings['remove_background'],
543
+ normalize=settings['normalize'],
544
+ normalization_percentiles=settings['normalization_percentiles'],
545
+ overlay=settings['overlay'],
546
+ overlay_chans=settings['overlay_chans'],
547
+ mask_dims=mask_dims,
548
+ outline_colors=outline_colors,
549
+ outline_thickness=settings['outline_thickness'])
472
550
  if index < settings['nr']:
473
551
  index += 1
474
- fig = _plot_merged_plot(settings['overlay'], image, stack, mask_dims, settings['figuresize'], overlayed_image, outlines, settings['cmap'], outline_colors, settings['print_object_number'])
552
+ fig = _plot_merged_plot(overlay=settings['overlay'],
553
+ image=image,
554
+ stack=stack,
555
+ mask_dims=mask_dims,
556
+ figuresize=settings['figuresize'],
557
+ overlayed_image=overlayed_image,
558
+ outlines=outlines,
559
+ cmap=settings['cmap'],
560
+ outline_colors=outline_colors,
561
+ print_object_number=settings['print_object_number'])
475
562
  else:
476
563
  return fig
477
564
 
@@ -702,7 +789,7 @@ def _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df, save, src,
702
789
  interactive (bool, optional): Flag indicating whether to display the timelapse stack interactively. Defaults to False.
703
790
  """
704
791
 
705
- from .timelapse import _save_mask_timelapse_as_gif
792
+ from .io import _save_mask_timelapse_as_gif
706
793
 
707
794
  highest_label = max(np.max(mask) for mask in masks)
708
795
  # Generate random colors for each label, including the background
@@ -950,7 +1037,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
950
1037
  if idx < n_images:
951
1038
  canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
952
1039
  plt.figure(figsize=(50, 50))
953
- plt._imshow(canvas)
1040
+ plt.imshow(canvas)
954
1041
  plt.axis("off")
955
1042
  for i, label in enumerate(labels):
956
1043
  row = i // n_col
@@ -1038,7 +1125,7 @@ def _reg_v_plot(df, grouping, variable, plate_number):
1038
1125
  plt.axhline(y=-np.log10(0.05), color='gray', linestyle='--') # line for p=0.05
1039
1126
  plt.show()
1040
1127
 
1041
- def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1128
+ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
1042
1129
  df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
1043
1130
  df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
1044
1131
 
@@ -1051,15 +1138,21 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1051
1138
 
1052
1139
  df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
1053
1140
  df['col'] = pd.Categorical(df['col'], categories=col_order, ordered=True)
1054
-
1141
+ df['count'] = df.groupby(['row', 'col'])['row'].transform('count')
1142
+
1143
+ if min_count > 0:
1144
+ df = df[df['count'] >= min_count]
1145
+
1055
1146
  # Explicitly set observed=True to avoid FutureWarning
1056
- grouped = df.groupby(['row', 'col'], observed=True)
1147
+ grouped = df.groupby(['row', 'col'], observed=True)
1148
+
1057
1149
 
1058
1150
  if grouping == 'mean':
1059
1151
  plate = grouped[variable].mean().reset_index()
1060
1152
  elif grouping == 'sum':
1061
1153
  plate = grouped[variable].sum().reset_index()
1062
1154
  elif grouping == 'count':
1155
+ variable = 'count'
1063
1156
  plate = grouped[variable].count().reset_index()
1064
1157
  else:
1065
1158
  raise ValueError(f"Unsupported grouping: {grouping}")
@@ -1069,21 +1162,24 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1069
1162
  if min_max == 'all':
1070
1163
  min_max = [plate_map.min().min(), plate_map.max().max()]
1071
1164
  elif min_max == 'allq':
1072
- min_max = np.quantile(plate_map.values, [0.2, 0.98])
1073
- elif min_max == 'plate':
1074
- min_max = [plate_map.min().min(), plate_map.max().max()]
1165
+ min_max = np.quantile(plate_map.values, [0.02, 0.98])
1166
+ elif isinstance(min_max, (list, tuple)) and len(min_max) == 2:
1167
+ if isinstance(min_max[0], (float)) and isinstance(min_max[1], (float)):
1168
+ min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
1169
+ if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
1170
+ min_max = [min_max[0], min_max[1]]
1075
1171
 
1076
1172
  return plate_map, min_max
1077
1173
 
1078
- def _plot_plates(df, variable, grouping, min_max, cmap):
1174
+ def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
1079
1175
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1080
1176
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1081
1177
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
1082
1178
  ax = ax.flatten()
1083
1179
 
1084
1180
  for index, plate in enumerate(plates):
1085
- plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max)
1086
- sns.heatmap(plate_map, cmap=cmap, vmin=0, vmax=2, ax=ax[index])
1181
+ plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max, min_count)
1182
+ sns.heatmap(plate_map, cmap=cmap, vmin=min_max_values[0], vmax=min_max_values[1], ax=ax[index])
1087
1183
  ax[index].set_title(plate)
1088
1184
 
1089
1185
  for i in range(len(plates), n_rows * n_cols):
@@ -1091,7 +1187,26 @@ def _plot_plates(df, variable, grouping, min_max, cmap):
1091
1187
 
1092
1188
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1093
1189
  plt.show()
1094
- return
1190
+ return fig
1191
+
1192
+ #def plate_heatmap(src, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, verbose=False):
1193
+ # db_loc = [src+'/measurements/measurements.db']
1194
+ # tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1195
+ # include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
1196
+ # df, _ = spacr.io._read_and_merge_data(db_loc,
1197
+ # tables,
1198
+ # verbose=verbose,
1199
+ # include_multinucleated=include_multinucleated,
1200
+ # include_multiinfected=include_multiinfected,
1201
+ # include_noninfected=include_noninfected)
1202
+ #
1203
+ # df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_outside_75_percentile']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1204
+ #
1205
+ # spacr.plot._plot_plates(df, variable, grouping, min_max, cmap, min_count)
1206
+ # #display(df)
1207
+ # #for col in df.columns:
1208
+ # # print(col)
1209
+ # return
1095
1210
 
1096
1211
  #from finetune cellpose
1097
1212
  #def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
@@ -1231,6 +1346,35 @@ def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
1231
1346
  ax.axis('off')
1232
1347
  plt.suptitle(title)
1233
1348
  plt.show()
1349
+
1350
+ def visualize_cellpose_masks(masks, titles=None, comparison_title="Masks Comparison"):
1351
+ """
1352
+ Visualize multiple masks with optional titles.
1353
+
1354
+ Parameters:
1355
+ masks (list of np.ndarray): A list of masks to visualize.
1356
+ titles (list of str, optional): A list of titles for the masks. If None, default titles will be used.
1357
+ comparison_title (str): Title for the entire figure.
1358
+ """
1359
+ if titles is None:
1360
+ titles = [f'Mask {i+1}' for i in range(len(masks))]
1361
+
1362
+ # Ensure the length of titles matches the number of masks
1363
+ assert len(titles) == len(masks), "Number of titles and masks must match"
1364
+
1365
+ num_masks = len(masks)
1366
+ fig, axs = plt.subplots(1, num_masks, figsize=(10 * num_masks, 10)) # Adjusting figure size dynamically
1367
+
1368
+ for ax, mask, title in zip(axs, masks, titles):
1369
+ cmap = generate_mask_random_cmap(mask)
1370
+ # Normalize and display the mask
1371
+ norm = plt.Normalize(vmin=0, vmax=mask.max())
1372
+ ax.imshow(mask, cmap=cmap, norm=norm)
1373
+ ax.set_title(title)
1374
+ ax.axis('off')
1375
+
1376
+ plt.suptitle(comparison_title)
1377
+ plt.show()
1234
1378
 
1235
1379
  def plot_comparison_results(comparison_results):
1236
1380
  df = pd.DataFrame(comparison_results)