spacr 0.3.38__py3-none-any.whl → 0.3.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -13,10 +13,11 @@ from IPython.display import display
13
13
  from skimage.segmentation import find_boundaries
14
14
  from skimage import measure
15
15
  from skimage.measure import find_contours, label, regionprops
16
+ import tifffile as tiff
16
17
 
17
18
  from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
18
19
  from statsmodels.stats.multicomp import pairwise_tukeyhsd
19
- from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
20
+ from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal, normaltest, shapiro
20
21
  import itertools
21
22
  import pingouin as pg
22
23
 
@@ -25,13 +26,26 @@ from IPython.display import Image as ipyimage
25
26
 
26
27
  import matplotlib.patches as patches
27
28
  from collections import defaultdict
29
+ from matplotlib.gridspec import GridSpec
28
30
 
29
- def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, normalize=True, thickness=3, save_pdf=True):
31
+ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2,98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False):
30
32
  """Plot image and mask overlays."""
31
33
 
32
- def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness):
34
+ def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness, percentiles, mode='outlines'):
33
35
  """Plot the merged plot with overlay, image channels, and masks."""
34
36
 
37
+ def _generate_colored_mask(mask, alpha):
38
+ """ Generate a colored mask with transparency using the given colormap. """
39
+ cmap = generate_mask_random_cmap(mask)
40
+ rgba_mask = cmap(mask / mask.max()) # Normalize mask and map to colormap (RGBA)
41
+ rgba_mask[..., 3] = np.where(mask > 0, alpha, 0) # Apply transparency only where mask is present
42
+ return rgba_mask
43
+
44
+ def _overlay_mask(image, mask):
45
+ """Overlay the colored mask onto the original image."""
46
+ combined = np.clip(image + mask[..., :3] * mask[..., 3:4], 0, 1) # Ensure pixel values stay in [0, 1]
47
+ return combined
48
+
35
49
  def _normalize_image(image, percentiles=(2, 98)):
36
50
  """Normalize the image to the given percentiles."""
37
51
  v_min, v_max = np.percentile(image, percentiles)
@@ -61,11 +75,15 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
61
75
  # Plot each channel with its corresponding outlines
62
76
  for v in range(num_channels):
63
77
  channel_image = image[..., v]
64
- channel_image_normalized = _normalize_image(channel_image)
78
+ channel_image_normalized = _normalize_image(channel_image, percentiles)
65
79
  channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
66
80
 
67
81
  for outline, color in zip(outlines, outline_colors):
68
- channel_image_rgb = _apply_contours(channel_image_rgb, outline, color, thickness)
82
+ if mode == 'outlines':
83
+ channel_image_rgb = _apply_contours(channel_image_rgb, outline, color, thickness)
84
+ else:
85
+ mask = _generate_colored_mask(outline, alpha=0.5)
86
+ channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
69
87
 
70
88
  ax[v].imshow(channel_image_rgb)
71
89
  ax[v].set_title(f'Image - Channel {v}')
@@ -75,11 +93,15 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
75
93
  rgb_channels = min(3, num_channels)
76
94
  for i in range(rgb_channels):
77
95
  channel_image = image[..., i]
78
- channel_image_normalized = _normalize_image(channel_image)
96
+ channel_image_normalized = _normalize_image(channel_image, percentiles)
79
97
  rgb_image[..., i] = channel_image_normalized
80
98
 
81
99
  for outline, color in zip(outlines, outline_colors):
82
- rgb_image = _apply_contours(rgb_image, outline, color, thickness)
100
+ if mode == 'outlines':
101
+ rgb_image = _apply_contours(rgb_image, outline, color, thickness)
102
+ else:
103
+ mask = _generate_colored_mask(outline, alpha=0.5)
104
+ rgb_image = _overlay_mask(rgb_image, mask)
83
105
 
84
106
  ax[-1].imshow(rgb_image)
85
107
  ax[-1].set_title('Combined RGB Image')
@@ -96,8 +118,22 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
96
118
  plt.show()
97
119
  return fig
98
120
 
121
+ def _save_channels_as_tiff(stack, save_dir, filename):
122
+ """Save each channel in the stack as a grayscale TIFF."""
123
+ os.makedirs(save_dir, exist_ok=True)
124
+ for i in range(stack.shape[-1]):
125
+ channel = stack[..., i]
126
+ tiff_path = os.path.join(save_dir, f"{filename}_channel_{i}.tiff")
127
+ tiff.imwrite(tiff_path, channel, photometric='minisblack')
128
+ print(f"Saved {tiff_path}")
129
+
99
130
  stack = np.load(file)
