spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -8,11 +8,9 @@ import scipy.ndimage as ndi
8
8
  import seaborn as sns
9
9
  import scipy.stats as stats
10
10
  import statsmodels.api as sm
11
-
11
+ import imageio.v2 as imageio
12
12
  from IPython.display import display
13
13
  from skimage.segmentation import find_boundaries
14
- from skimage.measure import find_contours
15
- from skimage.morphology import square, dilation
16
14
  from skimage import measure
17
15
 
18
16
  from ipywidgets import IntSlider, interact
@@ -195,6 +193,88 @@ def _get_colours_merged(outline_color):
195
193
  outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
196
194
  return outline_colors
197
195
 
196
+ def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, threshold=1000, extensions=['.npy', '.tif', '.tiff', '.png']):
197
+ """
198
+ Plot images and arrays from the given folders.
199
+
200
+ Args:
201
+ folders (list): A list of folder paths containing the images and arrays.
202
+ lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
203
+ upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
204
+ threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
205
+ extensions (list, optional): A list of file extensions to consider. Defaults to ['.npy', '.tif', '.tiff', '.png'].
206
+ """
207
+
208
+ def normalize_image(image, lower=1, upper=99):
209
+ p2, p98 = np.percentile(image, (lower, upper))
210
+ return np.clip((image - p2) / (p98 - p2), 0, 1)
211
+
212
+ def find_files(folders, extensions=['.npy', '.tif', '.tiff', '.png']):
213
+ file_dict = {}
214
+
215
+ for folder in folders:
216
+ for root, _, files in os.walk(folder):
217
+ for file in files:
218
+ if any(file.endswith(ext) for ext in extensions):
219
+ file_name_wo_ext = os.path.splitext(file)[0]
220
+ file_path = os.path.join(root, file)
221
+ if file_name_wo_ext not in file_dict:
222
+ file_dict[file_name_wo_ext] = {}
223
+ file_dict[file_name_wo_ext][folder] = file_path
224
+
225
+ # Filter out files that don't have paths in all folders
226
+ filtered_dict = {k: v for k, v in file_dict.items() if len(v) == len(folders)}
227
+ return filtered_dict
228
+
229
+ def plot_from_file_dict(file_dict, threshold=1000, lower_percentile=1, upper_percentile=99):
230
+ """
231
+ Plot images and arrays from the given file dictionary.
232
+
233
+ Args:
234
+ file_dict (dict): A dictionary containing the file paths for each image or array.
235
+ threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
236
+ lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
237
+ upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
238
+ """
239
+
240
+ for filename, folder_paths in file_dict.items():
241
+ num_files = len(folder_paths)
242
+ fig, axes = plt.subplots(1, num_files, figsize=(15, 5))
243
+ #fig.suptitle(filename)
244
+
245
+ # Ensure axes is always a list
246
+ if num_files == 1:
247
+ axes = [axes]
248
+
249
+ for i, (folder, path) in enumerate(folder_paths.items()):
250
+ if path.endswith('.npy'):
251
+ data = np.load(path)
252
+ elif path.endswith('.tif') or path.endswith('.tiff'):
253
+ data = imageio.imread(path)
254
+ else:
255
+ continue
256
+
257
+ ax = axes[i]
258
+ unique_values = np.unique(data)
259
+ if len(unique_values) > threshold:
260
+ # Normalize image to percentiles
261
+ data = normalize_image(data, lower_percentile, upper_percentile)
262
+ ax.imshow(data, cmap='gray')
263
+ else:
264
+ # Display as mask with random colormap
265
+ cmap = random_cmap(num_objects=len(unique_values))
266
+ ax.imshow(data, cmap=cmap)
267
+
268
+ ax.set_title(f"{os.path.basename(folder)}: {os.path.basename(path)}")
269
+ ax.axis('off')
270
+ plt.tight_layout
271
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
272
+ plt.show()
273
+
274
+ file_dict = find_files(folders, extensions)
275
+ plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile)
276
+ return
277
+
198
278
  def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
