spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -8,11 +8,9 @@ import scipy.ndimage as ndi
8
8
  import seaborn as sns
9
9
  import scipy.stats as stats
10
10
  import statsmodels.api as sm
11
-
11
+ import imageio.v2 as imageio
12
12
  from IPython.display import display
13
13
  from skimage.segmentation import find_boundaries
14
- from skimage.measure import find_contours
15
- from skimage.morphology import square, dilation
16
14
  from skimage import measure
17
15
 
18
16
  from ipywidgets import IntSlider, interact
@@ -195,6 +193,88 @@ def _get_colours_merged(outline_color):
195
193
  outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
196
194
  return outline_colors
197
195
 
196
+ def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, threshold=1000, extensions=['.npy', '.tif', '.tiff', '.png']):
197
+ """
198
+ Plot images and arrays from the given folders.
199
+
200
+ Args:
201
+ folders (list): A list of folder paths containing the images and arrays.
202
+ lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
203
+ upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
204
+ threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
205
+ extensions (list, optional): A list of file extensions to consider. Defaults to ['.npy', '.tif', '.tiff', '.png'].
206
+ """
207
+
208
+ def normalize_image(image, lower=1, upper=99):
209
+ p2, p98 = np.percentile(image, (lower, upper))
210
+ return np.clip((image - p2) / (p98 - p2), 0, 1)
211
+
212
+ def find_files(folders, extensions=['.npy', '.tif', '.tiff', '.png']):
213
+ file_dict = {}
214
+
215
+ for folder in folders:
216
+ for root, _, files in os.walk(folder):
217
+ for file in files:
218
+ if any(file.endswith(ext) for ext in extensions):
219
+ file_name_wo_ext = os.path.splitext(file)[0]
220
+ file_path = os.path.join(root, file)
221
+ if file_name_wo_ext not in file_dict:
222
+ file_dict[file_name_wo_ext] = {}
223
+ file_dict[file_name_wo_ext][folder] = file_path
224
+
225
+ # Filter out files that don't have paths in all folders
226
+ filtered_dict = {k: v for k, v in file_dict.items() if len(v) == len(folders)}
227
+ return filtered_dict
228
+
229
+ def plot_from_file_dict(file_dict, threshold=1000, lower_percentile=1, upper_percentile=99):
230
+ """
231
+ Plot images and arrays from the given file dictionary.
232
+
233
+ Args:
234
+ file_dict (dict): A dictionary containing the file paths for each image or array.
235
+ threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
236
+ lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
237
+ upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
238
+ """
239
+
240
+ for filename, folder_paths in file_dict.items():
241
+ num_files = len(folder_paths)
242
+ fig, axes = plt.subplots(1, num_files, figsize=(15, 5))
243
+ #fig.suptitle(filename)
244
+
245
+ # Ensure axes is always a list
246
+ if num_files == 1:
247
+ axes = [axes]
248
+
249
+ for i, (folder, path) in enumerate(folder_paths.items()):
250
+ if path.endswith('.npy'):
251
+ data = np.load(path)
252
+ elif path.endswith('.tif') or path.endswith('.tiff'):
253
+ data = imageio.imread(path)
254
+ else:
255
+ continue
256
+
257
+ ax = axes[i]
258
+ unique_values = np.unique(data)
259
+ if len(unique_values) > threshold:
260
+ # Normalize image to percentiles
261
+ data = normalize_image(data, lower_percentile, upper_percentile)
262
+ ax.imshow(data, cmap='gray')
263
+ else:
264
+ # Display as mask with random colormap
265
+ cmap = random_cmap(num_objects=len(unique_values))
266
+ ax.imshow(data, cmap=cmap)
267
+
268
+ ax.set_title(f"{os.path.basename(folder)}: {os.path.basename(path)}")
269
+ ax.axis('off')
270
+ plt.tight_layout
271
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
272
+ plt.show()
273
+
274
+ file_dict = find_files(folders, extensions)
275
+ plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile)
276
+ return
277
+
198
278
  def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
