spacr 0.3.47__py3-none-any.whl → 0.3.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -16,6 +16,7 @@ from skimage import measure
16
16
  from skimage.measure import find_contours, label, regionprops
17
17
  from skimage.segmentation import mark_boundaries
18
18
  from skimage.transform import resize as sk_resize
19
+ import scikit_posthocs as sp
19
20
 
20
21
  import tifffile as tiff
21
22
 
@@ -32,7 +33,340 @@ import matplotlib.patches as patches
32
33
  from collections import defaultdict
33
34
  from matplotlib.gridspec import GridSpec
34
35
 
35
- def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2,98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False):
36
+ #filter_dict={'cell':[(0,100000), (0, 65000)],'nucleus':[(3000,100000), (1500, 65000)],'pathogen':[(500,100000), (0, 65000)]}
37
+ def plot_image_mask_overlay(
38
+ file,
39
+ channels,
40
+ cell_channel,
41
+ nucleus_channel,
42
+ pathogen_channel,
43
+ figuresize=10,
44
+ percentiles=(2, 98),
45
+ thickness=3,
46
+ save_pdf=True,
47
+ mode='outlines',
48
+ export_tiffs=False,
49
+ all_on_all=False,
50
+ all_outlines=False,
51
+ filter_dict=None
52
+ ):
53
+ """Plot image and mask overlays."""
54
+
55
+ def random_color_cmap(n_labels, seed=None):
56
+ """Generates a random color map for a given number of labels."""
57
+ if seed is not None:
58
+ np.random.seed(seed)
59
+ rand_colors = np.random.rand(n_labels, 3)
60
+ rand_colors = np.vstack([[0, 0, 0], rand_colors]) # Ensure background is black
61
+ cmap = ListedColormap(rand_colors)
62
+ return cmap
63
+
64
+ def _plot_merged_plot(
65
+ image,
66
+ outlines,
67
+ outline_colors,
68
+ figuresize,
69
+ thickness,
70
+ percentiles,
71
+ mode='outlines',
72
+ all_on_all=False,
73
+ all_outlines=False,
74
+ channels=None,
75
+ cell_channel=None,
76
+ nucleus_channel=None,
77
+ pathogen_channel=None,
78
+ cell_outlines=None,
79
+ nucleus_outlines=None,
80
+ pathogen_outlines=None,
81
+ save_pdf=True
82
+ ):
83
+ """Plot the merged plot with overlay, image channels, and masks."""
84
+
85
+ def _generate_colored_mask(mask, cmap):
86
+ """Generate a colored mask using the given colormap."""
87
+ mask_norm = mask / (mask.max() + 1e-5) # Normalize mask
88
+ colored_mask = cmap(mask_norm)
89
+ colored_mask[..., 3] = np.where(mask > 0, 1, 0) # Alpha channel
90
+ return colored_mask
91
+
92
+ def _overlay_mask(image, mask):
93
+ """Overlay the colored mask onto the original image."""
94
+ combined = np.clip(image * (1 - mask[..., 3:]) + mask[..., :3] * mask[..., 3:], 0, 1)
95
+ return combined
96
+
97
+ def _normalize_image(image, percentiles):
98
+ """Normalize the image based on given percentiles."""
99
+ v_min, v_max = np.percentile(image, percentiles)
100
+ image_normalized = np.clip((image - v_min) / (v_max - v_min + 1e-5), 0, 1)
101
+ return image_normalized
102
+
103
+ def _generate_contours(mask):
104
+ """Generate contours from the mask using OpenCV."""
105
+ contours, _ = cv2.findContours(
106
+ mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
107
+ )
108
+ return contours
109
+
110
+ def _apply_contours(image, mask, color, thickness):
111
+ """Apply contours to the image."""
112
+ unique_labels = np.unique(mask)
113
+ for label in unique_labels:
114
+ if label == 0:
115
+ continue # Skip background
116
+ label_mask = (mask == label).astype(np.uint8)
117
+ contours = _generate_contours(label_mask)
118
+ cv2.drawContours(
119
+ image, contours, -1, mpl.colors.to_rgb(color), thickness
120
+ )
121
+ return image
122
+
123
+ num_channels = image.shape[-1]
124
+ fig, ax = plt.subplots(1, num_channels + 1, figsize=(4 * figuresize, figuresize))
125
+
126
+ # Identify channels without associated outlines
127
+ channels_with_outlines = []
128
+ if cell_channel is not None:
129
+ channels_with_outlines.append(cell_channel)
130
+ if nucleus_channel is not None:
131
+ channels_with_outlines.append(nucleus_channel)
132
+ if pathogen_channel is not None:
133
+ channels_with_outlines.append(pathogen_channel)
134
+
135
+ for v in range(num_channels):
136
+ channel_image = image[..., v]
137
+ channel_image_normalized = _normalize_image(channel_image, percentiles)
138
+ channel_image_rgb = np.dstack([channel_image_normalized] * 3)
139
+
140
+ current_channel = channels[v]
141
+
142
+ if all_on_all:
143
+ # Apply all outlines to all channels
144
+ for outline, color in zip(outlines, outline_colors):
145
+ if mode == 'outlines':
146
+ channel_image_rgb = _apply_contours(
147
+ channel_image_rgb, outline, color, thickness
148
+ )
149
+ else:
150
+ cmap = random_color_cmap(int(outline.max() + 1), random.randint(0, 100))
151
+ mask = _generate_colored_mask(outline, cmap)
152
+ channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
153
+ elif current_channel in channels_with_outlines:
154
+ # Apply only the relevant outline to each channel
155
+ outline = None
156
+ color = None
157
+
158
+ if current_channel == cell_channel and cell_outlines is not None:
159
+ outline = cell_outlines
160
+ elif current_channel == nucleus_channel and nucleus_outlines is not None:
161
+ outline = nucleus_outlines
162
+ elif current_channel == pathogen_channel and pathogen_outlines is not None:
163
+ outline = pathogen_outlines
164
+
165
+ if outline is not None:
166
+ if mode == 'outlines':
167
+ # Use magenta color when all_on_all=False
168
+ channel_image_rgb = _apply_contours(
169
+ channel_image_rgb, outline, '#FF00FF', thickness
170
+ )
171
+ else:
172
+ cmap = random_color_cmap(int(outline.max() + 1), random.randint(0, 100))
173
+ mask = _generate_colored_mask(outline, cmap)
174
+ channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
175
+ else:
176
+ # Channel without associated outlines
177
+ if all_outlines:
178
+ # Apply all outlines with specified colors
179
+ for outline, color in zip(outlines, ['blue', 'red', 'green']):
180
+ if mode == 'outlines':
181
+ channel_image_rgb = _apply_contours(
182
+ channel_image_rgb, outline, color, thickness
183
+ )
184
+ else:
185
+ cmap = random_color_cmap(int(outline.max() + 1), random.randint(0, 100))
186
+ mask = _generate_colored_mask(outline, cmap)
187
+ channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
188
+
189
+ ax[v].imshow(channel_image_rgb)
190
+ ax[v].set_title(f'Image - Channel {current_channel}')
191
+
192
+ # Create an image combining all objects filled with colors
193
+ combined_mask = np.zeros_like(outlines[0])
194
+ for outline in outlines:
195
+ combined_mask = np.maximum(combined_mask, outline)
196
+
197
+ cmap = random_color_cmap(int(combined_mask.max() + 1), random.randint(0, 100))
198
+ mask = _generate_colored_mask(combined_mask, cmap)
199
+ blank_image = np.zeros((*combined_mask.shape, 3))
200
+ filled_image = _overlay_mask(blank_image, mask)
201
+
202
+ ax[-1].imshow(filled_image)
203
+ ax[-1].set_title('Combined Objects Image')
204
+
205
+ plt.tight_layout()
206
+
207
+ # Save the figure as a PDF
208
+ if save_pdf:
209
+ pdf_dir = os.path.join(
210
+ os.path.dirname(os.path.dirname(file)), 'results', 'overlay'
211
+ )
212
+ os.makedirs(pdf_dir, exist_ok=True)
213
+ pdf_path = os.path.join(
214
+ pdf_dir, os.path.basename(file).replace('.npy', '.pdf')
215
+ )
216
+ fig.savefig(pdf_path, format='pdf')
217
+
218
+ plt.show()
219
+ return fig
220
+
221
+ def _save_channels_as_tiff(stack, save_dir, filename):
222
+ """Save each channel in the stack as a grayscale TIFF."""
223
+ os.makedirs(save_dir, exist_ok=True)
224
+ for i in range(stack.shape[-1]):
225
+ channel = stack[..., i]
226
+ tiff_path = os.path.join(save_dir, f"{filename}_channel_{i}.tiff")
227
+ tiff.imwrite(tiff_path, channel.astype(np.uint16), photometric='minisblack')
228
+ print(f"Saved {tiff_path}")
229
+
230
+ def _filter_object(mask, intensity_image, min_max_area=(0, 10000000), min_max_intensity=(0, 65000), type_='object'):
231
+ """
232
+ Filter objects in a mask based on their area (size) and mean intensity.
233
+
234
+ Args:
235
+ mask (ndarray): The input mask.
236
+ intensity_image (ndarray): The corresponding intensity image.
237
+ min_max_area (tuple): A tuple (min_area, max_area) specifying the minimum and maximum area thresholds.
238
+ min_max_intensity (tuple): A tuple (min_intensity, max_intensity) specifying the minimum and maximum intensity thresholds.
239
+
240
+ Returns:
241
+ ndarray: The filtered mask.
242
+ """
243
+ original_dtype = mask.dtype
244
+ mask_int = mask.astype(np.int64)
245
+ intensity_image = intensity_image.astype(np.float64)
246
+ # Compute properties for each labeled object
247
+ unique_labels = np.unique(mask_int)
248
+ unique_labels = unique_labels[unique_labels != 0] # Exclude background
249
+ num_objects_before = len(unique_labels)
250
+
251
+ # Initialize lists to store area and intensity for each object
252
+ areas = []
253
+ mean_intensities = []
254
+ labels_to_keep = []
255
+
256
+ for label in unique_labels:
257
+ label_mask = (mask_int == label)
258
+ area = np.sum(label_mask)
259
+ mean_intensity = np.mean(intensity_image[label_mask])
260
+
261
+ areas.append(area)
262
+ mean_intensities.append(mean_intensity)
263
+
264
+ # Check if the object meets both area and intensity criteria
265
+ if (min_max_area[0] <= area <= min_max_area[1]) and (min_max_intensity[0] <= mean_intensity <= min_max_intensity[1]):
266
+ labels_to_keep.append(label)
267
+
268
+ # Convert lists to numpy arrays for easier computation
269
+ areas = np.array(areas)
270
+ mean_intensities = np.array(mean_intensities)
271
+ num_objects_after = len(labels_to_keep)
272
+ # Compute average area and intensity before and after filtering
273
+ avg_area_before = areas.mean() if num_objects_before > 0 else 0
274
+ avg_intensity_before = mean_intensities.mean() if num_objects_before > 0 else 0
275
+ areas_after = areas[np.isin(unique_labels, labels_to_keep)]
276
+ mean_intensities_after = mean_intensities[np.isin(unique_labels, labels_to_keep)]
277
+ avg_area_after = areas_after.mean() if num_objects_after > 0 else 0
278
+ avg_intensity_after = mean_intensities_after.mean() if num_objects_after > 0 else 0
279
+ print(f"Before filtering {type_}: {num_objects_before} objects")
280
+ print(f"Average area {type_}: {avg_area_before:.2f} pixels, Average intensity: {avg_intensity_before:.2f}")
281
+ print(f"After filtering {type_}: {num_objects_after} objects")
282
+ print(f"Average area {type_}: {avg_area_after:.2f} pixels, Average intensity: {avg_intensity_after:.2f}")
283
+ mask_filtered = np.zeros_like(mask_int)
284
+ for label in labels_to_keep:
285
+ mask_filtered[mask_int == label] = label
286
+ mask_filtered = mask_filtered.astype(original_dtype)
287
+ return mask_filtered
288
+
289
+ stack = np.load(file)
290
+
291
+ if export_tiffs:
292
+ save_dir = os.path.join(
293
+ os.path.dirname(os.path.dirname(file)),
294
+ 'results',
295
+ os.path.splitext(os.path.basename(file))[0],
296
+ 'tiff'
297
+ )
298
+ filename = os.path.splitext(os.path.basename(file))[0]
299
+ _save_channels_as_tiff(stack, save_dir, filename)
300
+
301
+ # Convert to float for normalization and ensure correct handling of arrays
302
+ if stack.dtype in (np.uint16, np.uint8):
303
+ stack = stack.astype(np.float32)
304
+
305
+ image = stack[..., channels]
306
+ outlines = []
307
+ outline_colors = []
308
+
309
+ # Define variables to hold individual outlines
310
+ cell_outlines = None
311
+ nucleus_outlines = None
312
+ pathogen_outlines = None
313
+
314
+ if pathogen_channel is not None:
315
+ pathogen_mask_dim = -1
316
+ pathogen_outlines = np.take(stack, pathogen_mask_dim, axis=2)
317
+ if not filter_dict is None:
318
+ pathogen_intensity = np.take(stack, pathogen_channel, axis=2)
319
+ pathogen_outlines = _filter_object(pathogen_outlines, pathogen_intensity, filter_dict['pathogen'][0], filter_dict['pathogen'][1], type_='pathogen')
320
+
321
+ outlines.append(pathogen_outlines)
322
+ outline_colors.append('green')
323
+
324
+ if nucleus_channel is not None:
325
+ nucleus_mask_dim = -2 if pathogen_channel is not None else -1
326
+ nucleus_outlines = np.take(stack, nucleus_mask_dim, axis=2)
327
+ if not filter_dict is None:
328
+ nucleus_intensity = np.take(stack, nucleus_channel, axis=2)
329
+ nucleus_outlines = _filter_object(nucleus_outlines, nucleus_intensity, filter_dict['nucleus'][0], filter_dict['nucleus'][1], type_='nucleus')
330
+ outlines.append(nucleus_outlines)
331
+ outline_colors.append('blue')
332
+
333
+ if cell_channel is not None:
334
+ if nucleus_channel is not None and pathogen_channel is not None:
335
+ cell_mask_dim = -3
336
+ elif nucleus_channel is not None or pathogen_channel is not None:
337
+ cell_mask_dim = -2
338
+ else:
339
+ cell_mask_dim = -1
340
+ cell_outlines = np.take(stack, cell_mask_dim, axis=2)
341
+ if not filter_dict is None:
342
+ cell_intensity = np.take(stack, cell_channel, axis=2)
343
+ cell_outlines = _filter_object(cell_outlines, cell_intensity, filter_dict['cell'][0], filter_dict['cell'][1], type_='cell')
344
+ outlines.append(cell_outlines)
345
+ outline_colors.append('red')
346
+
347
+ fig = _plot_merged_plot(
348
+ image=image,
349
+ outlines=outlines,
350
+ outline_colors=outline_colors,
351
+ figuresize=figuresize,
352
+ thickness=thickness,
353
+ percentiles=percentiles, # Pass percentiles to the plotting function
354
+ mode=mode,
355
+ all_on_all=all_on_all,
356
+ all_outlines=all_outlines,
357
+ channels=channels,
358
+ cell_channel=cell_channel,
359
+ nucleus_channel=nucleus_channel,
360
+ pathogen_channel=pathogen_channel,
361
+ cell_outlines=cell_outlines,
362
+ nucleus_outlines=nucleus_outlines,
363
+ pathogen_outlines=pathogen_outlines,
364
+ save_pdf=save_pdf
365
+ )
366
+
367
+ return fig
368
+
369
+ def plot_image_mask_overlay_v1(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2,98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False):
36
370
  """Plot image and mask overlays."""
