spacr 0.3.37__py3-none-any.whl → 0.3.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -13,10 +13,11 @@ from IPython.display import display
13
13
  from skimage.segmentation import find_boundaries
14
14
  from skimage import measure
15
15
  from skimage.measure import find_contours, label, regionprops
16
+ import tifffile as tiff
16
17
 
17
18
  from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
18
19
  from statsmodels.stats.multicomp import pairwise_tukeyhsd
19
- from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
20
+ from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal, normaltest, shapiro
20
21
  import itertools
21
22
  import pingouin as pg
22
23
 
@@ -25,13 +26,26 @@ from IPython.display import Image as ipyimage
25
26
 
26
27
  import matplotlib.patches as patches
27
28
  from collections import defaultdict
29
+ from matplotlib.gridspec import GridSpec
28
30
 
29
- def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, normalize=True, thickness=3, save_pdf=True):
31
+ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2,98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False):
30
32
  """Plot image and mask overlays."""
31
33
 
32
- def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness):
34
+ def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness, percentiles, mode='outlines'):
33
35
  """Plot the merged plot with overlay, image channels, and masks."""
34
36
 
37
+ def _generate_colored_mask(mask, alpha):
38
+ """ Generate a colored mask with transparency using the given colormap. """
39
+ cmap = generate_mask_random_cmap(mask)
40
+ rgba_mask = cmap(mask / mask.max()) # Normalize mask and map to colormap (RGBA)
41
+ rgba_mask[..., 3] = np.where(mask > 0, alpha, 0) # Apply transparency only where mask is present
42
+ return rgba_mask
43
+
44
+ def _overlay_mask(image, mask):
45
+ """Overlay the colored mask onto the original image."""
46
+ combined = np.clip(image + mask[..., :3] * mask[..., 3:4], 0, 1) # Ensure pixel values stay in [0, 1]
47
+ return combined
48
+
35
49
  def _normalize_image(image, percentiles=(2, 98)):
36
50
  """Normalize the image to the given percentiles."""
37
51
  v_min, v_max = np.percentile(image, percentiles)
@@ -61,11 +75,15 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
61
75
  # Plot each channel with its corresponding outlines
62
76
  for v in range(num_channels):
63
77
  channel_image = image[..., v]
64
- channel_image_normalized = _normalize_image(channel_image)
78
+ channel_image_normalized = _normalize_image(channel_image, percentiles)
65
79
  channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
66
80
 
67
81
  for outline, color in zip(outlines, outline_colors):
68
- channel_image_rgb = _apply_contours(channel_image_rgb, outline, color, thickness)
82
+ if mode == 'outlines':
83
+ channel_image_rgb = _apply_contours(channel_image_rgb, outline, color, thickness)
84
+ else:
85
+ mask = _generate_colored_mask(outline, alpha=0.5)
86
+ channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
69
87
 
70
88
  ax[v].imshow(channel_image_rgb)
71
89
  ax[v].set_title(f'Image - Channel {v}')
@@ -75,11 +93,15 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
75
93
  rgb_channels = min(3, num_channels)
76
94
  for i in range(rgb_channels):
77
95
  channel_image = image[..., i]
78
- channel_image_normalized = _normalize_image(channel_image)
96
+ channel_image_normalized = _normalize_image(channel_image, percentiles)
79
97
  rgb_image[..., i] = channel_image_normalized
80
98
 
81
99
  for outline, color in zip(outlines, outline_colors):
82
- rgb_image = _apply_contours(rgb_image, outline, color, thickness)
100
+ if mode == 'outlines':
101
+ rgb_image = _apply_contours(rgb_image, outline, color, thickness)
102
+ else:
103
+ mask = _generate_colored_mask(outline, alpha=0.5)
104
+ rgb_image = _overlay_mask(rgb_image, mask)
83
105
 
84
106
  ax[-1].imshow(rgb_image)
85
107
  ax[-1].set_title('Combined RGB Image')
@@ -96,8 +118,22 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
96
118
  plt.show()
97
119
  return fig
98
120
 
121
+ def _save_channels_as_tiff(stack, save_dir, filename):
122
+ """Save each channel in the stack as a grayscale TIFF."""
123
+ os.makedirs(save_dir, exist_ok=True)
124
+ for i in range(stack.shape[-1]):
125
+ channel = stack[..., i]
126
+ tiff_path = os.path.join(save_dir, f"{filename}_channel_{i}.tiff")
127
+ tiff.imwrite(tiff_path, channel, photometric='minisblack')
128
+ print(f"Saved {tiff_path}")
129
+
99
130
  stack = np.load(file)