199
279
  """
200
280
  Filters objects in a plot based on various criteria.
@@ -220,10 +300,11 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
220
300
  if not filter_min_max is None:
221
301
  min_max = filter_min_max[i]
222
302
  else:
223
- min_max = [0, 100000]
303
+ min_max = [0, 100000000]
224
304
 
225
305
  mask = np.take(stack, mask_dim, axis=2)
226
306
  props = measure.regionprops_table(mask, properties=['label', 'area'])
307
+ #props = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'area', 'mean_intensity'])
227
308
  avg_size_before = np.mean(props['area'])
228
309
  total_count_before = len(props['label'])
229
310
 
@@ -264,7 +345,55 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
264
345
 
265
346
  return stack
266
347
 
267
- def _normalize_and_outline(image, remove_background, backgrounds, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
348
+ def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
349
+ """
350
+ Plot randomly selected arrays from a given directory.
351
+
352
+ Parameters:
353
+ - src (str): The directory path containing the arrays.
354
+ - figuresize (int): The size of the figure (default: 50).
355
+ - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
356
+ - nr (int): The number of arrays to plot (default: 1).
357
+ - normalize (bool): Whether to normalize the arrays (default: True).
358
+ - q1 (int): The lower percentile for normalization (default: 1).
359
+ - q2 (int): The upper percentile for normalization (default: 99).
360
+
361
+ Returns:
362
+ None
363
+ """
364
+ from .utils import normalize_to_dtype
365
+
366
+ mask_cmap = random_cmap()
367
+ paths = []
368
+ for file in os.listdir(src):
369
+ if file.endswith('.npy'):
370
+ path = os.path.join(src, file)
371
+ paths.append(path)
372
+ paths = random.sample(paths, nr)
373
+ for path in paths:
374
+ print(f'Image path:{path}')
375
+ img = np.load(path)
376
+ if normalize:
377
+ img = normalize_to_dtype(array=img, p1=q1, p2=q2)
378
+ dim = img.shape
379
+ if len(img.shape)>2:
380
+ array_nr = img.shape[2]
381
+ fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
382
+ for channel in range(array_nr):
383
+ i = np.take(img, [channel], axis=2)
384
+ axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
385
+ axs[channel].set_title('Channel '+str(channel),size=24)
386
+ axs[channel].axis('off')
387
+ else:
388
+ fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
389
+ ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
390
+ ax.set_title('Channel 0',size=24)
391
+ ax.axis('off')
392
+ fig.tight_layout()
393
+ plt.show()
394
+ return
395
+
396
+ def _normalize_and_outline(image, remove_background, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
268
397
  """
269
398
  Normalize and outline an image.
270
399
 
