spacr 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -26,24 +26,33 @@ from ipywidgets import IntSlider, interact
26
26
  from IPython.display import Image as ipyimage
27
27
  from matplotlib_venn import venn2
28
28
 
29
- def plot_image_mask_overlay(
30
- file,
31
- channels,
32
- cell_channel,
33
- nucleus_channel,
34
- pathogen_channel,
35
- figuresize=10,
36
- percentiles=(2, 98),
37
- thickness=3,
38
- save_pdf=True,
39
- mode='outlines',
40
- export_tiffs=False,
41
- all_on_all=False,
42
- all_outlines=False,
43
- filter_dict=None
44
- ):
45
- """Plot image and mask overlays."""
29
+ def plot_image_mask_overlay(file,channels,cell_channel,nucleus_channel,pathogen_channel,figuresize=10,percentiles=(2, 98),thickness=3,save_pdf=True,mode='outlines',export_tiffs=False,all_on_all=False,all_outlines=False,filter_dict=None):
30
+ """
31
+ Plot multi-channel microscopy image overlays with cell, nucleus, and pathogen masks.
32
+
33
+ This function visualizes microscopy images with optional mask overlays for different object types.
34
+ It supports contour and filled mask modes, object filtering based on size and intensity, and saving output.
46
35
 
36
+ Args:
37
+ file (str): Path to the `.npy` image stack (H x W x C).
38
+ channels (list): Indices of the image channels to display.
39
+ cell_channel (int or None): Channel index for the cell mask intensity.
40
+ nucleus_channel (int or None): Channel index for the nucleus mask intensity.
41
+ pathogen_channel (int or None): Channel index for the pathogen mask intensity.
42
+ figuresize (int): Base figure size for each subplot. Default is 10.
43
+ percentiles (tuple): Percentile range for image normalization. Default is (2, 98).
44
+ thickness (int): Thickness of mask contour outlines. Default is 3.
45
+ save_pdf (bool): If True, saves a PDF of the overlaid image. Default is True.
46
+ mode (str): 'outlines' or 'filled'. Determines how masks are rendered. Default is 'outlines'.
47
+ export_tiffs (bool): If True, exports grayscale TIFFs for each image channel. Default is False.
48
+ all_on_all (bool): If True, overlays all outlines on all channels. Default is False.
49
+ all_outlines (bool): If True, overlays all outlines on non-matching channels. Default is False.
50
+ filter_dict (dict or None): Dictionary of filtering thresholds for each object type.
51
+ For example: {"cell": [(min_area, max_area), (min_intensity, max_intensity)]}
52
+
53
+ Returns:
54
+ matplotlib.figure.Figure: The generated overlay figure.
55
+ """
47
56
  def random_color_cmap(n_labels, seed=None):
48
57
  """Generates a random color map for a given number of labels."""
49
58
  if seed is not None:
