microarray 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. microarray/__init__.py +15 -0
  2. microarray/_version.py +3 -0
  3. microarray/datasets/__init__.py +3 -0
  4. microarray/datasets/_arrayexpress.py +1 -0
  5. microarray/datasets/_cdf_files.py +35 -0
  6. microarray/datasets/_geo.py +1 -0
  7. microarray/datasets/_utils.py +143 -0
  8. microarray/io/__init__.py +17 -0
  9. microarray/io/_anndata_converter.py +198 -0
  10. microarray/io/_cdf.py +575 -0
  11. microarray/io/_cel.py +591 -0
  12. microarray/io/_read.py +127 -0
  13. microarray/plotting/__init__.py +28 -0
  14. microarray/plotting/_base.py +253 -0
  15. microarray/plotting/_cel.py +75 -0
  16. microarray/plotting/_de_plots.py +239 -0
  17. microarray/plotting/_diagnostic_plots.py +268 -0
  18. microarray/plotting/_heatmap.py +279 -0
  19. microarray/plotting/_ma_plots.py +136 -0
  20. microarray/plotting/_pca.py +320 -0
  21. microarray/plotting/_qc_plots.py +335 -0
  22. microarray/plotting/_score.py +38 -0
  23. microarray/plotting/_top_table_heatmap.py +98 -0
  24. microarray/plotting/_utils.py +280 -0
  25. microarray/preprocessing/__init__.py +39 -0
  26. microarray/preprocessing/_background.py +862 -0
  27. microarray/preprocessing/_log2.py +77 -0
  28. microarray/preprocessing/_normalize.py +1292 -0
  29. microarray/preprocessing/_rma.py +243 -0
  30. microarray/preprocessing/_robust.py +170 -0
  31. microarray/preprocessing/_summarize.py +318 -0
  32. microarray/py.typed +0 -0
  33. microarray/tools/__init__.py +26 -0
  34. microarray/tools/_biomart.py +416 -0
  35. microarray/tools/_empirical_bayes.py +401 -0
  36. microarray/tools/_fdist.py +171 -0
  37. microarray/tools/_linear_models.py +387 -0
  38. microarray/tools/_mds.py +101 -0
  39. microarray/tools/_pca.py +88 -0
  40. microarray/tools/_score.py +86 -0
  41. microarray/tools/_toptable.py +360 -0
  42. microarray-0.1.0.dist-info/METADATA +75 -0
  43. microarray-0.1.0.dist-info/RECORD +44 -0
  44. microarray-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,239 @@