@@ -283,43 +412,37 @@ def _normalize_and_outline(image, remove_background, backgrounds, normalize, nor
283
412
  Returns:
284
413
  tuple: A tuple containing the overlayed image, the original image, and a list of outlines.
285
414
  """
286
- from .utils import normalize_to_dtype
287
-
288
- outlines = []
415
+ from .utils import normalize_to_dtype, _outline_and_overlay, _gen_rgb_image
416
+
289
417
  if remove_background:
290
- for chan_index, channel in enumerate(range(image.shape[-1])):
291
- single_channel = image[:, :, channel] # Extract the specific channel
292
- background = backgrounds[chan_index]
293
- single_channel[single_channel < background] = 0
294
- image[:, :, channel] = single_channel
418
+ backgrounds = np.percentile(image, 1, axis=(0, 1))
419
+ backgrounds = backgrounds[:, np.newaxis, np.newaxis]
420
+ mask = np.zeros_like(image, dtype=bool)
421
+ for chan_index in range(image.shape[-1]):
422
+ if chan_index not in mask_dims:
423
+ mask[:, :, chan_index] = image[:, :, chan_index] < backgrounds[chan_index]
424
+ image[mask] = 0
425
+
295
426
  if normalize:
296
- image = normalize_to_dtype(array=image, q1=normalization_percentiles[0], q2=normalization_percentiles[1])
297
- rgb_image = np.take(image, overlay_chans, axis=-1)
298
- rgb_image = rgb_image.astype(float)
299
- rgb_image -= rgb_image.min()
300
- rgb_image /= rgb_image.max()
427
+ image = normalize_to_dtype(array=image, p1=normalization_percentiles[0], p2=normalization_percentiles[1])
428
+ else:
429
+ image = normalize_to_dtype(array=image, p1=0, p2=100)
430
+
431
+ rgb_image = _gen_rgb_image(image, channels=overlay_chans)
432
+
301
433
  if overlay:
302
- overlayed_image = rgb_image.copy()
303
- for i, mask_dim in enumerate(mask_dims):
304
- mask = np.take(image, mask_dim, axis=2)
305
- outline = np.zeros_like(mask)
306
- # Find the contours of the objects in the mask
307
- for j in np.unique(mask)[1:]:
308
- contours = find_contours(mask == j, 0.5)
309
- for contour in contours:
310
- contour = contour.astype(int)
311
- outline[contour[:, 0], contour[:, 1]] = j
312
- # Make the outline thicker
313
- outline = dilation(outline, square(outline_thickness))
314
- outlines.append(outline)
315
- # Overlay the outlines onto the RGB image
316
- for j in np.unique(outline)[1:]:
317
- overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
434
+ overlayed_image, outlines, image = _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness)
435
+
318
436
  return overlayed_image, image, outlines
319
437
  else:
438
+ # Remove mask_dims from image
439
+ channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
440
+ image = np.take(image, channels_to_keep, axis=-1)
320
441
  return [], image, []
321
442
 
443
+
322
444
  def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_image, outlines, cmap, outline_colors, print_object_number):
445
+
323
446
  """
324
447
  Plot the merged plot with overlay, image channels, and masks.
325
448
 
@@ -338,6 +461,7 @@ def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_im
338
461
  Returns:
339
462
  fig (Figure): The generated matplotlib figure.
340
463
  """
464
+
341
465
  if overlay:
342
466
  fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims) + 1, figsize=(4 * figuresize, figuresize))
343
467
  ax[0].imshow(overlayed_image) #_imshow
@@ -378,66 +502,20 @@ def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_im
378
502
  plt.show()
379
503
  return fig
380
504
 
381
- def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
382
- """
383
- Plot randomly selected arrays from a given directory.
384
-
385
- Parameters:
386
- - src (str): The directory path containing the arrays.
387
- - figuresize (int): The size of the figure (default: 50).
388
- - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
389
- - nr (int): The number of arrays to plot (default: 1).
390
- - normalize (bool): Whether to normalize the arrays (default: True).
391
- - q1 (int): The lower percentile for normalization (default: 1).
392
- - q2 (int): The upper percentile for normalization (default: 99).
393
-
394
- Returns:
395
- None
396
- """
397
- from .utils import normalize_to_dtype
398
-
399
- mask_cmap = random_cmap()
400
- paths = []
401
- for file in os.listdir(src):
402
- if file.endswith('.npy'):
403
- path = os.path.join(src, file)
404
- paths.append(path)
405
- paths = random.sample(paths, nr)
406
- for path in paths:
407
- print(f'Image path:{path}')
408
- img = np.load(path)
409
- if normalize:
410
- img = normalize_to_dtype(array=img, q1=q1, q2=q2)
411
- dim = img.shape
412
- if len(img.shape)>2:
413
- array_nr = img.shape[2]
414
- fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
415
- for channel in range(array_nr):
416
- i = np.take(img, [channel], axis=2)
417
- axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
418
- axs[channel].set_title('Channel '+str(channel),size=24)
419
- axs[channel].axis('off')
420
- else:
421
- fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
422
- ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
423
- ax.set_title('Channel 0',size=24)
424
- ax.axis('off')
425
- fig.tight_layout()
426
- plt.show()
427
- return
428
-
429
505
  def plot_merged(src, settings):
430
506
  """
431
507
  Plot the merged images after applying various filters and modifications.
432
508
 
433
509
  Args:
434
- src (ndarray): The source images.
510
+ src (path): Path to folder with images.
435
511
  settings (dict): The settings for the plot.
436
512
 
437
513
  Returns:
438
514
  None
439
515
  """
440
516
  from .utils import _remove_noninfected
517
+
518
+
441
519
 
442
520
  font = settings['figuresize']/2
443
521
  outline_colors = _get_colours_merged(settings['outline_color'])
@@ -463,13 +541,27 @@ def plot_merged(src, settings):
463
541
  if settings['include_multiinfected'] is not True or settings['include_multinucleated'] is not True or settings['filter_min_max'] is not None:
464
542
  stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['include_multinucleated'], settings['include_multiinfected'])
465
543
 
466
- #image = np.take(stack, settings['channel_dims'], axis=2)
467
- print('stack.shape', stack.shape)
468
- overlayed_image, image, outlines = _normalize_and_outline(stack, settings['remove_background'], settings['backgrounds'], settings['normalize'], settings['normalization_percentiles'], settings['overlay'], settings['overlay_chans'], mask_dims, outline_colors, settings['outline_thickness'])
469
-
544
+ overlayed_image, image, outlines = _normalize_and_outline(image=stack,
545
+ remove_background=settings['remove_background'],
546
+ normalize=settings['normalize'],
547
+ normalization_percentiles=settings['normalization_percentiles'],
548
+ overlay=settings['overlay'],
549
+ overlay_chans=settings['overlay_chans'],
550
+ mask_dims=mask_dims,
551
+ outline_colors=outline_colors,
552
+ outline_thickness=settings['outline_thickness'])
470
553
  if index < settings['nr']:
471
554
  index += 1
472
- fig = _plot_merged_plot(settings['overlay'], image, stack, mask_dims, settings['figuresize'], overlayed_image, outlines, settings['cmap'], outline_colors, settings['print_object_number'])
555
+ fig = _plot_merged_plot(overlay=settings['overlay'],
556
+ image=image,
557
+ stack=stack,
558
+ mask_dims=mask_dims,
559
+ figuresize=settings['figuresize'],
560
+ overlayed_image=overlayed_image,
561
+ outlines=outlines,
562
+ cmap=settings['cmap'],
563
+ outline_colors=outline_colors,
564
+ print_object_number=settings['print_object_number'])
473
565
  else:
474
566
  return fig
475
567
 
@@ -700,7 +792,7 @@ def _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df, save, src,
700
792
  interactive (bool, optional): Flag indicating whether to display the timelapse stack interactively. Defaults to False.
701
793
  """
702
794
 
703
- from .timelapse import _save_mask_timelapse_as_gif
795
+ from .io import _save_mask_timelapse_as_gif
704
796
 
705
797
  highest_label = max(np.max(mask) for mask in masks)
706
798
  # Generate random colors for each label, including the background
@@ -948,7 +1040,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
948
1040
  if idx < n_images:
949
1041
  canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
950
1042
  plt.figure(figsize=(50, 50))
951
- plt._imshow(canvas)
1043
+ plt.imshow(canvas)
952
1044
  plt.axis("off")
953
1045
  for i, label in enumerate(labels):
954
1046
  row = i // n_col
@@ -1036,7 +1128,7 @@ def _reg_v_plot(df, grouping, variable, plate_number):
1036
1128
  plt.axhline(y=-np.log10(0.05), color='gray', linestyle='--') # line for p=0.05
1037
1129
  plt.show()
1038
1130
 
1039
- def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1131
+ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
1040
1132
  df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
1041
1133
  df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
1042
1134
 
@@ -1049,15 +1141,21 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1049
1141
 
1050
1142
  df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
1051
1143
  df['col'] = pd.Categorical(df['col'], categories=col_order, ordered=True)
1052
-
1144
+ df['count'] = df.groupby(['row', 'col'])['row'].transform('count')
1145
+
1146
+ if min_count > 0:
1147
+ df = df[df['count'] >= min_count]
1148
+
1053
1149
  # Explicitly set observed=True to avoid FutureWarning
1054
- grouped = df.groupby(['row', 'col'], observed=True)
1150
+ grouped = df.groupby(['row', 'col'], observed=True)
1151
+
1055
1152
 
1056
1153
  if grouping == 'mean':
1057
1154
  plate = grouped[variable].mean().reset_index()
1058
1155
  elif grouping == 'sum':
1059
1156
  plate = grouped[variable].sum().reset_index()
1060
1157
  elif grouping == 'count':
1158
+ variable = 'count'
1061
1159
  plate = grouped[variable].count().reset_index()
1062
1160
  else:
1063
1161
  raise ValueError(f"Unsupported grouping: {grouping}")
@@ -1067,21 +1165,24 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1067
1165
  if min_max == 'all':
1068
1166
  min_max = [plate_map.min().min(), plate_map.max().max()]
1069
1167
  elif min_max == 'allq':
1070
- min_max = np.quantile(plate_map.values, [0.2, 0.98])
1071
- elif min_max == 'plate':
1072
- min_max = [plate_map.min().min(), plate_map.max().max()]
1168
+ min_max = np.quantile(plate_map.values, [0.02, 0.98])
1169
+ elif isinstance(min_max, (list, tuple)) and len(min_max) == 2:
1170
+ if isinstance(min_max[0], (float)) and isinstance(min_max[1], (float)):
1171
+ min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
1172
+ if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
1173
+ min_max = [min_max[0], min_max[1]]
1073
1174
 
1074
1175
  return plate_map, min_max
1075
1176
 
1076
- def _plot_plates(df, variable, grouping, min_max, cmap):
1177
+ def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
1077
1178
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1078
1179
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1079
1180
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
1080
1181
  ax = ax.flatten()
1081
1182
 
1082
1183
  for index, plate in enumerate(plates):
1083
- plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max)
1084
- sns.heatmap(plate_map, cmap=cmap, vmin=0, vmax=2, ax=ax[index])
1184
+ plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max, min_count)
1185
+ sns.heatmap(plate_map, cmap=cmap, vmin=min_max_values[0], vmax=min_max_values[1], ax=ax[index])
1085
1186
  ax[index].set_title(plate)