100
131
 
132
+ if export_tiffs:
133
+ save_dir = os.path.join(os.path.dirname(os.path.dirname(file)), 'results', os.path.splitext(os.path.basename(file))[0], 'tiff')
134
+ filename = os.path.splitext(os.path.basename(file))[0]
135
+ _save_channels_as_tiff(stack, save_dir, filename)
136
+
101
137
  # Convert to float for normalization and ensure correct handling of both 8-bit and 16-bit arrays
102
138
  if stack.dtype == np.uint16:
103
139
  stack = stack.astype(np.float32)
@@ -128,7 +164,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
128
164
  outlines.append(np.take(stack, cell_mask_dim, axis=2))
129
165
  outline_colors.append('red')
130
166
 
131
- fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
167
+ fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness, percentiles=percentiles, mode=mode)
132
168
 
133
169
  return fig
134
170
 
@@ -1691,17 +1727,25 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
1691
1727
  overlay=True,
1692
1728
  max_nr=10,
1693
1729
  randomize=True)
1694
-
1730
+
1695
1731
  def volcano_plot(coef_df, filename='volcano_plot.pdf'):
1732
+ palette = {
1733
+ 'pc': 'red',
1734
+ 'nc': 'green',
1735
+ 'control': 'blue',
1736
+ 'other': 'gray'
1737
+ }
1738
+
1696
1739
  # Create the volcano plot
1697
1740
  plt.figure(figsize=(10, 6))
1698
1741
  sns.scatterplot(
1699
1742
  data=coef_df,
1700
1743
  x='coefficient',
1701
1744
  y='-log10(p_value)',
1702
- hue='highlight',
1703
- palette={True: 'red', False: 'blue'}
1745
+ hue='condition',
1746
+ palette=palette
1704
1747
  )
1748
+
1705
1749
  plt.title('Volcano Plot of Coefficients')
1706
1750
  plt.xlabel('Coefficient')
1707
1751
  plt.ylabel('-log10(p-value)')
@@ -2098,7 +2142,7 @@ class spacrGraph:
2098
2142
  def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
2099
2143
  order=None, colors=None, output_dir='./output', save=False, y_lim=None,
2100
2144
  error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
2101
- paired=False, all_to_all=True, compare_group=None):
2145
+ paired=False, all_to_all=True, compare_group=None, graph_name=None):
2102
2146
 
2103
2147
  """
2104
2148
  Class for creating grouped plots with optional statistical tests and data preprocessing.
@@ -2121,11 +2165,14 @@ class spacrGraph:
2121
2165
  self.all_to_all = all_to_all
2122
2166
  self.compare_group = compare_group
2123
2167
  self.y_lim = y_lim
2168
+ self.graph_name = graph_name
2169
+
2170
+
2124
2171
  self.results_df = pd.DataFrame()
2125
2172
  self.sns_palette = None
2126
2173
  self.fig = None
2127
2174
 
2128
- self.results_name = str(self.data_column[0])+'_'+str(self.grouping_column)+'_'+str(self.graph_type)
2175
+ self.results_name = str(self.graph_name)+'_'+str(self.data_column[0])+'_'+str(self.grouping_column)+'_'+str(self.graph_type)
2129
2176
 
2130
2177
  self._set_theme()
2131
2178
  self.raw_df = self.df.copy()
@@ -2134,10 +2181,10 @@ class spacrGraph:
2134
2181
  def _set_theme(self):
2135
2182
  """Set the Seaborn theme and reorder colors if necessary."""
2136
2183
  integer_list = list(range(1, 81))
2137
- color_order = [0, 3, 9, 4, 6, 7, 9, 2] + integer_list
2184
+ color_order = [7,9,4,0,3,6,2] + integer_list
2138
2185
  self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
2139
2186
 
2140
- def _set_reordered_theme(self, theme='muted', order=None, n_colors=100, show_theme=False):
2187
+ def _set_reordered_theme(self, theme='deep', order=None, n_colors=100, show_theme=False):
2141
2188
  """Set and reorder the Seaborn color palette."""
2142
2189
  palette = sns.color_palette(theme, n_colors)
2143
2190
  if order:
@@ -2152,10 +2199,12 @@ class spacrGraph:
2152
2199
  def preprocess_data(self):
2153
2200
  """Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