@@ -53,25 +62,7 @@ def plot_image_mask_overlay(
53
62
  cmap = ListedColormap(rand_colors)
54
63
  return cmap
55
64
 
56
- def _plot_merged_plot(
57
- image,
58
- outlines,
59
- outline_colors,
60
- figuresize,
61
- thickness,
62
- percentiles,
63
- mode='outlines',
64
- all_on_all=False,
65
- all_outlines=False,
66
- channels=None,
67
- cell_channel=None,
68
- nucleus_channel=None,
69
- pathogen_channel=None,
70
- cell_outlines=None,
71
- nucleus_outlines=None,
72
- pathogen_outlines=None,
73
- save_pdf=True
74
- ):
65
+ def _plot_merged_plot(image,outlines,outline_colors,figuresize,thickness,percentiles,mode='outlines',all_on_all=False,all_outlines=False,channels=None,cell_channel=None,nucleus_channel=None,pathogen_channel=None,cell_outlines=None,nucleus_outlines=None,pathogen_outlines=None,save_pdf=True):
75
66
  """Plot the merged plot with overlay, image channels, and masks."""
76
67
 
77
68
  def _generate_colored_mask(mask, cmap):
@@ -506,7 +497,7 @@ def generate_mask_random_cmap(mask):
506
497
  """
507
498
  Generate a random colormap based on the unique labels in the given mask.
508
499
 
509
- Parameters:
500
+ Args:
510
501
  mask (numpy.ndarray): The input mask array.
511
502
 
512
503
  Returns:
@@ -524,7 +515,7 @@ def random_cmap(num_objects=100):
524
515
  """
525
516
  Generate a random colormap.
526
517
 
527
- Parameters:
518
+ Args:
528
519
  num_objects (int): The number of objects to generate colors for. Default is 100.
529
520
 
530
521
  Returns:
@@ -540,7 +531,7 @@ def _generate_mask_random_cmap(mask):
540
531
  """
541
532
  Generate a random colormap based on the unique labels in the given mask.
542
533
 
543
- Parameters:
534
+ Args:
544
535
  mask (ndarray): The mask array containing unique labels.
545
536
 
546
537
  Returns:
@@ -559,7 +550,7 @@ def _get_colours_merged(outline_color):
559
550
  """
560
551
  Get the merged outline colors based on the specified outline color format.
561
552
 
562
- Parameters:
553
+ Args:
563
554
  outline_color (str): The outline color format. Can be one of 'rgb', 'bgr', 'gbr', or 'rbg'.
564
555
 
565
556
  Returns:
@@ -765,7 +756,7 @@ def plot_arrays(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=1,
765
756
  """
766
757
  Plot randomly selected arrays from a given directory or a single .npz/.npy file.
767
758
 
768
- Parameters:
759
+ Args:
769
760
  - src (str): The directory path or file path containing the arrays.
770
761
  - figuresize (int): The size of the figure (default: 10).
771
762
  - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
@@ -820,7 +811,7 @@ def plot_arrays_v1(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=
820
811
  """
821
812
  Plot randomly selected arrays from a given directory.
822
813
 
823
- Parameters:
814
+ Args:
824
815
  - src (str): The directory path containing the arrays.
825
816
  - figuresize (int): The size of the figure (default: 50).
826
817
  - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
@@ -1286,7 +1277,7 @@ def _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df, save, src,
1286
1277
  """
1287
1278
  Display the frame with tracks overlaid.
1288
1279
 
1289
- Parameters:
1280
+ Args:
1290
1281
  frame (int): The frame number to display.
1291
1282
 
1292
1283
  Returns:
@@ -1329,7 +1320,7 @@ def _display_gif(path):
1329
1320
  """
1330
1321
  Display a GIF image from the given path.
1331
1322
 
1332
- Parameters:
1323
+ Args:
1333
1324
  path (str): The path to the GIF image file.
1334
1325
 
1335
1326
  Returns:
@@ -1572,6 +1563,26 @@ def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
1572
1563
  return fig
1573
1564
 
1574
1565
  def _plot_histograms_and_stats(df):
1566
+ """
1567
+ Plot histograms and print summary statistics for prediction values grouped by condition.
1568
+
1569
+ Args:
1570
+ df (pd.DataFrame): DataFrame containing at least two columns:
1571
+ - 'condition': categorical variable indicating group labels.
1572
+ - 'pred': continuous prediction scores (e.g., probabilities).
1573
+
1574
+ Returns:
1575
+ None. Displays histograms and prints statistics for each condition.
1576
+
1577
+ Printed Statistics per Condition:
1578
+ - Number of rows
1579
+ - Mean of 'pred' values
1580
+ - Count and percentage of 'pred' > 0.5 (positive)
1581
+ - Count and percentage of 'pred' <= 0.5 (negative)
1582
+
1583
+ Plots:
1584
+ - Histogram of 'pred' values with mean indicated by a dashed red line.
1585
+ """
1575
1586
  conditions = df['condition'].unique()
1576
1587
 
1577
1588
  for condition in conditions:
@@ -1603,7 +1614,22 @@ def _plot_histograms_and_stats(df):
1603
1614
  plt.show()
1604
1615
 
1605
1616
  def _show_residules(model):
1617
+ """
1618
+ Display diagnostic plots and test for residuals of a fitted regression model.
1619
+
1620
+ Args:
1621
+ model (statsmodels.regression.linear_model.RegressionResultsWrapper):
1622
+ A fitted statsmodels regression model with `.resid` and `.fittedvalues` attributes.
1623
+
1624
+ Returns:
1625
+ None. Displays plots and prints Shapiro-Wilk test results.
1606
1626
 
1627
+ Diagnostics:
1628
+ - Histogram of residuals
1629
+ - QQ plot for normality
1630
+ - Residuals vs. Fitted values plot
1631
+ - Shapiro-Wilk test for normality (prints W-statistic and p-value)
1632
+ """
1607
1633
  # Get the residuals
1608
1634
  residuals = model.resid
1609
1635
 
@@ -1632,6 +1658,23 @@ def _show_residules(model):
1632
1658
  print(f'Shapiro-Wilk Test W-statistic: {W}, p-value: {p_value}')
1633
1659
 
1634
1660
  def _reg_v_plot(df, grouping, variable, plate_number):
1661
+ """
1662
+ Generate a volcano plot for visualizing effect size versus significance.
1663
+
1664
+ Args:
1665
+ df (pd.DataFrame): DataFrame containing regression results with 'effect' and 'p' columns.
1666
+ grouping (str): Unused in current function but reserved for grouping logic.
1667
+ variable (str): Unused in current function but could annotate variable of interest.
1668
+ plate_number (str or int): Unused in current function but could title or segregate plots.
1669
+
1670
+ Returns:
1671
+ None. Displays a volcano plot using matplotlib.
1672
+
1673
+ Notes:
1674
+ - Points with p < 0.05 are labeled by index.
1675
+ - Color indicates direction of effect (positive or negative).
1676
+ - Dashed line marks significance threshold (p = 0.05).
1677
+ """
1635
1678
  df['-log10(p)'] = -np.log10(df['p'])
1636
1679
 
1637
1680
  # Create the volcano plot
@@ -1650,7 +1693,27 @@ def _reg_v_plot(df, grouping, variable, plate_number):
1650
1693
  plt.show()
1651
1694
 
1652
1695
  def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
1696
+ """
1697
+ Generate a heatmap matrix for a specific plate based on the specified variable.
1653
1698
 
1699
+ Args:
1700
+ df (pd.DataFrame): Input DataFrame containing 'prc' and variable to plot.
1701
+ plate_number (str): Identifier of the plate to generate the heatmap for.
1702
+ variable (str): Column name to be visualized.
1703
+ grouping (str): Aggregation method: 'mean', 'sum', or 'count'.
1704
+ min_max (str or tuple): Controls the colormap scaling:
1705
+ - 'all': use full data min/max
1706
+ - 'allq': use 2nd to 98th percentile
1707
+ - (low, high): percentiles (float) or fixed values (int)
1708
+ min_count (int or float): Minimum number of entries per well to be included.
1709
+
1710
+ Returns:
1711
+ plate_map (pd.DataFrame): Pivoted heatmap matrix indexed by row and column.
1712
+ min_max (list): The computed min/max values for the heatmap color scale.
1713
+
1714
+ Raises:
1715
+ ValueError: If `grouping` is not one of ['mean', 'sum', 'count'].
1716
+ """
1654
1717
  if not isinstance(min_count, (int, float)):
1655
1718
  min_count = 0
1656
1719
 
@@ -1713,6 +1776,27 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1713
1776
  return plate_map, min_max
1714
1777
 
1715
1778
  def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None):
1779
+ """
1780
+ Generate and optionally save heatmaps for one or more plates, showing the spatial distribution
1781
+ of a specified variable (e.g., unique_counts or signal intensity).
1782
+
1783
+ Args:
1784
+ df (pd.DataFrame): Input DataFrame containing 'prc' in the format 'plate_row_col' and the variable to plot.
1785
+ variable (str): Name of the column to visualize.
1786
+ grouping (str): Aggregation method for replicates (e.g., 'mean', 'median').
1787
+ min_max (str): Method to determine color scaling. One of {'global', 'local', 'allq'}.
1788
+ cmap (str or Colormap): Colormap to use for heatmap visualization.
1789
+ min_count (int, optional): Minimum count threshold for inclusion in the heatmap. Default is 0.
1790
+ verbose (bool, optional): If True, displays the heatmap(s) in an interactive window. Default is True.
1791
+ dst (str, optional): Directory to save the heatmaps. If None, the heatmaps are not saved.
1792
+
1793
+ Returns:
1794
+ matplotlib.figure.Figure: The figure containing the plotted heatmaps.
1795
+
1796
+ Side Effects:
1797
+ - Displays plate heatmaps using matplotlib/seaborn.
1798
+ - Saves the heatmap figure as 'plate_heatmap_#.pdf' in `dst` if provided and writable.
1799
+ """
1716
1800
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1717
1801
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1718
1802
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
@@ -1843,6 +1927,27 @@ def print_mask_and_flows(stack, mask, flows, overlay=True, max_size=1000, thickn
1843
1927
  plt.show()
1844
1928
 
1845
1929
  def plot_resize(images, resized_images, labels, resized_labels):
1930
+ """
1931
+ Plot original and resized images along with their corresponding labels for visual comparison.
1932
+
1933
+ Args:
1934
+ images (list or np.ndarray): List or array of original input images. Each image can be 2D (grayscale)
1935
+ or 3D (with 1, 3, or 4 channels).
1936
+ resized_images (list or np.ndarray): Corresponding resized input images.
1937
+ labels (list or np.ndarray): List or array of original label masks.
1938
+ resized_labels (list or np.ndarray): Corresponding resized label masks.
1939
+
1940
+ Returns:
1941
+ None. Displays a 2x2 grid of plots:
1942
+ - Top-left: Original image
1943
+ - Top-right: Resized image
1944
+ - Bottom-left: Original label
1945
+ - Bottom-right: Resized label
1946
+
1947
+ Notes:
1948
+ - Images with 1, 3, or 4 channels are supported.
1949
+ - Other channel numbers will be visualized by averaging across channels (grayscale fallback).
1950
+ """
1846
1951
  def prepare_image(img):
1847
1952
  if img.ndim == 2:
1848
1953
  return img, 'gray'
@@ -1884,7 +1989,23 @@ def plot_resize(images, resized_images, labels, resized_labels):
1884
1989
  plt.show()
1885
1990
 
1886
1991
  def normalize_and_visualize(image, normalized_image, title=""):
1887
- """Utility function for visualization"""
1992
+ """
1993
+ Display a side-by-side comparison of an original image and its normalized version.
1994
+
1995
+ Args:
1996
+ image (np.ndarray): The original image. Can be 2D (grayscale) or 3D (multi-channel).
1997
+ normalized_image (np.ndarray): The normalized image. Should have the same dimensions as `image`.
1998
+ title (str, optional): Optional string to append to the plot titles. Defaults to "".
1999
+
2000
+ Returns:
2001
+ None. Displays a matplotlib figure with two subplots:
2002
+ - Left: Original image
2003
+ - Right: Normalized image
2004
+
2005
+ Notes:
2006
+ - For multi-channel images, the mean across channels is visualized.
2007
+ - Axes are hidden for clarity.
2008
+ """
1888
2009
  fig, ax = plt.subplots(1, 2, figsize=(12, 6))
1889
2010
  if image.ndim == 3: # Multi-channel image
1890
2011
  ax[0].imshow(np.mean(image, axis=-1), cmap='gray') # Display the average over channels for visualization
@@ -1903,6 +2024,25 @@ def normalize_and_visualize(image, normalized_image, title=""):
1903
2024
  plt.show()
1904
2025
 
1905
2026
  def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
2027
+ """
2028
+ Display three segmentation masks side-by-side for visual comparison.
2029
+
2030
+ Args:
2031
+ mask1 (np.ndarray): First mask array (2D).
2032
+ mask2 (np.ndarray): Second mask array (2D).
2033
+ mask3 (np.ndarray): Third mask array (2D).
2034
+ title (str, optional): Title for the entire figure. Defaults to "Masks Comparison".
2035
+
2036
+ Returns:
2037
+ None. Displays a matplotlib figure with three subplots:
2038
+ - Each subplot corresponds to one of the input masks.
2039
+ - Masks are visualized using a randomly generated colormap.
2040
+ - If a mask is binary (0 and 1 only), no normalization is applied.
2041
+ - Otherwise, intensity values are normalized to [0, max(mask)].
2042
+
2043
+ Notes:
2044
+ Requires the function `generate_mask_random_cmap(mask)` to generate a colormap.
2045
+ """
1906
2046
  fig, axs = plt.subplots(1, 3, figsize=(30, 10))
1907
2047
  for ax, mask, title in zip(axs, [mask1, mask2, mask3], ['Mask 1', 'Mask 2', 'Mask 3']):
1908
2048
  cmap = generate_mask_random_cmap(mask)
@@ -1922,7 +2062,7 @@ def visualize_cellpose_masks(masks, titles=None, filename=None, save=False, src=
1922
2062
  """
1923
2063
  Visualize multiple masks with optional titles.
1924
2064
 
1925
- Parameters:
2065
+ Args:
1926
2066
  masks (list of np.ndarray): A list of masks to visualize.
1927
2067
  titles (list of str, optional): A list of titles for the masks. If None, default titles will be used.
1928
2068
  comparison_title (str): Title for the entire figure.
@@ -1962,6 +2102,26 @@ def visualize_cellpose_masks(masks, titles=None, filename=None, save=False, src=
1962
2102
 
1963
2103
 
1964
2104
  def plot_comparison_results(comparison_results):
2105
+ """
2106
+ Visualize segmentation comparison metrics using boxplots with overlaid strip plots.
2107
+
2108
+ Args:
2109
+ comparison_results (list of dict): Each dictionary represents one sample, with keys:
2110
+ - 'filename': Name of the sample.
2111
+ - Metric keys such as 'jaccard_*', 'dice_*', 'boundary_f1_*', 'average_precision_*'.
2112
+
2113
+ Returns:
2114
+ matplotlib.figure.Figure: The resulting figure with 4 subplots:
2115
+ - Jaccard Index
2116
+ - Dice Coefficient
2117
+ - Boundary F1 Score
2118
+ - Average Precision
2119
+
2120
+ Notes:
2121
+ - Metrics are grouped by type using substring matching in column names.
2122
+ - Outliers and individual sample values are shown with strip plots.
2123
+ - Assumes that all metric columns contain numeric values.
2124
+ """
1965
2125
  df = pd.DataFrame(comparison_results)
1966
2126
  df_melted = pd.melt(df, id_vars=['filename'], var_name='metric', value_name='value')
1967
2127
  df_jaccard = df_melted[df_melted['metric'].str.contains('jaccard')]
@@ -2004,7 +2164,25 @@ def plot_comparison_results(comparison_results):
2004
2164
  return fig
2005
2165
 
2006
2166
  def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0,1,2], max_nr=10):
2007
-
2167
+ """
2168
+ Overlay mask outlines on image channels for visual inspection of segmentation quality.
2169
+
2170
+ Args:
2171
+ src (str): Source directory containing the 'masks' folder and image channels.
2172
+ objects (list of str): List of object names (e.g., 'nucleus', 'cell', 'pathogen') whose masks to visualize.
2173
+ channels (list of int): Corresponding channel indices (0-based) for each object.
2174
+ max_nr (int): Maximum number of overlays to display.
2175
+
2176
+ Returns:
2177
+ None. Displays overlays using `plot_images_and_arrays()`.
2178
+
2179
+ Notes:
2180
+ - For each object/channel pair, attempts to load:
2181
+ - The mask from 'masks/{object}_mask_stack'
2182
+ - The image from '{channel+1}' subfolder (1-based indexing)
2183
+ - Assumes `plot_images_and_arrays()` can handle paired overlay inputs.
2184
+ - Random selection and percentile contrast stretching applied for visualization.
2185
+ """
2008
2186
  for object_, channel in zip(objects, channels):
2009
2187
  folders = [os.path.join(src, 'masks', f'{object_}_mask_stack'),
2010
2188
  os.path.join(src,f'{channel+1}')]
@@ -2019,6 +2197,24 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
2019
2197
  randomize=True)
2020
2198
 
2021
2199
  def volcano_plot(coef_df, filename='volcano_plot.pdf'):
2200
+ """
2201
+ Generate and save a volcano plot based on coefficient and p-value data.
2202
+
2203
+ Args:
2204
+ coef_df (pd.DataFrame): DataFrame containing columns:
2205
+ - 'coefficient': effect size for each term.
2206
+ - 'p_value': p-value for each coefficient.
2207
+ - 'condition': category for coloring (e.g., 'pc', 'nc', 'control', 'other').
2208
+ filename (str): File path for saving the plot as a PDF.
2209
+
2210
+ Returns:
2211
+ None. Displays the volcano plot and saves it to the specified file.
2212
+
2213
+ Notes:
2214
+ - Highlights p-value threshold at 0.05 with a horizontal dashed red line.
2215
+ - Uses pre-defined color palette for conditions.
2216
+ - Legend is removed for clarity.
2217
+ """
2022
2218
  palette = {
2023
2219
  'pc': 'red',
2024
2220
  'nc': 'green',
@@ -2046,6 +2242,21 @@ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
2046
2242
  plt.show()
2047
2243
 
2048
2244
  def plot_histogram(df, column, dst=None):
2245
+ """
2246
+ Plot and optionally save a histogram for a specified column in a DataFrame.
2247
+
2248
+ Args:
2249
+ df (pd.DataFrame): DataFrame containing the data.
2250
+ column (str): Name of the column to plot.
2251
+ dst (str or None): Directory to save the figure as a PDF. If None, the figure is not saved.
2252
+
2253
+ Returns:
2254
+ None. Displays the histogram and optionally saves it to disk.
2255
+
2256
+ Notes:
2257
+ - The histogram uses a fixed turquoise bar color (RGB: 0,155,155).
2258
+ - Saved file is named '{column}_histogram.pdf' and placed in the specified `dst` directory.
2259
+ """
2049
2260
  # Plot histogram of the dependent variable
2050
2261
  bar_color = (0/255, 155/255, 155/255)
2051
2262
  plt.figure(figsize=(10, 10))
@@ -2061,10 +2272,28 @@ def plot_histogram(df, column, dst=None):
2061
2272
 
2062
2273
  plt.show()
2063
2274
 
2064
- def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count',
2065
- remove_keys=None,
2066
- x_lim=[0.0, 1], y_lim=[0, 1], remove_outliers=False, save=True):
2067
-
2275
+ def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count', remove_keys=None, x_lim=[0.0, 1], y_lim=[0, 1], remove_outliers=False, save=True):
2276
+ """
2277
+ Plot Lorenz curves and compute Gini coefficients for a set of CSV files.
2278
+
2279
+ Args:
2280
+ csv_files (list of str): Paths to CSV files containing count data.
2281
+ name_column (str): Column name containing the gRNA or entity names.
2282
+ value_column (str): Column name with the numerical values (e.g., counts).
2283
+ remove_keys (list or None): List of names to exclude from analysis.
2284
+ x_lim (list): X-axis limits for the plot.
2285
+ y_lim (list): Y-axis limits for the plot.
2286
+ remove_outliers (bool): Whether to remove outlier entities by well counts.
2287
+ save (bool): If True, saves the plot to a PDF.
2288
+
2289
+ Returns:
2290
+ None. Displays the plot and prints Gini coefficients.
2291
+
2292
+ Notes:
2293
+ - Gini coefficient is a measure of inequality (0 = perfect equality, 1 = maximal inequality).
2294
+ - Outlier removal is based on IQR of well counts per entity name.
2295
+ - Saves to 'lorenz_curve_with_gini.pdf' under a 'results' folder next to the first input CSV.
2296
+ """
2068
2297
  def lorenz_curve(data):
2069
2298
  """Calculate Lorenz curve."""
2070
2299
  sorted_data = np.sort(data)
@@ -2153,6 +2382,23 @@ def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count',
2153
2382
  print(f"{plate}: Gini Coefficient = {gini:.4f}")
2154
2383
 
2155
2384
  def plot_permutation(permutation_df):
2385
+ """
2386
+ Plot permutation feature importance as a horizontal bar chart.
2387
+
2388
+ Args:
2389
+ permutation_df (pd.DataFrame): DataFrame with columns:
2390
+ - 'feature': Feature names.
2391
+ - 'importance_mean': Mean permutation importance.
2392
+ - 'importance_std': Standard deviation of permutation importance.
2393
+
2394
+ Returns:
2395
+ matplotlib.figure.Figure: The resulting plot figure object.
2396
+
2397
+ Notes:
2398
+ - Dynamically adjusts figure size and font size based on number of features.
2399
+ - Error bars represent the standard deviation across permutation runs.
2400
+ """
2401
+
2156
2402
  num_features = len(permutation_df)
2157
2403
  fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
2158
2404
  fig_width = 10 # Width can be fixed or adjusted similarly
@@ -2166,6 +2412,21 @@ def plot_permutation(permutation_df):
2166
2412
  return fig
2167
2413
 
2168
2414
  def plot_feature_importance(feature_importance_df):
2415
+ """
2416
+ Plot feature importance as a horizontal bar chart.
2417
+
2418
+ Args:
2419
+ feature_importance_df (pd.DataFrame): DataFrame with columns:
2420
+ - 'feature': Feature names.
2421
+ - 'importance': Importance scores for each feature.
2422
+
2423
+ Returns:
2424
+ matplotlib.figure.Figure: The resulting plot figure object.
2425
+
2426
+ Notes:
2427
+ - Dynamically adjusts figure size and font size based on number of features.
2428
+ - Use for visualizing static (e.g., model-based) feature importance.
2429
+ """
2169
2430
  num_features = len(feature_importance_df)
2170
2431
  fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
2171
2432
  fig_width = 10 # Width can be fixed or adjusted similarly
@@ -2179,6 +2440,23 @@ def plot_feature_importance(feature_importance_df):
2179
2440
  return fig
2180
2441
 
2181
2442
  def read_and_plot__vision_results(base_dir, y_axis='accuracy', name_split='_time', y_lim=[0.8, 0.9]):
2443
+ """
2444
+ Read multiple vision model test result CSV files and generate a bar plot of average performance.
2445
+
2446
+ Args:
2447
+ base_dir (str): Base directory to recursively search for files ending in '_test_result.csv'.
2448
+ y_axis (str): The performance metric to plot (must be a column in each CSV). Default is 'accuracy'.
2449
+ name_split (str): String delimiter used to extract the model name from the filename. Default is '_time'.
2450
+ y_lim (list or tuple): Optional y-axis limits for the plot. Default is [0.8, 0.9].
2451
+
2452
+ Returns:
2453
+ None. Displays a matplotlib bar plot of average `y_axis` for each model.
2454
+
2455
+ Notes:
2456
+ - Assumes each file is named with format: '<model>_time<timestamp>_test_result.csv'.
2457
+ - Extracts 'model' and 'epoch' metadata from the path and filename.
2458
+ - Saves results to a 'result' folder inside `base_dir` (created if not present).
2459
+ """
2182
2460
  # List to store data from all CSV files
2183
2461
  data_frames = []
2184
2462
 
@@ -2321,7 +2599,7 @@ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summ
2321
2599
  """
2322
2600
  Create a grouped plot, perform statistical tests, and optionally export the results along with the plot.
2323
2601
 
2324
- Parameters:
2602
+ Args:
2325
2603
  - df: DataFrame containing the data.
2326
2604
  - grouping_column: Column name for the categorical grouping.
2327
2605
  - data_column: Column name for the data to be grouped and plotted.
@@ -2479,15 +2757,66 @@ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summ
2479
2757
  return plt.gcf(), results_df
2480
2758
 
2481
2759
  class spacrGraph:
2760
+ """
2761
+ Class for generating grouped plots with optional data preprocessing,
2762
+ statistical comparisons, theming, and output control.
2763
+
2764
+ This class is designed to support common grouped plotting tasks for CRISPR screens,
2765
+ enabling flexible summarization across different experimental representations
2766
+ (e.g., object-level, well-level, or plate-level).
2767
+
2768
+ Args:
2769
+ df (pd.DataFrame): Input DataFrame containing data to plot.
2770
+ grouping_column (str): Column to group by on the x-axis.
2771
+ data_column (str or list): One or more numeric columns to summarize and plot.
2772
+ graph_type (str): Type of plot to create (e.g., 'bar', 'box').
2773
+ summary_func (str or callable): Function used to summarize data (e.g., 'mean', 'median').
2774
+ order (list): Optional list to define the order of categories on the x-axis.
2775
+ colors (list): Optional custom list of colors for plotting.
2776
+ output_dir (str): Path to directory where output will be saved.
2777
+ save (bool): Whether to save the plot as a PDF.
2778
+ y_lim (tuple): Tuple defining y-axis limits.
2779
+ log_y (bool): If True, applies log scaling to y-axis.
2780
+ log_x (bool): If True, applies log scaling to x-axis.
2781
+ error_bar_type (str): Error bar type to display ('std' or 'sem').
2782
+ remove_outliers (bool): Whether to exclude statistical outliers (not implemented here).
2783
+ theme (str): Seaborn theme to use for the plot.
2784
+ representation (str): Level of summarization; one of {'object', 'well', 'plate'}.
2785
+ paired (bool): If True, assumes samples are paired for statistical tests.
2786
+ all_to_all (bool): If True, performs all-to-all comparisons.
2787
+ compare_group (str): Optional group to compare all others against.
2788
+ graph_name (str): Optional name used in output file naming.
2789
+ """
2482
2790
  def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
2483
2791
  order=None, colors=None, output_dir='./output', save=False, y_lim=None, log_y=False,
2484
2792
  log_x=False, error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
2485
2793
  paired=False, all_to_all=True, compare_group=None, graph_name=None):
2486
-
2487
- """
2488
- Class for creating grouped plots with optional statistical tests and data preprocessing.
2489
2794
  """
2795
+ Initialize a spacrGraph instance for grouped data visualization.
2490
2796
 
2797
+ Args:
2798
+ df (pd.DataFrame): Input dataframe containing data to be plotted.
2799
+ grouping_column (str): Column name to group data along the x-axis.
2800
+ data_column (str or list): Column(s) containing values to summarize and plot.
2801
+ graph_type (str): Type of plot to generate ('bar', 'box', etc.). Default is 'bar'.
2802
+ summary_func (str or callable): Summary statistic to apply per group (e.g., 'mean', 'median').
2803
+ order (list or None): Optional order of groups on the x-axis. If None, will be inferred from the data.
2804
+ colors (list or None): List of colors to use for plotting. Defaults to seaborn palette.
2805
+ output_dir (str): Directory where plot will be saved if `save=True`.
2806
+ save (bool): Whether to save the plot as a PDF.
2807
+ y_lim (tuple or None): Y-axis limits. If None, determined automatically.
2808
+ log_y (bool): If True, use logarithmic scale on the y-axis.
2809
+ log_x (bool): If True, use logarithmic scale on the x-axis.
2810
+ error_bar_type (str): Type of error bars to include ('std' or 'sem').
2811
+ remove_outliers (bool): Whether to remove outliers from the data before plotting.
2812
+ theme (str): Seaborn color theme for plotting.
2813
+ representation (str): Level of summarization: 'object', 'well', or 'plate'.
2814
+ paired (bool): Whether samples are paired for statistical testing.
2815
+ all_to_all (bool): If True, perform all pairwise comparisons.
2816
+ compare_group (str or None): If provided, compares all groups against this one.
2817
+ graph_name (str or None): Optional string to include in saved file names.
2818
+ """
2819
+
2491
2820
  self.df = df
2492
2821
  self.grouping_column = grouping_column
2493
2822
  self.order = sorted(df[self.grouping_column].unique().tolist())
@@ -2522,13 +2851,29 @@ class spacrGraph:
2522
2851
  self.df = self.preprocess_data()
2523
2852
 
2524
2853
  def _set_theme(self):
2525
- """Set the Seaborn theme and reorder colors if necessary."""
2854
+ """
2855
+ Set the Seaborn color theme and define the plotting color order.
2856
+
2857
+ This method initializes the `sns_palette` using a reordered set of colors
2858
+ from the selected theme.
2859
+ """
2526
2860
  integer_list = list(range(1, 81))
2527
2861
  color_order = [7,9,4,0,3,6,2] + integer_list
2528
2862
  self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
2529
2863
 
2530
2864
  def _set_reordered_theme(self, theme='deep', order=None, n_colors=100, show_theme=False):
2531
- """Set and reorder the Seaborn color palette."""
2865
+ """
2866
+ Generate a Seaborn color palette with an optional reordering.
2867
+
2868
+ Parameters:
2869
+ theme (str): Seaborn palette name (e.g., 'deep', 'pastel').
2870
+ order (list): List of integer indices defining color reorder.
2871
+ n_colors (int): Number of colors to include in the palette.
2872
+ show_theme (bool): If True, displays the palette visually.
2873
+
2874
+ Returns:
2875
+ list: Reordered color palette.
2876
+ """
2532
2877
  palette = sns.color_palette(theme, n_colors)
2533
2878
  if order:
2534
2879
  reordered_palette = [palette[i] for i in order]
@@ -2539,25 +2884,21 @@ class spacrGraph:
2539
2884
  plt.show()
2540
2885
  return reordered_palette
2541
2886
 
2542
- #def preprocess_data(self):
2543
- # """Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
2544
- # # Remove NaNs in both the grouping column and each data column
2545
- # df = self.df.dropna(subset=[self.grouping_column] + self.data_column)
2546
- # # Group by 'prc' column if representation is 'well'
2547
- # if self.representation == 'well':
2548
- # df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
2549
- # if self.representation == 'plateID':
2550
- # df = df.groupby(['plateID', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
2551
- # if self.order:
2552
- # df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
2553
- # else:
2554
- # df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=sorted(df[self.grouping_column].unique()), ordered=True)
2555
- # return df
2556
-
2557
2887
  def preprocess_data(self):
2558
2888
  """
2559
- Preprocess the data: remove NaNs, optionally ensure 'plateID' column is created,
2560
- then group by either 'prc', 'plateID', or do no grouping at all if representation == 'object'.
2889
+ Preprocess the input DataFrame by removing NaNs and optionally aggregating data.
2890
+
2891
+ Aggregation is controlled by the `representation` parameter:
2892
+ - 'object': No aggregation; raw data is retained.
2893
+ - 'well': Data is grouped by ['prc', grouping_column].
2894
+ - 'plate': Data is grouped by ['plateID', grouping_column],
2895
+ extracting plateID from the 'prc' column if needed.
2896
+
2897
+ Ordering of the x-axis categories is set using the `order` parameter
2898
+ or automatically sorted if not provided.
2899
+
2900
+ Returns:
2901
+ pd.DataFrame: Cleaned and optionally aggregated DataFrame.
2561
2902
  """
2562
2903
  # 1) Remove NaNs in both the grouping column and each data column
2563
2904
  df = self.df.dropna(subset=[self.grouping_column] + self.data_column)
@@ -2613,7 +2954,13 @@ class spacrGraph:
2613
2954
  return df
2614
2955
 
2615
2956
  def remove_outliers_from_plot(self):
2616
- """Remove outliers from the plot but keep them in the data."""
2957
+ """
2958
+ Remove outliers from each group in the dataset for plotting purposes.
2959
+ This method applies the IQR method to filter extreme values per group.
2960
+
2961
+ Returns:
2962
+ pd.DataFrame: Filtered dataframe with outliers removed for plotting.
2963
+ """
2617
2964
  filtered_df = self.df.copy()
2618
2965
  unique_groups = filtered_df[self.grouping_column].unique()
2619
2966
  for group in unique_groups:
@@ -2627,7 +2974,18 @@ class spacrGraph:
2627
2974
  return filtered_df
2628
2975
 
2629
2976
  def perform_normality_tests(self):
2630
- """Perform normality tests for each group and data column."""
2977
+ """
2978
+ Perform normality tests on each group and data column in the dataframe.
2979
+
2980
+ Uses:
2981
+ - Shapiro-Wilk test for sample sizes < 8
2982
+ - D'Agostino-Pearson test for sample sizes ≥ 8
2983
+
2984
+ Returns:
2985
+ tuple:
2986
+ - is_normal (bool): True if all groups pass normality test (p > 0.05)
2987
+ - normality_results (list): List of dictionaries with test results
2988
+ """
2631
2989
  unique_groups = self.df[self.grouping_column].unique()
2632
2990
  normality_results = []
2633
2991
 
@@ -2674,13 +3032,40 @@ class spacrGraph:
2674
3032
  return is_normal, normality_results
2675
3033
 
2676
3034
  def perform_levene_test(self, unique_groups):
2677
- """Perform Levene's test for equal variance."""
3035
+ """
3036
+ Perform Levene’s test for homogeneity of variance across groups.
3037
+
3038
+ Args:
3039
+ unique_groups (list): List of group identifiers to compare.
3040
+
3041
+ Returns:
3042
+ tuple:
3043
+ - stat (float): Levene test statistic
3044
+ - p_value (float): p-value for the test
3045
+ """
2678
3046
  grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2679
3047
  stat, p_value = levene(*grouped_data)
2680
3048
  return stat, p_value
2681
3049
 
2682
3050
  def perform_statistical_tests(self, unique_groups, is_normal):
2683
- """Perform statistical tests separately for each data column."""
3051
+ """
3052
+ Perform appropriate statistical tests based on normality and group count.
3053
+
3054
+ If 2 groups:
3055
+ - Paired or unpaired t-test (if normal)
3056
+ - Wilcoxon or Mann-Whitney U test (if non-normal)
3057
+
3058
+ If >2 groups:
3059
+ - ANOVA (if normal)
3060
+ - Kruskal-Wallis test (if non-normal)
3061
+
3062
+ Args:
3063
+ unique_groups (list): List of unique group names.
3064
+ is_normal (bool): Whether the data passes normality tests.
3065
+
3066
+ Returns:
3067
+ list: List of dictionaries with test results for each data column.
3068
+ """
2684
3069
  test_results = []
2685
3070
  for column in self.data_column: # Iterate over each data column
2686
3071
  grouped_data = [self.df.loc[self.df[self.grouping_column] == group, column] for group in unique_groups]
@@ -2720,8 +3105,22 @@ class spacrGraph:
2720
3105
  return test_results
2721
3106
 
2722
3107
  def perform_posthoc_tests(self, is_normal, unique_groups):
2723
- """Perform post-hoc tests for multiple groups based on all_to_all flag."""
3108
+ """
3109
+ Perform post-hoc tests for multiple groups, depending on normality and group count.
3110
+
3111
+ If data are normally distributed and `all_to_all` is True:
3112
+ - Tukey HSD test is performed for pairwise comparisons.
3113
+
3114
+ If not normal and `all_to_all` is True:
3115
+ - Dunn's test is performed using appropriate p-value adjustment.
2724
3116
 
3117
+ Args:
3118
+ is_normal (bool): Whether the data passed normality tests.
3119
+ unique_groups (list): List of group identifiers.
3120
+
3121
+ Returns:
3122
+ list: List of dictionaries with post-hoc test results per pairwise comparison.
3123
+ """
2725
3124
  from .sp_stats import choose_p_adjust_method
2726
3125
 
2727
3126
  posthoc_results = []
@@ -2776,10 +3175,32 @@ class spacrGraph:
2776
3175
  return posthoc_results
2777
3176
 
2778
3177
  def create_plot(self, ax=None):
2779
- """Create and display the plot based on the chosen graph type."""
3178
+ """
3179
+ Create the plot based on the selected `graph_type`.
2780
3180
 
3181
+ If `graph_type` supports it and multiple `data_column` values are used,
3182
+ symbols can be placed below the plot to indicate groupings and correspondence.
3183
+
3184
+ Args:
3185
+ ax (matplotlib.axes.Axes, optional): An existing axis to draw the plot on.
3186
+ If None, a new figure and axis are created.
3187
+
3188
+ Returns:
3189
+ matplotlib.figure.Figure: The resulting figure.
3190
+ """
2781
3191
  def _generate_tabels(unique_groups):
2782
- """Generate row labels and a symbol table for multi-level grouping."""
3192
+ """
3193
+ Generate symbolic representation of the grouping and data columns
3194
+ for annotation under plots (used internally by create_plot).
3195
+
3196
+ Args:
3197
+ unique_groups (list): List of group identifiers.
3198
+
3199
+ Returns:
3200
+ tuple:
3201
+ - row_labels (list): List of label strings for each row.
3202
+ - transposed_table (list of lists): Symbolic table aligned with x-axis.
3203
+ """
2783
3204
  # Create row labels: Include the grouping column and data columns
2784
3205
  row_labels = [self.grouping_column] + self.data_column
2785
3206
 
@@ -2812,13 +3233,14 @@ class spacrGraph:
2812
3233
 
2813
3234
  def _place_symbols(row_labels, transposed_table, x_positions, ax):
2814
3235
  """
2815
- Places symbols and row labels aligned under the bars or jitter points on the graph.
2816
-
2817
- Parameters:
2818
- - row_labels: List of row titles to be displayed along the y-axis.
2819
- - transposed_table: Data to be placed under each bar/jitter as symbols.
2820
- - x_positions: X-axis positions for each group to align the symbols.
2821
- - ax: The matplotlib Axes object where the plot is drawn.
3236
+ Place text-based symbols under plot elements to annotate categories
3237
+ and features in multi-group, multi-feature plots.
3238
+
3239
+ Args:
3240
+ row_labels (list): Row label names.
3241
+ transposed_table (list of lists): Symbolic table to annotate.
3242
+ x_positions (list of float): X-coordinates for symbol placement.
3243
+ ax (matplotlib.axes.Axes): Axis object where symbols are placed.
2822
3244
  """
2823
3245
  # Get plot dimensions and adjust for different plot sizes
2824
3246
  y_axis_min = ax.get_ylim()[0] # Minimum y-axis value (usually 0)
@@ -2845,6 +3267,17 @@ class spacrGraph:
2845
3267
  ax.figure.canvas.draw()
2846
3268
 
2847
3269
  def _get_positions(self, ax):
3270
+ """
3271
+ Get x-axis positions of plotted elements depending on graph type.
3272
+
3273
+ Used to align symbols and annotations below plot elements.
3274
+
3275
+ Args:
3276
+ ax (matplotlib.axes.Axes): Axis object from which positions are extracted.
3277
+
3278
+ Returns:
3279
+ list of float: X-axis coordinates for each group/data element.
3280
+ """
2848
3281
  if self.graph_type in ['bar','jitter_bar']:
2849
3282
  x_positions = [np.mean(bar.get_paths()[0].vertices[:, 0]) for bar in ax.collections if hasattr(bar, 'get_paths')]
2850
3283
 
@@ -2863,7 +3296,24 @@ class spacrGraph:
2863
3296
  return x_positions
2864
3297
 
2865
3298
  def _draw_comparison_lines(ax, x_positions):
2866
- """Draw comparison lines and annotate significance based on results_df."""
3299
+ """
3300
+ Draw horizontal lines between group pairs to indicate statistical comparisons and significance.
3301
+
3302
+ This function annotates the plot using entries from `self.results_df`, which must contain
3303
+ a 'Comparison' column with group pair strings formatted as 'group1 vs group2' and a 'p-value' column.
3304
+
3305
+ Args:
3306
+ ax (matplotlib.axes.Axes): The axis on which to draw the lines and annotations.
3307
+ x_positions (list of float): The x-axis positions of each group, in order matching `unique_groups`.
3308
+
3309
+ Notes:
3310
+ - Significance markers:
3311
+ '***' for p ≤ 0.001
3312
+ '**' for p ≤ 0.01
3313
+ '*' for p ≤ 0.05
3314
+ 'ns' for p > 0.05
3315
+ - Lines are staggered vertically to avoid overlap.
3316
+ """
2867
3317
  if self.results_df.empty:
2868
3318
  print("No comparisons available to annotate.")
2869
3319
  return
@@ -2996,15 +3446,18 @@ class spacrGraph:
2996
3446
 
2997
3447
  def _standerdize_figure_format(self, ax, num_groups, graph_type):
2998
3448
  """
2999
- Adjusts the figure layout (size, bar width, jitter, and spacing) based on the number of groups.
3000
-
3001
- Parameters:
3002
- - ax: The matplotlib Axes object.
3003
- - num_groups: Number of unique groups.
3004
- - graph_type: The type of graph (e.g., 'bar', 'jitter', 'box', etc.).
3449
+ Standardize and adjust the visual format of the plot based on the number of groups and plot type.
3005
3450
 
3006
- Returns:
3007
- - None. Modifies the figure and Axes in place.
3451
+ Args:
3452
+ ax (matplotlib.axes.Axes): The axis object of the plot.
3453
+ num_groups (int): Number of groups/categories on the x-axis.
3454
+ graph_type (str): Type of plot, e.g., 'bar', 'jitter', 'box', 'line', etc.
3455
+
3456
+ Notes:
3457
+ - Dynamically sets figure size based on group count.
3458
+ - Adjusts bar widths, jitter offsets, and legend font sizes.
3459
+ - Skips formatting for line plots.
3460
+ - Intended to maintain a visually consistent appearance across graphs.
3008
3461
  """
3009
3462
  if graph_type in ['line', 'line_std']:
3010
3463
  print("Skipping layout adjustment for line graphs.")
@@ -3068,7 +3521,19 @@ class spacrGraph:
3068
3521
  ax.figure.canvas.draw()
3069
3522
 
3070
3523
  def _create_bar_plot(self, ax):
3071
- """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
3524
+ """
3525
+ Create a bar plot using seaborn, with error bars and standardized layout.
3526
+
3527
+ Args:
3528
+ ax (matplotlib.axes.Axes): Axis object to draw the bar plot.
3529
+
3530
+ Notes:
3531
+ - Supports multiple or single data columns.
3532
+ - Uses melted DataFrame (self.df_melted).
3533
+ - Adds error bars based on `self.error_bar_type` (std or sem).
3534
+ - Bar width and position are adjusted manually.
3535
+ - Log scaling is applied if enabled.
3536
+ """
3072
3537
  # Flatten DataFrame: Combine grouping column and data column into one group if needed
3073
3538
  if len(self.data_column) > 1:
3074
3539
  self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
@@ -3109,7 +3574,19 @@ class spacrGraph:
3109
3574
  ax.set_xscale('log')
3110
3575
 
3111
3576
  def _create_jitter_plot(self, ax):
3112
- """Helper method to create a jitter plot (strip plot) with consistent spacing."""
3577
+ """
3578
+ Create a jitter (strip) plot using seaborn with consistent spacing.
3579
+
3580
+ Args:
3581
+ ax (matplotlib.axes.Axes): Axis object to draw the jitter plot.
3582
+
3583
+ Notes:
3584
+ - Supports multiple or single data columns.
3585
+ - Uses melted DataFrame (self.df_melted).
3586
+ - Point size and jitter are controlled via class attributes.
3587
+ - Legend is deduplicated and managed manually.
3588
+ - Log scaling is applied if enabled.
3589
+ """
3113
3590
  # Combine grouping column and data column if needed
3114
3591
  if len(self.data_column) > 1:
3115
3592
  self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
@@ -3138,7 +3615,18 @@ class spacrGraph:
3138
3615
  ax.set_xscale('log')
3139
3616
 
3140
3617
  def _create_line_graph(self, ax):
3141
- """Helper method to create a line graph with one line per group based on epochs and accuracy."""
3618
+ """
3619
+ Create a line plot where each group is represented by a separate line.
3620
+
3621
+ Args:
3622
+ ax (matplotlib.axes.Axes): Axis object to draw the line graph.
3623
+
3624
+ Notes:
3625
+ - Expects `self.data_column` to contain two columns: x and y axes.
3626
+ - One line per group in `self.grouping_column`.
3627
+ - Log scaling applied as specified.
3628
+ - Raises ValueError if required columns are missing.
3629
+ """
3142
3630
  #display(self.df)
3143
3631
  # Ensure epoch is used on the x-axis and accuracy on the y-axis
3144
3632
  x_axis_column = self.data_column[0]
@@ -3167,7 +3655,18 @@ class spacrGraph:
3167
3655
  ax.set_ylabel(f"{y_axis_column}")
3168
3656
 
3169
3657
  def _create_line_with_std_area(self, ax):
3170
- """Helper method to create a line graph with shaded area representing standard deviation."""
3658
+ """
3659
+ Create a line plot with shaded standard deviation region around the mean.
3660
+
3661
+ Args:
3662
+ ax (matplotlib.axes.Axes): Axis object to draw the plot.
3663
+
3664
+ Notes:
3665
+ - Expects `self.data_column = [x_column, y_column]`.
3666
+ - Computes mean and std of `y_column` per `x_column`.
3667
+ - Plots the mean as a line and ±1 std as a shaded region.
3668
+ - Applies log scaling if enabled.
3669
+ """
3171
3670
 
3172
3671
  x_axis_column = self.data_column[0]
3173
3672
  y_axis_column = self.data_column[1]
@@ -3198,7 +3697,18 @@ class spacrGraph:
3198
3697
  ax.set_ylabel(f"{y_axis_column}")
3199
3698
 
3200
3699
  def _create_box_plot(self, ax):
3201
- """Helper method to create a box plot with consistent spacing."""
3700
+ """
3701
+ Create a box plot for each group with optional grouping by multiple columns.
3702
+
3703
+ Args:
3704
+ ax (matplotlib.axes.Axes): Axis object to draw the box plot.
3705
+
3706
+ Notes:
3707
+ - Supports multiple data columns by combining into 'Combined Group'.
3708
+ - Uses `self.df_melted`.
3709
+ - Log scaling applied if enabled.
3710
+ - Legend is deduplicated.
3711
+ """
3202
3712
  # Combine grouping column and data column if needed
3203
3713
  if len(self.data_column) > 1:
3204
3714
  self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
@@ -3227,7 +3737,18 @@ class spacrGraph:
3227
3737
  ax.set_xscale('log')
3228
3738
 
3229
3739
  def _create_violin_plot(self, ax):
3230
- """Helper method to create a violin plot with consistent spacing."""
3740
+ """
3741
+ Create a violin plot for each group, showing data distribution.
3742
+
3743
+ Args:
3744
+ ax (matplotlib.axes.Axes): Axis object to draw the violin plot.
3745
+
3746
+ Notes:
3747
+ - Supports multiple data columns via combined group label.
3748
+ - Uses `self.df_melted`.
3749
+ - Log scaling applied if enabled.
3750
+ - Legend is deduplicated.
3751
+ """
3231
3752
  # Combine grouping column and data column if needed
3232
3753
  if len(self.data_column) > 1:
3233
3754
  self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
@@ -3257,7 +3778,19 @@ class spacrGraph:
3257
3778
  ax.set_xscale('log')
3258
3779
 
3259
3780
  def _create_jitter_bar_plot(self, ax):
3260
- """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
3781
+ """
3782
+ Create a combined jitter and bar plot with optional error bars.
3783
+
3784
+ Args:
3785
+ ax (matplotlib.axes.Axes): Axis object to draw the plot.
3786
+
3787
+ Notes:
3788
+ - Bars show central tendency; jitter adds point-level detail.
3789
+ - Supports multiple data columns via combined group label.
3790
+ - Uses `self.df_melted` and `self.error_bar_type`.
3791
+ - Log scaling applied if enabled.
3792
+ - Error bars currently commented out; customize if needed.
3793
+ """
3261
3794
  # Flatten DataFrame: Combine grouping column and data column into one group if needed
3262
3795
  if len(self.data_column) > 1:
3263
3796
  self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
@@ -3299,7 +3832,19 @@ class spacrGraph:
3299
3832
  ax.set_xscale('log')
3300
3833
 
3301
3834
  def _create_jitter_box_plot(self, ax):
3302
- """Helper method to create a box plot with consistent spacing."""
3835
+ """
3836
+ Create a combined jitter and box plot.
3837
+
3838
+ Args:
3839
+ ax (matplotlib.axes.Axes): Axis object to draw the plot.
3840
+
3841
+ Notes:
3842
+ - Uses seaborn `boxplot` and overlays `stripplot`.
3843
+ - Supports multiple data columns via combined group label.
3844
+ - Uses `self.df_melted`.
3845
+ - Legend is deduplicated.
3846
+ - Log scaling applied if enabled.
3847
+ """
3303
3848
  # Combine grouping column and data column if needed
3304
3849
  if len(self.data_column) > 1:
3305
3850
  self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
@@ -3329,7 +3874,14 @@ class spacrGraph:
3329
3874
  ax.set_xscale('log')
3330
3875
 
3331
3876
  def _save_results(self):
3332
- """Helper method to save the plot and results."""
3877
+ """
3878
+ Save the figure and results DataFrame to disk.
3879
+
3880
+ Notes:
3881
+ - PDF saved to `self.output_dir` with name based on `self.results_name`.
3882
+ - CSV of results saved in the same directory.
3883
+ - Ensures output directory exists.
3884
+ """
3333
3885
  os.makedirs(self.output_dir, exist_ok=True)
3334
3886
  plot_path = os.path.join(self.output_dir, f"{self.results_name}.pdf")
3335
3887
  self.fig.savefig(plot_path, bbox_inches='tight', dpi=600, transparent=True, format='pdf')
@@ -3339,30 +3891,65 @@ class spacrGraph:
3339
3891
  print(f"Test results saved to {results_path}")
3340
3892
 
3341
3893
  def get_results(self):
3342
- """Return the results dataframe."""
3894
+ """
3895
+ Return the internal results DataFrame with computed statistics or comparisons.
3896
+
3897
+ Returns:
3898
+ pd.DataFrame: Results table.
3899
+ """
3343
3900
  return self.results_df
3344
3901
 
3345
3902
  def get_figure(self):
3346
- """Return the generated figure."""
3903
+ """
3904
+ Return the Matplotlib figure associated with the most recent plot.
3905
+
3906
+ Returns:
3907
+ matplotlib.figure.Figure: The plot figure.
3908
+ """
3347
3909
  return self.fig
3348
3910
 
3349
3911
  def plot_data_from_db(settings):
3350
-
3351
- from .io import _read_db, _read_and_merge_data
3352
- from .utils import annotate_conditions, save_settings
3353
- from .settings import set_default_plot_data_from_db
3354
-
3355
3912
  """
3356
- Extracts the specified table from the SQLite database and plots a specified column.
3913
+ Load measurement data from SQL databases, annotate with experimental conditions,
3914
+ and generate grouped plots using spacrGraph.
3357
3915
 
3358
3916
  Args:
3359
- db_path (str): The path to the SQLite database.
3360
- table_names (str): The name of the table to extract.
3361
- data_column (str): The column to plot from the table.
3917
+ settings (dict): A dictionary containing the following keys:
3918
+ - 'src' (str or list of str): Source directories containing measurement databases.
3919
+ - 'database' (str or list of str): Corresponding database filenames.
3920
+ - 'table_names' (str): Name of the SQL table to extract data from.
3921
+ - 'graph_name' (str): Name for saving the output plot and settings.
3922
+ - 'grouping_column' (str): Column to group by on the x-axis.
3923
+ - 'data_column' (str): Column to plot on the y-axis.
3924
+ - 'channel_of_interest' (int): Channel number for recruitment calculation (if applicable).
3925
+ - 'cell_types' (list): List of expected host cell types.
3926
+ - 'pathogen_types' (list): List of expected pathogen types.
3927
+ - 'treatments' (list): List of expected treatments.
3928
+ - 'cell_plate_metadata' (str or None): Path to metadata for host cells.
3929
+ - 'pathogen_plate_metadata' (str or None): Path to metadata for pathogens.
3930
+ - 'treatment_plate_metadata' (str or None): Path to metadata for treatments.
3931
+ - 'nuclei_limit' (int): Max number of nuclei to load per plate (optional).
3932
+ - 'pathogen_limit' (int): Max number of pathogens to load per plate (optional).
3933
+ - 'verbose' (bool): Whether to print loading information.
3934
+ - 'graph_type' (str): Type of plot to create (e.g., 'bar', 'box', 'jitter', etc.).
3935
+ - 'representation' (str): Data representation style for plotting (optional).
3936
+ - 'theme' (str): Seaborn color palette theme.
3937
+ - 'y_lim' (tuple or None): Limits for y-axis scaling.
3938
+ - 'save' (bool): Whether to save the plot and results to disk.
3362
3939
 
3363
3940
  Returns:
3364
- df (pd.DataFrame): The extracted table as a DataFrame.
3941
+ fig (matplotlib.figure.Figure): The plotted figure.
3942
+ results_df (pd.DataFrame): Statistical results or summary data used in the plot.
3943
+
3944
+ Notes:
3945
+ - Automatically handles multi-source input.
3946
+ - Computes recruitment ratio if specified.
3947
+ - Drops rows with missing grouping or data values.
3948
+ - Creates and saves plot and settings to disk if 'save' is True.
3365
3949
  """
3950
+ from .io import _read_db, _read_and_merge_data
3951
+ from .utils import annotate_conditions, save_settings
3952
+ from .settings import set_default_plot_data_from_db
3366
3953
 
3367
3954
  settings = set_default_plot_data_from_db(settings)
3368
3955
 
@@ -3472,21 +4059,40 @@ def plot_data_from_db(settings):
3472
4059
  return fig, results_df
3473
4060
 
3474
4061
  def plot_data_from_csv(settings):
3475
- from .io import _read_db, _read_and_merge_data
3476
- from .utils import annotate_conditions, save_settings, remove_outliers_by_group
3477
- from .settings import get_plot_data_from_csv_default_settings
3478
4062
  """
3479
- Extracts the specified table from the SQLite database and plots a specified column.
4063
+ Load measurement data from one or more CSV files, optionally filter and clean the data,
4064
+ and generate grouped plots using spacrGraph.
3480
4065
 
3481
4066
  Args:
3482
- db_path (str): The path to the SQLite database.
3483
- table_names (str): The name of the table to extract.
3484
- data_column (str): The column to plot from the table.
4067
+ settings (dict): A dictionary containing the following keys:
4068
+ - 'src' (str or list of str): Path(s) to the CSV file(s) containing measurement data.
4069
+ - 'grouping_column' (str): Column to group by on the x-axis.
4070
+ - 'data_column' (str): Column to plot on the y-axis.
4071
+ - 'graph_name' (str): Name for saving the output plot and settings.
4072
+ - 'graph_type' (str): Type of plot to create (e.g., 'bar', 'box', 'violin', 'jitter').
4073
+ - 'theme' (str): Seaborn color palette theme.
4074
+ - 'log_y' (bool): Whether to log-transform the y-axis.
4075
+ - 'log_x' (bool): Whether to log-transform the x-axis.
4076
+ - 'y_lim' (tuple or None): Limits for y-axis scaling (optional).
4077
+ - 'save' (bool): Whether to save the plot and results to disk.
4078
+ - 'verbose' (bool): Whether to print and display the DataFrame before plotting.
4079
+ - 'representation' (str): Plot style to use (e.g., 'jitter_box', 'violin', etc.).
4080
+ - 'keep_groups' (list or str, optional): Restrict plot to a subset of group labels.
4081
+ - 'remove_outliers' (bool): Whether to remove outliers using IQR filtering.
3485
4082
 
3486
4083
  Returns:
3487
- df (pd.DataFrame): The extracted table as a DataFrame.
4084
+ fig (matplotlib.figure.Figure): The plotted figure.
4085
+ results_df (pd.DataFrame): DataFrame with statistical results or plot summary.
4086
+
4087
+ Notes:
4088
+ - Merges multiple CSVs and tags them with a default 'plateID' if not provided.
4089
+ - Attempts to split 'prc' column into 'plateID', 'rowID', and 'columnID' if applicable.
4090
+ - Handles missing values in the grouping and data columns by dropping those rows.
4091
+ - Automatically creates the output directory and saves results if `save` is True.
3488
4092
  """
3489
-
4093
+ from .io import _read_db, _read_and_merge_data
4094
+ from .utils import annotate_conditions, save_settings, remove_outliers_by_group
4095
+ from .settings import get_plot_data_from_csv_default_settings
3490
4096
 
3491
4097
  def filter_rows_by_column_values(df: pd.DataFrame, column: str, values: list) -> pd.DataFrame:
3492
4098
  """Return a filtered DataFrame where only rows with the column value in the list are kept."""
@@ -3569,7 +4175,44 @@ def plot_data_from_csv(settings):
3569
4175
  return fig, results_df
3570
4176
 
3571
4177
  def plot_region(settings):
4178
+ """
4179
+ Generate and save region-specific plots including: mask overlay, raw PNG grid, and activation map grid.
4180
+
4181
+ This function loads image paths and metadata from database tables, filters for the specified region (field of view),
4182
+ and plots the following:
4183
+
4184
+ - A mask overlay of the full field image with Cellpose masks.
4185
+ - A grid of raw cropped PNGs corresponding to the region.
4186
+ - A grid of activation maps (e.g., saliency, Grad-CAM) for the same region.
4187
+
4188
+ Args:
4189
+ settings (dict): A dictionary containing the following keys:
4190
+
4191
+ - 'src' (str): Source folder containing subfolders like 'merged', 'measurements', and 'datasets'.
4192
+ - 'name' (str): Filename (e.g., 'plate1_A01_01.tif') identifying the region/FOV of interest.
4193
+ - 'activation_db' (str): Filename of the activation measurement database (e.g., 'activations.db').
4194
+ - 'activation_mode' (str): Mode of activation ('saliency', 'gradcam', etc.); used to find the correct table.
4195
+ - 'channels' (list of int): Indices of input channels to use for the mask overlay.
4196
+ - 'cell_channel' (int): Channel index used to generate the cell mask.
4197
+ - 'nucleus_channel' (int): Channel index used to generate the nucleus mask.
4198
+ - 'pathogen_channel' (int): Channel index used to generate the pathogen mask.
4199
+ - 'percentiles' (list): Two-element list (e.g., [2, 99]) specifying intensity percentiles for contrast scaling.
4200
+ - 'mode' (str): Image display mode for overlay ('rgb', 'stack', etc.).
4201
+ - 'export_tiffs' (bool): Whether to export TIFF images alongside overlays.
4202
+
4203
+ Returns:
4204
+ tuple: A 3-element tuple of matplotlib Figure objects or None
3572
4205
 
4206
+ - fig_1 (matplotlib.figure.Figure or None): Mask overlay figure.
4207
+ - fig_2 (matplotlib.figure.Figure or None): Grid of raw cropped PNGs.
4208
+ - fig_3 (matplotlib.figure.Figure or None): Grid of activation maps.
4209
+
4210
+ Notes:
4211
+ - Figures are saved as PDFs under ``<src>/results/<name>/``.
4212
+ - If no relevant PNGs or activations are found, the corresponding figure will be ``None``.
4213
+ - Paths are automatically corrected using ``correct_paths``.
4214
+ - The figure layout uses ``plot_image_grid`` and ``plot_image_mask_overlay``.
4215
+ """
3573
4216
  def _sort_paths_by_basename(paths):
3574
4217
  return sorted(paths, key=lambda path: os.path.basename(path))
3575
4218
 
@@ -3640,7 +4283,7 @@ def plot_image_grid(image_paths, percentiles):
3640
4283
  Plots a square grid of images from a list of image paths.
3641
4284
  Unused subplots are filled with black, and padding is minimized.
3642
4285
 
3643
- Parameters:
4286
+ Args:
3644
4287
  - image_paths: List of paths to images to be displayed.
3645
4288
 
3646
4289
  Returns:
@@ -3798,7 +4441,30 @@ def overlay_masks_on_images(img_folder, normalize=True, resize=True, save=False,
3798
4441
  plt.show()
3799
4442
 
3800
4443
  def graph_importance(settings):
3801
-
4444
+ """
4445
+ Generate and display a bar, box, or violin plot of importance values across grouped categories.
4446
+
4447
+ This function reads one or more CSV files containing importance scores (e.g., feature importances,
4448
+ saliency values, or other metrics) and visualizes them grouped by a specified column. It wraps
4449
+ the `spacrGraph` plotting class to create the plot and saves the settings used.
4450
+
4451
+ Args:
4452
+ settings (dict): A dictionary containing the following keys:
4453
+ - 'csvs' (str or list of str): Path(s) to CSV file(s) containing importance data.
4454
+ - 'grouping_column' (str): Column name used for grouping on the x-axis.
4455
+ - 'data_column' (str): Column name containing the data values to be plotted (e.g., importance scores).
4456
+ - 'graph_type' (str): Type of plot to generate ('bar', 'box', or 'violin').
4457
+ - 'save' (bool): Whether to save the plot as a PDF in the same directory as the input CSVs.
4458
+
4459
+ Returns:
4460
+ None
4461
+
4462
+ Notes:
4463
+ - If the required columns are missing from the input data, the function will print a warning and exit.
4464
+ - The plot is created using the `spacrGraph` class and is shown with `matplotlib.pyplot.show()`.
4465
+ - All input CSVs are concatenated before plotting.
4466
+ - Settings are saved to disk using `save_settings`.
4467
+ """
3802
4468
  from .settings import set_graph_importance_defaults
3803
4469
  from .utils import save_settings
3804
4470
 
@@ -3850,7 +4516,7 @@ def plot_proportion_stacked_bars(settings, df, group_column, bin_column, prc_col
3850
4516
  """
3851
4517
  Generate a stacked bar plot for proportions and perform chi-squared and pairwise tests.
3852
4518
 
3853
- Parameters:
4519
+ Args:
3854
4520
  - settings (dict): Analysis settings.
3855
4521
  - df (DataFrame): Input data.
3856
4522
  - group_column (str): Column indicating the groups.
@@ -3927,7 +4593,7 @@ def create_venn_diagram(file1, file2, gene_column="gene", filter_coeff=0.1, save
3927
4593
  Reads two CSV files, extracts the `gene` column, and creates a Venn diagram
3928
4594
  to show overlapping and non-overlapping genes.
3929
4595
 
3930
- Parameters:
4596
+ Args:
3931
4597
  file1 (str): Path to the first CSV file.
3932
4598
  file2 (str): Path to the second CSV file.
3933
4599
  gene_column (str): Name of the column containing gene data (default: "gene").