1086
1187
 
1087
1188
  for i in range(len(plates), n_rows * n_cols):
@@ -1089,7 +1190,26 @@ def _plot_plates(df, variable, grouping, min_max, cmap):
1089
1190
 
1090
1191
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1091
1192
  plt.show()
1092
- return
1193
+ return fig
1194
+
1195
+ #def plate_heatmap(src, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, verbose=False):
1196
+ # db_loc = [src+'/measurements/measurements.db']
1197
+ # tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1198
+ # include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
1199
+ # df, _ = spacr.io._read_and_merge_data(db_loc,
1200
+ # tables,
1201
+ # verbose=verbose,
1202
+ # include_multinucleated=include_multinucleated,
1203
+ # include_multiinfected=include_multiinfected,
1204
+ # include_noninfected=include_noninfected)
1205
+ #
1206
+ # df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_outside_75_percentile']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1207
+ #
1208
+ # spacr.plot._plot_plates(df, variable, grouping, min_max, cmap, min_count)
1209
+ # #display(df)
1210
+ # #for col in df.columns:
1211
+ # # print(col)
1212
+ # return
1093
1213
 
1094
1214
  #from finetune cellpose
1095
1215
  #def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
@@ -1229,6 +1349,49 @@ def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
1229
1349
  ax.axis('off')