37
371
 
38
372
  def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness, percentiles, mode='outlines'):
@@ -1398,7 +1732,7 @@ def _plot_histograms_and_stats(df):
1398
1732
  print('-'*40)
1399
1733
 
1400
1734
  # Plot the histogram
1401
- plt.figure(figsize=(10,6))
1735
+ plt.figure(figsize=(10,10))
1402
1736
  plt.hist(subset['pred'], bins=30, edgecolor='black')
1403
1737
  plt.axvline(mean_pred, color='red', linestyle='dashed', linewidth=1, label=f"Mean = {mean_pred:.2f}")
1404
1738
  plt.title(f'Histogram for pred - Condition: {condition}')
@@ -1455,12 +1789,16 @@ def _reg_v_plot(df, grouping, variable, plate_number):
1455
1789
  plt.show()
1456
1790
 
1457
1791
  def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
1792
+
1793
+ if not isinstance(min_count, (int, float)):
1794
+ min_count = 0
1795
+
1458
1796
  df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
1459
1797
  df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
1460
1798
 
1461
1799
  # Filtering the dataframe based on the plate_number
1462
1800
  df = df[df['plate'] == plate_number].copy() # Create another copy after filtering
1463
-
1801
+
1464
1802
  # Ensure proper ordering
1465
1803
  row_order = [f'r{i}' for i in range(1, 17)]
