pylocuszoom 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,474 @@
1
+ """Plotly backend for pyLocusZoom.
2
+
3
+ Interactive backend with hover tooltips and zoom/pan capabilities.
4
+ """
5
+
6
+ from typing import Any, List, Optional, Tuple, Union
7
+
8
+ import pandas as pd
9
+ import plotly.graph_objects as go
10
+ from plotly.subplots import make_subplots
11
+
12
+
13
+ class PlotlyBackend:
14
+ """Plotly backend for interactive plot generation.
15
+
16
+ Produces interactive HTML plots with hover tooltips showing:
17
+ - SNP RS ID
18
+ - P-value
19
+ - R² with lead SNP
20
+ - Nearest gene
21
+ """
22
+
23
+ def __init__(self) -> None:
24
+ """Initialize the plotly backend."""
25
+ self._marker_symbols = {
26
+ "o": "circle",
27
+ "D": "diamond",
28
+ "s": "square",
29
+ "^": "triangle-up",
30
+ "v": "triangle-down",
31
+ }
32
+
33
+ def create_figure(
34
+ self,
35
+ n_panels: int,
36
+ height_ratios: List[float],
37
+ figsize: Tuple[float, float],
38
+ sharex: bool = True,
39
+ ) -> Tuple[go.Figure, List[Any]]:
40
+ """Create a figure with multiple panels.
41
+
42
+ Args:
43
+ n_panels: Number of vertical panels.
44
+ height_ratios: Relative heights for each panel.
45
+ figsize: Figure size as (width, height) in inches.
46
+ sharex: Whether panels share the x-axis.
47
+
48
+ Returns:
49
+ Tuple of (figure, list of row indices for each panel).
50
+ """
51
+ # Convert inches to pixels (assuming 100 dpi for web)
52
+ width_px = int(figsize[0] * 100)
53
+ height_px = int(figsize[1] * 100)
54
+
55
+ # Normalize height ratios
56
+ total = sum(height_ratios)
57
+ row_heights = [h / total for h in height_ratios]
58
+
59
+ fig = make_subplots(
60
+ rows=n_panels,
61
+ cols=1,
62
+ shared_xaxes=sharex,
63
+ vertical_spacing=0.02,
64
+ row_heights=row_heights,
65
+ )
66
+
67
+ fig.update_layout(
68
+ width=width_px,
69
+ height=height_px,
70
+ showlegend=True,
71
+ template="plotly_white",
72
+ )
73
+
74
+ # Return row indices (1-based for plotly)
75
+ panel_refs = list(range(1, n_panels + 1))
76
+ return fig, panel_refs
77
+
78
+ def scatter(
79
+ self,
80
+ ax: Tuple[go.Figure, int],
81
+ x: pd.Series,
82
+ y: pd.Series,
83
+ colors: Union[str, List[str], pd.Series],
84
+ sizes: Union[float, List[float], pd.Series] = 60,
85
+ marker: str = "o",
86
+ edgecolor: str = "black",
87
+ linewidth: float = 0.5,
88
+ zorder: int = 2,
89
+ hover_data: Optional[pd.DataFrame] = None,
90
+ label: Optional[str] = None,
91
+ ) -> Any:
92
+ """Create a scatter plot on the given panel.
93
+
94
+ For plotly, ax is a tuple of (figure, row_number).
95
+ """
96
+ fig, row = ax
97
+
98
+ # Convert matplotlib marker to plotly symbol
99
+ symbol = self._marker_symbols.get(marker, "circle")
100
+
101
+ # Convert size (matplotlib uses area, plotly uses diameter)
102
+ if isinstance(sizes, (int, float)):
103
+ size = max(6, sizes ** 0.5) # Approximate conversion
104
+ else:
105
+ size = [max(6, s ** 0.5) for s in sizes]
106
+
107
+ # Build hover template
108
+ if hover_data is not None:
109
+ customdata = hover_data.values
110
+ hover_cols = hover_data.columns.tolist()
111
+ hovertemplate = "<b>%{customdata[0]}</b><br>"
112
+ for i, col in enumerate(hover_cols[1:], 1):
113
+ if "p" in col.lower():
114
+ hovertemplate += f"{col}: %{{customdata[{i}]:.2e}}<br>"
115
+ elif "r2" in col.lower() or "ld" in col.lower():
116
+ hovertemplate += f"{col}: %{{customdata[{i}]:.3f}}<br>"
117
+ else:
118
+ hovertemplate += f"{col}: %{{customdata[{i}]}}<br>"
119
+ hovertemplate += "<extra></extra>"
120
+ else:
121
+ customdata = None
122
+ hovertemplate = "x: %{x}<br>y: %{y:.2f}<extra></extra>"
123
+
124
+ # Handle color - could be single color or array
125
+ if isinstance(colors, str):
126
+ marker_color = colors
127
+ else:
128
+ marker_color = list(colors) if hasattr(colors, "tolist") else colors
129
+
130
+ trace = go.Scatter(
131
+ x=x,
132
+ y=y,
133
+ mode="markers",
134
+ marker=dict(
135
+ color=marker_color,
136
+ size=size,
137
+ symbol=symbol,
138
+ line=dict(color=edgecolor, width=linewidth),
139
+ ),
140
+ customdata=customdata,
141
+ hovertemplate=hovertemplate,
142
+ name=label or "",
143
+ showlegend=label is not None,
144
+ )
145
+
146
+ fig.add_trace(trace, row=row, col=1)
147
+ return trace
148
+
149
+ def line(
150
+ self,
151
+ ax: Tuple[go.Figure, int],
152
+ x: pd.Series,
153
+ y: pd.Series,
154
+ color: str = "blue",
155
+ linewidth: float = 1.5,
156
+ alpha: float = 1.0,
157
+ linestyle: str = "-",
158
+ zorder: int = 1,
159
+ label: Optional[str] = None,
160
+ ) -> Any:
161
+ """Create a line plot on the given panel."""
162
+ fig, row = ax
163
+
164
+ # Convert linestyle
165
+ dash_map = {
166
+ "-": "solid",
167
+ "--": "dash",
168
+ ":": "dot",
169
+ "-.": "dashdot",
170
+ }
171
+ dash = dash_map.get(linestyle, "solid")
172
+
173
+ trace = go.Scatter(
174
+ x=x,
175
+ y=y,
176
+ mode="lines",
177
+ line=dict(color=color, width=linewidth, dash=dash),
178
+ opacity=alpha,
179
+ name=label or "",
180
+ showlegend=label is not None,
181
+ )
182
+
183
+ fig.add_trace(trace, row=row, col=1)
184
+ return trace
185
+
186
+ def fill_between(
187
+ self,
188
+ ax: Tuple[go.Figure, int],
189
+ x: pd.Series,
190
+ y1: Union[float, pd.Series],
191
+ y2: Union[float, pd.Series],
192
+ color: str = "blue",
193
+ alpha: float = 0.3,
194
+ zorder: int = 0,
195
+ ) -> Any:
196
+ """Fill area between two y-values."""
197
+ fig, row = ax
198
+
199
+ # Convert y1 to series if scalar
200
+ if isinstance(y1, (int, float)):
201
+ y1 = pd.Series([y1] * len(x))
202
+
203
+ trace = go.Scatter(
204
+ x=pd.concat([x, x[::-1]]),
205
+ y=pd.concat([y2, y1[::-1]]),
206
+ fill="toself",
207
+ fillcolor=color,
208
+ opacity=alpha,
209
+ line=dict(width=0),
210
+ showlegend=False,
211
+ hoverinfo="skip",
212
+ )
213
+
214
+ fig.add_trace(trace, row=row, col=1)
215
+ return trace
216
+
217
+ def axhline(
218
+ self,
219
+ ax: Tuple[go.Figure, int],
220
+ y: float,
221
+ color: str = "grey",
222
+ linestyle: str = "--",
223
+ linewidth: float = 1.0,
224
+ zorder: int = 1,
225
+ ) -> Any:
226
+ """Add a horizontal line across the panel."""
227
+ fig, row = ax
228
+
229
+ dash_map = {"-": "solid", "--": "dash", ":": "dot", "-.": "dashdot"}
230
+ dash = dash_map.get(linestyle, "dash")
231
+
232
+ fig.add_hline(
233
+ y=y,
234
+ line_dash=dash,
235
+ line_color=color,
236
+ line_width=linewidth,
237
+ row=row,
238
+ col=1,
239
+ )
240
+
241
+ def add_text(
242
+ self,
243
+ ax: Tuple[go.Figure, int],
244
+ x: float,
245
+ y: float,
246
+ text: str,
247
+ fontsize: int = 10,
248
+ ha: str = "center",
249
+ va: str = "bottom",
250
+ rotation: float = 0,
251
+ color: str = "black",
252
+ ) -> Any:
253
+ """Add text annotation to panel."""
254
+ fig, row = ax
255
+
256
+ # Map alignment
257
+ xanchor_map = {"center": "center", "left": "left", "right": "right"}
258
+ yanchor_map = {"bottom": "bottom", "top": "top", "center": "middle"}
259
+
260
+ fig.add_annotation(
261
+ x=x,
262
+ y=y,
263
+ text=text,
264
+ font=dict(size=fontsize, color=color),
265
+ xanchor=xanchor_map.get(ha, "center"),
266
+ yanchor=yanchor_map.get(va, "bottom"),
267
+ textangle=-rotation,
268
+ showarrow=False,
269
+ row=row,
270
+ col=1,
271
+ )
272
+
273
+ def add_rectangle(
274
+ self,
275
+ ax: Tuple[go.Figure, int],
276
+ xy: Tuple[float, float],
277
+ width: float,
278
+ height: float,
279
+ facecolor: str = "blue",
280
+ edgecolor: str = "black",
281
+ linewidth: float = 0.5,
282
+ zorder: int = 2,
283
+ ) -> Any:
284
+ """Add a rectangle to the panel."""
285
+ fig, row = ax
286
+
287
+ x0, y0 = xy
288
+ x1, y1 = x0 + width, y0 + height
289
+
290
+ fig.add_shape(
291
+ type="rect",
292
+ x0=x0,
293
+ y0=y0,
294
+ x1=x1,
295
+ y1=y1,
296
+ fillcolor=facecolor,
297
+ line=dict(color=edgecolor, width=linewidth),
298
+ row=row,
299
+ col=1,
300
+ )
301
+
302
+ def set_xlim(self, ax: Tuple[go.Figure, int], left: float, right: float) -> None:
303
+ """Set x-axis limits."""
304
+ fig, row = ax
305
+ xaxis = f"xaxis{row}" if row > 1 else "xaxis"
306
+ fig.update_layout(**{xaxis: dict(range=[left, right])})
307
+
308
+ def set_ylim(self, ax: Tuple[go.Figure, int], bottom: float, top: float) -> None:
309
+ """Set y-axis limits."""
310
+ fig, row = ax
311
+ yaxis = f"yaxis{row}" if row > 1 else "yaxis"
312
+ fig.update_layout(**{yaxis: dict(range=[bottom, top])})
313
+
314
+ def set_xlabel(
315
+ self, ax: Tuple[go.Figure, int], label: str, fontsize: int = 12
316
+ ) -> None:
317
+ """Set x-axis label."""
318
+ fig, row = ax
319
+ xaxis = f"xaxis{row}" if row > 1 else "xaxis"
320
+ fig.update_layout(**{xaxis: dict(title=dict(text=label, font=dict(size=fontsize)))})
321
+
322
+ def set_ylabel(
323
+ self, ax: Tuple[go.Figure, int], label: str, fontsize: int = 12
324
+ ) -> None:
325
+ """Set y-axis label."""
326
+ fig, row = ax
327
+ yaxis = f"yaxis{row}" if row > 1 else "yaxis"
328
+ fig.update_layout(**{yaxis: dict(title=dict(text=label, font=dict(size=fontsize)))})
329
+
330
+ def set_title(
331
+ self, ax: Tuple[go.Figure, int], title: str, fontsize: int = 14
332
+ ) -> None:
333
+ """Set figure title (only works for first panel)."""
334
+ fig, row = ax
335
+ if row == 1:
336
+ fig.update_layout(title=dict(text=title, font=dict(size=fontsize)))
337
+
338
+ def create_twin_axis(self, ax: Tuple[go.Figure, int]) -> Tuple[go.Figure, int, str]:
339
+ """Create a secondary y-axis.
340
+
341
+ Returns tuple of (figure, row, secondary_yaxis_name).
342
+ """
343
+ fig, row = ax
344
+ secondary_y = f"y{row}2" if row > 1 else "y2"
345
+
346
+ # Configure secondary y-axis
347
+ yaxis_name = f"yaxis{row}2" if row > 1 else "yaxis2"
348
+ fig.update_layout(
349
+ **{
350
+ yaxis_name: dict(
351
+ overlaying=f"y{row}" if row > 1 else "y",
352
+ side="right",
353
+ )
354
+ }
355
+ )
356
+
357
+ return (fig, row, secondary_y)
358
+
359
+ def add_legend(
360
+ self,
361
+ ax: Tuple[go.Figure, int],
362
+ handles: List[Any],
363
+ labels: List[str],
364
+ loc: str = "upper left",
365
+ title: Optional[str] = None,
366
+ ) -> Any:
367
+ """Add a legend to the figure.
368
+
369
+ Note: Plotly handles legends automatically from trace names.
370
+ This method updates legend positioning.
371
+ """
372
+ fig, _ = ax
373
+
374
+ # Map matplotlib locations to plotly
375
+ loc_map = {
376
+ "upper left": dict(x=0.01, y=0.99, xanchor="left", yanchor="top"),
377
+ "upper right": dict(x=0.99, y=0.99, xanchor="right", yanchor="top"),
378
+ "lower left": dict(x=0.01, y=0.01, xanchor="left", yanchor="bottom"),
379
+ "lower right": dict(x=0.99, y=0.01, xanchor="right", yanchor="bottom"),
380
+ }
381
+
382
+ legend_pos = loc_map.get(loc, loc_map["upper left"])
383
+ fig.update_layout(
384
+ legend=dict(
385
+ **legend_pos,
386
+ title=dict(text=title) if title else None,
387
+ bgcolor="rgba(255,255,255,0.9)",
388
+ bordercolor="black",
389
+ borderwidth=1,
390
+ )
391
+ )
392
+
393
+ def hide_spines(self, ax: Tuple[go.Figure, int], spines: List[str]) -> None:
394
+ """Hide specified axis spines (lines).
395
+
396
+ Plotly doesn't have spines, but we can hide axis lines.
397
+ """
398
+ fig, row = ax
399
+
400
+ xaxis = f"xaxis{row}" if row > 1 else "xaxis"
401
+ yaxis = f"yaxis{row}" if row > 1 else "yaxis"
402
+
403
+ if "top" in spines or "right" in spines:
404
+ # Plotly's template "plotly_white" already hides these
405
+ pass
406
+
407
+ def format_xaxis_mb(self, ax: Tuple[go.Figure, int]) -> None:
408
+ """Format x-axis to show megabase values."""
409
+ fig, row = ax
410
+ xaxis = f"xaxis{row}" if row > 1 else "xaxis"
411
+
412
+ fig.update_layout(
413
+ **{
414
+ xaxis: dict(
415
+ tickformat=".2f",
416
+ ticksuffix=" Mb",
417
+ tickvals=None, # Auto
418
+ )
419
+ }
420
+ )
421
+
422
+ # Apply custom tick formatting via ticktext/tickvals if needed
423
+ # For now, let plotly auto-format
424
+
425
+ def save(
426
+ self,
427
+ fig: go.Figure,
428
+ path: str,
429
+ dpi: int = 150,
430
+ bbox_inches: str = "tight",
431
+ ) -> None:
432
+ """Save figure to file.
433
+
434
+ Supports .html for interactive and .png/.pdf for static.
435
+ """
436
+ if path.endswith(".html"):
437
+ fig.write_html(path)
438
+ else:
439
+ # Static export requires kaleido
440
+ scale = dpi / 100
441
+ fig.write_image(path, scale=scale)
442
+
443
+ def show(self, fig: go.Figure) -> None:
444
+ """Display the figure."""
445
+ fig.show()
446
+
447
+ def close(self, fig: go.Figure) -> None:
448
+ """Close the figure (no-op for plotly)."""
449
+ pass
450
+
451
+ def finalize_layout(
452
+ self,
453
+ fig: go.Figure,
454
+ left: float = 0.08,
455
+ right: float = 0.95,
456
+ top: float = 0.95,
457
+ bottom: float = 0.1,
458
+ hspace: float = 0.08,
459
+ ) -> None:
460
+ """Adjust layout margins.
461
+
462
+ Args:
463
+ fig: Figure object.
464
+ left, right, top, bottom: Margins as fractions.
465
+ hspace: Ignored for plotly (use vertical_spacing in make_subplots).
466
+ """
467
+ fig.update_layout(
468
+ margin=dict(
469
+ l=int(left * fig.layout.width) if fig.layout.width else 80,
470
+ r=int((1 - right) * fig.layout.width) if fig.layout.width else 50,
471
+ t=int((1 - top) * fig.layout.height) if fig.layout.height else 50,
472
+ b=int(bottom * fig.layout.height) if fig.layout.height else 80,
473
+ )
474
+ )
pylocuszoom/colors.py ADDED
@@ -0,0 +1,107 @@
1
+ """LD color schemes for regional association plots.
2
+
3
+ Implements LocusZoom-style coloring based on R² linkage disequilibrium values.
4
+ Colors match the locuszoomr R package color scheme.
5
+ """
6
+
7
+ import math
8
+ from typing import List, Optional, Tuple
9
+
10
+
11
+ def _is_missing(value: Optional[float]) -> bool:
12
+ """Check if value is None or NaN."""
13
+ return value is None or (isinstance(value, float) and math.isnan(value))
14
+
15
+
16
+ # LD bin thresholds, labels, and colors
17
+ # Format: (threshold, label, color)
18
+ LD_BINS: List[Tuple[float, str, str]] = [
19
+ (0.8, "0.8 - 1.0", "#FF0000"), # red
20
+ (0.6, "0.6 - 0.8", "#FFA500"), # orange
21
+ (0.4, "0.4 - 0.6", "#00CD00"), # green3
22
+ (0.2, "0.2 - 0.4", "#00EEEE"), # cyan2
23
+ (0.0, "0.0 - 0.2", "#4169E1"), # royalblue
24
+ ]
25
+
26
+ LD_NA_COLOR = "#BEBEBE" # grey - SNPs lacking LD information
27
+ LD_NA_LABEL = "NA"
28
+
29
+ # Lead SNP color (purple diamond)
30
+ LEAD_SNP_COLOR = "#7D26CD" # purple3
31
+
32
+
33
+ def get_ld_color(r2: Optional[float]) -> str:
34
+ """Get LocusZoom-style color based on LD R² value.
35
+
36
+ Uses the locuszoomr R package color scheme:
37
+ - 0.8-1.0: red
38
+ - 0.6-0.8: orange
39
+ - 0.4-0.6: green
40
+ - 0.2-0.4: cyan
41
+ - 0.0-0.2: blue
42
+ - NA: grey
43
+
44
+ Args:
45
+ r2: R² value between 0 and 1, or NaN for missing LD.
46
+
47
+ Returns:
48
+ Hex color code string.
49
+
50
+ Example:
51
+ >>> get_ld_color(0.85)
52
+ '#FF0000'
53
+ >>> get_ld_color(0.5)
54
+ '#00CD00'
55
+ >>> get_ld_color(float('nan'))
56
+ '#BEBEBE'
57
+ """
58
+ if _is_missing(r2):
59
+ return LD_NA_COLOR
60
+
61
+ for threshold, _, color in LD_BINS:
62
+ if r2 >= threshold:
63
+ return color
64
+
65
+ return LD_BINS[-1][2]
66
+
67
+
68
+ def get_ld_bin(r2: Optional[float]) -> str:
69
+ """Get LD bin label for categorical coloring.
70
+
71
+ Args:
72
+ r2: R² value between 0 and 1, or NaN for missing LD.
73
+
74
+ Returns:
75
+ Bin label string (e.g., "0.8 - 1.0" or "NA").
76
+
77
+ Example:
78
+ >>> get_ld_bin(0.85)
79
+ '0.8 - 1.0'
80
+ >>> get_ld_bin(float('nan'))
81
+ 'NA'
82
+ """
83
+ if _is_missing(r2):
84
+ return LD_NA_LABEL
85
+
86
+ for threshold, label, _ in LD_BINS:
87
+ if r2 >= threshold:
88
+ return label
89
+
90
+ return LD_BINS[-1][1]
91
+
92
+
93
+ def get_ld_color_palette() -> dict[str, str]:
94
+ """Get color palette mapping bin labels to colors.
95
+
96
+ Returns:
97
+ Dictionary mapping bin labels to hex colors, suitable for
98
+ use with seaborn or matplotlib.
99
+
100
+ Example:
101
+ >>> palette = get_ld_color_palette()
102
+ >>> palette["0.8 - 1.0"]
103
+ '#FF0000'
104
+ """
105
+ palette = {label: color for _, label, color in LD_BINS}
106
+ palette[LD_NA_LABEL] = LD_NA_COLOR
107
+ return palette