100
131
 
132
+ if export_tiffs:
133
+ save_dir = os.path.join(os.path.dirname(os.path.dirname(file)), 'results', os.path.splitext(os.path.basename(file))[0], 'tiff')
134
+ filename = os.path.splitext(os.path.basename(file))[0]
135
+ _save_channels_as_tiff(stack, save_dir, filename)
136
+
101
137
  # Convert to float for normalization and ensure correct handling of both 8-bit and 16-bit arrays
102
138
  if stack.dtype == np.uint16:
103
139
  stack = stack.astype(np.float32)
@@ -128,7 +164,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
128
164
  outlines.append(np.take(stack, cell_mask_dim, axis=2))
129
165
  outline_colors.append('red')
130
166
 
131
- fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
167
+ fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness, percentiles=percentiles, mode=mode)
132
168
 
133
169
  return fig
134
170
 
@@ -1691,17 +1727,25 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
1691
1727
  overlay=True,
1692
1728
  max_nr=10,
1693
1729
  randomize=True)
1694
-
1730
+
1695
1731
  def volcano_plot(coef_df, filename='volcano_plot.pdf'):
1732
+ palette = {
1733
+ 'pc': 'red',
1734
+ 'nc': 'green',
1735
+ 'control': 'blue',
1736
+ 'other': 'gray'
1737
+ }
1738
+
1696
1739
  # Create the volcano plot
1697
1740
  plt.figure(figsize=(10, 6))
1698
1741
  sns.scatterplot(
1699
1742
  data=coef_df,
1700
1743
  x='coefficient',
1701
1744
  y='-log10(p_value)',
1702
- hue='highlight',
1703
- palette={True: 'red', False: 'blue'}
1745
+ hue='condition',
1746
+ palette=palette
1704
1747
  )
1748
+
1705
1749
  plt.title('Volcano Plot of Coefficients')
1706
1750
  plt.xlabel('Coefficient')
1707
1751
  plt.ylabel('-log10(p-value)')
@@ -2098,7 +2142,7 @@ class spacrGraph:
2098
2142
  def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
2099
2143
  order=None, colors=None, output_dir='./output', save=False, y_lim=None,
2100
2144
  error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
2101
- paired=False, all_to_all=True, compare_group=None):
2145
+ paired=False, all_to_all=True, compare_group=None, graph_name=None):
2102
2146
 
