microarray 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- microarray/__init__.py +15 -0
- microarray/_version.py +3 -0
- microarray/datasets/__init__.py +3 -0
- microarray/datasets/_arrayexpress.py +1 -0
- microarray/datasets/_cdf_files.py +35 -0
- microarray/datasets/_geo.py +1 -0
- microarray/datasets/_utils.py +143 -0
- microarray/io/__init__.py +17 -0
- microarray/io/_anndata_converter.py +198 -0
- microarray/io/_cdf.py +575 -0
- microarray/io/_cel.py +591 -0
- microarray/io/_read.py +127 -0
- microarray/plotting/__init__.py +28 -0
- microarray/plotting/_base.py +253 -0
- microarray/plotting/_cel.py +75 -0
- microarray/plotting/_de_plots.py +239 -0
- microarray/plotting/_diagnostic_plots.py +268 -0
- microarray/plotting/_heatmap.py +279 -0
- microarray/plotting/_ma_plots.py +136 -0
- microarray/plotting/_pca.py +320 -0
- microarray/plotting/_qc_plots.py +335 -0
- microarray/plotting/_score.py +38 -0
- microarray/plotting/_top_table_heatmap.py +98 -0
- microarray/plotting/_utils.py +280 -0
- microarray/preprocessing/__init__.py +39 -0
- microarray/preprocessing/_background.py +862 -0
- microarray/preprocessing/_log2.py +77 -0
- microarray/preprocessing/_normalize.py +1292 -0
- microarray/preprocessing/_rma.py +243 -0
- microarray/preprocessing/_robust.py +170 -0
- microarray/preprocessing/_summarize.py +318 -0
- microarray/py.typed +0 -0
- microarray/tools/__init__.py +26 -0
- microarray/tools/_biomart.py +416 -0
- microarray/tools/_empirical_bayes.py +401 -0
- microarray/tools/_fdist.py +171 -0
- microarray/tools/_linear_models.py +387 -0
- microarray/tools/_mds.py +101 -0
- microarray/tools/_pca.py +88 -0
- microarray/tools/_score.py +86 -0
- microarray/tools/_toptable.py +360 -0
- microarray-0.1.0.dist-info/METADATA +75 -0
- microarray-0.1.0.dist-info/RECORD +44 -0
- microarray-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Differential expression plot functions for microarray analysis."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
from matplotlib.axes import Axes
|
|
8
|
+
from matplotlib.patches import Circle
|
|
9
|
+
|
|
10
|
+
from microarray.plotting._utils import add_reference_line, with_highlights
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def volcano(
|
|
14
|
+
logfc: np.ndarray,
|
|
15
|
+
pvalues: np.ndarray,
|
|
16
|
+
logfc_threshold: float = 1.0,
|
|
17
|
+
pvalue_threshold: float = 0.05,
|
|
18
|
+
labels: list[str] | np.ndarray | None = None,
|
|
19
|
+
top_n: int = 10,
|
|
20
|
+
status: np.ndarray | None = None,
|
|
21
|
+
xlab: str = "Log2 fold-change",
|
|
22
|
+
ylab: str = "-Log10(p-value)",
|
|
23
|
+
title: str = "Volcano Plot",
|
|
24
|
+
ax: Axes | None = None,
|
|
25
|
+
**kwargs: Any,
|
|
26
|
+
) -> Axes:
|
|
27
|
+
"""Volcano plot for differential expression results.
|
|
28
|
+
|
|
29
|
+
Volcano plot displays log fold-changes vs statistical significance
|
|
30
|
+
(-log10 p-values). Points in upper left/right corners represent genes
|
|
31
|
+
with large fold-changes and high significance.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
logfc: Array of log2 fold-changes
|
|
35
|
+
pvalues: Array of p-values
|
|
36
|
+
logfc_threshold: Fold-change threshold for significance lines. Default 1.0.
|
|
37
|
+
pvalue_threshold: P-value threshold for significance line. Default 0.05.
|
|
38
|
+
labels: Gene/probe labels. If provided with top_n, labels top genes.
|
|
39
|
+
top_n: Number of top genes to label (by significance). Default 10.
|
|
40
|
+
status: Custom status labels for coloring. If None, automatically determines
|
|
41
|
+
status based on thresholds (up/down/not-significant).
|
|
42
|
+
xlab: X-axis label
|
|
43
|
+
ylab: Y-axis label
|
|
44
|
+
title: Plot title
|
|
45
|
+
ax: Existing Axes object. If None, creates new figure.
|
|
46
|
+
**kwargs: Additional arguments passed to scatter plot
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Axes object with volcano plot
|
|
50
|
+
|
|
51
|
+
Examples:
|
|
52
|
+
>>> import numpy as np
|
|
53
|
+
>>> from microarray.plotting import volcano
|
|
54
|
+
>>> logfc = np.random.randn(1000) * 2
|
|
55
|
+
>>> pvalues = np.random.uniform(0, 1, 1000)
|
|
56
|
+
>>> ax = volcano(logfc, pvalues)
|
|
57
|
+
"""
|
|
58
|
+
if ax is None:
|
|
59
|
+
_, ax = plt.subplots(figsize=(8, 7))
|
|
60
|
+
|
|
61
|
+
# Calculate -log10(p-values)
|
|
62
|
+
# Handle p-values of 0 by setting a minimum
|
|
63
|
+
pvalues = np.maximum(pvalues, 1e-300)
|
|
64
|
+
logp = -np.log10(pvalues)
|
|
65
|
+
|
|
66
|
+
# Determine status if not provided
|
|
67
|
+
if status is None:
|
|
68
|
+
status = np.array(["not-significant"] * len(logfc))
|
|
69
|
+
significant = logp >= -np.log10(pvalue_threshold)
|
|
70
|
+
up = significant & (logfc >= logfc_threshold)
|
|
71
|
+
down = significant & (logfc <= -logfc_threshold)
|
|
72
|
+
status[up] = "up"
|
|
73
|
+
status[down] = "down"
|
|
74
|
+
|
|
75
|
+
# Create scatter plot with highlighting
|
|
76
|
+
ax = with_highlights(
|
|
77
|
+
logfc, logp, status=status, xlab=xlab, ylab=ylab, title=title, ax=ax, legend="upper right", **kwargs
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Volcano-specific styling: no background grid and legend outside without frame.
|
|
81
|
+
ax.grid(False)
|
|
82
|
+
legend = ax.get_legend()
|
|
83
|
+
if legend is not None:
|
|
84
|
+
legend.set_loc("upper left")
|
|
85
|
+
legend.set_bbox_to_anchor((1.02, 1.0))
|
|
86
|
+
legend.set_frame_on(False)
|
|
87
|
+
|
|
88
|
+
# Add threshold lines
|
|
89
|
+
# Vertical lines for fold-change thresholds
|
|
90
|
+
add_reference_line(ax, x=logfc_threshold, color="darkgray", linestyle="--", alpha=0.7)
|
|
91
|
+
add_reference_line(ax, x=-logfc_threshold, color="darkgray", linestyle="--", alpha=0.7)
|
|
92
|
+
|
|
93
|
+
# Horizontal line for p-value threshold
|
|
94
|
+
add_reference_line(ax, y=-np.log10(pvalue_threshold), color="darkgray", linestyle="--", alpha=0.7)
|
|
95
|
+
|
|
96
|
+
# Label top genes if requested
|
|
97
|
+
if labels is not None and top_n > 0:
|
|
98
|
+
# Get indices of top genes by p-value
|
|
99
|
+
top_indices = np.argsort(logp)[-top_n:]
|
|
100
|
+
|
|
101
|
+
for idx in top_indices:
|
|
102
|
+
ax.annotate(
|
|
103
|
+
labels[idx],
|
|
104
|
+
(logfc[idx], logp[idx]),
|
|
105
|
+
xytext=(5, 5),
|
|
106
|
+
textcoords="offset points",
|
|
107
|
+
fontsize=8,
|
|
108
|
+
alpha=0.7,
|
|
109
|
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7, edgecolor="none"),
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return ax
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def venn(
|
|
116
|
+
sets: dict[str, set] | list[set],
|
|
117
|
+
labels: list[str] | None = None,
|
|
118
|
+
colors: list[str] | None = None,
|
|
119
|
+
alpha: float = 0.4,
|
|
120
|
+
title: str = "Venn Diagram",
|
|
121
|
+
ax: Axes | None = None,
|
|
122
|
+
) -> Axes:
|
|
123
|
+
"""Venn diagram for visualizing overlap between sets.
|
|
124
|
+
|
|
125
|
+
Creates Venn diagram showing overlap between 2 or 3 sets.
|
|
126
|
+
Common use case: visualizing overlap of differentially expressed genes
|
|
127
|
+
across multiple contrasts or conditions.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
sets: Dictionary mapping labels to sets, or list of sets.
|
|
131
|
+
If list, labels parameter must be provided.
|
|
132
|
+
labels: Labels for each set. Required if sets is a list.
|
|
133
|
+
colors: Colors for each set. If None, uses default palette.
|
|
134
|
+
alpha: Transparency of circles (0-1)
|
|
135
|
+
title: Plot title
|
|
136
|
+
ax: Existing Axes object. If None, creates new figure.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Axes object with Venn diagram
|
|
140
|
+
|
|
141
|
+
Examples:
|
|
142
|
+
>>> from microarray.plotting import venn
|
|
143
|
+
>>> set1 = set(["gene1", "gene2", "gene3", "gene4"])
|
|
144
|
+
>>> set2 = set(["gene3", "gene4", "gene5", "gene6"])
|
|
145
|
+
>>> ax = venn({"Control": set1, "Treatment": set2})
|
|
146
|
+
>>> # Three-way Venn
|
|
147
|
+
>>> set3 = set(["gene1", "gene5", "gene7"])
|
|
148
|
+
>>> ax = venn({"A": set1, "B": set2, "C": set3})
|
|
149
|
+
"""
|
|
150
|
+
if ax is None:
|
|
151
|
+
_, ax = plt.subplots(figsize=(8, 8))
|
|
152
|
+
|
|
153
|
+
# Parse input
|
|
154
|
+
if isinstance(sets, dict):
|
|
155
|
+
labels = list(sets.keys())
|
|
156
|
+
set_list = list(sets.values())
|
|
157
|
+
else:
|
|
158
|
+
if labels is None:
|
|
159
|
+
raise ValueError("labels must be provided when sets is a list")
|
|
160
|
+
set_list = sets
|
|
161
|
+
|
|
162
|
+
n_sets = len(set_list)
|
|
163
|
+
|
|
164
|
+
if n_sets < 2 or n_sets > 3:
|
|
165
|
+
raise ValueError("Venn diagrams support 2 or 3 sets only")
|
|
166
|
+
|
|
167
|
+
# Get default colors
|
|
168
|
+
if colors is None:
|
|
169
|
+
default_colors = ["#E41A1C", "#377EB8", "#4DAF4A"]
|
|
170
|
+
colors = default_colors[:n_sets]
|
|
171
|
+
|
|
172
|
+
ax.set_aspect("equal")
|
|
173
|
+
ax.set_xlim(-2, 2)
|
|
174
|
+
ax.set_ylim(-2, 2)
|
|
175
|
+
ax.axis("off")
|
|
176
|
+
ax.set_title(title, fontsize=14, pad=20)
|
|
177
|
+
|
|
178
|
+
if n_sets == 2:
|
|
179
|
+
# Two-way Venn diagram
|
|
180
|
+
set_a, set_b = set_list
|
|
181
|
+
|
|
182
|
+
# Draw circles
|
|
183
|
+
circle_a = Circle((-0.5, 0), 1, color=colors[0], alpha=alpha, ec="black", linewidth=2)
|
|
184
|
+
circle_b = Circle((0.5, 0), 1, color=colors[1], alpha=alpha, ec="black", linewidth=2)
|
|
185
|
+
ax.add_patch(circle_a)
|
|
186
|
+
ax.add_patch(circle_b)
|
|
187
|
+
|
|
188
|
+
# Calculate counts
|
|
189
|
+
only_a = len(set_a - set_b)
|
|
190
|
+
only_b = len(set_b - set_a)
|
|
191
|
+
both = len(set_a & set_b)
|
|
192
|
+
|
|
193
|
+
# Add text labels
|
|
194
|
+
ax.text(-0.9, 0, str(only_a), fontsize=16, ha="center", va="center", weight="bold")
|
|
195
|
+
ax.text(0.9, 0, str(only_b), fontsize=16, ha="center", va="center", weight="bold")
|
|
196
|
+
ax.text(0, 0, str(both), fontsize=16, ha="center", va="center", weight="bold")
|
|
197
|
+
|
|
198
|
+
# Add set labels
|
|
199
|
+
ax.text(-0.6, 1.3, labels[0], fontsize=12, ha="center", weight="bold")
|
|
200
|
+
ax.text(0.6, 1.3, labels[1], fontsize=12, ha="center", weight="bold")
|
|
201
|
+
|
|
202
|
+
elif n_sets == 3:
|
|
203
|
+
# Three-way Venn diagram
|
|
204
|
+
set_a, set_b, set_c = set_list
|
|
205
|
+
|
|
206
|
+
# Draw circles
|
|
207
|
+
r = 1 # radius
|
|
208
|
+
d = 0.7 # distance from center
|
|
209
|
+
circle_a = Circle((-d / 2, d / 2), r, color=colors[0], alpha=alpha, ec="black", linewidth=2)
|
|
210
|
+
circle_b = Circle((d / 2, d / 2), r, color=colors[1], alpha=alpha, ec="black", linewidth=2)
|
|
211
|
+
circle_c = Circle((0, -d / 2), r, color=colors[2], alpha=alpha, ec="black", linewidth=2)
|
|
212
|
+
ax.add_patch(circle_a)
|
|
213
|
+
ax.add_patch(circle_b)
|
|
214
|
+
ax.add_patch(circle_c)
|
|
215
|
+
|
|
216
|
+
# Calculate counts
|
|
217
|
+
only_a = len(set_a - set_b - set_c)
|
|
218
|
+
only_b = len(set_b - set_a - set_c)
|
|
219
|
+
only_c = len(set_c - set_a - set_b)
|
|
220
|
+
ab_only = len((set_a & set_b) - set_c)
|
|
221
|
+
ac_only = len((set_a & set_c) - set_b)
|
|
222
|
+
bc_only = len((set_b & set_c) - set_a)
|
|
223
|
+
abc = len(set_a & set_b & set_c)
|
|
224
|
+
|
|
225
|
+
# Add text labels (positioned by eye for typical 3-way Venn)
|
|
226
|
+
ax.text(-0.8, 0.65, str(only_a), fontsize=14, ha="center", va="center", weight="bold")
|
|
227
|
+
ax.text(0.8, 0.65, str(only_b), fontsize=14, ha="center", va="center", weight="bold")
|
|
228
|
+
ax.text(0, -1.0, str(only_c), fontsize=14, ha="center", va="center", weight="bold")
|
|
229
|
+
ax.text(0, 0.78, str(ab_only), fontsize=14, ha="center", va="center", weight="bold")
|
|
230
|
+
ax.text(-0.55, -0.3, str(ac_only), fontsize=14, ha="center", va="center", weight="bold")
|
|
231
|
+
ax.text(0.55, -0.3, str(bc_only), fontsize=14, ha="center", va="center", weight="bold")
|
|
232
|
+
ax.text(0, 0.2, str(abc), fontsize=14, ha="center", va="center", weight="bold")
|
|
233
|
+
|
|
234
|
+
# Add set labels
|
|
235
|
+
ax.text(-d, d / 2 + 1.3, labels[0], fontsize=12, ha="center", weight="bold")
|
|
236
|
+
ax.text(d, d / 2 + 1.3, labels[1], fontsize=12, ha="center", weight="bold")
|
|
237
|
+
ax.text(0, -d / 2 - 1.3, labels[2], fontsize=12, ha="center", weight="bold")
|
|
238
|
+
|
|
239
|
+
return ax
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""Diagnostic plot functions for microarray data analysis."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
from anndata import AnnData
|
|
8
|
+
from matplotlib.axes import Axes
|
|
9
|
+
|
|
10
|
+
from microarray.plotting._utils import get_default_colors
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def mds(
|
|
14
|
+
adata: AnnData,
|
|
15
|
+
obsm_key: str = "X_mds",
|
|
16
|
+
top: int = 500,
|
|
17
|
+
gene_selection: str = "common",
|
|
18
|
+
dimensions: int = 2,
|
|
19
|
+
labels: list[str] | None = None,
|
|
20
|
+
colors: list[str] | str | None = None,
|
|
21
|
+
groups: np.ndarray | list | None = None,
|
|
22
|
+
xlab: str | None = None,
|
|
23
|
+
ylab: str | None = None,
|
|
24
|
+
title: str = "MDS Plot",
|
|
25
|
+
ax: Axes | None = None,
|
|
26
|
+
**kwargs: Any,
|
|
27
|
+
) -> Axes:
|
|
28
|
+
"""Plot Multidimensional Scaling (MDS) embedding.
|
|
29
|
+
|
|
30
|
+
Visualizes the MDS embedding stored in `.obsm` to show sample relationships
|
|
31
|
+
in 2D space. Samples that are similar (highly correlated) appear close together,
|
|
32
|
+
while dissimilar samples are far apart. Essential for quality control and
|
|
33
|
+
identifying batch effects or outliers.
|
|
34
|
+
|
|
35
|
+
Note:
|
|
36
|
+
If MDS embedding is not found in `.obsm[obsm_key]`, it will be computed
|
|
37
|
+
automatically using `microarray.tl.mds()` with the provided parameters.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
adata: AnnData object with MDS embedding in .obsm or expression data in .X
|
|
41
|
+
obsm_key: Key in .obsm where the MDS embedding is stored. Default "X_mds".
|
|
42
|
+
top: Number of top varying probes to use if computing MDS. Default 500.
|
|
43
|
+
gene_selection: Method for selecting genes if computing MDS. Default "common".
|
|
44
|
+
dimensions: Number of dimensions to plot (must be 2). Default 2.
|
|
45
|
+
labels: Custom labels for each sample. If None, uses obs_names.
|
|
46
|
+
colors: Color(s) for points. Can be single color or list of colors per sample.
|
|
47
|
+
groups: Group assignments for color coding. If provided with colors as dict,
|
|
48
|
+
maps groups to colors.
|
|
49
|
+
xlab: X-axis label. If None, uses "Dimension 1".
|
|
50
|
+
ylab: Y-axis label. If None, uses "Dimension 2".
|
|
51
|
+
title: Plot title
|
|
52
|
+
ax: Existing Axes object. If None, creates new figure.
|
|
53
|
+
**kwargs: Additional arguments passed to ax.scatter()
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Axes object with MDS plot
|
|
57
|
+
|
|
58
|
+
Examples:
|
|
59
|
+
>>> import anndata as ad
|
|
60
|
+
>>> import numpy as np
|
|
61
|
+
>>> import microarray as ma
|
|
62
|
+
>>> data = np.random.randn(1000, 6)
|
|
63
|
+
>>> adata = ad.AnnData(data.T)
|
|
64
|
+
>>> # Compute MDS first
|
|
65
|
+
>>> ma.tl.mds(adata, top=500)
|
|
66
|
+
>>> # Then plot it
|
|
67
|
+
>>> ax = ma.pl.mds(adata)
|
|
68
|
+
>>> # Or let the plot function compute it automatically
|
|
69
|
+
>>> ax = ma.pl.mds(adata, top=500)
|
|
70
|
+
>>> # With group coloring
|
|
71
|
+
>>> groups = ["control", "control", "control", "treated", "treated", "treated"]
|
|
72
|
+
>>> ax = ma.pl.mds(adata, groups=groups)
|
|
73
|
+
"""
|
|
74
|
+
if ax is None:
|
|
75
|
+
_, ax = plt.subplots(figsize=(8, 7))
|
|
76
|
+
|
|
77
|
+
if dimensions != 2:
|
|
78
|
+
raise NotImplementedError("Only 2D MDS plots are currently supported")
|
|
79
|
+
|
|
80
|
+
# Check if MDS embedding exists, if not compute it
|
|
81
|
+
if obsm_key not in adata.obsm:
|
|
82
|
+
# Import here to avoid circular dependency
|
|
83
|
+
from microarray.tools import mds as compute_mds
|
|
84
|
+
|
|
85
|
+
compute_mds(adata, top=top, gene_selection=gene_selection, n_components=dimensions, obsm_key=obsm_key)
|
|
86
|
+
|
|
87
|
+
# Get MDS coordinates from obsm
|
|
88
|
+
coords = adata.obsm[obsm_key]
|
|
89
|
+
n_samples = coords.shape[0]
|
|
90
|
+
|
|
91
|
+
# Prepare labels
|
|
92
|
+
if labels is None:
|
|
93
|
+
labels = list(adata.obs_names) if adata.obs_names is not None else [f"Sample {i}" for i in range(n_samples)]
|
|
94
|
+
|
|
95
|
+
# Prepare colors
|
|
96
|
+
if groups is not None:
|
|
97
|
+
unique_groups = np.unique(groups)
|
|
98
|
+
n_groups = len(unique_groups)
|
|
99
|
+
|
|
100
|
+
# If colors is a dict, map groups to colors
|
|
101
|
+
if isinstance(colors, dict):
|
|
102
|
+
color_map = colors
|
|
103
|
+
else:
|
|
104
|
+
# Generate default colors for groups
|
|
105
|
+
default_colors = get_default_colors(n_groups)
|
|
106
|
+
color_map = dict(zip(unique_groups, default_colors, strict=False))
|
|
107
|
+
|
|
108
|
+
# Plot by group for legend
|
|
109
|
+
for group in unique_groups:
|
|
110
|
+
mask = np.array(groups) == group
|
|
111
|
+
ax.scatter(
|
|
112
|
+
coords[mask, 0],
|
|
113
|
+
coords[mask, 1],
|
|
114
|
+
c=color_map[group],
|
|
115
|
+
label=str(group),
|
|
116
|
+
s=100,
|
|
117
|
+
alpha=0.7,
|
|
118
|
+
edgecolors="black",
|
|
119
|
+
linewidth=0.5,
|
|
120
|
+
**kwargs,
|
|
121
|
+
)
|
|
122
|
+
ax.legend(loc="best", frameon=True)
|
|
123
|
+
else:
|
|
124
|
+
# Single color or list of colors without grouping
|
|
125
|
+
if colors is None:
|
|
126
|
+
colors = get_default_colors(1)[0]
|
|
127
|
+
|
|
128
|
+
if isinstance(colors, str):
|
|
129
|
+
# Single color for all points
|
|
130
|
+
ax.scatter(
|
|
131
|
+
coords[:, 0], coords[:, 1], c=colors, s=100, alpha=0.7, edgecolors="black", linewidth=0.5, **kwargs
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
# List of colors
|
|
135
|
+
ax.scatter(
|
|
136
|
+
coords[:, 0], coords[:, 1], c=colors, s=100, alpha=0.7, edgecolors="black", linewidth=0.5, **kwargs
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Add labels to points
|
|
140
|
+
for i, label in enumerate(labels):
|
|
141
|
+
ax.annotate(
|
|
142
|
+
label, (coords[i, 0], coords[i, 1]), xytext=(5, 5), textcoords="offset points", fontsize=9, alpha=0.8
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Set labels
|
|
146
|
+
if xlab is None:
|
|
147
|
+
xlab = "Dimension 1"
|
|
148
|
+
if ylab is None:
|
|
149
|
+
ylab = "Dimension 2"
|
|
150
|
+
|
|
151
|
+
ax.set_xlabel(xlab)
|
|
152
|
+
ax.set_ylabel(ylab)
|
|
153
|
+
ax.set_title(title)
|
|
154
|
+
|
|
155
|
+
ax.grid(True, alpha=0.3, linestyle="--")
|
|
156
|
+
ax.axhline(y=0, color="gray", linewidth=0.5, alpha=0.5)
|
|
157
|
+
ax.axvline(x=0, color="gray", linewidth=0.5, alpha=0.5)
|
|
158
|
+
|
|
159
|
+
return ax
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def sa(
|
|
163
|
+
adata: AnnData,
|
|
164
|
+
fit_values: np.ndarray | None = None,
|
|
165
|
+
xlab: str = "Average log-expression",
|
|
166
|
+
ylab: str = "Sqrt(standard deviation)",
|
|
167
|
+
title: str = "SA Plot",
|
|
168
|
+
show_trend: bool = True,
|
|
169
|
+
ax: Axes | None = None,
|
|
170
|
+
**kwargs: Any,
|
|
171
|
+
) -> Axes:
|
|
172
|
+
"""Sigma vs average plot for mean-variance relationship.
|
|
173
|
+
|
|
174
|
+
SA plot (also called mean-variance plot) shows the relationship between
|
|
175
|
+
average expression and variability. Used to assess variance stabilization
|
|
176
|
+
and the appropriateness of statistical models.
|
|
177
|
+
|
|
178
|
+
Plots sqrt(standard deviation) vs mean log-expression. If fit_values are
|
|
179
|
+
provided (e.g., from limma's empirical Bayes estimation), shows the
|
|
180
|
+
smoothed variance trend.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
adata: AnnData object with probe-level expression data in .X
|
|
184
|
+
fit_values: Fitted/smoothed variance values from statistical model.
|
|
185
|
+
If provided, overlays trend line.
|
|
186
|
+
xlab: X-axis label
|
|
187
|
+
ylab: Y-axis label
|
|
188
|
+
title: Plot title
|
|
189
|
+
show_trend: Whether to show smoothed trend line. Default True.
|
|
190
|
+
ax: Existing Axes object. If None, creates new figure.
|
|
191
|
+
**kwargs: Additional arguments passed to ax.scatter()
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Axes object with SA plot
|
|
195
|
+
|
|
196
|
+
Examples:
|
|
197
|
+
>>> import anndata as ad
|
|
198
|
+
>>> import numpy as np
|
|
199
|
+
>>> from microarray.plotting import sa
|
|
200
|
+
>>> data = np.random.randn(1000, 6)
|
|
201
|
+
>>> adata = ad.AnnData(data.T)
|
|
202
|
+
>>> ax = sa(adata)
|
|
203
|
+
"""
|
|
204
|
+
if ax is None:
|
|
205
|
+
_, ax = plt.subplots(figsize=(8, 6))
|
|
206
|
+
|
|
207
|
+
# Get expression matrix (samples x probes)
|
|
208
|
+
expr = adata.X
|
|
209
|
+
|
|
210
|
+
# Convert to log2 if not already
|
|
211
|
+
if expr.min() >= 0 and (expr.max() - expr.min()) > 20:
|
|
212
|
+
log_expr = np.log2(expr + 1)
|
|
213
|
+
else:
|
|
214
|
+
log_expr = expr
|
|
215
|
+
|
|
216
|
+
# Calculate mean and standard deviation for each probe
|
|
217
|
+
mean_expr = np.mean(log_expr, axis=0)
|
|
218
|
+
std_expr = np.std(log_expr, axis=0, ddof=1) # Sample std dev
|
|
219
|
+
|
|
220
|
+
# Remove NaN/Inf values
|
|
221
|
+
mask = np.isfinite(mean_expr) & np.isfinite(std_expr) & (std_expr > 0)
|
|
222
|
+
mean_expr = mean_expr[mask]
|
|
223
|
+
std_expr = std_expr[mask]
|
|
224
|
+
|
|
225
|
+
# Transform standard deviation
|
|
226
|
+
sqrt_std = np.sqrt(std_expr)
|
|
227
|
+
|
|
228
|
+
# Create scatter plot
|
|
229
|
+
ax.scatter(mean_expr, sqrt_std, alpha=0.5, s=10, **kwargs)
|
|
230
|
+
|
|
231
|
+
# Add trend line if fit values provided
|
|
232
|
+
if fit_values is not None:
|
|
233
|
+
fit_values = fit_values[mask]
|
|
234
|
+
sqrt_fit = np.sqrt(fit_values)
|
|
235
|
+
|
|
236
|
+
# Sort for plotting
|
|
237
|
+
sort_idx = np.argsort(mean_expr)
|
|
238
|
+
ax.plot(mean_expr[sort_idx], sqrt_fit[sort_idx], color="red", linewidth=2, label="Fitted trend")
|
|
239
|
+
ax.legend(loc="best")
|
|
240
|
+
elif show_trend:
|
|
241
|
+
# Calculate simple smoothed trend using local polynomial
|
|
242
|
+
try:
|
|
243
|
+
from scipy.signal import savgol_filter
|
|
244
|
+
|
|
245
|
+
# Sort by mean expression
|
|
246
|
+
sort_idx = np.argsort(mean_expr)
|
|
247
|
+
sorted_mean = mean_expr[sort_idx]
|
|
248
|
+
sorted_sqrt_std = sqrt_std[sort_idx]
|
|
249
|
+
|
|
250
|
+
# Apply Savitzky-Golay filter for smoothing
|
|
251
|
+
window_length = min(51, len(sorted_mean) // 10)
|
|
252
|
+
if window_length % 2 == 0:
|
|
253
|
+
window_length += 1
|
|
254
|
+
if window_length >= 3:
|
|
255
|
+
smoothed = savgol_filter(sorted_sqrt_std, window_length, 3)
|
|
256
|
+
ax.plot(sorted_mean, smoothed, color="red", linewidth=2, label="Smoothed trend")
|
|
257
|
+
ax.legend(loc="best")
|
|
258
|
+
except ImportError:
|
|
259
|
+
pass # Skip trend line if scipy not available
|
|
260
|
+
|
|
261
|
+
# Set labels and title
|
|
262
|
+
ax.set_xlabel(xlab)
|
|
263
|
+
ax.set_ylabel(ylab)
|
|
264
|
+
ax.set_title(title)
|
|
265
|
+
|
|
266
|
+
ax.grid(True, alpha=0.3, linestyle="--")
|
|
267
|
+
|
|
268
|
+
return ax
|