pylocuszoom 1.1.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,252 @@
1
+ """LD heatmap generator for pairwise linkage disequilibrium visualization.
2
+
3
+ Provides triangular heatmap display of pairwise LD values (R² or D')
4
+ with colorbar legend and SNP highlighting support.
5
+ """
6
+
7
+ from typing import Any, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from .backends import BackendType, get_backend
13
+ from .colors import (
14
+ LD_HEATMAP_COLORS,
15
+ LEAD_SNP_HIGHLIGHT_COLOR,
16
+ SECONDARY_HIGHLIGHT_COLOR,
17
+ )
18
+
19
+
20
+ class LDHeatmapPlotter:
21
+ """LD heatmap generator for pairwise LD visualization.
22
+
23
+ Creates triangular heatmaps showing pairwise linkage disequilibrium
24
+ between variants. Supports R² and D' metrics, lead SNP highlighting,
25
+ and multiple backend renderers.
26
+
27
+ Supports multiple rendering backends:
28
+ - matplotlib (default): Static publication-quality plots
29
+ - plotly: Interactive HTML with hover tooltips
30
+ - bokeh: Interactive HTML for dashboards
31
+
32
+ Args:
33
+ species: Species name ('canine', 'feline', 'human', or None).
34
+ Currently unused but kept for API consistency.
35
+ backend: Plotting backend ('matplotlib', 'plotly', or 'bokeh').
36
+
37
+ Example:
38
+ >>> plotter = LDHeatmapPlotter()
39
+ >>> fig = plotter.plot_ld_heatmap(ld_matrix, lead_snp="rs12345")
40
+ >>> fig.savefig("ld_heatmap.png", dpi=150)
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ species: str = "canine",
46
+ backend: BackendType = "matplotlib",
47
+ ):
48
+ """Initialize the LD heatmap plotter."""
49
+ self.species = species # Kept for backward compatibility, currently unused
50
+ self._backend = get_backend(backend)
51
+ self.backend_name = backend
52
+
53
+ def plot_ld_heatmap(
54
+ self,
55
+ ld_matrix: Union[pd.DataFrame, np.ndarray],
56
+ snp_ids: Optional[List[str]] = None,
57
+ lead_snp: Optional[str] = None,
58
+ highlight_snps: Optional[List[str]] = None,
59
+ metric: str = "r2",
60
+ figsize: Tuple[float, float] = (8, 8),
61
+ title: Optional[str] = None,
62
+ show_colorbar: bool = True,
63
+ ) -> Any:
64
+ """Create triangular LD heatmap.
65
+
66
+ Args:
67
+ ld_matrix: Square DataFrame or numpy array with pairwise LD values.
68
+ NaN values are displayed as grey (missing data).
69
+ snp_ids: List of SNP IDs for axis labels. If None, uses matrix index.
70
+ lead_snp: SNP ID to highlight as lead variant (red highlight).
71
+ highlight_snps: Additional SNP IDs to highlight (blue highlight).
72
+ metric: LD metric label for colorbar ("r2" or "dprime").
73
+ figsize: Figure size as (width, height).
74
+ title: Plot title.
75
+ show_colorbar: Whether to show colorbar legend.
76
+
77
+ Returns:
78
+ Figure object (type depends on backend).
79
+
80
+ Raises:
81
+ ValueError: If ld_matrix is not square.
82
+ ValueError: If lead_snp not found in snp_ids.
83
+ ValueError: If any highlight_snps not found in snp_ids.
84
+
85
+ Example:
86
+ >>> fig = plotter.plot_ld_heatmap(
87
+ ... ld_matrix,
88
+ ... snp_ids=["rs1", "rs2", "rs3"],
89
+ ... lead_snp="rs1",
90
+ ... metric="r2",
91
+ ... )
92
+ """
93
+ # Extract data and snp_ids from DataFrame if needed
94
+ if isinstance(ld_matrix, pd.DataFrame):
95
+ data = ld_matrix.values
96
+ if snp_ids is None:
97
+ snp_ids = list(ld_matrix.index.astype(str))
98
+ else:
99
+ data = np.asarray(ld_matrix)
100
+ if snp_ids is None:
101
+ snp_ids = [str(i) for i in range(data.shape[0])]
102
+
103
+ # Validate square matrix
104
+ if data.ndim != 2 or data.shape[0] != data.shape[1]:
105
+ raise ValueError(f"ld_matrix must be square, got shape {data.shape}")
106
+
107
+ n_snps = len(snp_ids)
108
+ if data.shape[0] != n_snps:
109
+ raise ValueError(
110
+ f"snp_ids length ({n_snps}) does not match matrix dimension ({data.shape[0]})"
111
+ )
112
+
113
+ # Validate lead_snp
114
+ lead_idx = None
115
+ if lead_snp is not None:
116
+ if lead_snp not in snp_ids:
117
+ raise ValueError(f"lead_snp '{lead_snp}' not found in snp_ids")
118
+ lead_idx = snp_ids.index(lead_snp)
119
+
120
+ # Validate highlight_snps
121
+ highlight_indices = []
122
+ if highlight_snps:
123
+ for snp in highlight_snps:
124
+ if snp not in snp_ids:
125
+ raise ValueError(f"highlight_snp '{snp}' not found in snp_ids")
126
+ highlight_indices.append(snp_ids.index(snp))
127
+
128
+ # Create figure with single panel
129
+ fig, axes = self._backend.create_figure(
130
+ n_panels=1,
131
+ height_ratios=[1.0],
132
+ figsize=figsize,
133
+ sharex=False,
134
+ )
135
+ ax = axes[0]
136
+
137
+ # Render triangular heatmap
138
+ mappable = self._backend.add_heatmap(
139
+ ax,
140
+ data=data,
141
+ x_coords=list(range(n_snps)),
142
+ y_coords=list(range(n_snps)),
143
+ cmap_colors=LD_HEATMAP_COLORS,
144
+ vmin=0.0,
145
+ vmax=1.0,
146
+ mask_upper=True,
147
+ )
148
+
149
+ # Add colorbar
150
+ if show_colorbar:
151
+ label = "R²" if metric == "r2" else "D'"
152
+ self._backend.add_colorbar(ax, mappable, label=label)
153
+
154
+ # Highlight lead SNP
155
+ if lead_idx is not None:
156
+ self._highlight_snp(
157
+ ax=ax,
158
+ fig=fig,
159
+ snp_idx=lead_idx,
160
+ n_snps=n_snps,
161
+ color=LEAD_SNP_HIGHLIGHT_COLOR,
162
+ )
163
+
164
+ # Highlight additional SNPs
165
+ for idx in highlight_indices:
166
+ self._highlight_snp(
167
+ ax=ax,
168
+ fig=fig,
169
+ snp_idx=idx,
170
+ n_snps=n_snps,
171
+ color=SECONDARY_HIGHLIGHT_COLOR,
172
+ )
173
+
174
+ # Set axis ticks with SNP labels
175
+ tick_positions = list(range(n_snps))
176
+ self._backend.set_xticks(ax, tick_positions, snp_ids, rotation=90)
177
+ self._backend.set_yticks(ax, tick_positions, snp_ids)
178
+
179
+ # Set title
180
+ if title:
181
+ self._backend.set_title(ax, title)
182
+
183
+ # Finalize layout
184
+ self._backend.finalize_layout(fig)
185
+
186
+ return fig
187
+
188
+ def _highlight_snp(
189
+ self,
190
+ ax: Any,
191
+ fig: Any,
192
+ snp_idx: int,
193
+ n_snps: int,
194
+ color: str,
195
+ ) -> None:
196
+ """Add visual highlight for a SNP's row/column in the heatmap.
197
+
198
+ Draws rectangle borders around the row and column cells for the
199
+ given SNP in the lower triangle.
200
+
201
+ Args:
202
+ ax: Axes object from backend.
203
+ fig: Figure object from backend.
204
+ snp_idx: Index of the SNP to highlight.
205
+ n_snps: Total number of SNPs in the matrix.
206
+ color: Highlight color.
207
+ """
208
+ # Compute all cell positions to highlight (x, y pairs)
209
+ # Row cells: columns 0 to snp_idx, row = snp_idx
210
+ row_cells = [(j, snp_idx) for j in range(snp_idx + 1)]
211
+ # Column cells: column = snp_idx, rows snp_idx+1 to end (skip diagonal)
212
+ col_cells = [(snp_idx, i) for i in range(snp_idx + 1, n_snps)]
213
+ all_cells = row_cells + col_cells
214
+
215
+ if self.backend_name == "matplotlib":
216
+ from matplotlib.patches import Rectangle
217
+
218
+ for x, y in all_cells:
219
+ rect = Rectangle(
220
+ (x - 0.5, y - 0.5),
221
+ 1.0,
222
+ 1.0,
223
+ fill=False,
224
+ edgecolor=color,
225
+ linewidth=2,
226
+ zorder=10,
227
+ )
228
+ ax.add_patch(rect)
229
+
230
+ elif self.backend_name == "plotly":
231
+ for x, y in all_cells:
232
+ fig.add_shape(
233
+ type="rect",
234
+ x0=x - 0.5,
235
+ x1=x + 0.5,
236
+ y0=y - 0.5,
237
+ y1=y + 0.5,
238
+ line=dict(color=color, width=2),
239
+ fillcolor="rgba(0,0,0,0)",
240
+ )
241
+
242
+ elif self.backend_name == "bokeh":
243
+ for x, y in all_cells:
244
+ ax.rect(
245
+ x=x,
246
+ y=y,
247
+ width=1,
248
+ height=1,
249
+ fill_alpha=0,
250
+ line_color=color,
251
+ line_width=2,
252
+ )