microarray 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. microarray/__init__.py +15 -0
  2. microarray/_version.py +3 -0
  3. microarray/datasets/__init__.py +3 -0
  4. microarray/datasets/_arrayexpress.py +1 -0
  5. microarray/datasets/_cdf_files.py +35 -0
  6. microarray/datasets/_geo.py +1 -0
  7. microarray/datasets/_utils.py +143 -0
  8. microarray/io/__init__.py +17 -0
  9. microarray/io/_anndata_converter.py +198 -0
  10. microarray/io/_cdf.py +575 -0
  11. microarray/io/_cel.py +591 -0
  12. microarray/io/_read.py +127 -0
  13. microarray/plotting/__init__.py +28 -0
  14. microarray/plotting/_base.py +253 -0
  15. microarray/plotting/_cel.py +75 -0
  16. microarray/plotting/_de_plots.py +239 -0
  17. microarray/plotting/_diagnostic_plots.py +268 -0
  18. microarray/plotting/_heatmap.py +279 -0
  19. microarray/plotting/_ma_plots.py +136 -0
  20. microarray/plotting/_pca.py +320 -0
  21. microarray/plotting/_qc_plots.py +335 -0
  22. microarray/plotting/_score.py +38 -0
  23. microarray/plotting/_top_table_heatmap.py +98 -0
  24. microarray/plotting/_utils.py +280 -0
  25. microarray/preprocessing/__init__.py +39 -0
  26. microarray/preprocessing/_background.py +862 -0
  27. microarray/preprocessing/_log2.py +77 -0
  28. microarray/preprocessing/_normalize.py +1292 -0
  29. microarray/preprocessing/_rma.py +243 -0
  30. microarray/preprocessing/_robust.py +170 -0
  31. microarray/preprocessing/_summarize.py +318 -0
  32. microarray/py.typed +0 -0
  33. microarray/tools/__init__.py +26 -0
  34. microarray/tools/_biomart.py +416 -0
  35. microarray/tools/_empirical_bayes.py +401 -0
  36. microarray/tools/_fdist.py +171 -0
  37. microarray/tools/_linear_models.py +387 -0
  38. microarray/tools/_mds.py +101 -0
  39. microarray/tools/_pca.py +88 -0
  40. microarray/tools/_score.py +86 -0
  41. microarray/tools/_toptable.py +360 -0
  42. microarray-0.1.0.dist-info/METADATA +75 -0
  43. microarray-0.1.0.dist-info/RECORD +44 -0
  44. microarray-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import Any