2154
2201
  # Remove NaNs in both the grouping column and each data column
2155
- df = self.df.dropna(subset=[self.grouping_column] + self.data_column) # Handle multiple data columns
2202
+ df = self.df.dropna(subset=[self.grouping_column] + self.data_column)
2156
2203
  # Group by 'prc' column if representation is 'well'
2157
2204
  if self.representation == 'well':
2158
2205
  df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
2206
+ if self.representation == 'plate':
2207
+ df = df.groupby(['plate', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
2159
2208
  if self.order:
2160
2209
  df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
2161
2210
  else:
@@ -2180,20 +2229,36 @@ class spacrGraph:
2180
2229
  """Perform normality tests for each group and each data column."""
2181
2230
  unique_groups = self.df[self.grouping_column].unique()
2182
2231
  normality_results = []
2232
+
2183
2233
  for column in self.data_column:
2184
- grouped_data = [self.df.loc[self.df[self.grouping_column] == group, column] for group in unique_groups]
2185
- normal_p_values = [normaltest(data).pvalue for data in grouped_data]
2186
- normal_stats = [normaltest(data).statistic for data in grouped_data]
2187
- is_normal = all(p > 0.05 for p in normal_p_values) # Test if all groups are normal
2188
- for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
2234
+ # Iterate over each group and its corresponding data
2235
+ for group in unique_groups:
2236
+ data = self.df.loc[self.df[self.grouping_column] == group, column]
2237
+ n_samples = len(data)
2238
+
2239
+ if n_samples >= 8:
2240
+ # Use D'Agostino-Pearson test for larger samples
2241
+ stat, p_value = normaltest(data)
2242
+ test_name = "D'Agostino-Pearson test"
2243
+ else:
2244
+ # Use Shapiro-Wilk test for smaller samples
2245
+ stat, p_value = shapiro(data)
2246
+ test_name = "Shapiro-Wilk test"
2247
+
2248
+ # Store the result for this group and column
2189
2249
  normality_results.append({
2190
2250
  'Comparison': f'Normality test for {group} on {column}',
2191
2251
  'Test Statistic': stat,
2192
2252
  'p-value': p_value,
2193
- 'Test Name': 'Normality test',
2253
+ 'Test Name': test_name,
2194
2254
  'Column': column,
2195
- 'n': len(self.df[self.df[self.grouping_column] == group]) # Sample size
2255
+ 'n': n_samples # Sample size
2196
2256
  })
2257
+
2258
+ # Check if all groups are normally distributed (p > 0.05)
2259
+ normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column]
2260
+ is_normal = all(p > 0.05 for p in normal_p_values)
2261
+
2197
2262
  return is_normal, normality_results
2198
2263
 
2199
2264
  def perform_levene_test(self, unique_groups):
@@ -2337,17 +2402,21 @@ class spacrGraph:
2337
2402
  ax.text(x_pos, y_pos, text, ha='center', va='center', fontsize=12)
2338
2403
 
2339
2404
  def _get_positions(self, ax):
2340
- if self.graph_type == 'bar':
2405
+ if self.graph_type in ['bar','jitter_bar']:
2341
2406
  x_positions = [np.mean(bar.get_paths()[0].vertices[:, 0]) for bar in ax.collections if hasattr(bar, 'get_paths')]
2342
2407
 
2343
2408
  elif self.graph_type == 'violin':
2344
2409
  x_positions = [np.mean(violin.get_paths()[0].vertices[:, 0]) for violin in ax.collections if hasattr(violin, 'get_paths')]
2345
2410
 
2346
- elif self.graph_type == 'box':
2411
+ elif self.graph_type in ['box', 'jitter_box']:
2347
2412
  x_positions = list(set(line.get_xdata().mean() for line in ax.lines if line.get_linestyle() == '-'))
2348
2413
 
2349
2414
  elif self.graph_type == 'jitter':
2350
2415
  x_positions = [np.mean(collection.get_offsets()[:, 0]) for collection in ax.collections if collection.get_offsets().size > 0]
2416
+
2417
+ elif self.graph_type in ['line', 'line_std']:
2418
+ x_positions = []
2419
+
2351
2420
  return x_positions
2352
2421
 
2353
2422
  def _draw_comparison_lines(ax, x_positions):
@@ -2365,7 +2434,7 @@ class spacrGraph:
2365
2434
 
2366
2435
  # Determine significance marker
2367
2436
  if p_value <= 0.001:
2368
- significance = '***'
2437
+ signiresults_namecance = '***'
2369
2438
  elif p_value <= 0.01:
2370
2439
  significance = '**'
2371
2440
  elif p_value <= 0.05:
@@ -2406,6 +2475,9 @@ class spacrGraph:
2406
2475
  self.fig_width = (num_groups * self.bar_width) + (spacing_between_groups * num_groups)
2407
2476
  self.fig_height = self.fig_width/2
2408
2477
 
2478
+ if self.graph_type in ['line','line_std']:
2479
+ self.fig_height, self.fig_width = 10, 10
2480
+
2409
2481
  if ax is None:
2410
2482
  self.fig, ax = plt.subplots(figsize=(self.fig_height, self.fig_width))
2411
2483
  else:
@@ -2427,6 +2499,14 @@ class spacrGraph:
2427
2499
  self._create_box_plot(ax)
2428
2500
  elif self.graph_type == 'violin':
2429
2501
  self._create_violin_plot(ax)
2502
+ elif self.graph_type == 'jitter_box':
2503
+ self._create_jitter_box_plot(ax)
2504
+ elif self.graph_type == 'jitter_bar':
2505
+ self._create_jitter_bar_plot(ax)
2506
+ elif self.graph_type == 'line':
2507
+ self._create_line_graph(ax)
2508
+ elif self.graph_type == 'line_std':
2509
+ self._create_line_with_std_area(ax)
2430
2510
  else:
2431
2511
  raise ValueError(f"Unknown graph type: {self.graph_type}")
2432
2512
 
@@ -2439,14 +2519,17 @@ class spacrGraph:
2439
2519
 
2440
2520
  sns.despine(ax=ax, top=True, right=True)
2441
2521
  ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Data Column') # Move the legend outside the plot
2442
- ax.set_xlabel('')
2522
+
2523
+ if not self.graph_type in ['line','line_std']:
2524
+ ax.set_xlabel('')
2525
+
2443
2526
  x_positions = _get_positions(self, ax)
2444
2527
 
2445
- if len(self.data_column) == 1:
2528
+ if len(self.data_column) == 1 and not self.graph_type in ['line','line_std']:
2446
2529
  ax.legend().remove()
2447
2530
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
2448
2531
 
2449
- elif len(self.data_column) > 1:
2532
+ elif len(self.data_column) > 1 and not self.graph_type in ['line','line_std']:
2450
2533
  ax.set_xticks([])
2451
2534
  ax.tick_params(bottom=False)
2452
2535
  ax.set_xticklabels([])
@@ -2522,7 +2605,54 @@ class spacrGraph:
2522
2605
  handles, labels = ax.get_legend_handles_labels()
2523
2606
  unique_labels = dict(zip(labels, handles))
2524
2607
  ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
2525
-
2608
+
2609
+ def _create_line_graph(self, ax):
2610
+ """Helper method to create a line graph with one line per group based on epochs and accuracy."""
2611
+ #display(self.df)
2612
+ # Ensure epoch is used on the x-axis and accuracy on the y-axis
2613
+ x_axis_column = self.data_column[0]
2614
+ y_axis_column = self.data_column[1]
2615
+
2616
+ # Set hue to the grouping column to get one line per group
2617
+ hue = self.grouping_column
2618
+
2619
+ # Check if the required columns exist in the DataFrame
2620
+ required_columns = [x_axis_column, y_axis_column, self.grouping_column]
2621
+ for col in required_columns:
2622
+ if col not in self.df.columns:
2623
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
2624
+
2625
+ # Create the line graph with one line per group
2626
+ sns.lineplot(data=self.df,x=x_axis_column,y=y_axis_column,hue=hue,palette=self.sns_palette,ax=ax,marker='o',linewidth=1,markersize=6)
2627
+
2628
+ # Adjust axis labels
2629
+ ax.set_xlabel(f"{x_axis_column}")
2630
+ ax.set_ylabel(f"{y_axis_column}")
2631
+
2632
+ def _create_line_with_std_area(self, ax):
2633
+ """Helper method to create a line graph with shaded area representing standard deviation."""
2634
+
2635
+ x_axis_column = self.data_column[0]
2636
+ y_axis_column = self.data_column[1]
2637
+ y_axis_column_mean = f"mean_{y_axis_column}"
2638
+ y_axis_column_std = f"std_{y_axis_column_mean}"
2639
+
2640
+ # Pivot the DataFrame to get mean and std for each epoch across plates
2641
+ summary_df = self.df.pivot_table(index=x_axis_column,values=y_axis_column,aggfunc=['mean', 'std']).reset_index()
2642
+
2643
+ # Flatten MultiIndex columns (result of pivoting)
2644
+ summary_df.columns = [x_axis_column, y_axis_column_mean, y_axis_column_std]
2645
+
2646
+ # Plot the mean accuracy as a line
2647
+ sns.lineplot(data=summary_df,x=x_axis_column,y=y_axis_column_mean,ax=ax,marker='o',linewidth=1,markersize=0,color='blue',label=y_axis_column_mean)
2648
+
2649
+ # Fill the area representing the standard deviation
2650
+ ax.fill_between(summary_df[x_axis_column],summary_df[y_axis_column_mean] - summary_df[y_axis_column_std],summary_df[y_axis_column_mean] + summary_df[y_axis_column_std],color='blue', alpha=0.1 )
2651
+
2652
+ # Adjust axis labels
2653
+ ax.set_xlabel(f"{x_axis_column}")
2654
+ ax.set_ylabel(f"{y_axis_column}")
2655
+
2526
2656
  def _create_box_plot(self, ax):
2527
2657
  """Helper method to create a box plot with consistent spacing."""
2528
2658
  # Combine grouping column and data column if needed
@@ -2572,6 +2702,68 @@ class spacrGraph:
2572
2702
  unique_labels = dict(zip(labels, handles))
2573
2703
  ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
2574
2704
 
2705
+ def _create_jitter_bar_plot(self, ax):
2706
+ """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
2707
+ # Flatten DataFrame: Combine grouping column and data column into one group if needed
2708
+ if len(self.data_column) > 1:
2709
+ self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
2710
+ x_axis_column = 'Combined Group'
2711
+ hue = None
2712
+ ax.set_ylabel('Value')
2713
+ else:
2714
+ x_axis_column = self.grouping_column
2715
+ ax.set_ylabel(self.data_column[0])
2716
+ hue = None
2717
+
2718
+ summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
2719
+ error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
2720
+ sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
2721
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16)
2722
+
2723
+ # Adjust the bar width manually
2724
+ if len(self.data_column) > 1:
2725
+ bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
2726
+ target_width = self.bar_width * 2
2727
+ for bar in bars:
2728
+ bar.set_width(target_width) # Set new width
2729
+ # Center the bar on its x-coordinate
2730
+ bar.set_x(bar.get_x() - target_width / 2)
2731
+
2732
+ # Adjust error bars alignment with bars
2733
+ bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
2734
+ for bar, (_, row) in zip(bars, summary_df.iterrows()):
2735
+ x_bar = bar.get_x() + bar.get_width() / 2
2736
+ err = row[self.error_bar_type]
2737
+ ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
2738
+
2739
+ # Set legend and labels
2740
+ ax.set_xlabel(self.grouping_column)
2741
+
2742
+ def _create_jitter_box_plot(self, ax):
2743
+ """Helper method to create a box plot with consistent spacing."""
2744
+ # Combine grouping column and data column if needed
2745
+ if len(self.data_column) > 1:
2746
+ self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
2747
+ x_axis_column = 'Combined Group'
2748
+ hue = None
2749
+ ax.set_ylabel('Value')
2750
+ else:
2751
+ x_axis_column = self.grouping_column
2752
+ ax.set_ylabel(self.data_column[0])
2753
+ hue = None
2754
+
2755
+ # Create the box plot
2756
+ sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
2757
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12)
2758
+
2759
+ # Adjust legend and labels
2760
+ ax.set_xlabel(self.grouping_column)
2761
+
2762
+ # Manage the legend
2763
+ handles, labels = ax.get_legend_handles_labels()
2764
+ unique_labels = dict(zip(labels, handles))
2765
+ ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
2766
+
2575
2767
  def _save_results(self):
2576
2768
  """Helper method to save the plot and results."""
2577
2769
  os.makedirs(self.output_dir, exist_ok=True)
@@ -2592,14 +2784,14 @@ class spacrGraph:
2592
2784
 
2593
2785
  def plot_data_from_db(settings):
2594
2786
  from .io import _read_db, _read_and_merge_data
2595
- from .utils import annotate_conditions
2787
+ from .utils import annotate_conditions, save_settings
2596
2788
  """
2597
2789
  Extracts the specified table from the SQLite database and plots a specified column.
2598
2790
 
2599
2791
  Args:
2600
2792
  db_path (str): The path to the SQLite database.
2601
2793
  table_names (str): The name of the table to extract.
2602
- column_name (str): The column to plot from the table.
2794
+ data_column (str): The column to plot from the table.
2603
2795
 
2604
2796
  Returns:
2605
2797
  df (pd.DataFrame): The extracted table as a DataFrame.
@@ -2614,6 +2806,8 @@ def plot_data_from_db(settings):
2614
2806
  else:
2615
2807
  raise ValueError("src must be a string or a list of strings.")
2616
2808
 
2809
+ save_settings(settings, name=f"{settings['graph_name']}_plot_settings_db", show=True)
2810
+
2617
2811
  dfs = []
2618
2812
  for i, src in enumerate(srcs):
2619
2813
 
@@ -2641,6 +2835,7 @@ def plot_data_from_db(settings):
2641
2835
  df = pd.concat(dfs, axis=0)
2642
2836
  df['prc'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str)
2643
2837
  df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
2838
+ df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
2644
2839
 
2645
2840
  if settings['cell_plate_metadata'] != None:
2646
2841
  df = df.dropna(subset='host_cell')
@@ -2651,20 +2846,23 @@ def plot_data_from_db(settings):
2651
2846
  if settings['treatment_plate_metadata'] != None:
2652
2847
  df = df.dropna(subset='treatment')
2653
2848
 
2654
- df = df.dropna(subset=settings['column_name'])
2849
+ df = df.dropna(subset=settings['data_column'])
2655
2850
  df = df.dropna(subset=settings['grouping_column'])
2656
- #display(df)
2657
2851
 
2658
2852
  #df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
2853
+ src = srcs[0]
2854
+ dst = os.path.join(src, 'results', settings['graph_name'])
2855
+ os.makedirs(dst, exist_ok=True)
2659
2856
 
2660
2857
  spacr_graph = spacrGraph(
2661
2858
  df=df, # Your DataFrame
2662
2859
  grouping_column=settings['grouping_column'], # Column for grouping the data (x-axis)
2663
- data_column=settings['column_name'], # Column for the data (y-axis)
2860
+ data_column=settings['data_column'], # Column for the data (y-axis)
2664
2861
  graph_type=settings['graph_type'], # Type of plot ('bar', 'box', 'violin', 'jitter')
2862
+ graph_name=settings['graph_name'], # Name of the plot
2665
2863
  summary_func='mean', # Function to summarize data (e.g., 'mean', 'median')
2666
2864
  colors=None, # Custom colors for the plot (optional)
2667
- output_dir=settings['dst'], # Directory to save the plot and results
2865
+ output_dir=dst, # Directory to save the plot and results
2668
2866
  save=settings['save'], # Whether to save the plot and results
2669
2867
  y_lim=settings['y_lim'], # Starting point for y-axis (optional)
2670
2868
  error_bar_type='std', # Type of error bar ('std' or 'sem')
@@ -2681,5 +2879,197 @@ def plot_data_from_db(settings):
2681
2879
 
2682
2880
  # Optional: Get the results DataFrame containing statistical test results
2683
2881
  results_df = spacr_graph.get_results()
2882
+ return fig, results_df
2883
+
2884
+ def plot_data_from_csv(settings):
2885
+ from .io import _read_db, _read_and_merge_data
2886
+ from .utils import annotate_conditions, save_settings
2887
+ """
2888
+ Extracts the specified table from the SQLite database and plots a specified column.
2889
+
2890
+ Args:
2891
+ db_path (str): The path to the SQLite database.
2892
+ table_names (str): The name of the table to extract.
2893
+ data_column (str): The column to plot from the table.
2894
+
2895
+ Returns:
2896
+ df (pd.DataFrame): The extracted table as a DataFrame.
2897
+ """
2898
+
2899
+ if isinstance(settings['src'], str):
2900
+ srcs = [settings['src']]
2901
+ elif isinstance(settings['src'], list):
2902
+ srcs = settings['src']
2903
+ else:
2904
+ raise ValueError("src must be a string or a list of strings.")
2684
2905
 
2685
- return fig, results_df
2906
+ #save_settings(settings, name=f"{settings['graph_name']}_plot_settings_csv", show=True)
2907
+
2908
+ dfs = []
2909
+ for i, src in enumerate(srcs):
2910
+
2911
+ dft = pd.read_csv(src)
2912
+ if 'plate' not in dft.columns:
2913
+ dft['plate'] = f"plate{i+1}"
2914
+ dfs.append(dft)
2915
+
2916
+ df = pd.concat(dfs, axis=0)
2917
+ #display(df)
2918
+
2919
+ df = df.dropna(subset=settings['data_column'])
2920
+ df = df.dropna(subset=settings['grouping_column'])
2921
+ src = srcs[0]
2922
+ dst = os.path.join(os.path.dirname(src), 'results', settings['graph_name'])
2923
+ os.makedirs(dst, exist_ok=True)
2924
+
2925
+ spacr_graph = spacrGraph(
2926
+ df=df, # Your DataFrame
2927
+ grouping_column=settings['grouping_column'], # Column for grouping the data (x-axis)
2928
+ data_column=settings['data_column'], # Column for the data (y-axis)
2929
+ graph_type=settings['graph_type'], # Type of plot ('bar', 'box', 'violin', 'jitter')
2930
+ graph_name=settings['graph_name'], # Name of the plot
2931
+ summary_func='mean', # Function to summarize data (e.g., 'mean', 'median')
2932
+ colors=None, # Custom colors for the plot (optional)
2933
+ output_dir=dst, # Directory to save the plot and results
2934
+ save=settings['save'], # Whether to save the plot and results
2935
+ y_lim=settings['y_lim'], # Starting point for y-axis (optional)
2936
+ error_bar_type='std', # Type of error bar ('std' or 'sem')
2937
+ representation=settings['representation'],
2938
+ theme=settings['theme'], # Seaborn color palette theme (e.g., 'pastel', 'muted')
2939
+ )
2940
+
2941
+ # Create the plot
2942
+ spacr_graph.create_plot()
2943
+
2944
+ # Get the figure object if needed
2945
+ fig = spacr_graph.get_figure()
2946
+ plt.show()
2947
+
2948
+ # Optional: Get the results DataFrame containing statistical test results
2949
+ results_df = spacr_graph.get_results()
2950
+ return fig, results_df
2951
+
2952
+ def plot_region(settings):
2953
+
2954
+ def _sort_paths_by_basename(paths):
2955
+ return sorted(paths, key=lambda path: os.path.basename(path))
2956
+
2957
+ def save_figure_as_pdf(fig, path):
2958
+ os.makedirs(os.path.dirname(path), exist_ok=True) # Create directory if it doesn't exist
2959
+ fig.savefig(path, format='pdf', dpi=600, bbox_inches='tight')
2960
+ print(f"Saved {path}")
2961
+
2962
+ from .io import _read_db
2963
+ fov_path = os.path.join(settings['src'], 'merged', settings['name'])
2964
+ name = os.path.splitext(settings['name'])[0]
2965
+
2966
+ db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
2967
+ paths_df = _read_db(db_path, tables=['png_list'])[0]
2968
+ paths_df = paths_df[paths_df['png_path'].str.contains(name, na=False)]
2969
+
2970
+ activation_mode = f"{settings['activation_mode']}_list"
2971
+ activation_db_path = os.path.join(settings['src'], 'measurements', settings['activation_db'])
2972
+ activation_paths_df = _read_db(activation_db_path, tables=[activation_mode])[0]
2973
+ activation_paths_df = activation_paths_df[activation_paths_df['png_path'].str.contains(name, na=False)]
2974
+
2975
+ png_paths = _sort_paths_by_basename(paths_df['png_path'].tolist())
2976
+ activation_paths = _sort_paths_by_basename(activation_paths_df['png_path'].tolist())
2977
+
2978
+ fig_3 = plot_image_grid(image_paths=activation_paths, percentiles=settings['percentiles'])
2979
+ fig_2 = plot_image_grid(image_paths=png_paths, percentiles=settings['percentiles'])
2980
+ fig_1 = plot_image_mask_overlay(file=fov_path,
2981
+ channels=settings['channels'],
2982
+ cell_channel=settings['cell_channel'],
2983
+ nucleus_channel=settings['nucleus_channel'],
2984
+ pathogen_channel=settings['pathogen_channel'],
2985
+ figuresize=10,
2986
+ percentiles=settings['percentiles'],
2987
+ thickness=3,
2988
+ save_pdf=False,
2989
+ mode=settings['mode'],
2990
+ export_tiffs=settings['export_tiffs'])
2991
+
2992
+ dst = os.path.join(settings['src'], 'results', name)
2993
+ save_figure_as_pdf(fig_1, os.path.join(dst, f"{name}_mask_overlay.pdf"))
2994
+ save_figure_as_pdf(fig_2, os.path.join(dst, f"{name}_png_grid.pdf"))
2995
+ save_figure_as_pdf(fig_3, os.path.join(dst, f"{name}_activation_grid.pdf"))
2996
+
2997
+ return fig_1, fig_2, fig_3
2998
+
2999
+ def plot_image_grid(image_paths, percentiles):
3000
+ """
3001
+ Plots a square grid of images from a list of image paths.
3002
+ Unused subplots are filled with black, and padding is minimized.
3003
+
3004
+ Parameters:
3005
+ - image_paths: List of paths to images to be displayed.
3006
+
3007
+ Returns:
3008
+ - fig: The generated matplotlib figure.
3009
+ """
3010
+
3011
+ from PIL import Image
3012
+ import matplotlib.pyplot as plt
3013
+ import math
3014
+
3015
+ def _normalize_image(image, percentiles=(2, 98)):
3016
+ """ Normalize the image to the given percentiles for each channel independently, preserving the input type (either PIL.Image or numpy.ndarray)."""
3017
+
3018
+ # Check if the input is a PIL image and convert it to a NumPy array
3019
+ is_pil_image = isinstance(image, Image.Image)
3020
+ if is_pil_image:
3021
+ image = np.array(image)
3022
+
3023
+ # If the image is single-channel, normalize directly
3024
+ if image.ndim == 2:
3025
+ v_min, v_max = np.percentile(image, percentiles)
3026
+ normalized_image = np.clip((image - v_min) / (v_max - v_min), 0, 1)
3027
+ else:
3028
+ # If multi-channel, normalize each channel independently
3029
+ normalized_image = np.zeros_like(image, dtype=np.float32)
3030
+ for c in range(image.shape[-1]):
3031
+ v_min, v_max = np.percentile(image[..., c], percentiles)
3032
+ normalized_image[..., c] = np.clip((image[..., c] - v_min) / (v_max - v_min), 0, 1)
3033
+
3034
+ # If the input was a PIL image, convert the result back to PIL format
3035
+ if is_pil_image:
3036
+ # Ensure the image is converted back to 8-bit range (0-255) for PIL
3037
+ normalized_image = (normalized_image * 255).astype(np.uint8)
3038
+ return Image.fromarray(normalized_image)
3039
+
3040
+ return normalized_image
3041
+
3042
+ N = len(image_paths)
3043
+ # Calculate the smallest square grid size to fit all images
3044
+ grid_size = math.ceil(math.sqrt(N))
3045
+
3046
+ # Create the square grid of subplots with a black background
3047
+ fig, axs = plt.subplots(
3048
+ grid_size, grid_size,
3049
+ figsize=(grid_size * 2, grid_size * 2),
3050
+ facecolor='black' # Set figure background to black
3051
+ )
3052
+
3053
+ # Flatten axs in case of a 2D array
3054
+ axs = axs.flatten()
3055
+
3056
+ for i, img_path in enumerate(image_paths):
3057
+ ax = axs[i]
3058
+
3059
+ # Load the image
3060
+ img = Image.open(img_path)
3061
+ img = _normalize_image(img, percentiles)
3062
+
3063
+ # Display the image
3064
+ ax.imshow(img)
3065
+ ax.axis('off') # Hide axes
3066
+
3067
+ # Fill any unused subplots with black
3068
+ for j in range(i + 1, len(axs)):
3069
+ axs[j].imshow([[0, 0, 0]], cmap='gray') # Black square
3070
+ axs[j].axis('off') # Hide axes
3071
+
3072
+ # Adjust layout to minimize white space
3073
+ plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, top=1, bottom=0)
3074
+
3075
+ return fig