1466
1804
  col_order = [f'c{i}' for i in range(1, 28)] # Exclude c15 as per your earlier code
@@ -1496,7 +1834,6 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1496
1834
  min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
1497
1835
  if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
1498
1836
  min_max = [min_max[0], min_max[1]]
1499
-
1500
1837
  return plate_map, min_max
1501
1838
 
1502
1839
  def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None):
@@ -1516,10 +1853,14 @@ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True
1516
1853
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1517
1854
 
1518
1855
  if not dst is None:
1519
- filename = os.path.join(dst, 'plate_heatmap.pdf')
1520
- fig.savefig(filename, format='pdf')
1521
- print(f'Saved heatmap to {filename}')
1522
-
1856
+ for i in range(0,1000):
1857
+ filename = os.path.join(dst, f'plate_heatmap_{i}.pdf')
1858
+ if os.path.exists(filename):
1859
+ continue
1860
+ else:
1861
+ fig.savefig(filename, format='pdf')
1862
+ print(f'Saved heatmap to {filename}')
1863
+ break
1523
1864
  if verbose:
1524
1865
  plt.show()
1525
1866
  return fig
@@ -1886,22 +2227,77 @@ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
1886
2227
  print(f'Saved Volcano plot: {filename}')
1887
2228
  plt.show()