2103
2147
  """
2104
2148
  Class for creating grouped plots with optional statistical tests and data preprocessing.
@@ -2121,11 +2165,14 @@ class spacrGraph:
2121
2165
  self.all_to_all = all_to_all
2122
2166
  self.compare_group = compare_group
2123
2167
  self.y_lim = y_lim
2168
+ self.graph_name = graph_name
2169
+
2170
+
2124
2171
  self.results_df = pd.DataFrame()
2125
2172
  self.sns_palette = None
2126
2173
  self.fig = None
2127
2174
 
2128
- self.results_name = str(self.data_column[0])+'_'+str(self.grouping_column)+'_'+str(self.graph_type)
2175
+ self.results_name = str(self.graph_name)+'_'+str(self.data_column[0])+'_'+str(self.grouping_column)+'_'+str(self.graph_type)
2129
2176
 
2130
2177
  self._set_theme()
2131
2178
  self.raw_df = self.df.copy()
@@ -2134,10 +2181,10 @@ class spacrGraph:
2134
2181
  def _set_theme(self):
2135
2182
  """Set the Seaborn theme and reorder colors if necessary."""
2136
2183
  integer_list = list(range(1, 81))
2137
- color_order = [0, 3, 9, 4, 6, 7, 9, 2] + integer_list
2184
+ color_order = [7,9,4,0,3,6,2] + integer_list
2138
2185
  self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
2139
2186
 
2140
- def _set_reordered_theme(self, theme='muted', order=None, n_colors=100, show_theme=False):
2187
+ def _set_reordered_theme(self, theme='deep', order=None, n_colors=100, show_theme=False):
2141
2188
  """Set and reorder the Seaborn color palette."""
2142
2189
  palette = sns.color_palette(theme, n_colors)
2143
2190
  if order:
@@ -2182,20 +2229,36 @@ class spacrGraph:
2182
2229
  """Perform normality tests for each group and each data column."""
2183
2230
  unique_groups = self.df[self.grouping_column].unique()
2184
2231
  normality_results = []
2232
+
2185
2233
  for column in self.data_column:
2186
- grouped_data = [self.df.loc[self.df[self.grouping_column] == group, column] for group in unique_groups]
2187
- normal_p_values = [normaltest(data).pvalue for data in grouped_data]
2188
- normal_stats = [normaltest(data).statistic for data in grouped_data]
2189
- is_normal = all(p > 0.05 for p in normal_p_values) # Test if all groups are normal
2190
- for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
2234
+ # Iterate over each group and its corresponding data
2235
+ for group in unique_groups:
2236
+ data = self.df.loc[self.df[self.grouping_column] == group, column]
2237
+ n_samples = len(data)
2238
+
2239
+ if n_samples >= 8:
2240
+ # Use D'Agostino-Pearson test for larger samples
2241
+ stat, p_value = normaltest(data)
2242
+ test_name = "D'Agostino-Pearson test"
2243
+ else:
2244
+ # Use Shapiro-Wilk test for smaller samples
2245
+ stat, p_value = shapiro(data)
2246
+ test_name = "Shapiro-Wilk test"
2247
+
2248
+ # Store the result for this group and column
2191
2249
  normality_results.append({
2192
2250
  'Comparison': f'Normality test for {group} on {column}',
2193
2251
  'Test Statistic': stat,
2194
2252
  'p-value': p_value,
2195
- 'Test Name': 'Normality test',
2253
+ 'Test Name': test_name,
2196
2254
  'Column': column,
2197
- 'n': len(self.df[self.df[self.grouping_column] == group]) # Sample size
2255
+ 'n': n_samples # Sample size
2198
2256
  })
2257
+
2258
+ # Check if all groups are normally distributed (p > 0.05)
2259
+ normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column]
2260
+ is_normal = all(p > 0.05 for p in normal_p_values)
2261
+
2199
2262
  return is_normal, normality_results
2200
2263
 
2201
2264
  def perform_levene_test(self, unique_groups):
@@ -2339,17 +2402,21 @@ class spacrGraph:
2339
2402
  ax.text(x_pos, y_pos, text, ha='center', va='center', fontsize=12)
2340
2403
 
2341
2404
  def _get_positions(self, ax):
2342
- if self.graph_type == 'bar':
2405
+ if self.graph_type in ['bar','jitter_bar']:
2343
2406
  x_positions = [np.mean(bar.get_paths()[0].vertices[:, 0]) for bar in ax.collections if hasattr(bar, 'get_paths')]
2344
2407
 
2345
2408
  elif self.graph_type == 'violin':
2346
2409
  x_positions = [np.mean(violin.get_paths()[0].vertices[:, 0]) for violin in ax.collections if hasattr(violin, 'get_paths')]
2347
2410
 
2348
- elif self.graph_type == 'box':
2411
+ elif self.graph_type in ['box', 'jitter_box']:
2349
2412
  x_positions = list(set(line.get_xdata().mean() for line in ax.lines if line.get_linestyle() == '-'))
2350
2413
 
2351
2414
  elif self.graph_type == 'jitter':
2352
2415
  x_positions = [np.mean(collection.get_offsets()[:, 0]) for collection in ax.collections if collection.get_offsets().size > 0]
2416
+
2417
+ elif self.graph_type in ['line', 'line_std']:
2418
+ x_positions = []
2419
+
2353
2420
  return x_positions
2354
2421
 
2355
2422
  def _draw_comparison_lines(ax, x_positions):
@@ -2367,7 +2434,7 @@ class spacrGraph:
2367
2434
 
2368
2435
  # Determine significance marker
2369
2436
  if p_value <= 0.001:
2370
- significance = '***'
2437
+ signiresults_namecance = '***'
2371
2438
  elif p_value <= 0.01:
2372
2439
  significance = '**'
2373
2440
  elif p_value <= 0.05:
@@ -2408,6 +2475,9 @@ class spacrGraph:
2408
2475
  self.fig_width = (num_groups * self.bar_width) + (spacing_between_groups * num_groups)
2409
2476
  self.fig_height = self.fig_width/2
2410
2477
 
2478
+ if self.graph_type in ['line','line_std']:
2479
+ self.fig_height, self.fig_width = 10, 10
2480
+
2411
2481
  if ax is None:
2412
2482
  self.fig, ax = plt.subplots(figsize=(self.fig_height, self.fig_width))
2413
2483
  else:
@@ -2429,6 +2499,14 @@ class spacrGraph:
2429
2499
  self._create_box_plot(ax)
2430
2500
  elif self.graph_type == 'violin':
2431
2501
  self._create_violin_plot(ax)
2502
+ elif self.graph_type == 'jitter_box':
2503
+ self._create_jitter_box_plot(ax)
2504
+ elif self.graph_type == 'jitter_bar':
2505
+ self._create_jitter_bar_plot(ax)
2506
+ elif self.graph_type == 'line':
2507
+ self._create_line_graph(ax)
2508
+ elif self.graph_type == 'line_std':
2509
+ self._create_line_with_std_area(ax)
2432
2510
  else:
2433
2511
  raise ValueError(f"Unknown graph type: {self.graph_type}")
2434
2512
 
@@ -2441,14 +2519,17 @@ class spacrGraph:
2441
2519
 
2442
2520
  sns.despine(ax=ax, top=True, right=True)
2443
2521
  ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Data Column') # Move the legend outside the plot
2444
- ax.set_xlabel('')
2522
+
2523
+ if not self.graph_type in ['line','line_std']:
2524
+ ax.set_xlabel('')
2525
+
2445
2526
  x_positions = _get_positions(self, ax)
2446
2527
 
2447
- if len(self.data_column) == 1:
2528
+ if len(self.data_column) == 1 and not self.graph_type in ['line','line_std']:
2448
2529
  ax.legend().remove()
2449
2530
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
2450
2531
 
2451
- elif len(self.data_column) > 1:
2532
+ elif len(self.data_column) > 1 and not self.graph_type in ['line','line_std']:
2452
2533
  ax.set_xticks([])
2453
2534
  ax.tick_params(bottom=False)
2454
2535
  ax.set_xticklabels([])
@@ -2524,7 +2605,54 @@ class spacrGraph:
2524
2605
  handles, labels = ax.get_legend_handles_labels()
2525
2606
  unique_labels = dict(zip(labels, handles))
2526
2607
  ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
2527
-
2608
+
2609
+ def _create_line_graph(self, ax):
2610
+ """Helper method to create a line graph with one line per group based on epochs and accuracy."""
2611
+ #display(self.df)
2612
+ # Ensure epoch is used on the x-axis and accuracy on the y-axis
2613
+ x_axis_column = self.data_column[0]
2614
+ y_axis_column = self.data_column[1]
2615
+
2616
+ # Set hue to the grouping column to get one line per group
2617
+ hue = self.grouping_column
2618
+
2619
+ # Check if the required columns exist in the DataFrame
2620
+ required_columns = [x_axis_column, y_axis_column, self.grouping_column]
2621
+ for col in required_columns:
2622
+ if col not in self.df.columns:
2623
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
2624
+
2625
+ # Create the line graph with one line per group
2626
+ sns.lineplot(data=self.df,x=x_axis_column,y=y_axis_column,hue=hue,palette=self.sns_palette,ax=ax,marker='o',linewidth=1,markersize=6)
2627
+
2628
+ # Adjust axis labels
2629
+ ax.set_xlabel(f"{x_axis_column}")
2630
+ ax.set_ylabel(f"{y_axis_column}")
2631
+
2632
+ def _create_line_with_std_area(self, ax):
2633
+ """Helper method to create a line graph with shaded area representing standard deviation."""
2634
+
2635
+ x_axis_column = self.data_column[0]
2636
+ y_axis_column = self.data_column[1]
2637
+ y_axis_column_mean = f"mean_{y_axis_column}"
2638
+ y_axis_column_std = f"std_{y_axis_column_mean}"
2639
+
2640
+ # Pivot the DataFrame to get mean and std for each epoch across plates
2641
+ summary_df = self.df.pivot_table(index=x_axis_column,values=y_axis_column,aggfunc=['mean', 'std']).reset_index()
2642
+
2643
+ # Flatten MultiIndex columns (result of pivoting)
2644
+ summary_df.columns = [x_axis_column, y_axis_column_mean, y_axis_column_std]
2645
+
2646
+ # Plot the mean accuracy as a line
2647
+ sns.lineplot(data=summary_df,x=x_axis_column,y=y_axis_column_mean,ax=ax,marker='o',linewidth=1,markersize=0,color='blue',label=y_axis_column_mean)
2648
+
2649
+ # Fill the area representing the standard deviation
2650
+ ax.fill_between(summary_df[x_axis_column],summary_df[y_axis_column_mean] - summary_df[y_axis_column_std],summary_df[y_axis_column_mean] + summary_df[y_axis_column_std],color='blue', alpha=0.1 )
2651
+
2652
+ # Adjust axis labels
2653
+ ax.set_xlabel(f"{x_axis_column}")
2654
+ ax.set_ylabel(f"{y_axis_column}")
2655
+
2528
2656
  def _create_box_plot(self, ax):
2529
2657
  """Helper method to create a box plot with consistent spacing."""
2530
2658
  # Combine grouping column and data column if needed
@@ -2574,6 +2702,68 @@ class spacrGraph:
2574
2702
  unique_labels = dict(zip(labels, handles))
2575
2703
  ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
2576
2704
 
2705
+ def _create_jitter_bar_plot(self, ax):
2706
+ """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
2707
+ # Flatten DataFrame: Combine grouping column and data column into one group if needed
2708
+ if len(self.data_column) > 1:
2709
+ self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
2710
+ x_axis_column = 'Combined Group'
2711
+ hue = None
2712
+ ax.set_ylabel('Value')
2713
+ else:
2714
+ x_axis_column = self.grouping_column
2715
+ ax.set_ylabel(self.data_column[0])
2716
+ hue = None
2717
+
2718
+ summary_df = self.df_melted.groupby([x_axis_column]).agg(mean=('Value', 'mean'),std=('Value', 'std'),sem=('Value', 'sem')).reset_index()
2719
+ error_bars = summary_df[self.error_bar_type] if self.error_bar_type in ['std', 'sem'] else None
2720
+ sns.barplot(data=self.df_melted, x=x_axis_column, y='Value', hue=self.hue, palette=self.sns_palette, ax=ax, dodge=self.jitter_bar_dodge, ci=None)
2721
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=16)
2722
+
2723
+ # Adjust the bar width manually
2724
+ if len(self.data_column) > 1:
2725
+ bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
2726
+ target_width = self.bar_width * 2
2727
+ for bar in bars:
2728
+ bar.set_width(target_width) # Set new width
2729
+ # Center the bar on its x-coordinate
2730
+ bar.set_x(bar.get_x() - target_width / 2)
2731
+
2732
+ # Adjust error bars alignment with bars
2733
+ bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
2734
+ for bar, (_, row) in zip(bars, summary_df.iterrows()):
2735
+ x_bar = bar.get_x() + bar.get_width() / 2
2736
+ err = row[self.error_bar_type]
2737
+ ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
2738
+
2739
+ # Set legend and labels
2740
+ ax.set_xlabel(self.grouping_column)
2741
+
2742
+ def _create_jitter_box_plot(self, ax):
2743
+ """Helper method to create a box plot with consistent spacing."""
2744
+ # Combine grouping column and data column if needed
2745
+ if len(self.data_column) > 1:
2746
+ self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
2747
+ x_axis_column = 'Combined Group'
2748
+ hue = None
2749
+ ax.set_ylabel('Value')
2750
+ else:
2751
+ x_axis_column = self.grouping_column
2752
+ ax.set_ylabel(self.data_column[0])
2753
+ hue = None
2754
+
2755
+ # Create the box plot
2756
+ sns.boxplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue,palette=self.sns_palette,ax=ax)
2757
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6, edgecolor='white',linewidth=1, size=12)
2758
+
2759
+ # Adjust legend and labels
2760
+ ax.set_xlabel(self.grouping_column)
2761
+
2762
+ # Manage the legend
2763
+ handles, labels = ax.get_legend_handles_labels()
2764
+ unique_labels = dict(zip(labels, handles))
2765
+ ax.legend(unique_labels.values(), unique_labels.keys(), loc='best')
2766
+
2577
2767
  def _save_results(self):