1230
1350
  plt.suptitle(title)
1231
1351
  plt.show()
1352
+
1353
+ def visualize_cellpose_masks(masks, titles=None, filename=None, save=False, src=None):
1354
+ """
1355
+ Visualize multiple masks with optional titles.
1356
+
1357
+ Parameters:
1358
+ masks (list of np.ndarray): A list of masks to visualize.
1359
+ titles (list of str, optional): A list of titles for the masks. If None, default titles will be used.
1360
+ comparison_title (str): Title for the entire figure.
1361
+ """
1362
+
1363
+ comparison_title=f"Masks Comparison for {filename}"
1364
+
1365
+ if titles is None:
1366
+ titles = [f'Mask {i+1}' for i in range(len(masks))]
1367
+
1368
+ # Ensure the length of titles matches the number of masks
1369
+ assert len(titles) == len(masks), "Number of titles and masks must match"
1370
+
1371
+ num_masks = len(masks)
1372
+ fig, axs = plt.subplots(1, num_masks, figsize=(10 * num_masks, 10)) # Adjusting figure size dynamically
1373
+
1374
+ for ax, mask, title in zip(axs, masks, titles):
1375
+ cmap = generate_mask_random_cmap(mask)
1376
+ # Normalize and display the mask
1377
+ norm = plt.Normalize(vmin=0, vmax=mask.max())
1378
+ ax.imshow(mask, cmap=cmap, norm=norm)
1379
+ ax.set_title(title)
1380
+ ax.axis('off')
1381
+
1382
+ plt.suptitle(comparison_title)
1383
+ plt.show()
1384
+
1385
+ if save:
1386
+ if src is None:
1387
+ src = os.getcwd()
1388
+ results_dir = os.path.join(src, 'results')
1389
+ os.makedirs(results_dir, exist_ok=True)
1390
+ fig_path = os.path.join(results_dir, f'{filename}.pdf')
1391
+ fig.savefig(fig_path, format='pdf')
1392
+ print(f'Saved figure to {fig_path}')
1393
+ return
1394
+
1232
1395
 
1233
1396
  def plot_comparison_results(comparison_results):
1234
1397
  df = pd.DataFrame(comparison_results)