199
279
  """
200
280
  Filters objects in a plot based on various criteria.
@@ -294,7 +374,7 @@ def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1,
294
374
  print(f'Image path:{path}')
295
375
  img = np.load(path)
296
376
  if normalize:
297
- img = normalize_to_dtype(array=img, q1=q1, q2=q2)
377
+ img = normalize_to_dtype(array=img, p1=q1, p2=q2)
298
378
  dim = img.shape
299
379
  if len(img.shape)>2:
300
380
  array_nr = img.shape[2]
@@ -344,9 +424,11 @@ def _normalize_and_outline(image, remove_background, normalize, normalization_pe
344
424
  image[mask] = 0
345
425
 
346
426
  if normalize:
347
- image = normalize_to_dtype(array=image, q1=normalization_percentiles[0], q2=normalization_percentiles[1])
427
+ image = normalize_to_dtype(array=image, p1=normalization_percentiles[0], p2=normalization_percentiles[1])
428
+ else:
429
+ image = normalize_to_dtype(array=image, p1=0, p2=100)
348
430
 
349
- rgb_image = _gen_rgb_image(image, cahnnels=overlay_chans)
431
+ rgb_image = _gen_rgb_image(image, channels=overlay_chans)
350
432
 
351
433
  if overlay:
352
434
  overlayed_image, outlines, image = _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness)
@@ -358,6 +440,7 @@ def _normalize_and_outline(image, remove_background, normalize, normalization_pe
358
440
  image = np.take(image, channels_to_keep, axis=-1)
359
441
  return [], image, []
360
442
 
443
+
361
444
  def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_image, outlines, cmap, outline_colors, print_object_number):
362
445
 
363
446
  """
@@ -431,6 +514,8 @@ def plot_merged(src, settings):
431
514
  None