2578
2768
  """Helper method to save the plot and results."""
2579
2769
  os.makedirs(self.output_dir, exist_ok=True)
@@ -2594,14 +2784,14 @@ class spacrGraph:
2594
2784
 
2595
2785
  def plot_data_from_db(settings):
2596
2786
  from .io import _read_db, _read_and_merge_data
2597
- from .utils import annotate_conditions
2787
+ from .utils import annotate_conditions, save_settings
2598
2788
  """
2599
2789
  Extracts the specified table from the SQLite database and plots a specified column.
2600
2790
 
2601
2791
  Args:
2602
2792
  db_path (str): The path to the SQLite database.
2603
2793
  table_names (str): The name of the table to extract.
2604
- column_name (str): The column to plot from the table.
2794
+ data_column (str): The column to plot from the table.
2605
2795
 
2606
2796
  Returns:
2607
2797
  df (pd.DataFrame): The extracted table as a DataFrame.
@@ -2616,6 +2806,8 @@ def plot_data_from_db(settings):
2616
2806
  else:
2617
2807
  raise ValueError("src must be a string or a list of strings.")
2618
2808
 
2809
+ save_settings(settings, name=f"{settings['graph_name']}_plot_settings_db", show=True)
2810
+
2619
2811
  dfs = []
2620
2812
  for i, src in enumerate(srcs):
2621
2813
 
@@ -2643,6 +2835,7 @@ def plot_data_from_db(settings):
2643
2835
  df = pd.concat(dfs, axis=0)
2644
2836
  df['prc'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str)
2645
2837
  df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
2838
+ df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
2646
2839
 
2647
2840
  if settings['cell_plate_metadata'] != None:
2648
2841
  df = df.dropna(subset='host_cell')
@@ -2653,24 +2846,91 @@ def plot_data_from_db(settings):
2653
2846
  if settings['treatment_plate_metadata'] != None:
2654
2847
  df = df.dropna(subset='treatment')
2655
2848
 
2656
- df = df.dropna(subset=settings['column_name'])
2849
+ df = df.dropna(subset=settings['data_column'])
2657
2850
  df = df.dropna(subset=settings['grouping_column'])
2658
2851
 
2852
+ #df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
2853
+ src = srcs[0]
2854
+ dst = os.path.join(src, 'results', settings['graph_name'])
2855
+ os.makedirs(dst, exist_ok=True)
2856
+
2857
+ spacr_graph = spacrGraph(
2858
+ df=df, # Your DataFrame
2859
+ grouping_column=settings['grouping_column'], # Column for grouping the data (x-axis)
2860
+ data_column=settings['data_column'], # Column for the data (y-axis)
2861
+ graph_type=settings['graph_type'], # Type of plot ('bar', 'box', 'violin', 'jitter')
2862
+ graph_name=settings['graph_name'], # Name of the plot
2863
+ summary_func='mean', # Function to summarize data (e.g., 'mean', 'median')
2864
+ colors=None, # Custom colors for the plot (optional)
2865
+ output_dir=dst, # Directory to save the plot and results
2866
+ save=settings['save'], # Whether to save the plot and results
2867
+ y_lim=settings['y_lim'], # Starting point for y-axis (optional)
2868
+ error_bar_type='std', # Type of error bar ('std' or 'sem')
2869
+ representation=settings['representation'],
2870
+ theme=settings['theme'], # Seaborn color palette theme (e.g., 'pastel', 'muted')
2871
+ )
2872
+
2873
+ # Create the plot
2874
+ spacr_graph.create_plot()
2875
+
2876
+ # Get the figure object if needed
2877
+ fig = spacr_graph.get_figure()
2878
+ plt.show()
2879
+
2880
+ # Optional: Get the results DataFrame containing statistical test results
2881
+ results_df = spacr_graph.get_results()
2882
+ return fig, results_df
2883
+
2884
+ def plot_data_from_csv(settings):
2885
+ from .io import _read_db, _read_and_merge_data
2886
+ from .utils import annotate_conditions, save_settings
2887
+ """
2888
+ Extracts the specified table from the SQLite database and plots a specified column.
2889
+
2890
+ Args:
2891
+ db_path (str): The path to the SQLite database.
2892
+ table_names (str): The name of the table to extract.
2893
+ data_column (str): The column to plot from the table.
2894
+
2895
+ Returns:
2896
+ df (pd.DataFrame): The extracted table as a DataFrame.
2897
+ """
2659
2898
 
2899
+ if isinstance(settings['src'], str):
2900
+ srcs = [settings['src']]
2901
+ elif isinstance(settings['src'], list):
2902
+ srcs = settings['src']
2903
+ else:
2904
+ raise ValueError("src must be a string or a list of strings.")
2905
+
2906
+ #save_settings(settings, name=f"{settings['graph_name']}_plot_settings_csv", show=True)
2660
2907
 
2908
+ dfs = []
2909
+ for i, src in enumerate(srcs):
2661
2910
 
2911
+ dft = pd.read_csv(src)
2912
+ if 'plate' not in dft.columns:
2913
+ dft['plate'] = f"plate{i+1}"
2914
+ dfs.append(dft)
2915
+
2916
+ df = pd.concat(dfs, axis=0)
2662
2917
  #display(df)
2663
2918
 
2664
- #df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
2919
+ df = df.dropna(subset=settings['data_column'])
2920
+ df = df.dropna(subset=settings['grouping_column'])
2921
+ src = srcs[0]
2922
+ dst = os.path.join(os.path.dirname(src), 'results', settings['graph_name'])
2923
+ os.makedirs(dst, exist_ok=True)
2665
2924
 
2666
2925
  spacr_graph = spacrGraph(
2667
2926
  df=df, # Your DataFrame
2668
2927
  grouping_column=settings['grouping_column'], # Column for grouping the data (x-axis)
2669
- data_column=settings['column_name'], # Column for the data (y-axis)
2928
+ data_column=settings['data_column'], # Column for the data (y-axis)
2670
2929
  graph_type=settings['graph_type'], # Type of plot ('bar', 'box', 'violin', 'jitter')
2930
+ graph_name=settings['graph_name'], # Name of the plot
2671
2931
  summary_func='mean', # Function to summarize data (e.g., 'mean', 'median')
2672
2932
  colors=None, # Custom colors for the plot (optional)
2673
- output_dir=settings['dst'], # Directory to save the plot and results
2933
+ output_dir=dst, # Directory to save the plot and results
2674
2934
  save=settings['save'], # Whether to save the plot and results
2675
2935
  y_lim=settings['y_lim'], # Starting point for y-axis (optional)
2676
2936
  error_bar_type='std', # Type of error bar ('std' or 'sem')
@@ -2687,5 +2947,129 @@ def plot_data_from_db(settings):
2687
2947
 
2688
2948
  # Optional: Get the results DataFrame containing statistical test results
2689
2949
  results_df = spacr_graph.get_results()
2690
-
2691
- return fig, results_df
2950
+ return fig, results_df
2951
+
2952
+ def plot_region(settings):
2953
+
2954
+ def _sort_paths_by_basename(paths):
2955
+ return sorted(paths, key=lambda path: os.path.basename(path))
2956
+
2957
+ def save_figure_as_pdf(fig, path):
2958
+ os.makedirs(os.path.dirname(path), exist_ok=True) # Create directory if it doesn't exist
2959
+ fig.savefig(path, format='pdf', dpi=600, bbox_inches='tight')
2960
+ print(f"Saved {path}")
2961
+
2962
+ from .io import _read_db
2963
+ fov_path = os.path.join(settings['src'], 'merged', settings['name'])
2964
+ name = os.path.splitext(settings['name'])[0]
2965
+
2966
+ db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
2967
+ paths_df = _read_db(db_path, tables=['png_list'])[0]
2968
+ paths_df = paths_df[paths_df['png_path'].str.contains(name, na=False)]
2969
+
2970
+ activation_mode = f"{settings['activation_mode']}_list"
2971
+ activation_db_path = os.path.join(settings['src'], 'measurements', settings['activation_db'])
2972
+ activation_paths_df = _read_db(activation_db_path, tables=[activation_mode])[0]
2973
+ activation_paths_df = activation_paths_df[activation_paths_df['png_path'].str.contains(name, na=False)]
2974
+
2975
+ png_paths = _sort_paths_by_basename(paths_df['png_path'].tolist())
2976
+ activation_paths = _sort_paths_by_basename(activation_paths_df['png_path'].tolist())
2977
+
2978
+ fig_3 = plot_image_grid(image_paths=activation_paths, percentiles=settings['percentiles'])
2979
+ fig_2 = plot_image_grid(image_paths=png_paths, percentiles=settings['percentiles'])
2980
+ fig_1 = plot_image_mask_overlay(file=fov_path,
2981
+ channels=settings['channels'],
2982
+ cell_channel=settings['cell_channel'],
2983
+ nucleus_channel=settings['nucleus_channel'],
2984
+ pathogen_channel=settings['pathogen_channel'],
2985
+ figuresize=10,
2986
+ percentiles=settings['percentiles'],
2987
+ thickness=3,
2988
+ save_pdf=False,
2989
+ mode=settings['mode'],
2990
+ export_tiffs=settings['export_tiffs'])
2991
+
2992
+ dst = os.path.join(settings['src'], 'results', name)
2993
+ save_figure_as_pdf(fig_1, os.path.join(dst, f"{name}_mask_overlay.pdf"))
2994
+ save_figure_as_pdf(fig_2, os.path.join(dst, f"{name}_png_grid.pdf"))
2995
+ save_figure_as_pdf(fig_3, os.path.join(dst, f"{name}_activation_grid.pdf"))
2996
+
2997
+ return fig_1, fig_2, fig_3
2998
+
2999
+ def plot_image_grid(image_paths, percentiles):
3000
+ """
3001
+ Plots a square grid of images from a list of image paths.
3002
+ Unused subplots are filled with black, and padding is minimized.
3003
+
3004
+ Parameters:
3005
+ - image_paths: List of paths to images to be displayed.
3006
+
3007
+ Returns:
3008
+ - fig: The generated matplotlib figure.
3009
+ """
3010
+
3011
+ from PIL import Image
3012
+ import matplotlib.pyplot as plt
3013
+ import math
3014
+
3015
+ def _normalize_image(image, percentiles=(2, 98)):
3016
+ """ Normalize the image to the given percentiles for each channel independently, preserving the input type (either PIL.Image or numpy.ndarray)."""
3017
+
3018
+ # Check if the input is a PIL image and convert it to a NumPy array
3019
+ is_pil_image = isinstance(image, Image.Image)
3020
+ if is_pil_image:
3021
+ image = np.array(image)
3022
+
3023
+ # If the image is single-channel, normalize directly
3024
+ if image.ndim == 2:
3025
+ v_min, v_max = np.percentile(image, percentiles)
3026
+ normalized_image = np.clip((image - v_min) / (v_max - v_min), 0, 1)
3027
+ else:
3028
+ # If multi-channel, normalize each channel independently
3029
+ normalized_image = np.zeros_like(image, dtype=np.float32)
3030
+ for c in range(image.shape[-1]):
3031
+ v_min, v_max = np.percentile(image[..., c], percentiles)
3032
+ normalized_image[..., c] = np.clip((image[..., c] - v_min) / (v_max - v_min), 0, 1)
3033
+
3034
+ # If the input was a PIL image, convert the result back to PIL format
3035
+ if is_pil_image:
3036
+ # Ensure the image is converted back to 8-bit range (0-255) for PIL
3037
+ normalized_image = (normalized_image * 255).astype(np.uint8)
3038
+ return Image.fromarray(normalized_image)
3039
+
3040
+ return normalized_image
3041
+
3042
+ N = len(image_paths)
3043
+ # Calculate the smallest square grid size to fit all images
3044
+ grid_size = math.ceil(math.sqrt(N))
3045
+
3046
+ # Create the square grid of subplots with a black background
3047
+ fig, axs = plt.subplots(
3048
+ grid_size, grid_size,
3049
+ figsize=(grid_size * 2, grid_size * 2),
3050
+ facecolor='black' # Set figure background to black
3051
+ )
3052
+
3053
+ # Flatten axs in case of a 2D array
3054
+ axs = axs.flatten()
3055
+
3056
+ for i, img_path in enumerate(image_paths):
3057
+ ax = axs[i]
3058
+
3059
+ # Load the image
3060
+ img = Image.open(img_path)
3061
+ img = _normalize_image(img, percentiles)
3062
+
3063
+ # Display the image
3064
+ ax.imshow(img)
3065
+ ax.axis('off') # Hide axes
3066
+
3067
+ # Fill any unused subplots with black
3068
+ for j in range(i + 1, len(axs)):
3069
+ axs[j].imshow([[0, 0, 0]], cmap='gray') # Black square
3070
+ axs[j].axis('off') # Hide axes
3071
+
3072
+ # Adjust layout to minimize white space
3073
+ plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, top=1, bottom=0)
3074
+
3075
+ return fig