1888
2229
 
1889
- def plot_histogram(df, dependent_variable, dst=None):
2230
+ def plot_histogram(df, column, dst=None):
1890
2231
  # Plot histogram of the dependent variable
1891
- plt.figure(figsize=(10, 6))
1892
- sns.histplot(df[dependent_variable], kde=True)
1893
- plt.title(f'Histogram of {dependent_variable}')
1894
- plt.xlabel(dependent_variable)
2232
+ bar_color = (0/255, 155/255, 155/255)
2233
+ plt.figure(figsize=(10, 10))
2234
+ sns.histplot(df[column], kde=False, color=bar_color, edgecolor=None, alpha=0.6)
2235
+ plt.title(f'Histogram of {column}')
2236
+ plt.xlabel(column)
1895
2237
  plt.ylabel('Frequency')
1896
2238
 
1897
2239
  if not dst is None:
1898
- filename = os.path.join(dst, 'dependent_variable_histogram.pdf')
2240
+ filename = os.path.join(dst, f'{column}_histogram.pdf')
1899
2241
  plt.savefig(filename, format='pdf')
1900
2242
  print(f'Saved histogram to {filename}')
1901
2243
 
1902
2244
  plt.show()
1903
2245
 
1904
- def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
2246
+ def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count', remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4'], x_lim=[0.0,1], y_lim=[0,1], save=True):
2247
+
2248
+ def lorenz_curve(data):
2249
+ """Calculate Lorenz curve."""
2250
+ sorted_data = np.sort(data)
2251
+ cumulative_data = np.cumsum(sorted_data)
2252
+ lorenz_curve = cumulative_data / cumulative_data[-1]
2253
+ lorenz_curve = np.insert(lorenz_curve, 0, 0)
2254
+ return lorenz_curve
2255
+
2256
+ combined_data = []
2257
+
2258
+ plt.figure(figsize=(10, 10))
2259
+
2260
+ for idx, csv_file in enumerate(csv_files):
2261
+ if idx == 1:
2262
+ save_fldr = os.path.dirname(csv_file)
2263
+ save_path = os.path.join(save_fldr, 'lorenz_curve.pdf')
2264
+
2265
+ df = pd.read_csv(csv_file)
2266
+ for remove in remove_keys:
2267
+ df = df[df[name_column] != remove]
2268
+
2269
+ values = df[value_column].values
2270
+ combined_data.extend(values)
2271
+
2272
+ lorenz = lorenz_curve(values)
2273
+ name = f"plate {idx+1}"
2274
+ plt.plot(np.linspace(0, 1, len(lorenz)), lorenz, label=name)
2275
+
2276
+ # Plot combined Lorenz curve
2277
+ combined_lorenz = lorenz_curve(np.array(combined_data))
2278
+ plt.plot(np.linspace(0, 1, len(combined_lorenz)), combined_lorenz, label="Combined", linestyle='--', color='black')
2279
+
2280
+ if x_lim != None:
2281
+ plt.xlim(x_lim)
2282
+
2283
+ if y_lim != None:
2284
+ plt.ylim(y_lim)
2285
+
2286
+ plt.title('Lorenz Curves')
2287
+ plt.xlabel('Cumulative Share of Individuals')
2288
+ plt.ylabel('Cumulative Share of Value')
2289
+ plt.legend()
2290
+ plt.grid(False)
2291
+
2292
+ if save:
2293
+ save_path = os.path.join(os.path.dirname(csv_files[0]), 'results')
2294
+ os.makedirs(save_path, exist_ok=True)
2295
+ save_file_path = os.path.join(save_path, 'lorenz_curve.pdf')
2296
+ plt.savefig(save_file_path, format='pdf', bbox_inches='tight')
2297
+ print(f"Saved Lorenz Curve: {save_file_path}")
2298
+ plt.show()
2299
+
2300
+ def plot_lorenz_curves_v1(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
1905
2301
 
1906
2302
  def lorenz_curve(data):
1907
2303
  """Calculate Lorenz curve."""
@@ -2358,22 +2754,33 @@ class spacrGraph:
2358
2754
  return filtered_df
2359
2755
 
2360
2756
  def perform_normality_tests(self):
2361
- """Perform normality tests for each group and each data column."""
2757
+ """Perform normality tests for each group and data column."""
2362
2758
  unique_groups = self.df[self.grouping_column].unique()
2363
2759
  normality_results = []
2364
2760
 
2365
2761
  for column in self.data_column:
2366
- # Iterate over each group and its corresponding data
2367
2762
  for group in unique_groups:
2368
- data = self.df.loc[self.df[self.grouping_column] == group, column]
2763
+ data = self.df.loc[self.df[self.grouping_column] == group, column].dropna()
2369
2764
  n_samples = len(data)
2370
2765
 
2766
+ if n_samples < 3:
2767
+ # Skip test if there aren't enough data points
2768
+ print(f"Skipping normality test for group '{group}' on column '{column}' - Not enough data.")
2769
+ normality_results.append({
2770
+ 'Comparison': f'Normality test for {group} on {column}',
2771
+ 'Test Statistic': None,
2772
+ 'p-value': None,
2773
+ 'Test Name': 'Skipped',
2774
+ 'Column': column,
2775
+ 'n': n_samples
2776
+ })
2777
+ continue
2778
+
2779
+ # Choose the appropriate normality test based on the sample size
2371
2780
  if n_samples >= 8:
2372
- # Use D'Agostino-Pearson test for larger samples
2373
2781
  stat, p_value = normaltest(data)
2374
2782
  test_name = "D'Agostino-Pearson test"
2375
2783
  else:
2376
- # Use Shapiro-Wilk test for smaller samples
2377
2784
  stat, p_value = shapiro(data)
2378
2785
  test_name = "Shapiro-Wilk test"
2379
2786
 
@@ -2384,11 +2791,11 @@ class spacrGraph:
2384
2791
  'p-value': p_value,
2385
2792
  'Test Name': test_name,
2386
2793
  'Column': column,
2387
- 'n': n_samples # Sample size
2794
+ 'n': n_samples
2388
2795
  })
