spacr 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +3 -2
- spacr/app_classify.py +10 -0
- spacr/app_mask.py +9 -0
- spacr/app_measure.py +9 -0
- spacr/app_sequencing.py +9 -0
- spacr/core.py +172 -1
- spacr/deep_spacr.py +296 -7
- spacr/gui.py +68 -0
- spacr/gui_core.py +319 -10
- spacr/gui_elements.py +772 -13
- spacr/gui_utils.py +304 -12
- spacr/io.py +887 -71
- spacr/logger.py +36 -0
- spacr/measure.py +206 -28
- spacr/ml.py +606 -142
- spacr/plot.py +797 -131
- spacr/sequencing.py +363 -8
- spacr/settings.py +1158 -38
- spacr/sp_stats.py +80 -12
- spacr/spacr_cellpose.py +115 -2
- spacr/submodules.py +747 -19
- spacr/timelapse.py +237 -53
- spacr/toxo.py +132 -6
- spacr/utils.py +2422 -80
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/METADATA +31 -17
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/RECORD +30 -30
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/LICENSE +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/WHEEL +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/entry_points.txt +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/top_level.txt +0 -0
spacr/plot.py
CHANGED
@@ -26,24 +26,33 @@ from ipywidgets import IntSlider, interact
|
|
26
26
|
from IPython.display import Image as ipyimage
|
27
27
|
from matplotlib_venn import venn2
|
28
28
|
|
29
|
-
def plot_image_mask_overlay(
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
figuresize=10,
|
36
|
-
percentiles=(2, 98),
|
37
|
-
thickness=3,
|
38
|
-
save_pdf=True,
|
39
|
-
mode='outlines',
|
40
|
-
export_tiffs=False,
|
41
|
-
all_on_all=False,
|
42
|
-
all_outlines=False,
|
43
|
-
filter_dict=None
|
44
|
-
):
|
45
|
-
"""Plot image and mask overlays."""
|
29
|
+
def plot_image_mask_overlay(file,channels,cell_channel,nucleus_channel,pathogen_channel,figuresize=10,percentiles=(2, 98),thickness=3,save_pdf=True,mode='outlines',export_tiffs=False,all_on_all=False,all_outlines=False,filter_dict=None):
|
30
|
+
"""
|
31
|
+
Plot multi-channel microscopy image overlays with cell, nucleus, and pathogen masks.
|
32
|
+
|
33
|
+
This function visualizes microscopy images with optional mask overlays for different object types.
|
34
|
+
It supports contour and filled mask modes, object filtering based on size and intensity, and saving output.
|
46
35
|
|
36
|
+
Args:
|
37
|
+
file (str): Path to the `.npy` image stack (H x W x C).
|
38
|
+
channels (list): Indices of the image channels to display.
|
39
|
+
cell_channel (int or None): Channel index for the cell mask intensity.
|
40
|
+
nucleus_channel (int or None): Channel index for the nucleus mask intensity.
|
41
|
+
pathogen_channel (int or None): Channel index for the pathogen mask intensity.
|
42
|
+
figuresize (int): Base figure size for each subplot. Default is 10.
|
43
|
+
percentiles (tuple): Percentile range for image normalization. Default is (2, 98).
|
44
|
+
thickness (int): Thickness of mask contour outlines. Default is 3.
|
45
|
+
save_pdf (bool): If True, saves a PDF of the overlaid image. Default is True.
|
46
|
+
mode (str): 'outlines' or 'filled'. Determines how masks are rendered. Default is 'outlines'.
|
47
|
+
export_tiffs (bool): If True, exports grayscale TIFFs for each image channel. Default is False.
|
48
|
+
all_on_all (bool): If True, overlays all outlines on all channels. Default is False.
|
49
|
+
all_outlines (bool): If True, overlays all outlines on non-matching channels. Default is False.
|
50
|
+
filter_dict (dict or None): Dictionary of filtering thresholds for each object type.
|
51
|
+
For example: {"cell": [(min_area, max_area), (min_intensity, max_intensity)]}
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
matplotlib.figure.Figure: The generated overlay figure.
|
55
|
+
"""
|
47
56
|
def random_color_cmap(n_labels, seed=None):
|
48
57
|
"""Generates a random color map for a given number of labels."""
|
49
58
|
if seed is not None:
|
@@ -53,25 +62,7 @@ def plot_image_mask_overlay(
|
|
53
62
|
cmap = ListedColormap(rand_colors)
|
54
63
|
return cmap
|
55
64
|
|
56
|
-
def _plot_merged_plot(
|
57
|
-
image,
|
58
|
-
outlines,
|
59
|
-
outline_colors,
|
60
|
-
figuresize,
|
61
|
-
thickness,
|
62
|
-
percentiles,
|
63
|
-
mode='outlines',
|
64
|
-
all_on_all=False,
|
65
|
-
all_outlines=False,
|
66
|
-
channels=None,
|
67
|
-
cell_channel=None,
|
68
|
-
nucleus_channel=None,
|
69
|
-
pathogen_channel=None,
|
70
|
-
cell_outlines=None,
|
71
|
-
nucleus_outlines=None,
|
72
|
-
pathogen_outlines=None,
|
73
|
-
save_pdf=True
|
74
|
-
):
|
65
|
+
def _plot_merged_plot(image,outlines,outline_colors,figuresize,thickness,percentiles,mode='outlines',all_on_all=False,all_outlines=False,channels=None,cell_channel=None,nucleus_channel=None,pathogen_channel=None,cell_outlines=None,nucleus_outlines=None,pathogen_outlines=None,save_pdf=True):
|
75
66
|
"""Plot the merged plot with overlay, image channels, and masks."""
|
76
67
|
|
77
68
|
def _generate_colored_mask(mask, cmap):
|
@@ -506,7 +497,7 @@ def generate_mask_random_cmap(mask):
|
|
506
497
|
"""
|
507
498
|
Generate a random colormap based on the unique labels in the given mask.
|
508
499
|
|
509
|
-
|
500
|
+
Args:
|
510
501
|
mask (numpy.ndarray): The input mask array.
|
511
502
|
|
512
503
|
Returns:
|
@@ -524,7 +515,7 @@ def random_cmap(num_objects=100):
|
|
524
515
|
"""
|
525
516
|
Generate a random colormap.
|
526
517
|
|
527
|
-
|
518
|
+
Args:
|
528
519
|
num_objects (int): The number of objects to generate colors for. Default is 100.
|
529
520
|
|
530
521
|
Returns:
|
@@ -540,7 +531,7 @@ def _generate_mask_random_cmap(mask):
|
|
540
531
|
"""
|
541
532
|
Generate a random colormap based on the unique labels in the given mask.
|
542
533
|
|
543
|
-
|
534
|
+
Args:
|
544
535
|
mask (ndarray): The mask array containing unique labels.
|
545
536
|
|
546
537
|
Returns:
|
@@ -559,7 +550,7 @@ def _get_colours_merged(outline_color):
|
|
559
550
|
"""
|
560
551
|
Get the merged outline colors based on the specified outline color format.
|
561
552
|
|
562
|
-
|
553
|
+
Args:
|
563
554
|
outline_color (str): The outline color format. Can be one of 'rgb', 'bgr', 'gbr', or 'rbg'.
|
564
555
|
|
565
556
|
Returns:
|
@@ -765,7 +756,7 @@ def plot_arrays(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=1,
|
|
765
756
|
"""
|
766
757
|
Plot randomly selected arrays from a given directory or a single .npz/.npy file.
|
767
758
|
|
768
|
-
|
759
|
+
Args:
|
769
760
|
- src (str): The directory path or file path containing the arrays.
|
770
761
|
- figuresize (int): The size of the figure (default: 10).
|
771
762
|
- cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
|
@@ -820,7 +811,7 @@ def plot_arrays_v1(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=
|
|
820
811
|
"""
|
821
812
|
Plot randomly selected arrays from a given directory.
|
822
813
|
|
823
|
-
|
814
|
+
Args:
|
824
815
|
- src (str): The directory path containing the arrays.
|
825
816
|
- figuresize (int): The size of the figure (default: 50).
|
826
817
|
- cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
|
@@ -1286,7 +1277,7 @@ def _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df, save, src,
|
|
1286
1277
|
"""
|
1287
1278
|
Display the frame with tracks overlaid.
|
1288
1279
|
|
1289
|
-
|
1280
|
+
Args:
|
1290
1281
|
frame (int): The frame number to display.
|
1291
1282
|
|
1292
1283
|
Returns:
|
@@ -1329,7 +1320,7 @@ def _display_gif(path):
|
|
1329
1320
|
"""
|
1330
1321
|
Display a GIF image from the given path.
|
1331
1322
|
|
1332
|
-
|
1323
|
+
Args:
|
1333
1324
|
path (str): The path to the GIF image file.
|
1334
1325
|
|
1335
1326
|
Returns:
|
@@ -1572,6 +1563,26 @@ def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
|
|
1572
1563
|
return fig
|
1573
1564
|
|
1574
1565
|
def _plot_histograms_and_stats(df):
|
1566
|
+
"""
|
1567
|
+
Plot histograms and print summary statistics for prediction values grouped by condition.
|
1568
|
+
|
1569
|
+
Args:
|
1570
|
+
df (pd.DataFrame): DataFrame containing at least two columns:
|
1571
|
+
- 'condition': categorical variable indicating group labels.
|
1572
|
+
- 'pred': continuous prediction scores (e.g., probabilities).
|
1573
|
+
|
1574
|
+
Returns:
|
1575
|
+
None. Displays histograms and prints statistics for each condition.
|
1576
|
+
|
1577
|
+
Printed Statistics per Condition:
|
1578
|
+
- Number of rows
|
1579
|
+
- Mean of 'pred' values
|
1580
|
+
- Count and percentage of 'pred' > 0.5 (positive)
|
1581
|
+
- Count and percentage of 'pred' <= 0.5 (negative)
|
1582
|
+
|
1583
|
+
Plots:
|
1584
|
+
- Histogram of 'pred' values with mean indicated by a dashed red line.
|
1585
|
+
"""
|
1575
1586
|
conditions = df['condition'].unique()
|
1576
1587
|
|
1577
1588
|
for condition in conditions:
|
@@ -1603,7 +1614,22 @@ def _plot_histograms_and_stats(df):
|
|
1603
1614
|
plt.show()
|
1604
1615
|
|
1605
1616
|
def _show_residules(model):
|
1617
|
+
"""
|
1618
|
+
Display diagnostic plots and test for residuals of a fitted regression model.
|
1619
|
+
|
1620
|
+
Args:
|
1621
|
+
model (statsmodels.regression.linear_model.RegressionResultsWrapper):
|
1622
|
+
A fitted statsmodels regression model with `.resid` and `.fittedvalues` attributes.
|
1623
|
+
|
1624
|
+
Returns:
|
1625
|
+
None. Displays plots and prints Shapiro-Wilk test results.
|
1606
1626
|
|
1627
|
+
Diagnostics:
|
1628
|
+
- Histogram of residuals
|
1629
|
+
- QQ plot for normality
|
1630
|
+
- Residuals vs. Fitted values plot
|
1631
|
+
- Shapiro-Wilk test for normality (prints W-statistic and p-value)
|
1632
|
+
"""
|
1607
1633
|
# Get the residuals
|
1608
1634
|
residuals = model.resid
|
1609
1635
|
|
@@ -1632,6 +1658,23 @@ def _show_residules(model):
|
|
1632
1658
|
print(f'Shapiro-Wilk Test W-statistic: {W}, p-value: {p_value}')
|
1633
1659
|
|
1634
1660
|
def _reg_v_plot(df, grouping, variable, plate_number):
|
1661
|
+
"""
|
1662
|
+
Generate a volcano plot for visualizing effect size versus significance.
|
1663
|
+
|
1664
|
+
Args:
|
1665
|
+
df (pd.DataFrame): DataFrame containing regression results with 'effect' and 'p' columns.
|
1666
|
+
grouping (str): Unused in current function but reserved for grouping logic.
|
1667
|
+
variable (str): Unused in current function but could annotate variable of interest.
|
1668
|
+
plate_number (str or int): Unused in current function but could title or segregate plots.
|
1669
|
+
|
1670
|
+
Returns:
|
1671
|
+
None. Displays a volcano plot using matplotlib.
|
1672
|
+
|
1673
|
+
Notes:
|
1674
|
+
- Points with p < 0.05 are labeled by index.
|
1675
|
+
- Color indicates direction of effect (positive or negative).
|
1676
|
+
- Dashed line marks significance threshold (p = 0.05).
|
1677
|
+
"""
|
1635
1678
|
df['-log10(p)'] = -np.log10(df['p'])
|
1636
1679
|
|
1637
1680
|
# Create the volcano plot
|
@@ -1650,7 +1693,27 @@ def _reg_v_plot(df, grouping, variable, plate_number):
|
|
1650
1693
|
plt.show()
|
1651
1694
|
|
1652
1695
|
def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
|
1696
|
+
"""
|
1697
|
+
Generate a heatmap matrix for a specific plate based on the specified variable.
|
1653
1698
|
|
1699
|
+
Args:
|
1700
|
+
df (pd.DataFrame): Input DataFrame containing 'prc' and variable to plot.
|
1701
|
+
plate_number (str): Identifier of the plate to generate the heatmap for.
|
1702
|
+
variable (str): Column name to be visualized.
|
1703
|
+
grouping (str): Aggregation method: 'mean', 'sum', or 'count'.
|
1704
|
+
min_max (str or tuple): Controls the colormap scaling:
|
1705
|
+
- 'all': use full data min/max
|
1706
|
+
- 'allq': use 2nd to 98th percentile
|
1707
|
+
- (low, high): percentiles (float) or fixed values (int)
|
1708
|
+
min_count (int or float): Minimum number of entries per well to be included.
|
1709
|
+
|
1710
|
+
Returns:
|
1711
|
+
plate_map (pd.DataFrame): Pivoted heatmap matrix indexed by row and column.
|
1712
|
+
min_max (list): The computed min/max values for the heatmap color scale.
|
1713
|
+
|
1714
|
+
Raises:
|
1715
|
+
ValueError: If `grouping` is not one of ['mean', 'sum', 'count'].
|
1716
|
+
"""
|
1654
1717
|
if not isinstance(min_count, (int, float)):
|
1655
1718
|
min_count = 0
|
1656
1719
|
|
@@ -1713,6 +1776,27 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
|
|
1713
1776
|
return plate_map, min_max
|
1714
1777
|
|
1715
1778
|
def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None):
|
1779
|
+
"""
|
1780
|
+
Generate and optionally save heatmaps for one or more plates, showing the spatial distribution
|
1781
|
+
of a specified variable (e.g., unique_counts or signal intensity).
|
1782
|
+
|
1783
|
+
Args:
|
1784
|
+
df (pd.DataFrame): Input DataFrame containing 'prc' in the format 'plate_row_col' and the variable to plot.
|
1785
|
+
variable (str): Name of the column to visualize.
|
1786
|
+
grouping (str): Aggregation method for replicates (e.g., 'mean', 'median').
|
1787
|
+
min_max (str): Method to determine color scaling. One of {'global', 'local', 'allq'}.
|
1788
|
+
cmap (str or Colormap): Colormap to use for heatmap visualization.
|
1789
|
+
min_count (int, optional): Minimum count threshold for inclusion in the heatmap. Default is 0.
|
1790
|
+
verbose (bool, optional): If True, displays the heatmap(s) in an interactive window. Default is True.
|
1791
|
+
dst (str, optional): Directory to save the heatmaps. If None, the heatmaps are not saved.
|
1792
|
+
|
1793
|
+
Returns:
|
1794
|
+
matplotlib.figure.Figure: The figure containing the plotted heatmaps.
|
1795
|
+
|
1796
|
+
Side Effects:
|
1797
|
+
- Displays plate heatmaps using matplotlib/seaborn.
|
1798
|
+
- Saves the heatmap figure as 'plate_heatmap_#.pdf' in `dst` if provided and writable.
|
1799
|
+
"""
|
1716
1800
|
plates = df['prc'].str.split('_', expand=True)[0].unique()
|
1717
1801
|
n_rows, n_cols = (len(plates) + 3) // 4, 4
|
1718
1802
|
fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
|
@@ -1843,6 +1927,27 @@ def print_mask_and_flows(stack, mask, flows, overlay=True, max_size=1000, thickn
|
|
1843
1927
|
plt.show()
|
1844
1928
|
|
1845
1929
|
def plot_resize(images, resized_images, labels, resized_labels):
|
1930
|
+
"""
|
1931
|
+
Plot original and resized images along with their corresponding labels for visual comparison.
|
1932
|
+
|
1933
|
+
Args:
|
1934
|
+
images (list or np.ndarray): List or array of original input images. Each image can be 2D (grayscale)
|
1935
|
+
or 3D (with 1, 3, or 4 channels).
|
1936
|
+
resized_images (list or np.ndarray): Corresponding resized input images.
|
1937
|
+
labels (list or np.ndarray): List or array of original label masks.
|
1938
|
+
resized_labels (list or np.ndarray): Corresponding resized label masks.
|
1939
|
+
|
1940
|
+
Returns:
|
1941
|
+
None. Displays a 2x2 grid of plots:
|
1942
|
+
- Top-left: Original image
|
1943
|
+
- Top-right: Resized image
|
1944
|
+
- Bottom-left: Original label
|
1945
|
+
- Bottom-right: Resized label
|
1946
|
+
|
1947
|
+
Notes:
|
1948
|
+
- Images with 1, 3, or 4 channels are supported.
|
1949
|
+
- Other channel numbers will be visualized by averaging across channels (grayscale fallback).
|
1950
|
+
"""
|
1846
1951
|
def prepare_image(img):
|
1847
1952
|
if img.ndim == 2:
|
1848
1953
|
return img, 'gray'
|
@@ -1884,7 +1989,23 @@ def plot_resize(images, resized_images, labels, resized_labels):
|
|
1884
1989
|
plt.show()
|
1885
1990
|
|
1886
1991
|
def normalize_and_visualize(image, normalized_image, title=""):
|
1887
|
-
"""
|
1992
|
+
"""
|
1993
|
+
Display a side-by-side comparison of an original image and its normalized version.
|
1994
|
+
|
1995
|
+
Args:
|
1996
|
+
image (np.ndarray): The original image. Can be 2D (grayscale) or 3D (multi-channel).
|
1997
|
+
normalized_image (np.ndarray): The normalized image. Should have the same dimensions as `image`.
|
1998
|
+
title (str, optional): Optional string to append to the plot titles. Defaults to "".
|
1999
|
+
|
2000
|
+
Returns:
|
2001
|
+
None. Displays a matplotlib figure with two subplots:
|
2002
|
+
- Left: Original image
|
2003
|
+
- Right: Normalized image
|
2004
|
+
|
2005
|
+
Notes:
|
2006
|
+
- For multi-channel images, the mean across channels is visualized.
|
2007
|
+
- Axes are hidden for clarity.
|
2008
|
+
"""
|
1888
2009
|
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
|
1889
2010
|
if image.ndim == 3: # Multi-channel image
|
1890
2011
|
ax[0].imshow(np.mean(image, axis=-1), cmap='gray') # Display the average over channels for visualization
|
@@ -1903,6 +2024,25 @@ def normalize_and_visualize(image, normalized_image, title=""):
|
|
1903
2024
|
plt.show()
|
1904
2025
|
|
1905
2026
|
def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
|
2027
|
+
"""
|
2028
|
+
Display three segmentation masks side-by-side for visual comparison.
|
2029
|
+
|
2030
|
+
Args:
|
2031
|
+
mask1 (np.ndarray): First mask array (2D).
|
2032
|
+
mask2 (np.ndarray): Second mask array (2D).
|
2033
|
+
mask3 (np.ndarray): Third mask array (2D).
|
2034
|
+
title (str, optional): Title for the entire figure. Defaults to "Masks Comparison".
|
2035
|
+
|
2036
|
+
Returns:
|
2037
|
+
None. Displays a matplotlib figure with three subplots:
|
2038
|
+
- Each subplot corresponds to one of the input masks.
|
2039
|
+
- Masks are visualized using a randomly generated colormap.
|
2040
|
+
- If a mask is binary (0 and 1 only), no normalization is applied.
|
2041
|
+
- Otherwise, intensity values are normalized to [0, max(mask)].
|
2042
|
+
|
2043
|
+
Notes:
|
2044
|
+
Requires the function `generate_mask_random_cmap(mask)` to generate a colormap.
|
2045
|
+
"""
|
1906
2046
|
fig, axs = plt.subplots(1, 3, figsize=(30, 10))
|
1907
2047
|
for ax, mask, title in zip(axs, [mask1, mask2, mask3], ['Mask 1', 'Mask 2', 'Mask 3']):
|
1908
2048
|
cmap = generate_mask_random_cmap(mask)
|
@@ -1922,7 +2062,7 @@ def visualize_cellpose_masks(masks, titles=None, filename=None, save=False, src=
|
|
1922
2062
|
"""
|
1923
2063
|
Visualize multiple masks with optional titles.
|
1924
2064
|
|
1925
|
-
|
2065
|
+
Args:
|
1926
2066
|
masks (list of np.ndarray): A list of masks to visualize.
|
1927
2067
|
titles (list of str, optional): A list of titles for the masks. If None, default titles will be used.
|
1928
2068
|
comparison_title (str): Title for the entire figure.
|
@@ -1962,6 +2102,26 @@ def visualize_cellpose_masks(masks, titles=None, filename=None, save=False, src=
|
|
1962
2102
|
|
1963
2103
|
|
1964
2104
|
def plot_comparison_results(comparison_results):
|
2105
|
+
"""
|
2106
|
+
Visualize segmentation comparison metrics using boxplots with overlaid strip plots.
|
2107
|
+
|
2108
|
+
Args:
|
2109
|
+
comparison_results (list of dict): Each dictionary represents one sample, with keys:
|
2110
|
+
- 'filename': Name of the sample.
|
2111
|
+
- Metric keys such as 'jaccard_*', 'dice_*', 'boundary_f1_*', 'average_precision_*'.
|
2112
|
+
|
2113
|
+
Returns:
|
2114
|
+
matplotlib.figure.Figure: The resulting figure with 4 subplots:
|
2115
|
+
- Jaccard Index
|
2116
|
+
- Dice Coefficient
|
2117
|
+
- Boundary F1 Score
|
2118
|
+
- Average Precision
|
2119
|
+
|
2120
|
+
Notes:
|
2121
|
+
- Metrics are grouped by type using substring matching in column names.
|
2122
|
+
- Outliers and individual sample values are shown with strip plots.
|
2123
|
+
- Assumes that all metric columns contain numeric values.
|
2124
|
+
"""
|
1965
2125
|
df = pd.DataFrame(comparison_results)
|
1966
2126
|
df_melted = pd.melt(df, id_vars=['filename'], var_name='metric', value_name='value')
|
1967
2127
|
df_jaccard = df_melted[df_melted['metric'].str.contains('jaccard')]
|
@@ -2004,7 +2164,25 @@ def plot_comparison_results(comparison_results):
|
|
2004
2164
|
return fig
|
2005
2165
|
|
2006
2166
|
def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0,1,2], max_nr=10):
|
2007
|
-
|
2167
|
+
"""
|
2168
|
+
Overlay mask outlines on image channels for visual inspection of segmentation quality.
|
2169
|
+
|
2170
|
+
Args:
|
2171
|
+
src (str): Source directory containing the 'masks' folder and image channels.
|
2172
|
+
objects (list of str): List of object names (e.g., 'nucleus', 'cell', 'pathogen') whose masks to visualize.
|
2173
|
+
channels (list of int): Corresponding channel indices (0-based) for each object.
|
2174
|
+
max_nr (int): Maximum number of overlays to display.
|
2175
|
+
|
2176
|
+
Returns:
|
2177
|
+
None. Displays overlays using `plot_images_and_arrays()`.
|
2178
|
+
|
2179
|
+
Notes:
|
2180
|
+
- For each object/channel pair, attempts to load:
|
2181
|
+
- The mask from 'masks/{object}_mask_stack'
|
2182
|
+
- The image from '{channel+1}' subfolder (1-based indexing)
|
2183
|
+
- Assumes `plot_images_and_arrays()` can handle paired overlay inputs.
|
2184
|
+
- Random selection and percentile contrast stretching applied for visualization.
|
2185
|
+
"""
|
2008
2186
|
for object_, channel in zip(objects, channels):
|
2009
2187
|
folders = [os.path.join(src, 'masks', f'{object_}_mask_stack'),
|
2010
2188
|
os.path.join(src,f'{channel+1}')]
|
@@ -2019,6 +2197,24 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
|
|
2019
2197
|
randomize=True)
|
2020
2198
|
|
2021
2199
|
def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
2200
|
+
"""
|
2201
|
+
Generate and save a volcano plot based on coefficient and p-value data.
|
2202
|
+
|
2203
|
+
Args:
|
2204
|
+
coef_df (pd.DataFrame): DataFrame containing columns:
|
2205
|
+
- 'coefficient': effect size for each term.
|
2206
|
+
- 'p_value': p-value for each coefficient.
|
2207
|
+
- 'condition': category for coloring (e.g., 'pc', 'nc', 'control', 'other').
|
2208
|
+
filename (str): File path for saving the plot as a PDF.
|
2209
|
+
|
2210
|
+
Returns:
|
2211
|
+
None. Displays the volcano plot and saves it to the specified file.
|
2212
|
+
|
2213
|
+
Notes:
|
2214
|
+
- Highlights p-value threshold at 0.05 with a horizontal dashed red line.
|
2215
|
+
- Uses pre-defined color palette for conditions.
|
2216
|
+
- Legend is removed for clarity.
|
2217
|
+
"""
|
2022
2218
|
palette = {
|
2023
2219
|
'pc': 'red',
|
2024
2220
|
'nc': 'green',
|
@@ -2046,6 +2242,21 @@ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
|
2046
2242
|
plt.show()
|
2047
2243
|
|
2048
2244
|
def plot_histogram(df, column, dst=None):
|
2245
|
+
"""
|
2246
|
+
Plot and optionally save a histogram for a specified column in a DataFrame.
|
2247
|
+
|
2248
|
+
Args:
|
2249
|
+
df (pd.DataFrame): DataFrame containing the data.
|
2250
|
+
column (str): Name of the column to plot.
|
2251
|
+
dst (str or None): Directory to save the figure as a PDF. If None, the figure is not saved.
|
2252
|
+
|
2253
|
+
Returns:
|
2254
|
+
None. Displays the histogram and optionally saves it to disk.
|
2255
|
+
|
2256
|
+
Notes:
|
2257
|
+
- The histogram uses a fixed turquoise bar color (RGB: 0,155,155).
|
2258
|
+
- Saved file is named '{column}_histogram.pdf' and placed in the specified `dst` directory.
|
2259
|
+
"""
|
2049
2260
|
# Plot histogram of the dependent variable
|
2050
2261
|
bar_color = (0/255, 155/255, 155/255)
|
2051
2262
|
plt.figure(figsize=(10, 10))
|
@@ -2061,10 +2272,28 @@ def plot_histogram(df, column, dst=None):
|
|
2061
2272
|
|
2062
2273
|
plt.show()
|
2063
2274
|
|
2064
|
-
def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count',
|
2065
|
-
|
2066
|
-
|
2067
|
-
|
2275
|
+
def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count', remove_keys=None, x_lim=[0.0, 1], y_lim=[0, 1], remove_outliers=False, save=True):
|
2276
|
+
"""
|
2277
|
+
Plot Lorenz curves and compute Gini coefficients for a set of CSV files.
|
2278
|
+
|
2279
|
+
Args:
|
2280
|
+
csv_files (list of str): Paths to CSV files containing count data.
|
2281
|
+
name_column (str): Column name containing the gRNA or entity names.
|
2282
|
+
value_column (str): Column name with the numerical values (e.g., counts).
|
2283
|
+
remove_keys (list or None): List of names to exclude from analysis.
|
2284
|
+
x_lim (list): X-axis limits for the plot.
|
2285
|
+
y_lim (list): Y-axis limits for the plot.
|
2286
|
+
remove_outliers (bool): Whether to remove outlier entities by well counts.
|
2287
|
+
save (bool): If True, saves the plot to a PDF.
|
2288
|
+
|
2289
|
+
Returns:
|
2290
|
+
None. Displays the plot and prints Gini coefficients.
|
2291
|
+
|
2292
|
+
Notes:
|
2293
|
+
- Gini coefficient is a measure of inequality (0 = perfect equality, 1 = maximal inequality).
|
2294
|
+
- Outlier removal is based on IQR of well counts per entity name.
|
2295
|
+
- Saves to 'lorenz_curve_with_gini.pdf' under a 'results' folder next to the first input CSV.
|
2296
|
+
"""
|
2068
2297
|
def lorenz_curve(data):
|
2069
2298
|
"""Calculate Lorenz curve."""
|
2070
2299
|
sorted_data = np.sort(data)
|
@@ -2153,6 +2382,23 @@ def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count',
|
|
2153
2382
|
print(f"{plate}: Gini Coefficient = {gini:.4f}")
|
2154
2383
|
|
2155
2384
|
def plot_permutation(permutation_df):
|
2385
|
+
"""
|
2386
|
+
Plot permutation feature importance as a horizontal bar chart.
|
2387
|
+
|
2388
|
+
Args:
|
2389
|
+
permutation_df (pd.DataFrame): DataFrame with columns:
|
2390
|
+
- 'feature': Feature names.
|
2391
|
+
- 'importance_mean': Mean permutation importance.
|
2392
|
+
- 'importance_std': Standard deviation of permutation importance.
|
2393
|
+
|
2394
|
+
Returns:
|
2395
|
+
matplotlib.figure.Figure: The resulting plot figure object.
|
2396
|
+
|
2397
|
+
Notes:
|
2398
|
+
- Dynamically adjusts figure size and font size based on number of features.
|
2399
|
+
- Error bars represent the standard deviation across permutation runs.
|
2400
|
+
"""
|
2401
|
+
|
2156
2402
|
num_features = len(permutation_df)
|
2157
2403
|
fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
|
2158
2404
|
fig_width = 10 # Width can be fixed or adjusted similarly
|
@@ -2166,6 +2412,21 @@ def plot_permutation(permutation_df):
|
|
2166
2412
|
return fig
|
2167
2413
|
|
2168
2414
|
def plot_feature_importance(feature_importance_df):
|
2415
|
+
"""
|
2416
|
+
Plot feature importance as a horizontal bar chart.
|
2417
|
+
|
2418
|
+
Args:
|
2419
|
+
feature_importance_df (pd.DataFrame): DataFrame with columns:
|
2420
|
+
- 'feature': Feature names.
|
2421
|
+
- 'importance': Importance scores for each feature.
|
2422
|
+
|
2423
|
+
Returns:
|
2424
|
+
matplotlib.figure.Figure: The resulting plot figure object.
|
2425
|
+
|
2426
|
+
Notes:
|
2427
|
+
- Dynamically adjusts figure size and font size based on number of features.
|
2428
|
+
- Use for visualizing static (e.g., model-based) feature importance.
|
2429
|
+
"""
|
2169
2430
|
num_features = len(feature_importance_df)
|
2170
2431
|
fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
|
2171
2432
|
fig_width = 10 # Width can be fixed or adjusted similarly
|
@@ -2179,6 +2440,23 @@ def plot_feature_importance(feature_importance_df):
|
|
2179
2440
|
return fig
|
2180
2441
|
|
2181
2442
|
def read_and_plot__vision_results(base_dir, y_axis='accuracy', name_split='_time', y_lim=[0.8, 0.9]):
|
2443
|
+
"""
|
2444
|
+
Read multiple vision model test result CSV files and generate a bar plot of average performance.
|
2445
|
+
|
2446
|
+
Args:
|
2447
|
+
base_dir (str): Base directory to recursively search for files ending in '_test_result.csv'.
|
2448
|
+
y_axis (str): The performance metric to plot (must be a column in each CSV). Default is 'accuracy'.
|
2449
|
+
name_split (str): String delimiter used to extract the model name from the filename. Default is '_time'.
|
2450
|
+
y_lim (list or tuple): Optional y-axis limits for the plot. Default is [0.8, 0.9].
|
2451
|
+
|
2452
|
+
Returns:
|
2453
|
+
None. Displays a matplotlib bar plot of average `y_axis` for each model.
|
2454
|
+
|
2455
|
+
Notes:
|
2456
|
+
- Assumes each file is named with format: '<model>_time<timestamp>_test_result.csv'.
|
2457
|
+
- Extracts 'model' and 'epoch' metadata from the path and filename.
|
2458
|
+
- Saves results to a 'result' folder inside `base_dir` (created if not present).
|
2459
|
+
"""
|
2182
2460
|
# List to store data from all CSV files
|
2183
2461
|
data_frames = []
|
2184
2462
|
|
@@ -2321,7 +2599,7 @@ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summ
|
|
2321
2599
|
"""
|
2322
2600
|
Create a grouped plot, perform statistical tests, and optionally export the results along with the plot.
|
2323
2601
|
|
2324
|
-
|
2602
|
+
Args:
|
2325
2603
|
- df: DataFrame containing the data.
|
2326
2604
|
- grouping_column: Column name for the categorical grouping.
|
2327
2605
|
- data_column: Column name for the data to be grouped and plotted.
|
@@ -2479,15 +2757,66 @@ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summ
|
|
2479
2757
|
return plt.gcf(), results_df
|
2480
2758
|
|
2481
2759
|
class spacrGraph:
|
2760
|
+
"""
|
2761
|
+
Class for generating grouped plots with optional data preprocessing,
|
2762
|
+
statistical comparisons, theming, and output control.
|
2763
|
+
|
2764
|
+
This class is designed to support common grouped plotting tasks for CRISPR screens,
|
2765
|
+
enabling flexible summarization across different experimental representations
|
2766
|
+
(e.g., object-level, well-level, or plate-level).
|
2767
|
+
|
2768
|
+
Args:
|
2769
|
+
df (pd.DataFrame): Input DataFrame containing data to plot.
|
2770
|
+
grouping_column (str): Column to group by on the x-axis.
|
2771
|
+
data_column (str or list): One or more numeric columns to summarize and plot.
|
2772
|
+
graph_type (str): Type of plot to create (e.g., 'bar', 'box').
|
2773
|
+
summary_func (str or callable): Function used to summarize data (e.g., 'mean', 'median').
|
2774
|
+
order (list): Optional list to define the order of categories on the x-axis.
|
2775
|
+
colors (list): Optional custom list of colors for plotting.
|
2776
|
+
output_dir (str): Path to directory where output will be saved.
|
2777
|
+
save (bool): Whether to save the plot as a PDF.
|
2778
|
+
y_lim (tuple): Tuple defining y-axis limits.
|
2779
|
+
log_y (bool): If True, applies log scaling to y-axis.
|
2780
|
+
log_x (bool): If True, applies log scaling to x-axis.
|
2781
|
+
error_bar_type (str): Error bar type to display ('std' or 'sem').
|
2782
|
+
remove_outliers (bool): Whether to exclude statistical outliers (not implemented here).
|
2783
|
+
theme (str): Seaborn theme to use for the plot.
|
2784
|
+
representation (str): Level of summarization; one of {'object', 'well', 'plate'}.
|
2785
|
+
paired (bool): If True, assumes samples are paired for statistical tests.
|
2786
|
+
all_to_all (bool): If True, performs all-to-all comparisons.
|
2787
|
+
compare_group (str): Optional group to compare all others against.
|
2788
|
+
graph_name (str): Optional name used in output file naming.
|
2789
|
+
"""
|
2482
2790
|
def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
|
2483
2791
|
order=None, colors=None, output_dir='./output', save=False, y_lim=None, log_y=False,
|
2484
2792
|
log_x=False, error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
|
2485
2793
|
paired=False, all_to_all=True, compare_group=None, graph_name=None):
|
2486
|
-
|
2487
|
-
"""
|
2488
|
-
Class for creating grouped plots with optional statistical tests and data preprocessing.
|
2489
2794
|
"""
|
2795
|
+
Initialize a spacrGraph instance for grouped data visualization.
|
2490
2796
|
|
2797
|
+
Args:
|
2798
|
+
df (pd.DataFrame): Input dataframe containing data to be plotted.
|
2799
|
+
grouping_column (str): Column name to group data along the x-axis.
|
2800
|
+
data_column (str or list): Column(s) containing values to summarize and plot.
|
2801
|
+
graph_type (str): Type of plot to generate ('bar', 'box', etc.). Default is 'bar'.
|
2802
|
+
summary_func (str or callable): Summary statistic to apply per group (e.g., 'mean', 'median').
|
2803
|
+
order (list or None): Optional order of groups on the x-axis. If None, will be inferred from the data.
|
2804
|
+
colors (list or None): List of colors to use for plotting. Defaults to seaborn palette.
|
2805
|
+
output_dir (str): Directory where plot will be saved if `save=True`.
|
2806
|
+
save (bool): Whether to save the plot as a PDF.
|
2807
|
+
y_lim (tuple or None): Y-axis limits. If None, determined automatically.
|
2808
|
+
log_y (bool): If True, use logarithmic scale on the y-axis.
|
2809
|
+
log_x (bool): If True, use logarithmic scale on the x-axis.
|
2810
|
+
error_bar_type (str): Type of error bars to include ('std' or 'sem').
|
2811
|
+
remove_outliers (bool): Whether to remove outliers from the data before plotting.
|
2812
|
+
theme (str): Seaborn color theme for plotting.
|
2813
|
+
representation (str): Level of summarization: 'object', 'well', or 'plate'.
|
2814
|
+
paired (bool): Whether samples are paired for statistical testing.
|
2815
|
+
all_to_all (bool): If True, perform all pairwise comparisons.
|
2816
|
+
compare_group (str or None): If provided, compares all groups against this one.
|
2817
|
+
graph_name (str or None): Optional string to include in saved file names.
|
2818
|
+
"""
|
2819
|
+
|
2491
2820
|
self.df = df
|
2492
2821
|
self.grouping_column = grouping_column
|
2493
2822
|
self.order = sorted(df[self.grouping_column].unique().tolist())
|
@@ -2522,13 +2851,29 @@ class spacrGraph:
|
|
2522
2851
|
self.df = self.preprocess_data()
|
2523
2852
|
|
2524
2853
|
def _set_theme(self):
|
2525
|
-
"""
|
2854
|
+
"""
|
2855
|
+
Set the Seaborn color theme and define the plotting color order.
|
2856
|
+
|
2857
|
+
This method initializes the `sns_palette` using a reordered set of colors
|
2858
|
+
from the selected theme.
|
2859
|
+
"""
|
2526
2860
|
integer_list = list(range(1, 81))
|
2527
2861
|
color_order = [7,9,4,0,3,6,2] + integer_list
|
2528
2862
|
self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
|
2529
2863
|
|
2530
2864
|
def _set_reordered_theme(self, theme='deep', order=None, n_colors=100, show_theme=False):
|
2531
|
-
"""
|
2865
|
+
"""
|
2866
|
+
Generate a Seaborn color palette with an optional reordering.
|
2867
|
+
|
2868
|
+
Parameters:
|
2869
|
+
theme (str): Seaborn palette name (e.g., 'deep', 'pastel').
|
2870
|
+
order (list): List of integer indices defining color reorder.
|
2871
|
+
n_colors (int): Number of colors to include in the palette.
|
2872
|
+
show_theme (bool): If True, displays the palette visually.
|
2873
|
+
|
2874
|
+
Returns:
|
2875
|
+
list: Reordered color palette.
|
2876
|
+
"""
|
2532
2877
|
palette = sns.color_palette(theme, n_colors)
|
2533
2878
|
if order:
|
2534
2879
|
reordered_palette = [palette[i] for i in order]
|
@@ -2539,25 +2884,21 @@ class spacrGraph:
|
|
2539
2884
|
plt.show()
|
2540
2885
|
return reordered_palette
|
2541
2886
|
|
2542
|
-
#def preprocess_data(self):
|
2543
|
-
# """Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
|
2544
|
-
# # Remove NaNs in both the grouping column and each data column
|
2545
|
-
# df = self.df.dropna(subset=[self.grouping_column] + self.data_column)
|
2546
|
-
# # Group by 'prc' column if representation is 'well'
|
2547
|
-
# if self.representation == 'well':
|
2548
|
-
# df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
|
2549
|
-
# if self.representation == 'plateID':
|
2550
|
-
# df = df.groupby(['plateID', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
|
2551
|
-
# if self.order:
|
2552
|
-
# df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
|
2553
|
-
# else:
|
2554
|
-
# df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=sorted(df[self.grouping_column].unique()), ordered=True)
|
2555
|
-
# return df
|
2556
|
-
|
2557
2887
|
def preprocess_data(self):
|
2558
2888
|
"""
|
2559
|
-
Preprocess the
|
2560
|
-
|
2889
|
+
Preprocess the input DataFrame by removing NaNs and optionally aggregating data.
|
2890
|
+
|
2891
|
+
Aggregation is controlled by the `representation` parameter:
|
2892
|
+
- 'object': No aggregation; raw data is retained.
|
2893
|
+
- 'well': Data is grouped by ['prc', grouping_column].
|
2894
|
+
- 'plate': Data is grouped by ['plateID', grouping_column],
|
2895
|
+
extracting plateID from the 'prc' column if needed.
|
2896
|
+
|
2897
|
+
Ordering of the x-axis categories is set using the `order` parameter
|
2898
|
+
or automatically sorted if not provided.
|
2899
|
+
|
2900
|
+
Returns:
|
2901
|
+
pd.DataFrame: Cleaned and optionally aggregated DataFrame.
|
2561
2902
|
"""
|
2562
2903
|
# 1) Remove NaNs in both the grouping column and each data column
|
2563
2904
|
df = self.df.dropna(subset=[self.grouping_column] + self.data_column)
|
@@ -2613,7 +2954,13 @@ class spacrGraph:
|
|
2613
2954
|
return df
|
2614
2955
|
|
2615
2956
|
def remove_outliers_from_plot(self):
|
2616
|
-
"""
|
2957
|
+
"""
|
2958
|
+
Remove outliers from each group in the dataset for plotting purposes.
|
2959
|
+
This method applies the IQR method to filter extreme values per group.
|
2960
|
+
|
2961
|
+
Returns:
|
2962
|
+
pd.DataFrame: Filtered dataframe with outliers removed for plotting.
|
2963
|
+
"""
|
2617
2964
|
filtered_df = self.df.copy()
|
2618
2965
|
unique_groups = filtered_df[self.grouping_column].unique()
|
2619
2966
|
for group in unique_groups:
|
@@ -2627,7 +2974,18 @@ class spacrGraph:
|
|
2627
2974
|
return filtered_df
|
2628
2975
|
|
2629
2976
|
def perform_normality_tests(self):
|
2630
|
-
"""
|
2977
|
+
"""
|
2978
|
+
Perform normality tests on each group and data column in the dataframe.
|
2979
|
+
|
2980
|
+
Uses:
|
2981
|
+
- Shapiro-Wilk test for sample sizes < 8
|
2982
|
+
- D'Agostino-Pearson test for sample sizes ≥ 8
|
2983
|
+
|
2984
|
+
Returns:
|
2985
|
+
tuple:
|
2986
|
+
- is_normal (bool): True if all groups pass normality test (p > 0.05)
|
2987
|
+
- normality_results (list): List of dictionaries with test results
|
2988
|
+
"""
|
2631
2989
|
unique_groups = self.df[self.grouping_column].unique()
|
2632
2990
|
normality_results = []
|
2633
2991
|
|
@@ -2674,13 +3032,40 @@ class spacrGraph:
|
|
2674
3032
|
return is_normal, normality_results
|
2675
3033
|
|
2676
3034
|
def perform_levene_test(self, unique_groups):
|
2677
|
-
"""
|
3035
|
+
"""
|
3036
|
+
Perform Levene’s test for homogeneity of variance across groups.
|
3037
|
+
|
3038
|
+
Args:
|
3039
|
+
unique_groups (list): List of group identifiers to compare.
|
3040
|
+
|
3041
|
+
Returns:
|
3042
|
+
tuple:
|
3043
|
+
- stat (float): Levene test statistic
|
3044
|
+
- p_value (float): p-value for the test
|
3045
|
+
"""
|
2678
3046
|
grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
|
2679
3047
|
stat, p_value = levene(*grouped_data)
|
2680
3048
|
return stat, p_value
|
2681
3049
|
|
2682
3050
|
def perform_statistical_tests(self, unique_groups, is_normal):
|
2683
|
-
"""
|
3051
|
+
"""
|
3052
|
+
Perform appropriate statistical tests based on normality and group count.
|
3053
|
+
|
3054
|
+
If 2 groups:
|
3055
|
+
- Paired or unpaired t-test (if normal)
|
3056
|
+
- Wilcoxon or Mann-Whitney U test (if non-normal)
|
3057
|
+
|
3058
|
+
If >2 groups:
|
3059
|
+
- ANOVA (if normal)
|
3060
|
+
- Kruskal-Wallis test (if non-normal)
|
3061
|
+
|
3062
|
+
Args:
|
3063
|
+
unique_groups (list): List of unique group names.
|
3064
|
+
is_normal (bool): Whether the data passes normality tests.
|
3065
|
+
|
3066
|
+
Returns:
|
3067
|
+
list: List of dictionaries with test results for each data column.
|
3068
|
+
"""
|
2684
3069
|
test_results = []
|
2685
3070
|
for column in self.data_column: # Iterate over each data column
|
2686
3071
|
grouped_data = [self.df.loc[self.df[self.grouping_column] == group, column] for group in unique_groups]
|
@@ -2720,8 +3105,22 @@ class spacrGraph:
|
|
2720
3105
|
return test_results
|
2721
3106
|
|
2722
3107
|
def perform_posthoc_tests(self, is_normal, unique_groups):
|
2723
|
-
"""
|
3108
|
+
"""
|
3109
|
+
Perform post-hoc tests for multiple groups, depending on normality and group count.
|
3110
|
+
|
3111
|
+
If data are normally distributed and `all_to_all` is True:
|
3112
|
+
- Tukey HSD test is performed for pairwise comparisons.
|
3113
|
+
|
3114
|
+
If not normal and `all_to_all` is True:
|
3115
|
+
- Dunn's test is performed using appropriate p-value adjustment.
|
2724
3116
|
|
3117
|
+
Args:
|
3118
|
+
is_normal (bool): Whether the data passed normality tests.
|
3119
|
+
unique_groups (list): List of group identifiers.
|
3120
|
+
|
3121
|
+
Returns:
|
3122
|
+
list: List of dictionaries with post-hoc test results per pairwise comparison.
|
3123
|
+
"""
|
2725
3124
|
from .sp_stats import choose_p_adjust_method
|
2726
3125
|
|
2727
3126
|
posthoc_results = []
|
@@ -2776,10 +3175,32 @@ class spacrGraph:
|
|
2776
3175
|
return posthoc_results
|
2777
3176
|
|
2778
3177
|
def create_plot(self, ax=None):
|
2779
|
-
"""
|
3178
|
+
"""
|
3179
|
+
Create the plot based on the selected `graph_type`.
|
2780
3180
|
|
3181
|
+
If `graph_type` supports it and multiple `data_column` values are used,
|
3182
|
+
symbols can be placed below the plot to indicate groupings and correspondence.
|
3183
|
+
|
3184
|
+
Args:
|
3185
|
+
ax (matplotlib.axes.Axes, optional): An existing axis to draw the plot on.
|
3186
|
+
If None, a new figure and axis are created.
|
3187
|
+
|
3188
|
+
Returns:
|
3189
|
+
matplotlib.figure.Figure: The resulting figure.
|
3190
|
+
"""
|
2781
3191
|
def _generate_tabels(unique_groups):
|
2782
|
-
"""
|
3192
|
+
"""
|
3193
|
+
Generate symbolic representation of the grouping and data columns
|
3194
|
+
for annotation under plots (used internally by create_plot).
|
3195
|
+
|
3196
|
+
Args:
|
3197
|
+
unique_groups (list): List of group identifiers.
|
3198
|
+
|
3199
|
+
Returns:
|
3200
|
+
tuple:
|
3201
|
+
- row_labels (list): List of label strings for each row.
|
3202
|
+
- transposed_table (list of lists): Symbolic table aligned with x-axis.
|
3203
|
+
"""
|
2783
3204
|
# Create row labels: Include the grouping column and data columns
|
2784
3205
|
row_labels = [self.grouping_column] + self.data_column
|
2785
3206
|
|
@@ -2812,13 +3233,14 @@ class spacrGraph:
|
|
2812
3233
|
|
2813
3234
|
def _place_symbols(row_labels, transposed_table, x_positions, ax):
|
2814
3235
|
"""
|
2815
|
-
|
2816
|
-
|
2817
|
-
|
2818
|
-
|
2819
|
-
|
2820
|
-
|
2821
|
-
|
3236
|
+
Place text-based symbols under plot elements to annotate categories
|
3237
|
+
and features in multi-group, multi-feature plots.
|
3238
|
+
|
3239
|
+
Args:
|
3240
|
+
row_labels (list): Row label names.
|
3241
|
+
transposed_table (list of lists): Symbolic table to annotate.
|
3242
|
+
x_positions (list of float): X-coordinates for symbol placement.
|
3243
|
+
ax (matplotlib.axes.Axes): Axis object where symbols are placed.
|
2822
3244
|
"""
|
2823
3245
|
# Get plot dimensions and adjust for different plot sizes
|
2824
3246
|
y_axis_min = ax.get_ylim()[0] # Minimum y-axis value (usually 0)
|
@@ -2845,6 +3267,17 @@ class spacrGraph:
|
|
2845
3267
|
ax.figure.canvas.draw()
|
2846
3268
|
|
2847
3269
|
def _get_positions(self, ax):
|
3270
|
+
"""
|
3271
|
+
Get x-axis positions of plotted elements depending on graph type.
|
3272
|
+
|
3273
|
+
Used to align symbols and annotations below plot elements.
|
3274
|
+
|
3275
|
+
Args:
|
3276
|
+
ax (matplotlib.axes.Axes): Axis object from which positions are extracted.
|
3277
|
+
|
3278
|
+
Returns:
|
3279
|
+
list of float: X-axis coordinates for each group/data element.
|
3280
|
+
"""
|
2848
3281
|
if self.graph_type in ['bar','jitter_bar']:
|
2849
3282
|
x_positions = [np.mean(bar.get_paths()[0].vertices[:, 0]) for bar in ax.collections if hasattr(bar, 'get_paths')]
|
2850
3283
|
|
@@ -2863,7 +3296,24 @@ class spacrGraph:
|
|
2863
3296
|
return x_positions
|
2864
3297
|
|
2865
3298
|
def _draw_comparison_lines(ax, x_positions):
|
2866
|
-
"""
|
3299
|
+
"""
|
3300
|
+
Draw horizontal lines between group pairs to indicate statistical comparisons and significance.
|
3301
|
+
|
3302
|
+
This function annotates the plot using entries from `self.results_df`, which must contain
|
3303
|
+
a 'Comparison' column with group pair strings formatted as 'group1 vs group2' and a 'p-value' column.
|
3304
|
+
|
3305
|
+
Args:
|
3306
|
+
ax (matplotlib.axes.Axes): The axis on which to draw the lines and annotations.
|
3307
|
+
x_positions (list of float): The x-axis positions of each group, in order matching `unique_groups`.
|
3308
|
+
|
3309
|
+
Notes:
|
3310
|
+
- Significance markers:
|
3311
|
+
'***' for p ≤ 0.001
|
3312
|
+
'**' for p ≤ 0.01
|
3313
|
+
'*' for p ≤ 0.05
|
3314
|
+
'ns' for p > 0.05
|
3315
|
+
- Lines are staggered vertically to avoid overlap.
|
3316
|
+
"""
|
2867
3317
|
if self.results_df.empty:
|
2868
3318
|
print("No comparisons available to annotate.")
|
2869
3319
|
return
|
@@ -2996,15 +3446,18 @@ class spacrGraph:
|
|
2996
3446
|
|
2997
3447
|
def _standerdize_figure_format(self, ax, num_groups, graph_type):
|
2998
3448
|
"""
|
2999
|
-
|
3000
|
-
|
3001
|
-
Parameters:
|
3002
|
-
- ax: The matplotlib Axes object.
|
3003
|
-
- num_groups: Number of unique groups.
|
3004
|
-
- graph_type: The type of graph (e.g., 'bar', 'jitter', 'box', etc.).
|
3449
|
+
Standardize and adjust the visual format of the plot based on the number of groups and plot type.
|
3005
3450
|
|
3006
|
-
|
3007
|
-
|
3451
|
+
Args:
|
3452
|
+
ax (matplotlib.axes.Axes): The axis object of the plot.
|
3453
|
+
num_groups (int): Number of groups/categories on the x-axis.
|
3454
|
+
graph_type (str): Type of plot, e.g., 'bar', 'jitter', 'box', 'line', etc.
|
3455
|
+
|
3456
|
+
Notes:
|
3457
|
+
- Dynamically sets figure size based on group count.
|
3458
|
+
- Adjusts bar widths, jitter offsets, and legend font sizes.
|
3459
|
+
- Skips formatting for line plots.
|
3460
|
+
- Intended to maintain a visually consistent appearance across graphs.
|
3008
3461
|
"""
|
3009
3462
|
if graph_type in ['line', 'line_std']:
|
3010
3463
|
print("Skipping layout adjustment for line graphs.")
|
@@ -3068,7 +3521,19 @@ class spacrGraph:
|
|
3068
3521
|
ax.figure.canvas.draw()
|
3069
3522
|
|
3070
3523
|
def _create_bar_plot(self, ax):
|
3071
|
-
"""
|
3524
|
+
"""
|
3525
|
+
Create a bar plot using seaborn, with error bars and standardized layout.
|
3526
|
+
|
3527
|
+
Args:
|
3528
|
+
ax (matplotlib.axes.Axes): Axis object to draw the bar plot.
|
3529
|
+
|
3530
|
+
Notes:
|
3531
|
+
- Supports multiple or single data columns.
|
3532
|
+
- Uses melted DataFrame (self.df_melted).
|
3533
|
+
- Adds error bars based on `self.error_bar_type` (std or sem).
|
3534
|
+
- Bar width and position are adjusted manually.
|
3535
|
+
- Log scaling is applied if enabled.
|
3536
|
+
"""
|
3072
3537
|
# Flatten DataFrame: Combine grouping column and data column into one group if needed
|
3073
3538
|
if len(self.data_column) > 1:
|
3074
3539
|
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
@@ -3109,7 +3574,19 @@ class spacrGraph:
|
|
3109
3574
|
ax.set_xscale('log')
|
3110
3575
|
|
3111
3576
|
def _create_jitter_plot(self, ax):
|
3112
|
-
"""
|
3577
|
+
"""
|
3578
|
+
Create a jitter (strip) plot using seaborn with consistent spacing.
|
3579
|
+
|
3580
|
+
Args:
|
3581
|
+
ax (matplotlib.axes.Axes): Axis object to draw the jitter plot.
|
3582
|
+
|
3583
|
+
Notes:
|
3584
|
+
- Supports multiple or single data columns.
|
3585
|
+
- Uses melted DataFrame (self.df_melted).
|
3586
|
+
- Point size and jitter are controlled via class attributes.
|
3587
|
+
- Legend is deduplicated and managed manually.
|
3588
|
+
- Log scaling is applied if enabled.
|
3589
|
+
"""
|
3113
3590
|
# Combine grouping column and data column if needed
|
3114
3591
|
if len(self.data_column) > 1:
|
3115
3592
|
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
@@ -3138,7 +3615,18 @@ class spacrGraph:
|
|
3138
3615
|
ax.set_xscale('log')
|
3139
3616
|
|
3140
3617
|
def _create_line_graph(self, ax):
|
3141
|
-
"""
|
3618
|
+
"""
|
3619
|
+
Create a line plot where each group is represented by a separate line.
|
3620
|
+
|
3621
|
+
Args:
|
3622
|
+
ax (matplotlib.axes.Axes): Axis object to draw the line graph.
|
3623
|
+
|
3624
|
+
Notes:
|
3625
|
+
- Expects `self.data_column` to contain two columns: x and y axes.
|
3626
|
+
- One line per group in `self.grouping_column`.
|
3627
|
+
- Log scaling applied as specified.
|
3628
|
+
- Raises ValueError if required columns are missing.
|
3629
|
+
"""
|
3142
3630
|
#display(self.df)
|
3143
3631
|
# Ensure epoch is used on the x-axis and accuracy on the y-axis
|
3144
3632
|
x_axis_column = self.data_column[0]
|
@@ -3167,7 +3655,18 @@ class spacrGraph:
|
|
3167
3655
|
ax.set_ylabel(f"{y_axis_column}")
|
3168
3656
|
|
3169
3657
|
def _create_line_with_std_area(self, ax):
|
3170
|
-
"""
|
3658
|
+
"""
|
3659
|
+
Create a line plot with shaded standard deviation region around the mean.
|
3660
|
+
|
3661
|
+
Args:
|
3662
|
+
ax (matplotlib.axes.Axes): Axis object to draw the plot.
|
3663
|
+
|
3664
|
+
Notes:
|
3665
|
+
- Expects `self.data_column = [x_column, y_column]`.
|
3666
|
+
- Computes mean and std of `y_column` per `x_column`.
|
3667
|
+
- Plots the mean as a line and ±1 std as a shaded region.
|
3668
|
+
- Applies log scaling if enabled.
|
3669
|
+
"""
|
3171
3670
|
|
3172
3671
|
x_axis_column = self.data_column[0]
|
3173
3672
|
y_axis_column = self.data_column[1]
|
@@ -3198,7 +3697,18 @@ class spacrGraph:
|
|
3198
3697
|
ax.set_ylabel(f"{y_axis_column}")
|
3199
3698
|
|
3200
3699
|
def _create_box_plot(self, ax):
|
3201
|
-
"""
|
3700
|
+
"""
|
3701
|
+
Create a box plot for each group with optional grouping by multiple columns.
|
3702
|
+
|
3703
|
+
Args:
|
3704
|
+
ax (matplotlib.axes.Axes): Axis object to draw the box plot.
|
3705
|
+
|
3706
|
+
Notes:
|
3707
|
+
- Supports multiple data columns by combining into 'Combined Group'.
|
3708
|
+
- Uses `self.df_melted`.
|
3709
|
+
- Log scaling applied if enabled.
|
3710
|
+
- Legend is deduplicated.
|
3711
|
+
"""
|
3202
3712
|
# Combine grouping column and data column if needed
|
3203
3713
|
if len(self.data_column) > 1:
|
3204
3714
|
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
@@ -3227,7 +3737,18 @@ class spacrGraph:
|
|
3227
3737
|
ax.set_xscale('log')
|
3228
3738
|
|
3229
3739
|
def _create_violin_plot(self, ax):
|
3230
|
-
"""
|
3740
|
+
"""
|
3741
|
+
Create a violin plot for each group, showing data distribution.
|
3742
|
+
|
3743
|
+
Args:
|
3744
|
+
ax (matplotlib.axes.Axes): Axis object to draw the violin plot.
|
3745
|
+
|
3746
|
+
Notes:
|
3747
|
+
- Supports multiple data columns via combined group label.
|
3748
|
+
- Uses `self.df_melted`.
|
3749
|
+
- Log scaling applied if enabled.
|
3750
|
+
- Legend is deduplicated.
|
3751
|
+
"""
|
3231
3752
|
# Combine grouping column and data column if needed
|
3232
3753
|
if len(self.data_column) > 1:
|
3233
3754
|
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
@@ -3257,7 +3778,19 @@ class spacrGraph:
|
|
3257
3778
|
ax.set_xscale('log')
|
3258
3779
|
|
3259
3780
|
def _create_jitter_bar_plot(self, ax):
|
3260
|
-
"""
|
3781
|
+
"""
|
3782
|
+
Create a combined jitter and bar plot with optional error bars.
|
3783
|
+
|
3784
|
+
Args:
|
3785
|
+
ax (matplotlib.axes.Axes): Axis object to draw the plot.
|
3786
|
+
|
3787
|
+
Notes:
|
3788
|
+
- Bars show central tendency; jitter adds point-level detail.
|
3789
|
+
- Supports multiple data columns via combined group label.
|
3790
|
+
- Uses `self.df_melted` and `self.error_bar_type`.
|
3791
|
+
- Log scaling applied if enabled.
|
3792
|
+
- Error bars currently commented out; customize if needed.
|
3793
|
+
"""
|
3261
3794
|
# Flatten DataFrame: Combine grouping column and data column into one group if needed
|
3262
3795
|
if len(self.data_column) > 1:
|
3263
3796
|
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
@@ -3299,7 +3832,19 @@ class spacrGraph:
|
|
3299
3832
|
ax.set_xscale('log')
|
3300
3833
|
|
3301
3834
|
def _create_jitter_box_plot(self, ax):
|
3302
|
-
"""
|
3835
|
+
"""
|
3836
|
+
Create a combined jitter and box plot.
|
3837
|
+
|
3838
|
+
Args:
|
3839
|
+
ax (matplotlib.axes.Axes): Axis object to draw the plot.
|
3840
|
+
|
3841
|
+
Notes:
|
3842
|
+
- Uses seaborn `boxplot` and overlays `stripplot`.
|
3843
|
+
- Supports multiple data columns via combined group label.
|
3844
|
+
- Uses `self.df_melted`.
|
3845
|
+
- Legend is deduplicated.
|
3846
|
+
- Log scaling applied if enabled.
|
3847
|
+
"""
|
3303
3848
|
# Combine grouping column and data column if needed
|
3304
3849
|
if len(self.data_column) > 1:
|
3305
3850
|
self.df_melted['Combined Group'] = (self.df_melted[self.grouping_column].astype(str) + " - " + self.df_melted['Data Column'].astype(str))
|
@@ -3329,7 +3874,14 @@ class spacrGraph:
|
|
3329
3874
|
ax.set_xscale('log')
|
3330
3875
|
|
3331
3876
|
def _save_results(self):
|
3332
|
-
"""
|
3877
|
+
"""
|
3878
|
+
Save the figure and results DataFrame to disk.
|
3879
|
+
|
3880
|
+
Notes:
|
3881
|
+
- PDF saved to `self.output_dir` with name based on `self.results_name`.
|
3882
|
+
- CSV of results saved in the same directory.
|
3883
|
+
- Ensures output directory exists.
|
3884
|
+
"""
|
3333
3885
|
os.makedirs(self.output_dir, exist_ok=True)
|
3334
3886
|
plot_path = os.path.join(self.output_dir, f"{self.results_name}.pdf")
|
3335
3887
|
self.fig.savefig(plot_path, bbox_inches='tight', dpi=600, transparent=True, format='pdf')
|
@@ -3339,30 +3891,65 @@ class spacrGraph:
|
|
3339
3891
|
print(f"Test results saved to {results_path}")
|
3340
3892
|
|
3341
3893
|
def get_results(self):
|
3342
|
-
"""
|
3894
|
+
"""
|
3895
|
+
Return the internal results DataFrame with computed statistics or comparisons.
|
3896
|
+
|
3897
|
+
Returns:
|
3898
|
+
pd.DataFrame: Results table.
|
3899
|
+
"""
|
3343
3900
|
return self.results_df
|
3344
3901
|
|
3345
3902
|
def get_figure(self):
|
3346
|
-
"""
|
3903
|
+
"""
|
3904
|
+
Return the Matplotlib figure associated with the most recent plot.
|
3905
|
+
|
3906
|
+
Returns:
|
3907
|
+
matplotlib.figure.Figure: The plot figure.
|
3908
|
+
"""
|
3347
3909
|
return self.fig
|
3348
3910
|
|
3349
3911
|
def plot_data_from_db(settings):
|
3350
|
-
|
3351
|
-
from .io import _read_db, _read_and_merge_data
|
3352
|
-
from .utils import annotate_conditions, save_settings
|
3353
|
-
from .settings import set_default_plot_data_from_db
|
3354
|
-
|
3355
3912
|
"""
|
3356
|
-
|
3913
|
+
Load measurement data from SQL databases, annotate with experimental conditions,
|
3914
|
+
and generate grouped plots using spacrGraph.
|
3357
3915
|
|
3358
3916
|
Args:
|
3359
|
-
|
3360
|
-
|
3361
|
-
|
3917
|
+
settings (dict): A dictionary containing the following keys:
|
3918
|
+
- 'src' (str or list of str): Source directories containing measurement databases.
|
3919
|
+
- 'database' (str or list of str): Corresponding database filenames.
|
3920
|
+
- 'table_names' (str): Name of the SQL table to extract data from.
|
3921
|
+
- 'graph_name' (str): Name for saving the output plot and settings.
|
3922
|
+
- 'grouping_column' (str): Column to group by on the x-axis.
|
3923
|
+
- 'data_column' (str): Column to plot on the y-axis.
|
3924
|
+
- 'channel_of_interest' (int): Channel number for recruitment calculation (if applicable).
|
3925
|
+
- 'cell_types' (list): List of expected host cell types.
|
3926
|
+
- 'pathogen_types' (list): List of expected pathogen types.
|
3927
|
+
- 'treatments' (list): List of expected treatments.
|
3928
|
+
- 'cell_plate_metadata' (str or None): Path to metadata for host cells.
|
3929
|
+
- 'pathogen_plate_metadata' (str or None): Path to metadata for pathogens.
|
3930
|
+
- 'treatment_plate_metadata' (str or None): Path to metadata for treatments.
|
3931
|
+
- 'nuclei_limit' (int): Max number of nuclei to load per plate (optional).
|
3932
|
+
- 'pathogen_limit' (int): Max number of pathogens to load per plate (optional).
|
3933
|
+
- 'verbose' (bool): Whether to print loading information.
|
3934
|
+
- 'graph_type' (str): Type of plot to create (e.g., 'bar', 'box', 'jitter', etc.).
|
3935
|
+
- 'representation' (str): Data representation style for plotting (optional).
|
3936
|
+
- 'theme' (str): Seaborn color palette theme.
|
3937
|
+
- 'y_lim' (tuple or None): Limits for y-axis scaling.
|
3938
|
+
- 'save' (bool): Whether to save the plot and results to disk.
|
3362
3939
|
|
3363
3940
|
Returns:
|
3364
|
-
|
3941
|
+
fig (matplotlib.figure.Figure): The plotted figure.
|
3942
|
+
results_df (pd.DataFrame): Statistical results or summary data used in the plot.
|
3943
|
+
|
3944
|
+
Notes:
|
3945
|
+
- Automatically handles multi-source input.
|
3946
|
+
- Computes recruitment ratio if specified.
|
3947
|
+
- Drops rows with missing grouping or data values.
|
3948
|
+
- Creates and saves plot and settings to disk if 'save' is True.
|
3365
3949
|
"""
|
3950
|
+
from .io import _read_db, _read_and_merge_data
|
3951
|
+
from .utils import annotate_conditions, save_settings
|
3952
|
+
from .settings import set_default_plot_data_from_db
|
3366
3953
|
|
3367
3954
|
settings = set_default_plot_data_from_db(settings)
|
3368
3955
|
|
@@ -3472,21 +4059,40 @@ def plot_data_from_db(settings):
|
|
3472
4059
|
return fig, results_df
|
3473
4060
|
|
3474
4061
|
def plot_data_from_csv(settings):
|
3475
|
-
from .io import _read_db, _read_and_merge_data
|
3476
|
-
from .utils import annotate_conditions, save_settings, remove_outliers_by_group
|
3477
|
-
from .settings import get_plot_data_from_csv_default_settings
|
3478
4062
|
"""
|
3479
|
-
|
4063
|
+
Load measurement data from one or more CSV files, optionally filter and clean the data,
|
4064
|
+
and generate grouped plots using spacrGraph.
|
3480
4065
|
|
3481
4066
|
Args:
|
3482
|
-
|
3483
|
-
|
3484
|
-
|
4067
|
+
settings (dict): A dictionary containing the following keys:
|
4068
|
+
- 'src' (str or list of str): Path(s) to the CSV file(s) containing measurement data.
|
4069
|
+
- 'grouping_column' (str): Column to group by on the x-axis.
|
4070
|
+
- 'data_column' (str): Column to plot on the y-axis.
|
4071
|
+
- 'graph_name' (str): Name for saving the output plot and settings.
|
4072
|
+
- 'graph_type' (str): Type of plot to create (e.g., 'bar', 'box', 'violin', 'jitter').
|
4073
|
+
- 'theme' (str): Seaborn color palette theme.
|
4074
|
+
- 'log_y' (bool): Whether to log-transform the y-axis.
|
4075
|
+
- 'log_x' (bool): Whether to log-transform the x-axis.
|
4076
|
+
- 'y_lim' (tuple or None): Limits for y-axis scaling (optional).
|
4077
|
+
- 'save' (bool): Whether to save the plot and results to disk.
|
4078
|
+
- 'verbose' (bool): Whether to print and display the DataFrame before plotting.
|
4079
|
+
- 'representation' (str): Plot style to use (e.g., 'jitter_box', 'violin', etc.).
|
4080
|
+
- 'keep_groups' (list or str, optional): Restrict plot to a subset of group labels.
|
4081
|
+
- 'remove_outliers' (bool): Whether to remove outliers using IQR filtering.
|
3485
4082
|
|
3486
4083
|
Returns:
|
3487
|
-
|
4084
|
+
fig (matplotlib.figure.Figure): The plotted figure.
|
4085
|
+
results_df (pd.DataFrame): DataFrame with statistical results or plot summary.
|
4086
|
+
|
4087
|
+
Notes:
|
4088
|
+
- Merges multiple CSVs and tags them with a default 'plateID' if not provided.
|
4089
|
+
- Attempts to split 'prc' column into 'plateID', 'rowID', and 'columnID' if applicable.
|
4090
|
+
- Handles missing values in the grouping and data columns by dropping those rows.
|
4091
|
+
- Automatically creates the output directory and saves results if `save` is True.
|
3488
4092
|
"""
|
3489
|
-
|
4093
|
+
from .io import _read_db, _read_and_merge_data
|
4094
|
+
from .utils import annotate_conditions, save_settings, remove_outliers_by_group
|
4095
|
+
from .settings import get_plot_data_from_csv_default_settings
|
3490
4096
|
|
3491
4097
|
def filter_rows_by_column_values(df: pd.DataFrame, column: str, values: list) -> pd.DataFrame:
|
3492
4098
|
"""Return a filtered DataFrame where only rows with the column value in the list are kept."""
|
@@ -3569,7 +4175,44 @@ def plot_data_from_csv(settings):
|
|
3569
4175
|
return fig, results_df
|
3570
4176
|
|
3571
4177
|
def plot_region(settings):
|
4178
|
+
"""
|
4179
|
+
Generate and save region-specific plots including: mask overlay, raw PNG grid, and activation map grid.
|
4180
|
+
|
4181
|
+
This function loads image paths and metadata from database tables, filters for the specified region (field of view),
|
4182
|
+
and plots the following:
|
4183
|
+
|
4184
|
+
- A mask overlay of the full field image with Cellpose masks.
|
4185
|
+
- A grid of raw cropped PNGs corresponding to the region.
|
4186
|
+
- A grid of activation maps (e.g., saliency, Grad-CAM) for the same region.
|
4187
|
+
|
4188
|
+
Args:
|
4189
|
+
settings (dict): A dictionary containing the following keys:
|
4190
|
+
|
4191
|
+
- 'src' (str): Source folder containing subfolders like 'merged', 'measurements', and 'datasets'.
|
4192
|
+
- 'name' (str): Filename (e.g., 'plate1_A01_01.tif') identifying the region/FOV of interest.
|
4193
|
+
- 'activation_db' (str): Filename of the activation measurement database (e.g., 'activations.db').
|
4194
|
+
- 'activation_mode' (str): Mode of activation ('saliency', 'gradcam', etc.); used to find the correct table.
|
4195
|
+
- 'channels' (list of int): Indices of input channels to use for the mask overlay.
|
4196
|
+
- 'cell_channel' (int): Channel index used to generate the cell mask.
|
4197
|
+
- 'nucleus_channel' (int): Channel index used to generate the nucleus mask.
|
4198
|
+
- 'pathogen_channel' (int): Channel index used to generate the pathogen mask.
|
4199
|
+
- 'percentiles' (list): Two-element list (e.g., [2, 99]) specifying intensity percentiles for contrast scaling.
|
4200
|
+
- 'mode' (str): Image display mode for overlay ('rgb', 'stack', etc.).
|
4201
|
+
- 'export_tiffs' (bool): Whether to export TIFF images alongside overlays.
|
4202
|
+
|
4203
|
+
Returns:
|
4204
|
+
tuple: A 3-element tuple of matplotlib Figure objects or None
|
3572
4205
|
|
4206
|
+
- fig_1 (matplotlib.figure.Figure or None): Mask overlay figure.
|
4207
|
+
- fig_2 (matplotlib.figure.Figure or None): Grid of raw cropped PNGs.
|
4208
|
+
- fig_3 (matplotlib.figure.Figure or None): Grid of activation maps.
|
4209
|
+
|
4210
|
+
Notes:
|
4211
|
+
- Figures are saved as PDFs under ``<src>/results/<name>/``.
|
4212
|
+
- If no relevant PNGs or activations are found, the corresponding figure will be ``None``.
|
4213
|
+
- Paths are automatically corrected using ``correct_paths``.
|
4214
|
+
- The figure layout uses ``plot_image_grid`` and ``plot_image_mask_overlay``.
|
4215
|
+
"""
|
3573
4216
|
def _sort_paths_by_basename(paths):
|
3574
4217
|
return sorted(paths, key=lambda path: os.path.basename(path))
|
3575
4218
|
|
@@ -3640,7 +4283,7 @@ def plot_image_grid(image_paths, percentiles):
|
|
3640
4283
|
Plots a square grid of images from a list of image paths.
|
3641
4284
|
Unused subplots are filled with black, and padding is minimized.
|
3642
4285
|
|
3643
|
-
|
4286
|
+
Args:
|
3644
4287
|
- image_paths: List of paths to images to be displayed.
|
3645
4288
|
|
3646
4289
|
Returns:
|
@@ -3798,7 +4441,30 @@ def overlay_masks_on_images(img_folder, normalize=True, resize=True, save=False,
|
|
3798
4441
|
plt.show()
|
3799
4442
|
|
3800
4443
|
def graph_importance(settings):
|
3801
|
-
|
4444
|
+
"""
|
4445
|
+
Generate and display a bar, box, or violin plot of importance values across grouped categories.
|
4446
|
+
|
4447
|
+
This function reads one or more CSV files containing importance scores (e.g., feature importances,
|
4448
|
+
saliency values, or other metrics) and visualizes them grouped by a specified column. It wraps
|
4449
|
+
the `spacrGraph` plotting class to create the plot and saves the settings used.
|
4450
|
+
|
4451
|
+
Args:
|
4452
|
+
settings (dict): A dictionary containing the following keys:
|
4453
|
+
- 'csvs' (str or list of str): Path(s) to CSV file(s) containing importance data.
|
4454
|
+
- 'grouping_column' (str): Column name used for grouping on the x-axis.
|
4455
|
+
- 'data_column' (str): Column name containing the data values to be plotted (e.g., importance scores).
|
4456
|
+
- 'graph_type' (str): Type of plot to generate ('bar', 'box', or 'violin').
|
4457
|
+
- 'save' (bool): Whether to save the plot as a PDF in the same directory as the input CSVs.
|
4458
|
+
|
4459
|
+
Returns:
|
4460
|
+
None
|
4461
|
+
|
4462
|
+
Notes:
|
4463
|
+
- If the required columns are missing from the input data, the function will print a warning and exit.
|
4464
|
+
- The plot is created using the `spacrGraph` class and is shown with `matplotlib.pyplot.show()`.
|
4465
|
+
- All input CSVs are concatenated before plotting.
|
4466
|
+
- Settings are saved to disk using `save_settings`.
|
4467
|
+
"""
|
3802
4468
|
from .settings import set_graph_importance_defaults
|
3803
4469
|
from .utils import save_settings
|
3804
4470
|
|
@@ -3850,7 +4516,7 @@ def plot_proportion_stacked_bars(settings, df, group_column, bin_column, prc_col
|
|
3850
4516
|
"""
|
3851
4517
|
Generate a stacked bar plot for proportions and perform chi-squared and pairwise tests.
|
3852
4518
|
|
3853
|
-
|
4519
|
+
Args:
|
3854
4520
|
- settings (dict): Analysis settings.
|
3855
4521
|
- df (DataFrame): Input data.
|
3856
4522
|
- group_column (str): Column indicating the groups.
|
@@ -3927,7 +4593,7 @@ def create_venn_diagram(file1, file2, gene_column="gene", filter_coeff=0.1, save
|
|
3927
4593
|
Reads two CSV files, extracts the `gene` column, and creates a Venn diagram
|
3928
4594
|
to show overlapping and non-overlapping genes.
|
3929
4595
|
|
3930
|
-
|
4596
|
+
Args:
|
3931
4597
|
file1 (str): Path to the first CSV file.
|
3932
4598
|
file2 (str): Path to the second CSV file.
|
3933
4599
|
gene_column (str): Name of the column containing gene data (default: "gene").
|