432
515
  """
433
516
  from .utils import _remove_noninfected
517
+
518
+
434
519
 
435
520
  font = settings['figuresize']/2
436
521
  outline_colors = _get_colours_merged(settings['outline_color'])
@@ -955,7 +1040,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
955
1040
  if idx < n_images:
956
1041
  canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
957
1042
  plt.figure(figsize=(50, 50))
958
- plt._imshow(canvas)
1043
+ plt.imshow(canvas)
959
1044
  plt.axis("off")
960
1045
  for i, label in enumerate(labels):
961
1046
  row = i // n_col
@@ -1043,7 +1128,7 @@ def _reg_v_plot(df, grouping, variable, plate_number):
1043
1128
  plt.axhline(y=-np.log10(0.05), color='gray', linestyle='--') # line for p=0.05
1044
1129
  plt.show()
1045
1130
 
1046
- def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1131
+ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
1047
1132
  df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
1048
1133
  df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
1049
1134
 
@@ -1056,15 +1141,21 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1056
1141
 
1057
1142
  df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
1058
1143
  df['col'] = pd.Categorical(df['col'], categories=col_order, ordered=True)
1059
-
1144
+ df['count'] = df.groupby(['row', 'col'])['row'].transform('count')
1145
+
1146
+ if min_count > 0:
1147
+ df = df[df['count'] >= min_count]
1148
+
1060
1149
  # Explicitly set observed=True to avoid FutureWarning
1061
- grouped = df.groupby(['row', 'col'], observed=True)
1150
+ grouped = df.groupby(['row', 'col'], observed=True)
1151
+
1062
1152
 
1063
1153
  if grouping == 'mean':
1064
1154
  plate = grouped[variable].mean().reset_index()
1065
1155
  elif grouping == 'sum':
1066
1156
  plate = grouped[variable].sum().reset_index()
1067
1157
  elif grouping == 'count':
1158
+ variable = 'count'
1068
1159
  plate = grouped[variable].count().reset_index()
1069
1160
  else:
1070
1161
  raise ValueError(f"Unsupported grouping: {grouping}")
@@ -1074,21 +1165,24 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
1074
1165
  if min_max == 'all':
1075
1166
  min_max = [plate_map.min().min(), plate_map.max().max()]
1076
1167
  elif min_max == 'allq':
1077
- min_max = np.quantile(plate_map.values, [0.2, 0.98])
1078
- elif min_max == 'plate':
1079
- min_max = [plate_map.min().min(), plate_map.max().max()]
1168
+ min_max = np.quantile(plate_map.values, [0.02, 0.98])
1169
+ elif isinstance(min_max, (list, tuple)) and len(min_max) == 2:
1170
+ if isinstance(min_max[0], (float)) and isinstance(min_max[1], (float)):
1171
+ min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
1172
+ if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
1173
+ min_max = [min_max[0], min_max[1]]
1080
1174
 
1081
1175
  return plate_map, min_max
1082
1176
 
1083
- def _plot_plates(df, variable, grouping, min_max, cmap):
1177
+ def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
1084
1178
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1085
1179
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1086
1180
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
1087
1181
  ax = ax.flatten()
1088
1182
 
1089
1183
  for index, plate in enumerate(plates):
1090
- plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max)
1091
- sns.heatmap(plate_map, cmap=cmap, vmin=0, vmax=2, ax=ax[index])
1184
+ plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max, min_count)
1185
+ sns.heatmap(plate_map, cmap=cmap, vmin=min_max_values[0], vmax=min_max_values[1], ax=ax[index])
1092
1186
  ax[index].set_title(plate)
1093
1187
 
1094
1188
  for i in range(len(plates), n_rows * n_cols):
@@ -1096,7 +1190,26 @@ def _plot_plates(df, variable, grouping, min_max, cmap):
1096
1190
 
1097
1191
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1098
1192
  plt.show()
1099
- return
1193
+ return fig
1194
+
1195
+ #def plate_heatmap(src, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, verbose=False):
1196
+ # db_loc = [src+'/measurements/measurements.db']
1197
+ # tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1198
+ # include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
1199
+ # df, _ = spacr.io._read_and_merge_data(db_loc,
1200
+ # tables,
1201
+ # verbose=verbose,
1202
+ # include_multinucleated=include_multinucleated,
1203
+ # include_multiinfected=include_multiinfected,
1204
+ # include_noninfected=include_noninfected)
1205
+ #
1206
+ # df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_outside_75_percentile']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1207
+ #
1208
+ # spacr.plot._plot_plates(df, variable, grouping, min_max, cmap, min_count)
1209
+ # #display(df)
1210
+ # #for col in df.columns:
1211
+ # # print(col)
1212
+ # return
1100
1213
 
1101
1214
  #from finetune cellpose
1102
1215
  #def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
@@ -1236,6 +1349,49 @@ def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
1236
1349
  ax.axis('off')
1237
1350
  plt.suptitle(title)
1238
1351
  plt.show()
1352
+
1353
+ def visualize_cellpose_masks(masks, titles=None, filename=None, save=False, src=None):
1354
+ """
1355
+ Visualize multiple masks with optional titles.
1356
+
1357
+ Parameters:
1358
+ masks (list of np.ndarray): A list of masks to visualize.
1359
+ titles (list of str, optional): A list of titles for the masks. If None, default titles will be used.
1360
+ comparison_title (str): Title for the entire figure.
1361
+ """
1362
+
1363
+ comparison_title=f"Masks Comparison for {filename}"
1364
+
1365
+ if titles is None:
1366
+ titles = [f'Mask {i+1}' for i in range(len(masks))]
1367
+
1368
+ # Ensure the length of titles matches the number of masks
1369
+ assert len(titles) == len(masks), "Number of titles and masks must match"
1370
+
1371
+ num_masks = len(masks)
1372
+ fig, axs = plt.subplots(1, num_masks, figsize=(10 * num_masks, 10)) # Adjusting figure size dynamically
1373
+
1374
+ for ax, mask, title in zip(axs, masks, titles):
1375
+ cmap = generate_mask_random_cmap(mask)
1376
+ # Normalize and display the mask
1377
+ norm = plt.Normalize(vmin=0, vmax=mask.max())
1378
+ ax.imshow(mask, cmap=cmap, norm=norm)
1379
+ ax.set_title(title)
1380
+ ax.axis('off')
1381
+
1382
+ plt.suptitle(comparison_title)
1383
+ plt.show()
1384
+
1385
+ if save:
1386
+ if src is None:
1387
+ src = os.getcwd()
1388
+ results_dir = os.path.join(src, 'results')
1389
+ os.makedirs(results_dir, exist_ok=True)
1390
+ fig_path = os.path.join(results_dir, f'{filename}.pdf')
1391
+ fig.savefig(fig_path, format='pdf')
1392
+ print(f'Saved figure to {fig_path}')
1393
+ return
1394
+
1239
1395
 
1240
1396
  def plot_comparison_results(comparison_results):
1241
1397
  df = pd.DataFrame(comparison_results)