2389
2796
 
2390
2797
  # Check if all groups are normally distributed (p > 0.05)
2391
- normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column]
2798
+ normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column and result['p-value'] is not None]
2392
2799
  is_normal = all(p > 0.05 for p in normal_p_values)
2393
2800
 
2394
2801
  return is_normal, normality_results
@@ -2438,9 +2845,13 @@ class spacrGraph:
2438
2845
  len(self.df[self.df[self.grouping_column] == unique_groups[1]])})
2439
2846
 
2440
2847
  return test_results
2441
-
2848
+
2442
2849
  def perform_posthoc_tests(self, is_normal, unique_groups):
2443
2850
  """Perform post-hoc tests for multiple groups based on all_to_all flag."""
2851
+
2852
+ from .utils import choose_p_adjust_method
2853
+
2854
+ posthoc_results = []
2444
2855
  if is_normal and len(unique_groups) > 2 and self.all_to_all:
2445
2856
  tukey_result = pairwise_tukeyhsd(self.df[self.data_column], self.df[self.grouping_column], alpha=0.05)
2446
2857
  posthoc_results = []
@@ -2456,22 +2867,40 @@ class spacrGraph:
2456
2867
  'n_object': len(raw_data1) + len(raw_data2),
2457
2868
  'n_well': len(self.df[self.df[self.grouping_column] == comparison[0]]) + len(self.df[self.df[self.grouping_column] == comparison[1]])})
2458
2869
  return posthoc_results