1
+ """Differential expression plot functions for microarray analysis."""
2
+
3
+ from typing import Any
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from matplotlib.axes import Axes
8
+ from matplotlib.patches import Circle
9
+
10
+ from microarray.plotting._utils import add_reference_line, with_highlights
11
+
12
+
13
+ def volcano(
14
+ logfc: np.ndarray,
15
+ pvalues: np.ndarray,
16
+ logfc_threshold: float = 1.0,
17
+ pvalue_threshold: float = 0.05,
18
+ labels: list[str] | np.ndarray | None = None,
19
+ top_n: int = 10,
20
+ status: np.ndarray | None = None,
21
+ xlab: str = "Log2 fold-change",
22
+ ylab: str = "-Log10(p-value)",
23
+ title: str = "Volcano Plot",
24
+ ax: Axes | None = None,
25
+ **kwargs: Any,
26
+ ) -> Axes:
27
+ """Volcano plot for differential expression results.
28
+
29
+ Volcano plot displays log fold-changes vs statistical significance
30
+ (-log10 p-values). Points in upper left/right corners represent genes
31
+ with large fold-changes and high significance.
32
+
33
+ Args:
34
+ logfc: Array of log2 fold-changes
35
+ pvalues: Array of p-values
36
+ logfc_threshold: Fold-change threshold for significance lines. Default 1.0.
37
+ pvalue_threshold: P-value threshold for significance line. Default 0.05.
38
+ labels: Gene/probe labels. If provided with top_n, labels top genes.
39
+ top_n: Number of top genes to label (by significance). Default 10.
40
+ status: Custom status labels for coloring. If None, automatically determines
41
+ status based on thresholds (up/down/not-significant).
42
+ xlab: X-axis label
43
+ ylab: Y-axis label
44
+ title: Plot title
45
+ ax: Existing Axes object. If None, creates new figure.
46
+ **kwargs: Additional arguments passed to scatter plot
47
+
48
+ Returns:
49
+ Axes object with volcano plot
50
+
51
+ Examples:
52
+ >>> import numpy as np
53
+ >>> from microarray.plotting import volcano
54
+ >>> logfc = np.random.randn(1000) * 2
55
+ >>> pvalues = np.random.uniform(0, 1, 1000)
56
+ >>> ax = volcano(logfc, pvalues)
57
+ """
58
+ if ax is None:
59
+ _, ax = plt.subplots(figsize=(8, 7))
60
+
61
+ # Calculate -log10(p-values)
62
+ # Handle p-values of 0 by setting a minimum
63
+ pvalues = np.maximum(pvalues, 1e-300)
64
+ logp = -np.log10(pvalues)
65
+
66
+ # Determine status if not provided
67
+ if status is None:
68
+ status = np.array(["not-significant"] * len(logfc))
69
+ significant = logp >= -np.log10(pvalue_threshold)
70
+ up = significant & (logfc >= logfc_threshold)
71
+ down = significant & (logfc <= -logfc_threshold)
72
+ status[up] = "up"
73
+ status[down] = "down"
74
+
75
+ # Create scatter plot with highlighting
76
+ ax = with_highlights(
77
+ logfc, logp, status=status, xlab=xlab, ylab=ylab, title=title, ax=ax, legend="upper right", **kwargs
78
+ )
79
+
80
+ # Volcano-specific styling: no background grid and legend outside without frame.
81
+ ax.grid(False)
82
+ legend = ax.get_legend()
83
+ if legend is not None:
84
+ legend.set_loc("upper left")
85
+ legend.set_bbox_to_anchor((1.02, 1.0))
86
+ legend.set_frame_on(False)
87
+
88
+ # Add threshold lines
89
+ # Vertical lines for fold-change thresholds
90
+ add_reference_line(ax, x=logfc_threshold, color="darkgray", linestyle="--", alpha=0.7)
91
+ add_reference_line(ax, x=-logfc_threshold, color="darkgray", linestyle="--", alpha=0.7)
92
+
93
+ # Horizontal line for p-value threshold
94
+ add_reference_line(ax, y=-np.log10(pvalue_threshold), color="darkgray", linestyle="--", alpha=0.7)
95
+
96
+ # Label top genes if requested
97
+ if labels is not None and top_n > 0:
98
+ # Get indices of top genes by p-value
99
+ top_indices = np.argsort(logp)[-top_n:]
100
+
101
+ for idx in top_indices:
102
+ ax.annotate(
103
+ labels[idx],
104
+ (logfc[idx], logp[idx]),
105
+ xytext=(5, 5),
106
+ textcoords="offset points",
107
+ fontsize=8,
108
+ alpha=0.7,
109
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7, edgecolor="none"),
110
+ )
111
+
112
+ return ax
113
+
114
+
115
+ def venn(
116
+ sets: dict[str, set] | list[set],
117
+ labels: list[str] | None = None,
118
+ colors: list[str] | None = None,
119
+ alpha: float = 0.4,
120
+ title: str = "Venn Diagram",
121
+ ax: Axes | None = None,
122
+ ) -> Axes:
123
+ """Venn diagram for visualizing overlap between sets.
124
+
125
+ Creates Venn diagram showing overlap between 2 or 3 sets.
126
+ Common use case: visualizing overlap of differentially expressed genes
127
+ across multiple contrasts or conditions.
128
+
129
+ Args:
130
+ sets: Dictionary mapping labels to sets, or list of sets.
131
+ If list, labels parameter must be provided.
132
+ labels: Labels for each set. Required if sets is a list.
133
+ colors: Colors for each set. If None, uses default palette.
134
+ alpha: Transparency of circles (0-1)
135
+ title: Plot title
136
+ ax: Existing Axes object. If None, creates new figure.
137
+
138
+ Returns:
139
+ Axes object with Venn diagram
140
+
141
+ Examples:
142
+ >>> from microarray.plotting import venn
143
+ >>> set1 = set(["gene1", "gene2", "gene3", "gene4"])
144
+ >>> set2 = set(["gene3", "gene4", "gene5", "gene6"])
145
+ >>> ax = venn({"Control": set1, "Treatment": set2})
146
+ >>> # Three-way Venn
147
+ >>> set3 = set(["gene1", "gene5", "gene7"])
148
+ >>> ax = venn({"A": set1, "B": set2, "C": set3})
149
+ """
150
+ if ax is None:
151
+ _, ax = plt.subplots(figsize=(8, 8))
152
+
153
+ # Parse input
154
+ if isinstance(sets, dict):
155
+ labels = list(sets.keys())
156
+ set_list = list(sets.values())
157
+ else:
158
+ if labels is None:
159
+ raise ValueError("labels must be provided when sets is a list")
160
+ set_list = sets
161
+
162
+ n_sets = len(set_list)
163
+
164
+ if n_sets < 2 or n_sets > 3:
165
+ raise ValueError("Venn diagrams support 2 or 3 sets only")
166
+
167
+ # Get default colors
168
+ if colors is None:
169
+ default_colors = ["#E41A1C", "#377EB8", "#4DAF4A"]
170
+ colors = default_colors[:n_sets]
171
+
172
+ ax.set_aspect("equal")
173
+ ax.set_xlim(-2, 2)
174
+ ax.set_ylim(-2, 2)
175
+ ax.axis("off")
176
+ ax.set_title(title, fontsize=14, pad=20)
177
+
178
+ if n_sets == 2:
179
+ # Two-way Venn diagram
180
+ set_a, set_b = set_list
181
+
182
+ # Draw circles
183
+ circle_a = Circle((-0.5, 0), 1, color=colors[0], alpha=alpha, ec="black", linewidth=2)
184
+ circle_b = Circle((0.5, 0), 1, color=colors[1], alpha=alpha, ec="black", linewidth=2)
185
+ ax.add_patch(circle_a)
186
+ ax.add_patch(circle_b)
187
+
188
+ # Calculate counts
189
+ only_a = len(set_a - set_b)
190
+ only_b = len(set_b - set_a)
191
+ both = len(set_a & set_b)
192
+
193
+ # Add text labels
194
+ ax.text(-0.9, 0, str(only_a), fontsize=16, ha="center", va="center", weight="bold")
195
+ ax.text(0.9, 0, str(only_b), fontsize=16, ha="center", va="center", weight="bold")
196
+ ax.text(0, 0, str(both), fontsize=16, ha="center", va="center", weight="bold")
197
+
198
+ # Add set labels
199
+ ax.text(-0.6, 1.3, labels[0], fontsize=12, ha="center", weight="bold")
200
+ ax.text(0.6, 1.3, labels[1], fontsize=12, ha="center", weight="bold")
201
+
202
+ elif n_sets == 3:
203
+ # Three-way Venn diagram
204
+ set_a, set_b, set_c = set_list
205
+
206
+ # Draw circles
207
+ r = 1 # radius
208
+ d = 0.7 # distance from center
209
+ circle_a = Circle((-d / 2, d / 2), r, color=colors[0], alpha=alpha, ec="black", linewidth=2)
210
+ circle_b = Circle((d / 2, d / 2), r, color=colors[1], alpha=alpha, ec="black", linewidth=2)
211
+ circle_c = Circle((0, -d / 2), r, color=colors[2], alpha=alpha, ec="black", linewidth=2)
212
+ ax.add_patch(circle_a)
213
+ ax.add_patch(circle_b)
214
+ ax.add_patch(circle_c)
215
+
216
+ # Calculate counts
217
+ only_a = len(set_a - set_b - set_c)
218
+ only_b = len(set_b - set_a - set_c)
219
+ only_c = len(set_c - set_a - set_b)
220
+ ab_only = len((set_a & set_b) - set_c)
221
+ ac_only = len((set_a & set_c) - set_b)
222
+ bc_only = len((set_b & set_c) - set_a)
223
+ abc = len(set_a & set_b & set_c)
224
+
225
+ # Add text labels (positioned by eye for typical 3-way Venn)
226
+ ax.text(-0.8, 0.65, str(only_a), fontsize=14, ha="center", va="center", weight="bold")
227
+ ax.text(0.8, 0.65, str(only_b), fontsize=14, ha="center", va="center", weight="bold")
228
+ ax.text(0, -1.0, str(only_c), fontsize=14, ha="center", va="center", weight="bold")
229
+ ax.text(0, 0.78, str(ab_only), fontsize=14, ha="center", va="center", weight="bold")
230
+ ax.text(-0.55, -0.3, str(ac_only), fontsize=14, ha="center", va="center", weight="bold")
231
+ ax.text(0.55, -0.3, str(bc_only), fontsize=14, ha="center", va="center", weight="bold")
232
+ ax.text(0, 0.2, str(abc), fontsize=14, ha="center", va="center", weight="bold")
233
+
234
+ # Add set labels
235
+ ax.text(-d, d / 2 + 1.3, labels[0], fontsize=12, ha="center", weight="bold")
236
+ ax.text(d, d / 2 + 1.3, labels[1], fontsize=12, ha="center", weight="bold")
237
+ ax.text(0, -d / 2 - 1.3, labels[2], fontsize=12, ha="center", weight="bold")
238
+
239
+ return ax
@@ -0,0 +1,268 @@
1
+ """Diagnostic plot functions for microarray data analysis."""
2
+
3
+ from typing import Any
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from anndata import AnnData
8
+ from matplotlib.axes import Axes
9
+
10
+ from microarray.plotting._utils import get_default_colors
11
+
12
+
13
+ def mds(
14
+ adata: AnnData,
15
+ obsm_key: str = "X_mds",
16
+ top: int = 500,
17
+ gene_selection: str = "common",
18
+ dimensions: int = 2,
19
+ labels: list[str] | None = None,
20
+ colors: list[str] | str | None = None,
21
+ groups: np.ndarray | list | None = None,
22
+ xlab: str | None = None,
23
+ ylab: str | None = None,
24
+ title: str = "MDS Plot",
25
+ ax: Axes | None = None,
26
+ **kwargs: Any,
27
+ ) -> Axes:
28
+ """Plot Multidimensional Scaling (MDS) embedding.
29
+
30
+ Visualizes the MDS embedding stored in `.obsm` to show sample relationships
31
+ in 2D space. Samples that are similar (highly correlated) appear close together,
32
+ while dissimilar samples are far apart. Essential for quality control and
33
+ identifying batch effects or outliers.
34
+
35
+ Note:
36
+ If MDS embedding is not found in `.obsm[obsm_key]`, it will be computed
37
+ automatically using `microarray.tl.mds()` with the provided parameters.
38
+
39
+ Args:
40
+ adata: AnnData object with MDS embedding in .obsm or expression data in .X
41
+ obsm_key: Key in .obsm where the MDS embedding is stored. Default "X_mds".
42
+ top: Number of top varying probes to use if computing MDS. Default 500.
43
+ gene_selection: Method for selecting genes if computing MDS. Default "common".
44
+ dimensions: Number of dimensions to plot (must be 2). Default 2.
45
+ labels: Custom labels for each sample. If None, uses obs_names.
46
+ colors: Color(s) for points. Can be single color or list of colors per sample.
47
+ groups: Group assignments for color coding. If provided with colors as dict,
48
+ maps groups to colors.
49
+ xlab: X-axis label. If None, uses "Dimension 1".
50
+ ylab: Y-axis label. If None, uses "Dimension 2".
51
+ title: Plot title
52
+ ax: Existing Axes object. If None, creates new figure.
53
+ **kwargs: Additional arguments passed to ax.scatter()
54
+
55
+ Returns:
56
+ Axes object with MDS plot
57
+
58
+ Examples:
59
+ >>> import anndata as ad
60
+ >>> import numpy as np
61
+ >>> import microarray as ma
62
+ >>> data = np.random.randn(1000, 6)
63
+ >>> adata = ad.AnnData(data.T)
64
+ >>> # Compute MDS first
65
+ >>> ma.tl.mds(adata, top=500)
66
+ >>> # Then plot it
67
+ >>> ax = ma.pl.mds(adata)
68
+ >>> # Or let the plot function compute it automatically
69
+ >>> ax = ma.pl.mds(adata, top=500)
70
+ >>> # With group coloring
71
+ >>> groups = ["control", "control", "control", "treated", "treated", "treated"]
72
+ >>> ax = ma.pl.mds(adata, groups=groups)
73
+ """
74
+ if ax is None:
75
+ _, ax = plt.subplots(figsize=(8, 7))
76
+
77
+ if dimensions != 2:
78
+ raise NotImplementedError("Only 2D MDS plots are currently supported")
79
+
80
+ # Check if MDS embedding exists, if not compute it
81
+ if obsm_key not in adata.obsm:
82
+ # Import here to avoid circular dependency
83
+ from microarray.tools import mds as compute_mds
84
+
85
+ compute_mds(adata, top=top, gene_selection=gene_selection, n_components=dimensions, obsm_key=obsm_key)
86
+
87
+ # Get MDS coordinates from obsm
88
+ coords = adata.obsm[obsm_key]
89
+ n_samples = coords.shape[0]
90
+
91
+ # Prepare labels
92
+ if labels is None:
93
+ labels = list(adata.obs_names) if adata.obs_names is not None else [f"Sample {i}" for i in range(n_samples)]
94
+
95
+ # Prepare colors
96
+ if groups is not None:
97
+ unique_groups = np.unique(groups)
98
+ n_groups = len(unique_groups)
99
+
100
+ # If colors is a dict, map groups to colors
101
+ if isinstance(colors, dict):
102
+ color_map = colors
103
+ else:
104
+ # Generate default colors for groups
105
+ default_colors = get_default_colors(n_groups)
106
+ color_map = dict(zip(unique_groups, default_colors, strict=False))
107
+
108
+ # Plot by group for legend
109
+ for group in unique_groups:
110
+ mask = np.array(groups) == group
111
+ ax.scatter(
112
+ coords[mask, 0],
113
+ coords[mask, 1],
114
+ c=color_map[group],
115
+ label=str(group),
116
+ s=100,
117
+ alpha=0.7,
118
+ edgecolors="black",
119
+ linewidth=0.5,
120
+ **kwargs,
121
+ )
122
+ ax.legend(loc="best", frameon=True)
123
+ else:
124
+ # Single color or list of colors without grouping
125
+ if colors is None:
126
+ colors = get_default_colors(1)[0]
127
+
128
+ if isinstance(colors, str):
129
+ # Single color for all points
130
+ ax.scatter(
131
+ coords[:, 0], coords[:, 1], c=colors, s=100, alpha=0.7, edgecolors="black", linewidth=0.5, **kwargs
132
+ )
133
+ else:
134
+ # List of colors
135
+ ax.scatter(
136
+ coords[:, 0], coords[:, 1], c=colors, s=100, alpha=0.7, edgecolors="black", linewidth=0.5, **kwargs
137
+ )
138
+
139
+ # Add labels to points
140
+ for i, label in enumerate(labels):
141
+ ax.annotate(
142
+ label, (coords[i, 0], coords[i, 1]), xytext=(5, 5), textcoords="offset points", fontsize=9, alpha=0.8
143
+ )
144
+
145
+ # Set labels
146
+ if xlab is None:
147
+ xlab = "Dimension 1"
148
+ if ylab is None:
149
+ ylab = "Dimension 2"
150
+
151
+ ax.set_xlabel(xlab)
152
+ ax.set_ylabel(ylab)
153
+ ax.set_title(title)
154
+
155
+ ax.grid(True, alpha=0.3, linestyle="--")
156
+ ax.axhline(y=0, color="gray", linewidth=0.5, alpha=0.5)
157
+ ax.axvline(x=0, color="gray", linewidth=0.5, alpha=0.5)
158
+
159
+ return ax
160
+
161
+
162
+ def sa(
163
+ adata: AnnData,
164
+ fit_values: np.ndarray | None = None,
165
+ xlab: str = "Average log-expression",
166
+ ylab: str = "Sqrt(standard deviation)",
167
+ title: str = "SA Plot",
168
+ show_trend: bool = True,
169
+ ax: Axes | None = None,
170
+ **kwargs: Any,
171
+ ) -> Axes:
172
+ """Sigma vs average plot for mean-variance relationship.
173
+
174
+ SA plot (also called mean-variance plot) shows the relationship between
175
+ average expression and variability. Used to assess variance stabilization
176
+ and the appropriateness of statistical models.
177
+
178
+ Plots sqrt(standard deviation) vs mean log-expression. If fit_values are
179
+ provided (e.g., from limma's empirical Bayes estimation), shows the
180
+ smoothed variance trend.
181
+
182
+ Args:
183
+ adata: AnnData object with probe-level expression data in .X
184
+ fit_values: Fitted/smoothed variance values from statistical model.
185
+ If provided, overlays trend line.
186
+ xlab: X-axis label
187
+ ylab: Y-axis label
188
+ title: Plot title
189
+ show_trend: Whether to show smoothed trend line. Default True.
190
+ ax: Existing Axes object. If None, creates new figure.
191
+ **kwargs: Additional arguments passed to ax.scatter()
192
+
193
+ Returns:
194
+ Axes object with SA plot
195
+
196
+ Examples:
197
+ >>> import anndata as ad
198
+ >>> import numpy as np
199
+ >>> from microarray.plotting import sa
200
+ >>> data = np.random.randn(1000, 6)
201
+ >>> adata = ad.AnnData(data.T)
202
+ >>> ax = sa(adata)
203
+ """
204
+ if ax is None:
205
+ _, ax = plt.subplots(figsize=(8, 6))
206
+
207
+ # Get expression matrix (samples x probes)
208
+ expr = adata.X
209
+
210
+ # Convert to log2 if not already
211
+ if expr.min() >= 0 and (expr.max() - expr.min()) > 20:
212
+ log_expr = np.log2(expr + 1)
213
+ else:
214
+ log_expr = expr
215
+
216
+ # Calculate mean and standard deviation for each probe
217
+ mean_expr = np.mean(log_expr, axis=0)
218
+ std_expr = np.std(log_expr, axis=0, ddof=1) # Sample std dev
219
+
220
+ # Remove NaN/Inf values
221
+ mask = np.isfinite(mean_expr) & np.isfinite(std_expr) & (std_expr > 0)
222
+ mean_expr = mean_expr[mask]
223
+ std_expr = std_expr[mask]
224
+
225
+ # Transform standard deviation
226
+ sqrt_std = np.sqrt(std_expr)
227
+
228
+ # Create scatter plot
229
+ ax.scatter(mean_expr, sqrt_std, alpha=0.5, s=10, **kwargs)
230
+
231
+ # Add trend line if fit values provided
232
+ if fit_values is not None:
233
+ fit_values = fit_values[mask]
234
+ sqrt_fit = np.sqrt(fit_values)
235
+
236
+ # Sort for plotting
237
+ sort_idx = np.argsort(mean_expr)
238
+ ax.plot(mean_expr[sort_idx], sqrt_fit[sort_idx], color="red", linewidth=2, label="Fitted trend")
239
+ ax.legend(loc="best")
240
+ elif show_trend:
241
+ # Calculate simple smoothed trend using local polynomial
242
+ try:
243
+ from scipy.signal import savgol_filter
244
+
245
+ # Sort by mean expression
246
+ sort_idx = np.argsort(mean_expr)
247
+ sorted_mean = mean_expr[sort_idx]
248
+ sorted_sqrt_std = sqrt_std[sort_idx]
249
+
250
+ # Apply Savitzky-Golay filter for smoothing
251
+ window_length = min(51, len(sorted_mean) // 10)
252
+ if window_length % 2 == 0:
253
+ window_length += 1
254
+ if window_length >= 3:
255
+ smoothed = savgol_filter(sorted_sqrt_std, window_length, 3)
256
+ ax.plot(sorted_mean, smoothed, color="red", linewidth=2, label="Smoothed trend")
257
+ ax.legend(loc="best")
258
+ except ImportError:
259
+ pass # Skip trend line if scipy not available
260
+
261
+ # Set labels and title
262
+ ax.set_xlabel(xlab)
263
+ ax.set_ylabel(ylab)
264
+ ax.set_title(title)
265
+
266
+ ax.grid(True, alpha=0.3, linestyle="--")
267
+
268
+ return ax