5
+
6
+ from anndata import AnnData
7
+ from matplotlib.axes import Axes
8
+ from matplotlib.figure import Figure
9
+
10
+ from microarray.plotting._heatmap import heatmap
11
+ from microarray.tools._toptable import top_table
12
+
13
+ _TOP_TABLE_KWARGS = set(inspect.signature(top_table).parameters.keys()) - {"data", "group", "number"}
14
+ _HEATMAP_KWARGS = set(inspect.signature(heatmap).parameters.keys()) - {"adata", "genes", "groupby", "title", "show"}
15
+
16
+
17
+ def top_table_heatmap(
18
+ adata: AnnData,
19
+ n_top: int = 10,
20
+ groupby: str | None = None,
21
+ title: str | None = None,
22
+ show: bool = True,
23
+ **kwargs: Any,
24
+ ) -> tuple[Figure, dict[str, Axes | None]]:
25
+ """Plot top marker genes per condition in a clustered heatmap.
26
+
27
+ This helper uses :func:`microarray.tools.top_table` to collect top marker
28
+ genes for each condition and then visualizes the combined gene set using
29
+ :func:`microarray.plotting.heatmap`.
30
+
31
+ Args:
32
+ adata: AnnData object containing a moderated fit in ``adata.uns['lm_fit']``.
33
+ n_top: Number of top genes to extract per condition.
34
+ groupby: Grouping column in ``adata.obs``. If omitted, uses
35
+ ``adata.uns['lm_fit']['groupby']``.
36
+ title: Optional heatmap title. Uses a default title when omitted.
37
+ show: Whether to show the matplotlib figure.
38
+ **kwargs: Additional keyword arguments forwarded to
39
+ :func:`microarray.tools.top_table` and/or
40
+ :func:`microarray.plotting.heatmap`, depending on supported
41
+ parameter names.
42
+
43
+ Returns:
44
+ A tuple ``(figure, axes_dict)`` as returned by :func:`heatmap`.
45
+ """
46
+ if n_top <= 0:
47
+ raise ValueError("n_top must be a positive integer")
48
+
49
+ fit = adata.uns.get("lm_fit")
50
+ if fit is None:
51
+ raise ValueError("No fit object found in adata.uns['lm_fit']. Run lm_fit and ebayes first.")
52
+
53
+ if groupby is None:
54
+ groupby = fit.get("groupby")
55
+ if not groupby:
56
+ raise ValueError("groupby must be provided or available in adata.uns['lm_fit']['groupby']")
57
+ if groupby not in adata.obs:
58
+ raise ValueError(f"Column '{groupby}' not found in adata.obs")
59
+
60
+ group_to_column = fit.get("group_to_column")
61
+ if isinstance(group_to_column, dict) and len(group_to_column) > 0:
62
+ groups = list(group_to_column.keys())
63
+ else:
64
+ groups = list(dict.fromkeys(adata.obs[groupby].astype(str).tolist()))
65
+
66
+ top_kwargs: dict[str, Any] = {}
67
+ heatmap_kwargs: dict[str, Any] = {}
68
+ unknown_keys: list[str] = []
69
+
70
+ for key, value in kwargs.items():
71
+ matched = False
72
+ if key in _TOP_TABLE_KWARGS:
73
+ top_kwargs[key] = value
74
+ matched = True
75
+ if key in _HEATMAP_KWARGS:
76
+ heatmap_kwargs[key] = value
77
+ matched = True
78
+ if not matched:
79
+ unknown_keys.append(key)
80
+
81
+ if unknown_keys:
82
+ unknown_str = ", ".join(sorted(unknown_keys))
83
+ raise TypeError(f"Unknown keyword argument(s): {unknown_str}")
84
+
85
+ marker_genes: list[str] = []
86
+ seen: set[str] = set()
87
+ for group in groups:
88
+ results = top_table(adata, group=str(group), number=n_top, **top_kwargs)
89
+ for gene in results.index.astype(str):
90
+ if gene not in seen:
91
+ marker_genes.append(gene)
92
+ seen.add(gene)
93
+
94
+ if len(marker_genes) == 0:
95
+ raise ValueError("No marker genes found. Adjust filtering parameters for top_table.")
96
+
97
+ heatmap_title = title if title is not None else f"Top {n_top} marker genes per {groupby}"
98
+ return heatmap(adata, genes=marker_genes, groupby=groupby, title=heatmap_title, show=show, **heatmap_kwargs)
@@ -0,0 +1,280 @@
1
+ """Utility functions for plotting."""
2
+
3
+ from typing import Any
4
+ from warnings import warn
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from matplotlib.axes import Axes
9
+
10
+
11
+ def with_highlights(
12
+ x: np.ndarray,
13
+ y: np.ndarray,
14
+ status: np.ndarray | None = None,
15
+ colors: dict[str, str] | None = None,
16
+ labels: dict[str, str] | None = None,
17
+ pch: int | dict[str, str] = 16,
18
+ cex: float | dict[str, float] = 1.0,
19
+ alpha: float = 0.6,
20
+ xlab: str = "",
21
+ ylab: str = "",
22
+ title: str = "",
23
+ legend: bool | str = "best",
24
+ ax: Axes | None = None,
25
+ **kwargs: Any,
26
+ ) -> Axes:
27
+ """Create scatter plot with status-based highlighting.
28
+
29
+ Core plotting utility inspired by limma's plotWithHighlights.
30
+ Supports color coding and symbol customization based on status groups.
31
+
32
+ Args:
33
+ x: X-axis values
34
+ y: Y-axis values
35
+ status: Status labels for each point. If None, all points have same appearance.
36
+ colors: Dictionary mapping status values to colors. Defaults to standard palette.
37
+ labels: Dictionary mapping status values to legend labels. Defaults to status values.
38
+ pch: Point marker style. Can be single value or dict mapping status to marker.
39
+ cex: Point size multiplier. Can be single value or dict mapping status to size.
40
+ alpha: Point transparency (0-1)
41
+ xlab: X-axis label
42
+ ylab: Y-axis label
43
+ title: Plot title
44
+ legend: Legend position ('best', 'upper right', etc.) or False to disable
45
+ ax: Existing Axes object. If None, creates new figure.
46
+ **kwargs: Additional arguments passed to ax.scatter()
47
+
48
+ Returns:
49
+ Axes object with the plot
50
+
51
+ Examples:
52
+ >>> import numpy as np
53
+ >>> from microarray.plotting import with_highlights
54
+ >>> x = np.random.randn(100)
55
+ >>> y = np.random.randn(100)
56
+ >>> status = np.where(np.abs(y) > 1, "significant", "not-significant")
57
+ >>> ax = with_highlights(x, y, status=status)
58
+ """
59
+ if ax is None:
60
+ _, ax = plt.subplots(figsize=(8, 6))
61
+
62
+ # Convert marker codes (R-style pch values)
63
+ marker_map = {
64
+ 15: "s", # square
65
+ 16: "o", # circle
66
+ 17: "^", # triangle up
67
+ 18: "D", # diamond
68
+ 19: "o", # filled circle
69
+ }
70
+
71
+ if status is None:
72
+ # Single group: plot all points with same style
73
+ marker = marker_map.get(pch, "o") if isinstance(pch, int) else "o"
74
+ size = cex * 20 if isinstance(cex, (int, float)) else 20
75
+ ax.scatter(x, y, marker=marker, s=size, alpha=alpha, **kwargs)
76
+ else:
77
+ # Multiple groups: plot by status
78
+ unique_statuses = np.unique(status)
79
+
80
+ # Default color palette (similar to R's default colors)
81
+ default_colors = {
82
+ "up": "#E41A1C",
83
+ "down": "#377EB8",
84
+ "not-significant": "#999999",
85
+ "NotSig": "#999999",
86
+ "Sig": "#E41A1C",
87
+ "-1": "#377EB8",
88
+ "0": "#999999",
89
+ "1": "#E41A1C",
90
+ }
91
+
92
+ if colors is None:
93
+ colors = {}
94
+
95
+ # Assign colors to each unique status
96
+ color_palette = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33"]
97
+ status_colors = {}
98
+ for i, stat in enumerate(unique_statuses):
99
+ if stat in colors:
100
+ status_colors[stat] = colors[stat]
101
+ elif stat in default_colors:
102
+ status_colors[stat] = default_colors[stat]
103
+ else:
104
+ status_colors[stat] = color_palette[i % len(color_palette)]
105
+
106
+ # Prepare labels for legend
107
+ if labels is None:
108
+ labels = {stat: str(stat) for stat in unique_statuses}
109
+
110
+ # Plot each status group
111
+ for stat in unique_statuses:
112
+ mask = status == stat
113
+ color = status_colors[stat]
114
+ label = labels.get(stat, str(stat))
115
+
116
+ # Get marker and size for this status
117
+ if isinstance(pch, dict):
118
+ marker_code = pch.get(stat, 16)
119
+ marker = marker_map.get(marker_code, "o") if isinstance(marker_code, int) else marker_code
120
+ else:
121
+ marker = marker_map.get(pch, "o") if isinstance(pch, int) else "o"
122
+
123
+ if isinstance(cex, dict):
124
+ size = cex.get(stat, 1.0) * 20
125
+ else:
126
+ size = cex * 20
127
+
128
+ ax.scatter(x[mask], y[mask], c=color, marker=marker, s=size, alpha=alpha, label=label, **kwargs)
129
+
130
+ # Add legend if requested
131
+ if legend and len(unique_statuses) > 1:
132
+ ax.legend(loc=legend if isinstance(legend, str) else "best", frameon=True)
133
+
134
+ # Set labels and title
135
+ if xlab:
136
+ ax.set_xlabel(xlab)
137
+ if ylab:
138
+ ax.set_ylabel(ylab)
139
+ if title:
140
+ ax.set_title(title)
141
+
142
+ ax.grid(True, alpha=0.3, linestyle="--")
143
+
144
+ return ax
145
+
146
+
147
+ def add_loess_curve(
148
+ ax: Axes,
149
+ x: np.ndarray,
150
+ y: np.ndarray,
151
+ span: float = 0.3,
152
+ color: str = "blue",
153
+ linewidth: float = 2,
154
+ linestyle: str = "-",
155
+ label: str | None = None,
156
+ ) -> Axes:
157
+ """Add LOESS (locally weighted scatterplot smoothing) curve to existing plot.
158
+
159
+ Args:
160
+ ax: Axes object to add curve to
161
+ x: X-axis values
162
+ y: Y-axis values
163
+ span: Smoothing span (fraction of data to use for smoothing). Default 0.3.
164
+ color: Line color
165
+ linewidth: Line width
166
+ linestyle: Line style ('-', '--', '-.', ':')
167
+ label: Legend label for the curve
168
+
169
+ Returns:
170
+ Axes object with added curve
171
+
172
+ Examples:
173
+ >>> import numpy as np
174
+ >>> import matplotlib.pyplot as plt
175
+ >>> from microarray.plotting._utils import add_loess_curve
176
+ >>> fig, ax = plt.subplots()
177
+ >>> x = np.linspace(0, 10, 100)
178
+ >>> y = np.sin(x) + np.random.randn(100) * 0.1
179
+ >>> ax.scatter(x, y, alpha=0.5)
180
+ >>> add_loess_curve(ax, x, y, span=0.2)
181
+ """
182
+ try:
183
+ from statsmodels.nonparametric.smoothers_lowess import lowess
184
+ except ImportError:
185
+ # Fallback: use simple moving average if statsmodels not available
186
+ warn(
187
+ "statsmodels not available, using moving average instead of LOESS",
188
+ ImportWarning,
189
+ stacklevel=2,
190
+ )
191
+ from scipy.ndimage import uniform_filter1d
192
+
193
+ # Sort by x
194
+ sort_idx = np.argsort(x)
195
+ x_sorted = x[sort_idx]
196
+ y_sorted = y[sort_idx]
197
+
198
+ # Apply moving average
199
+ window = max(3, int(len(x) * span))
200
+ if window % 2 == 0:
201
+ window += 1 # Make odd for symmetry
202
+ y_smooth = uniform_filter1d(y_sorted, size=window, mode="nearest")
203
+
204
+ ax.plot(x_sorted, y_smooth, color=color, linewidth=linewidth, linestyle=linestyle, label=label)
205
+ return ax
206
+
207
+ # Remove NaN values
208
+ mask = ~(np.isnan(x) | np.isnan(y))
209
+ x_clean = x[mask]
210
+ y_clean = y[mask]
211
+
212
+ if len(x_clean) < 3:
213
+ return ax # Need at least 3 points
214
+
215
+ # Calculate LOESS curve
216
+ # lowess returns (x, y) pairs already sorted
217
+ smoothed = lowess(y_clean, x_clean, frac=span, return_sorted=True)
218
+
219
+ ax.plot(smoothed[:, 0], smoothed[:, 1], color=color, linewidth=linewidth, linestyle=linestyle, label=label)
220
+
221
+ return ax
222
+
223
+
224
+ def add_reference_line(
225
+ ax: Axes,
226
+ y: float = 0,
227
+ x: float | None = None,
228
+ color: str = "gray",
229
+ linewidth: float = 1,
230
+ linestyle: str = "--",
231
+ alpha: float = 0.7,
232
+ ) -> Axes:
233
+ """Add horizontal or vertical reference line.
234
+
235
+ Args:
236
+ ax: Axes object
237
+ y: Y-coordinate for horizontal line (used if x is None)
238
+ x: X-coordinate for vertical line (overrides y if provided)
239
+ color: Line color
240
+ linewidth: Line width
241
+ linestyle: Line style
242
+ alpha: Line transparency
243
+
244
+ Returns:
245
+ Axes object with reference line
246
+ """
247
+ if x is not None:
248
+ ax.axvline(x=x, color=color, linewidth=linewidth, linestyle=linestyle, alpha=alpha)
249
+ else:
250
+ ax.axhline(y=y, color=color, linewidth=linewidth, linestyle=linestyle, alpha=alpha)
251
+
252
+ return ax
253
+
254
+
255
+ def get_default_colors(n: int) -> list[str]:
256
+ """Get default color palette for n categories.
257
+
258
+ Args:
259
+ n: Number of colors needed
260
+
261
+ Returns:
262
+ List of color hex codes
263
+ """
264
+ # R-like default colors
265
+ colors = [
266
+ "#E41A1C", # red
267
+ "#377EB8", # blue
268
+ "#4DAF4A", # green
269
+ "#984EA3", # purple
270
+ "#FF7F00", # orange
271
+ "#FFFF33", # yellow
272
+ "#A65628", # brown
273
+ "#F781BF", # pink
274
+ ]
275
+
276
+ if n <= len(colors):
277
+ return colors[:n]
278
+
279
+ # If more colors needed, cycle through
280
+ return [colors[i % len(colors)] for i in range(n)]
@@ -0,0 +1,39 @@
1
+ """Preprocessing functions for microarray data analysis.
2
+
3
+ This module provides preprocessing methods including RMA (Robust Multi-array Average),
4
+ MAS5 (MicroArray Suite 5.0), Li-Wong (dChip), and a flexible expresso pipeline for
5
+ custom preprocessing workflows.
6
+ """
7
+
8
+ from ._background import background_correct, rma_background_correct
9
+ from ._log2 import log2
10
+ from ._normalize import (
11
+ normalize_constant,
12
+ normalize_contrasts,
13
+ normalize_invariantset,
14
+ normalize_loess,
15
+ normalize_qspline,
16
+ normalize_quantile,
17
+ normalize_quantile_robust,
18
+ )
19
+ from ._rma import rma
20
+ from ._robust import tukey_biweight, tukey_biweight_summary
21
+ from ._summarize import median_polish, summarize_probesets
22
+
23
+ __all__ = [
24
+ "background_correct",
25
+ "rma_background_correct",
26
+ "log2",
27
+ "normalize_constant",
28
+ "normalize_contrasts",
29
+ "normalize_invariantset",
30
+ "normalize_loess",
31
+ "normalize_qspline",
32
+ "normalize_quantile",
33
+ "normalize_quantile_robust",
34
+ "rma",
35
+ "tukey_biweight",
36
+ "tukey_biweight_summary",
37
+ "median_polish",
38
+ "summarize_probesets",
39
+ ]