2459
-
2460
- elif len(unique_groups) > 2 and not self.all_to_all and self.compare_group:
2461
- dunn_result = pg.pairwise_tests(data=self.df, dv=self.data_column, between=self.grouping_column, padjust='bonf', test='dunn')
2462
- posthoc_results = []
2463
- for idx, row in dunn_result.iterrows():
2464
- if row['A'] == self.compare_group or row['B'] == self.compare_group:
2465
- posthoc_results.append({
2466
- 'Comparison': f"{row['A']} vs {row['B']}",
2467
- 'Test Statistic': row['T'], # Test statistic from Dunn's test
2468
- 'p-value': row['p-val'],
2469
- 'Test Name': 'Dunn’s Post-hoc',
2470
- 'n_object': None,
2471
- 'n_well': None})
2472
-
2870
+
2871
+ elif len(unique_groups) > 2 and self.all_to_all:
2872
+ print('performing_dunns')
2873
+
2874
+ # Prepare data for Dunn's test in long format
2875
+ long_data = self.df[[self.data_column[0], self.grouping_column]].dropna()
2876
+
2877
+ p_adjust_method = choose_p_adjust_method(num_groups=len(long_data[self.grouping_column].unique()),num_data_points=len(long_data) // len(long_data[self.grouping_column].unique()))
2878
+
2879
+ # Perform Dunn's test with Bonferroni correction
2880
+ dunn_result = sp.posthoc_dunn(
2881
+ long_data,
2882
+ val_col=self.data_column[0],
2883
+ group_col=self.grouping_column,
2884
+ p_adjust=p_adjust_method
2885
+ )
2886
+
2887
+ for group_a, group_b in zip(*np.triu_indices_from(dunn_result, k=1)):
2888
+ raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == dunn_result.index[group_a]][self.data_column]
2889
+ raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == dunn_result.columns[group_b]][self.data_column]
2890
+
2891
+ posthoc_results.append({
2892
+ 'Comparison': f"{dunn_result.index[group_a]} vs {dunn_result.columns[group_b]}",
2893
+ 'Test Statistic': None, # Dunn's test does not return a specific test statistic
2894
+ 'p-value': dunn_result.iloc[group_a, group_b], # Extract the p-value from the matrix
2895
+ 'Test Name': "Dunn's Post-hoc",
2896
+ 'p_adjust_method': p_adjust_method,
2897
+ 'n_object': len(raw_data1) + len(raw_data2), # Total objects
2898
+ 'n_well': len(self.df[self.df[self.grouping_column] == dunn_result.index[group_a]]) +
2899
+ len(self.df[self.grouping_column] == dunn_result.columns[group_b])})
2900
+
2473
2901
  return posthoc_results
2474
- return []
2902
+
2903
+ return posthoc_results
2475
2904
 
2476
2905
  def create_plot(self, ax=None):
2477
2906
  """Create and display the plot based on the chosen graph type."""
@@ -2507,7 +2936,43 @@ class spacrGraph:
2507
2936
  transposed_table = list(map(list, zip(*table_data)))
2508
2937
  return row_labels, transposed_table
2509
2938
 
2510
- def _place_symbols(row_labels, transposed_table, x_positions, ax):
2939
+
2940
+ def _place_symbols(row_labels, transposed_table, x_positions, ax):
2941
+ """
2942
+ Places symbols and row labels aligned under the bars or jitter points on the graph.
2943
+
2944
+ Parameters:
2945
+ - row_labels: List of row titles to be displayed along the y-axis.
2946
+ - transposed_table: Data to be placed under each bar/jitter as symbols.
2947
+ - x_positions: X-axis positions for each group to align the symbols.
2948
+ - ax: The matplotlib Axes object where the plot is drawn.
2949
+ """
2950
+ # Get plot dimensions and adjust for different plot sizes
2951
+ y_axis_min = ax.get_ylim()[0] # Minimum y-axis value (usually 0)
2952
+ symbol_start_y = y_axis_min - 0.05 * (ax.get_ylim()[1] - y_axis_min) # Adjust a bit below the x-axis
2953
+
2954
+ # Calculate spacing for the table rows (adjust as needed)
2955
+ y_spacing = 0.04 # Adjust this for better spacing between rows
2956
+
2957
+ # Determine the leftmost x-position for row labels (align with the y-axis)
2958
+ label_x_pos = ax.get_xlim()[0] - 0.3 # Adjust offset from the y-axis
2959
+
2960
+ # Place row labels vertically aligned with symbols
2961
+ for row_idx, title in enumerate(row_labels):
2962
+ y_pos = symbol_start_y - (row_idx * y_spacing) # Calculate vertical position for each label
2963
+ ax.text(label_x_pos, y_pos, title, ha='right', va='center', fontsize=12, fontweight='regular')
2964
+
2965
+ # Place symbols under each bar or jitter point based on x-positions
2966
+ for idx, (x_pos, column_data) in enumerate(zip(x_positions, transposed_table)):
2967
+ for row_idx, text in enumerate(column_data):
2968
+ y_pos = symbol_start_y - (row_idx * y_spacing) # Adjust vertical spacing for symbols
2969
+ ax.text(x_pos, y_pos, text, ha='center', va='center', fontsize=12, fontweight='regular')
2970
+
2971
+ # Redraw to apply changes
2972
+ ax.figure.canvas.draw()
2973
+
2974
+
2975
+ def _place_symbols_v1(row_labels, transposed_table, x_positions, ax):
2511
2976
 
2512
2977
  # Get the bottom of the y-axis (y=0) in data coordinates and convert to display coordinates
2513
2978
  y_axis_min = ax.get_ylim()[0] # Minimum y-axis value (usually 0)
@@ -2642,6 +3107,10 @@ class spacrGraph:
2642
3107
  else:
2643
3108
  raise ValueError(f"Unknown graph type: {self.graph_type}")
2644
3109
 
3110
+ if len(self.data_column) == 1:
3111
+ num_groups = len(self.df[self.grouping_column].unique())
3112
+ self._standerdize_figure_format(ax=ax, num_groups=num_groups, graph_type=self.graph_type)
3113
+
2645
3114
  # Set y-axis start
2646
3115
  if isinstance(self.y_lim, list):
2647
3116
  if len(self.y_lim) == 2:
@@ -2676,7 +3145,73 @@ class spacrGraph:
2676
3145
  if self.save:
2677
3146
  self._save_results()
2678
3147
 
2679
- ax.margins(x=0.12)
3148
+ ax.margins(x=0.12)
3149
+
3150
+ def _standerdize_figure_format(self, ax, num_groups, graph_type):
3151
+ """
3152
+ Adjusts the figure layout (size, bar width, jitter, and spacing) based on the number of groups.
3153
+
3154
+ Parameters:
3155
+ - ax: The matplotlib Axes object.
3156
+ - num_groups: Number of unique groups.
3157
+ - graph_type: The type of graph (e.g., 'bar', 'jitter', 'box', etc.).
3158
+
3159
+ Returns:
3160
+ - None. Modifies the figure and Axes in place.
3161
+ """
3162
+ if graph_type in ['line', 'line_std']:
3163
+ print("Skipping layout adjustment for line graphs.")
3164
+ return # Skip layout adjustment for line graphs
3165
+
3166
+ correction_factor = 4
3167
+
3168
+ # Set figure size to ensure it remains square with a minimum size
3169
+ fig_size = max(6, num_groups * 2) / correction_factor
3170
+ ax.figure.set_size_inches(fig_size, fig_size)
3171
+
3172
+ # Configure layout based on the number of groups
3173
+ bar_width = min(0.8, 1.5 / num_groups) / correction_factor
3174
+ jitter_amount = min(0.1, 0.2 / num_groups) / correction_factor
3175
+ jitter_size = max(50 / num_groups, 200)
3176
+
3177
+ # Adjust axis limits to ensure bars are centered with respect to group labels
3178
+ ax.set_xlim(-0.5, num_groups - 0.5)
3179
+
3180
+ # Set ticks to match the group labels in your DataFrame
3181
+ group_labels = self.df[self.grouping_column].unique()
3182
+ ax.set_xticks(range(len(group_labels)))
3183
+ ax.set_xticklabels(group_labels, rotation=45, ha='right')
3184
+
3185
+ # Customize elements based on the graph type
3186
+ if graph_type == 'bar':
3187
+ # Adjust bars' width and position
3188
+ for bar in ax.patches:
3189
+ bar.set_width(bar_width)
3190
+ bar.set_x(bar.get_x() - bar_width / 2)
3191
+
3192
+ elif graph_type in ['jitter', 'jitter_bar', 'jitter_box']:
3193
+ # Adjust jitter points' position and size
3194
+ for coll in ax.collections:
3195
+ offsets = coll.get_offsets()
3196
+ offsets[:, 0] += jitter_amount # Shift jitter points slightly
3197
+ coll.set_offsets(offsets)
3198
+ coll.set_sizes([jitter_size] * len(offsets)) # Adjust point size dynamically
3199
+
3200
+ elif graph_type in ['box', 'violin']:
3201
+ # Adjust box width for consistent spacing
3202
+ for artist in ax.artists:
3203
+ artist.set_width(bar_width)
3204
+
3205
+ # Adjust legend and axis labels
3206
+ ax.tick_params(axis='x', labelsize=max(10, 15 - num_groups // 2))
3207
+ ax.tick_params(axis='y', labelsize=max(10, 15 - num_groups // 2))
3208
+
3209
+ if ax.get_legend():
3210
+ ax.get_legend().set_bbox_to_anchor((1.05, 1)) #loc='upper left',borderaxespad=0.
3211
+ ax.get_legend().prop.set_size(max(8, 12 - num_groups // 3))
3212
+
3213
+ # Redraw the figure to apply changes
3214
+ ax.figure.canvas.draw()
2680
3215
 
2681
3216
  def _create_bar_plot(self, ax):
2682
3217
  """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
@@ -2895,11 +3430,11 @@ class spacrGraph:
2895
3430
  bar.set_x(bar.get_x() - target_width / 2)
2896
3431
 
2897
3432
  # Adjust error bars alignment with bars
2898
- bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
2899
- for bar, (_, row) in zip(bars, summary_df.iterrows()):
2900
- x_bar = bar.get_x() + bar.get_width() / 2
2901
- err = row[self.error_bar_type]
2902
- ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
3433
+ #bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
3434
+ #for bar, (_, row) in zip(bars, summary_df.iterrows()):
3435
+ # x_bar = bar.get_x() + bar.get_width() / 2
3436
+ # err = row[self.error_bar_type]
3437
+ # ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
2903
3438
 
2904
3439
  # Set legend and labels
2905
3440
  ax.set_xlabel(self.grouping_column)
@@ -3092,9 +3627,13 @@ def plot_data_from_csv(settings):
3092
3627
  dft = pd.read_csv(src)
3093
3628
  if 'plate' not in dft.columns:
3094
3629
  dft['plate'] = f"plate{i+1}"
3630
+ dft['common'] = 'spacr'
3095
3631
  dfs.append(dft)
3096
3632
 
3097
3633
  df = pd.concat(dfs, axis=0)
3634
+
3635
+ display(df)
3636
+
3098
3637
  df = df.dropna(subset=settings['data_column'])
3099
3638
  df = df.dropna(subset=settings['grouping_column'])
3100
3639
  src = srcs[0]
@@ -3141,23 +3680,39 @@ def plot_region(settings):
3141
3680
  print(f"Saved {path}")
3142
3681
 
3143
3682
  from .io import _read_db
3683
+ from .utils import correct_paths
3144
3684
  fov_path = os.path.join(settings['src'], 'merged', settings['name'])
3145
3685
  name = os.path.splitext(settings['name'])[0]
3146
3686
 
3147
3687
  db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
3148
3688
  paths_df = _read_db(db_path, tables=['png_list'])[0]
3689
+ paths_df, _ = correct_paths(df=paths_df, base_path=settings['src'], folder='data')
3149
3690
  paths_df = paths_df[paths_df['png_path'].str.contains(name, na=False)]
3150
3691
 
3151
3692
  activation_mode = f"{settings['activation_mode']}_list"
3152
3693
  activation_db_path = os.path.join(settings['src'], 'measurements', settings['activation_db'])
3153
3694
  activation_paths_df = _read_db(activation_db_path, tables=[activation_mode])[0]
3695
+ activation_db = os.path.splitext(settings['activation_db'])[0]
3696
+ base_path=os.path.join(settings['src'], 'datasets',activation_db)
3697
+ activation_paths_df, _ = correct_paths(df=activation_paths_df, base_path=base_path, folder=settings['activation_mode'])
3154
3698
  activation_paths_df = activation_paths_df[activation_paths_df['png_path'].str.contains(name, na=False)]
3155
3699
 
3156
3700
  png_paths = _sort_paths_by_basename(paths_df['png_path'].tolist())
3157
3701
  activation_paths = _sort_paths_by_basename(activation_paths_df['png_path'].tolist())
3158
3702
 
3159
- fig_3 = plot_image_grid(image_paths=activation_paths, percentiles=settings['percentiles'])
3160
- fig_2 = plot_image_grid(image_paths=png_paths, percentiles=settings['percentiles'])
3703
+
3704
+ if activation_paths:
3705
+ fig_3 = plot_image_grid(image_paths=activation_paths, percentiles=settings['percentiles'])
3706
+ else:
3707
+ fig_3 = None
3708
+ print(f"Could not find any cropped PNGs")
3709
+ if png_paths:
3710
+ fig_2 = plot_image_grid(image_paths=png_paths, percentiles=settings['percentiles'])
3711
+ else:
3712
+ fig_2 = None
3713
+ print(f"Could not find any activation maps")
3714
+
3715
+ print('fov_path', fov_path)
3161
3716
  fig_1 = plot_image_mask_overlay(file=fov_path,
3162
3717
  channels=settings['channels'],
3163
3718
  cell_channel=settings['cell_channel'],
@@ -3166,14 +3721,18 @@ def plot_region(settings):
3166
3721
  figuresize=10,
3167
3722
  percentiles=settings['percentiles'],
3168
3723
  thickness=3,
3169
- save_pdf=False,
3724
+ save_pdf=True,
3170
3725
  mode=settings['mode'],
3171
3726
  export_tiffs=settings['export_tiffs'])
3172
3727
 
3173
3728
  dst = os.path.join(settings['src'], 'results', name)
3174
- save_figure_as_pdf(fig_1, os.path.join(dst, f"{name}_mask_overlay.pdf"))
3175
- save_figure_as_pdf(fig_2, os.path.join(dst, f"{name}_png_grid.pdf"))
3176
- save_figure_as_pdf(fig_3, os.path.join(dst, f"{name}_activation_grid.pdf"))
3729
+
3730
+ if not fig_1 == None:
3731
+ save_figure_as_pdf(fig_1, os.path.join(dst, f"{name}_mask_overlay.pdf"))
3732
+ if not fig_2 == None:
3733
+ save_figure_as_pdf(fig_2, os.path.join(dst, f"{name}_png_grid.pdf"))
3734
+ if not fig_3 == None:
3735
+ save_figure_as_pdf(fig_3, os.path.join(dst, f"{name}_activation_grid.pdf"))
3177
3736
 
3178
3737
  return fig_1, fig_2, fig_3
3179
3738
 
@@ -3337,4 +3896,5 @@ def overlay_masks_on_images(img_folder, normalize=True, resize=True, save=False,
3337
3896
  plt.imshow(blended)
3338
3897
  plt.title(f"Overlay: {filename}")
3339
3898
  plt.axis('off')
3340
- plt.show()
3899
+ plt.show()
3900
+