pylocuszoom 0.6.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylocuszoom/plotter.py CHANGED
@@ -15,12 +15,10 @@ from typing import Any, List, Optional, Tuple
15
15
  import matplotlib.pyplot as plt
16
16
  import numpy as np
17
17
  import pandas as pd
18
- from matplotlib.axes import Axes
19
- from matplotlib.figure import Figure
20
- from matplotlib.lines import Line2D
21
- from matplotlib.patches import Patch
18
+ import requests
22
19
 
23
20
  from .backends import BackendType, get_backend
21
+ from .backends.hover import HoverConfig, HoverDataBuilder
24
22
  from .colors import (
25
23
  EQTL_NEGATIVE_BINS,
26
24
  EQTL_POSITIVE_BINS,
@@ -33,6 +31,8 @@ from .colors import (
33
31
  get_ld_color_palette,
34
32
  get_phewas_category_palette,
35
33
  )
34
+ from .config import PlotConfig, StackedPlotConfig
35
+ from .ensembl import get_genes_for_region
36
36
  from .eqtl import validate_eqtl_df
37
37
  from .finemapping import (
38
38
  get_credible_sets,
@@ -41,16 +41,13 @@ from .finemapping import (
41
41
  from .forest import validate_forest_df
42
42
  from .gene_track import (
43
43
  assign_gene_positions,
44
- plot_gene_track,
45
44
  plot_gene_track_generic,
46
45
  )
47
- from .labels import add_snp_labels
48
46
  from .ld import calculate_ld, find_plink
49
47
  from .logging import enable_logging, logger
50
48
  from .phewas import validate_phewas_df
51
49
  from .recombination import (
52
50
  RECOMB_COLOR,
53
- add_recombination_overlay,
54
51
  download_canine_recombination_maps,
55
52
  get_default_data_dir,
56
53
  get_recombination_rate_for_region,
@@ -119,8 +116,21 @@ class LocusZoomPlotter:
119
116
  recomb_data_dir: Optional[str] = None,
120
117
  genomewide_threshold: float = DEFAULT_GENOMEWIDE_THRESHOLD,
121
118
  log_level: Optional[str] = "INFO",
119
+ auto_genes: bool = False,
122
120
  ):
123
- """Initialize the plotter."""
121
+ """Initialize the plotter.
122
+
123
+ Args:
124
+ species: Species name ('canine', 'feline', or None for custom).
125
+ genome_build: Genome build for coordinate system.
126
+ backend: Plotting backend ('matplotlib', 'plotly', or 'bokeh').
127
+ plink_path: Path to PLINK executable for LD calculation.
128
+ recomb_data_dir: Directory containing recombination maps.
129
+ genomewide_threshold: P-value threshold for significance line.
130
+ log_level: Logging level.
131
+ auto_genes: If True, automatically fetch genes from Ensembl when
132
+ genes_df is not provided. Default False for backward compatibility.
133
+ """
124
134
  # Configure logging
125
135
  if log_level is not None:
126
136
  enable_logging(log_level)
@@ -129,12 +139,12 @@ class LocusZoomPlotter:
129
139
  self.genome_build = (
130
140
  genome_build if genome_build else self._default_build(species)
131
141
  )
132
- self.backend_name = backend
133
142
  self._backend = get_backend(backend)
134
143
  self.plink_path = plink_path or find_plink()
135
144
  self.recomb_data_dir = recomb_data_dir
136
145
  self.genomewide_threshold = genomewide_threshold
137
146
  self._genomewide_line = -np.log10(genomewide_threshold)
147
+ self._auto_genes = auto_genes
138
148
 
139
149
  # Cache for loaded data
140
150
  self._recomb_cache = {}
@@ -163,9 +173,17 @@ class LocusZoomPlotter:
163
173
  # Download
164
174
  try:
165
175
  return download_canine_recombination_maps()
166
- except Exception as e:
176
+ except (requests.RequestException, OSError, IOError) as e:
177
+ # Expected network/file errors - graceful fallback
167
178
  logger.warning(f"Could not download recombination maps: {e}")
168
179
  return None
180
+ except Exception as e:
181
+ # JUSTIFICATION: Download failure should not prevent plotting.
182
+ # We catch broadly here because graceful degradation is acceptable
183
+ # for optional recombination map downloads. Error-level logging
184
+ # ensures the issue is visible.
185
+ logger.error(f"Unexpected error downloading recombination maps: {e}")
186
+ return None
169
187
  elif self.recomb_data_dir:
170
188
  return Path(self.recomb_data_dir)
171
189
  return None
@@ -199,55 +217,94 @@ class LocusZoomPlotter:
199
217
  def plot(
200
218
  self,
201
219
  gwas_df: pd.DataFrame,
220
+ *,
202
221
  chrom: int,
203
222
  start: int,
204
223
  end: int,
224
+ pos_col: str = "ps",
225
+ p_col: str = "p_wald",
226
+ rs_col: str = "rs",
227
+ snp_labels: bool = True,
228
+ label_top_n: int = 5,
229
+ show_recombination: bool = True,
230
+ figsize: Tuple[float, float] = (12.0, 8.0),
205
231
  lead_pos: Optional[int] = None,
206
232
  ld_reference_file: Optional[str] = None,
207
233
  ld_col: Optional[str] = None,
208
234
  genes_df: Optional[pd.DataFrame] = None,
209
235
  exons_df: Optional[pd.DataFrame] = None,
210
236
  recomb_df: Optional[pd.DataFrame] = None,
211
- show_recombination: bool = True,
212
- snp_labels: bool = True,
213
- label_top_n: int = 5,
214
- pos_col: str = "ps",
215
- p_col: str = "p_wald",
216
- rs_col: str = "rs",
217
- figsize: Tuple[int, int] = (12, 8),
218
237
  ) -> Any:
219
238
  """Create a regional association plot.
220
239
 
221
240
  Args:
222
241
  gwas_df: GWAS results DataFrame.
223
242
  chrom: Chromosome number.
224
- start: Start position of the region.
225
- end: End position of the region.
226
- lead_pos: Position of the lead/index SNP to highlight.
227
- ld_reference_file: PLINK binary fileset for LD calculation.
228
- If provided with lead_pos, calculates LD on the fly.
229
- ld_col: Column name for pre-computed LD (R²) values.
230
- Use this if LD was calculated externally.
243
+ start: Start position in base pairs.
244
+ end: End position in base pairs.
245
+ pos_col: Column name for genomic position.
246
+ p_col: Column name for p-value.
247
+ rs_col: Column name for SNP identifier.
248
+ snp_labels: Whether to show SNP labels on plot.
249
+ label_top_n: Number of top SNPs to label.
250
+ show_recombination: Whether to show recombination rate overlay.
251
+ figsize: Figure size as (width, height) in inches.
252
+ lead_pos: Position of lead/index SNP to highlight.
253
+ ld_reference_file: Path to PLINK binary fileset for LD calculation.
254
+ ld_col: Column name for pre-computed LD (R^2) values.
231
255
  genes_df: Gene annotations with chr, start, end, gene_name.
232
256
  exons_df: Exon annotations with chr, start, end, gene_name.
233
257
  recomb_df: Pre-loaded recombination rate data.
234
258
  If None and show_recombination=True, loads from species default.
235
- show_recombination: Whether to show recombination rate overlay.
236
- snp_labels: Whether to label top SNPs.
237
- label_top_n: Number of top SNPs to label.
238
- pos_col: Column name for position.
239
- p_col: Column name for p-value.
240
- rs_col: Column name for SNP ID.
241
- figsize: Figure size.
242
259
 
243
260
  Returns:
244
- Matplotlib Figure object.
261
+ Figure object (type depends on backend).
245
262
 
246
263
  Raises:
247
- ValidationError: If required DataFrame columns are missing.
264
+ ValidationError: If parameters or DataFrame columns are invalid.
265
+
266
+ Example:
267
+ >>> fig = plotter.plot(
268
+ ... gwas_df,
269
+ ... chrom=1, start=1000000, end=2000000,
270
+ ... lead_pos=1500000, snp_labels=True,
271
+ ... )
248
272
  """
273
+ # Validate parameters via Pydantic
274
+ PlotConfig.from_kwargs(
275
+ chrom=chrom,
276
+ start=start,
277
+ end=end,
278
+ pos_col=pos_col,
279
+ p_col=p_col,
280
+ rs_col=rs_col,
281
+ snp_labels=snp_labels,
282
+ label_top_n=label_top_n,
283
+ show_recombination=show_recombination,
284
+ figsize=figsize,
285
+ lead_pos=lead_pos,
286
+ ld_reference_file=ld_reference_file,
287
+ ld_col=ld_col,
288
+ )
289
+
249
290
  # Validate inputs
250
291
  validate_gwas_df(gwas_df, pos_col=pos_col, p_col=p_col)
292
+
293
+ # Auto-fetch genes if enabled and not provided
294
+ if genes_df is None and self._auto_genes:
295
+ logger.debug(
296
+ f"auto_genes enabled, fetching genes for chr{chrom}:{start}-{end}"
297
+ )
298
+ genes_df = get_genes_for_region(
299
+ species=self.species,
300
+ chrom=chrom,
301
+ start=start,
302
+ end=end,
303
+ )
304
+ if genes_df.empty:
305
+ logger.debug("No genes found in region from Ensembl")
306
+ genes_df = None
307
+
251
308
  if genes_df is not None:
252
309
  validate_genes_df(genes_df)
253
310
 
@@ -258,6 +315,23 @@ class LocusZoomPlotter:
258
315
 
259
316
  # Prepare data
260
317
  df = gwas_df.copy()
318
+
319
+ # Validate p-values and warn about issues
320
+ p_values = df[p_col]
321
+ nan_count = p_values.isna().sum()
322
+ if nan_count > 0:
323
+ logger.warning(
324
+ f"GWAS data contains {nan_count} NaN p-values which will be excluded"
325
+ )
326
+ invalid_count = ((p_values < 0) | (p_values > 1)).sum()
327
+ if invalid_count > 0:
328
+ logger.warning(
329
+ f"GWAS data contains {invalid_count} p-values outside [0, 1] range"
330
+ )
331
+ clipped_count = (p_values < 1e-300).sum()
332
+ if clipped_count > 0:
333
+ logger.debug(f"Clipping {clipped_count} p-values below 1e-300 to 1e-300")
334
+
261
335
  df["neglog10p"] = -np.log10(df[p_col].clip(lower=1e-300))
262
336
 
263
337
  # Calculate LD if reference file provided
@@ -305,10 +379,10 @@ class LocusZoomPlotter:
305
379
  zorder=1,
306
380
  )
307
381
 
308
- # Add SNP labels (matplotlib only - interactive backends use hover tooltips)
382
+ # Add SNP labels (capability check - interactive backends use hover tooltips)
309
383
  if snp_labels and rs_col in df.columns and label_top_n > 0 and not df.empty:
310
- if self.backend_name == "matplotlib":
311
- add_snp_labels(
384
+ if self._backend.supports_snp_labels:
385
+ self._backend.add_snp_labels(
312
386
  ax,
313
387
  df,
314
388
  pos_col=pos_col,
@@ -319,12 +393,10 @@ class LocusZoomPlotter:
319
393
  chrom=chrom,
320
394
  )
321
395
 
322
- # Add recombination overlay (all backends)
396
+ # Add recombination overlay (all backends with secondary axis support)
323
397
  if recomb_df is not None and not recomb_df.empty:
324
- if self.backend_name == "matplotlib":
325
- add_recombination_overlay(ax, recomb_df, start, end)
326
- else:
327
- self._add_recombination_overlay_generic(ax, recomb_df, start, end)
398
+ if self._backend.supports_secondary_axis:
399
+ self._add_recombination_overlay(ax, recomb_df, start, end)
328
400
 
329
401
  # Format axes
330
402
  self._backend.set_ylabel(ax, r"$-\log_{10}$ P")
@@ -333,25 +405,21 @@ class LocusZoomPlotter:
333
405
 
334
406
  # Add LD legend (all backends)
335
407
  if ld_col is not None and ld_col in df.columns:
336
- if self.backend_name == "matplotlib":
337
- self._add_ld_legend(ax)
338
- else:
339
- self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
408
+ self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
340
409
 
341
- # Plot gene track (all backends)
410
+ # Plot gene track (all backends use generic function)
342
411
  if genes_df is not None and gene_ax is not None:
343
- if self.backend_name == "matplotlib":
344
- plot_gene_track(gene_ax, genes_df, chrom, start, end, exons_df)
345
- else:
346
- plot_gene_track_generic(
347
- gene_ax, self._backend, genes_df, chrom, start, end, exons_df
348
- )
412
+ plot_gene_track_generic(
413
+ gene_ax, self._backend, genes_df, chrom, start, end, exons_df
414
+ )
349
415
  self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
350
416
  self._backend.hide_spines(gene_ax, ["top", "right", "left"])
417
+ # Format both axes for interactive backends (they don't share x-axis)
418
+ self._backend.format_xaxis_mb(gene_ax)
351
419
  else:
352
420
  self._backend.set_xlabel(ax, f"Chromosome {chrom} (Mb)")
353
421
 
354
- # Format x-axis with Mb labels
422
+ # Format x-axis with Mb labels (association axis always needs formatting)
355
423
  self._backend.format_xaxis_mb(ax)
356
424
 
357
425
  # Adjust layout
@@ -366,7 +434,7 @@ class LocusZoomPlotter:
366
434
  start: int,
367
435
  end: int,
368
436
  figsize: Tuple[int, int],
369
- ) -> Tuple[Figure, Axes, Optional[Axes]]:
437
+ ) -> Tuple[Any, Any, Optional[Any]]:
370
438
  """Create figure with optional gene track."""
371
439
  if genes_df is not None:
372
440
  # Calculate dynamic height based on gene rows
@@ -410,7 +478,7 @@ class LocusZoomPlotter:
410
478
 
411
479
  def _plot_association(
412
480
  self,
413
- ax: Axes,
481
+ ax: Any,
414
482
  df: pd.DataFrame,
415
483
  pos_col: str,
416
484
  ld_col: Optional[str],
@@ -419,23 +487,14 @@ class LocusZoomPlotter:
419
487
  p_col: Optional[str] = None,
420
488
  ) -> None:
421
489
  """Plot association scatter with LD coloring."""
422
-
423
- def _build_hover_data(subset_df: pd.DataFrame) -> Optional[pd.DataFrame]:
424
- """Build hover data for interactive backends."""
425
- hover_cols = {}
426
- # RS ID first (will be bold in hover)
427
- if rs_col and rs_col in subset_df.columns:
428
- hover_cols["SNP"] = subset_df[rs_col].values
429
- # Position
430
- if pos_col in subset_df.columns:
431
- hover_cols["Position"] = subset_df[pos_col].values
432
- # P-value
433
- if p_col and p_col in subset_df.columns:
434
- hover_cols["P-value"] = subset_df[p_col].values
435
- # LD
436
- if ld_col and ld_col in subset_df.columns:
437
- hover_cols["R²"] = subset_df[ld_col].values
438
- return pd.DataFrame(hover_cols) if hover_cols else None
490
+ # Build hover data using HoverDataBuilder
491
+ hover_config = HoverConfig(
492
+ snp_col=rs_col if rs_col and rs_col in df.columns else None,
493
+ pos_col=pos_col if pos_col in df.columns else None,
494
+ p_col=p_col if p_col and p_col in df.columns else None,
495
+ ld_col=ld_col if ld_col and ld_col in df.columns else None,
496
+ )
497
+ hover_builder = HoverDataBuilder(hover_config)
439
498
 
440
499
  # LD-based coloring
441
500
  if ld_col is not None and ld_col in df.columns:
@@ -454,7 +513,7 @@ class LocusZoomPlotter:
454
513
  edgecolor="black",
455
514
  linewidth=0.5,
456
515
  zorder=2,
457
- hover_data=_build_hover_data(bin_data),
516
+ hover_data=hover_builder.build_dataframe(bin_data),
458
517
  )
459
518
  else:
460
519
  # Default: grey points
@@ -467,7 +526,7 @@ class LocusZoomPlotter:
467
526
  edgecolor="black",
468
527
  linewidth=0.5,
469
528
  zorder=2,
470
- hover_data=_build_hover_data(df),
529
+ hover_data=hover_builder.build_dataframe(df),
471
530
  )
472
531
 
473
532
  # Highlight lead SNP with larger, more prominent marker
@@ -484,57 +543,21 @@ class LocusZoomPlotter:
484
543
  edgecolor="black",
485
544
  linewidth=1.5,
486
545
  zorder=10,
487
- hover_data=_build_hover_data(lead_snp),
546
+ hover_data=hover_builder.build_dataframe(lead_snp),
488
547
  )
489
548
 
490
- def _add_ld_legend(self, ax: Axes) -> None:
491
- """Add LD color legend to plot."""
492
- palette = get_ld_color_palette()
493
- legend_elements = [
494
- Line2D(
495
- [0],
496
- [0],
497
- marker="D",
498
- color="w",
499
- markerfacecolor=LEAD_SNP_COLOR,
500
- markeredgecolor="black",
501
- markersize=6,
502
- label="Lead SNP",
503
- ),
504
- ]
505
-
506
- for threshold, label, _ in LD_BINS:
507
- legend_elements.append(
508
- Patch(
509
- facecolor=palette[label],
510
- edgecolor="black",
511
- label=label,
512
- )
513
- )
514
-
515
- ax.legend(
516
- handles=legend_elements,
517
- loc="upper right",
518
- fontsize=9,
519
- frameon=True,
520
- framealpha=0.9,
521
- title=r"$r^2$",
522
- title_fontsize=10,
523
- handlelength=1.5,
524
- handleheight=1.0,
525
- labelspacing=0.4,
526
- )
527
-
528
- def _add_recombination_overlay_generic(
549
+ def _add_recombination_overlay(
529
550
  self,
530
551
  ax: Any,
531
552
  recomb_df: pd.DataFrame,
532
553
  start: int,
533
554
  end: int,
534
555
  ) -> None:
535
- """Add recombination overlay for interactive backends (plotly/bokeh).
556
+ """Add recombination overlay for all backends.
536
557
 
537
558
  Creates a secondary y-axis with recombination rate line and fill.
559
+ Uses backend-agnostic secondary axis methods that work across
560
+ matplotlib, plotly, and bokeh.
538
561
  """
539
562
  # Filter to region
540
563
  region_recomb = recomb_df[
@@ -545,18 +568,29 @@ class LocusZoomPlotter:
545
568
  return
546
569
 
547
570
  # Create secondary y-axis
548
- yaxis_name = self._backend.create_twin_axis(ax)
549
-
550
- # For plotly, yaxis_name is a tuple (fig, row, secondary_y)
551
- # For bokeh, yaxis_name is just a string
552
- if isinstance(yaxis_name, tuple):
553
- _, _, secondary_y = yaxis_name
571
+ twin_result = self._backend.create_twin_axis(ax)
572
+
573
+ # Matplotlib returns the twin Axes object itself - use it for drawing
574
+ # Plotly returns tuple (fig, row, secondary_y_name)
575
+ # Bokeh returns string "secondary"
576
+ from matplotlib.axes import Axes
577
+
578
+ if isinstance(twin_result, Axes):
579
+ # Matplotlib: use the twin axis for all secondary axis operations
580
+ secondary_ax = twin_result
581
+ secondary_y = None # Not used for matplotlib
582
+ elif isinstance(twin_result, tuple):
583
+ # Plotly: use original ax, specify y-axis via yaxis_name
584
+ secondary_ax = ax
585
+ _, _, secondary_y = twin_result
554
586
  else:
555
- secondary_y = yaxis_name
587
+ # Bokeh: use original ax, specify y-axis via yaxis_name
588
+ secondary_ax = ax
589
+ secondary_y = twin_result
556
590
 
557
591
  # Plot fill under curve
558
592
  self._backend.fill_between_secondary(
559
- ax,
593
+ secondary_ax,
560
594
  region_recomb["pos"],
561
595
  0,
562
596
  region_recomb["rate"],
@@ -567,7 +601,7 @@ class LocusZoomPlotter:
567
601
 
568
602
  # Plot recombination rate line
569
603
  self._backend.line_secondary(
570
- ax,
604
+ secondary_ax,
571
605
  region_recomb["pos"],
572
606
  region_recomb["rate"],
573
607
  color=RECOMB_COLOR,
@@ -579,10 +613,10 @@ class LocusZoomPlotter:
579
613
  # Set y-axis limits and label
580
614
  max_rate = region_recomb["rate"].max()
581
615
  self._backend.set_secondary_ylim(
582
- ax, 0, max(max_rate * 1.2, 20), yaxis_name=secondary_y
616
+ secondary_ax, 0, max(max_rate * 1.2, 20), yaxis_name=secondary_y
583
617
  )
584
618
  self._backend.set_secondary_ylabel(
585
- ax,
619
+ secondary_ax,
586
620
  "Recombination rate (cM/Mb)",
587
621
  color=RECOMB_COLOR,
588
622
  fontsize=9,
@@ -591,7 +625,7 @@ class LocusZoomPlotter:
591
625
 
592
626
  def _plot_finemapping(
593
627
  self,
594
- ax: Axes,
628
+ ax: Any,
595
629
  df: pd.DataFrame,
596
630
  pos_col: str = "pos",
597
631
  pip_col: str = "pip",
@@ -610,22 +644,15 @@ class LocusZoomPlotter:
610
644
  show_credible_sets: Whether to color points by credible set.
611
645
  pip_threshold: Minimum PIP to display as scatter point.
612
646
  """
613
-
614
- def _build_finemapping_hover_data(
615
- subset_df: pd.DataFrame,
616
- ) -> Optional[pd.DataFrame]:
617
- """Build hover data for interactive backends."""
618
- hover_cols = {}
619
- # Position
620
- if pos_col in subset_df.columns:
621
- hover_cols["Position"] = subset_df[pos_col].values
622
- # PIP
623
- if pip_col in subset_df.columns:
624
- hover_cols["PIP"] = subset_df[pip_col].values
625
- # Credible set
626
- if cs_col and cs_col in subset_df.columns:
627
- hover_cols["Credible Set"] = subset_df[cs_col].values
628
- return pd.DataFrame(hover_cols) if hover_cols else None
647
+ # Build hover data using HoverDataBuilder
648
+ extra_cols = {pip_col: "PIP"}
649
+ if cs_col and cs_col in df.columns:
650
+ extra_cols[cs_col] = "Credible Set"
651
+ hover_config = HoverConfig(
652
+ pos_col=pos_col if pos_col in df.columns else None,
653
+ extra_cols=extra_cols,
654
+ )
655
+ hover_builder = HoverDataBuilder(hover_config)
629
656
 
630
657
  # Sort by position for line plotting
631
658
  df = df.sort_values(pos_col)
@@ -660,7 +687,7 @@ class LocusZoomPlotter:
660
687
  edgecolor="black",
661
688
  linewidth=0.5,
662
689
  zorder=3,
663
- hover_data=_build_finemapping_hover_data(cs_data),
690
+ hover_data=hover_builder.build_dataframe(cs_data),
664
691
  )
665
692
  # Plot variants not in any credible set
666
693
  non_cs_data = df[(df[cs_col].isna()) | (df[cs_col] == 0)]
@@ -677,7 +704,7 @@ class LocusZoomPlotter:
677
704
  edgecolor="black",
678
705
  linewidth=0.3,
679
706
  zorder=2,
680
- hover_data=_build_finemapping_hover_data(non_cs_data),
707
+ hover_data=hover_builder.build_dataframe(non_cs_data),
681
708
  )
682
709
  else:
683
710
  # No credible sets - show all points above threshold
@@ -694,20 +721,28 @@ class LocusZoomPlotter:
694
721
  edgecolor="black",
695
722
  linewidth=0.5,
696
723
  zorder=3,
697
- hover_data=_build_finemapping_hover_data(high_pip),
724
+ hover_data=hover_builder.build_dataframe(high_pip),
698
725
  )
699
726
 
700
727
  def plot_stacked(
701
728
  self,
702
729
  gwas_dfs: List[pd.DataFrame],
730
+ *,
703
731
  chrom: int,
704
732
  start: int,
705
733
  end: int,
734
+ pos_col: str = "ps",
735
+ p_col: str = "p_wald",
736
+ rs_col: str = "rs",
737
+ snp_labels: bool = True,
738
+ label_top_n: int = 3,
739
+ show_recombination: bool = True,
740
+ figsize: Tuple[float, float] = (12.0, 8.0),
741
+ ld_reference_file: Optional[str] = None,
742
+ ld_col: Optional[str] = None,
706
743
  lead_positions: Optional[List[int]] = None,
707
744
  panel_labels: Optional[List[str]] = None,
708
- ld_reference_file: Optional[str] = None,
709
745
  ld_reference_files: Optional[List[str]] = None,
710
- ld_col: Optional[str] = None,
711
746
  genes_df: Optional[pd.DataFrame] = None,
712
747
  exons_df: Optional[pd.DataFrame] = None,
713
748
  eqtl_df: Optional[pd.DataFrame] = None,
@@ -715,13 +750,6 @@ class LocusZoomPlotter:
715
750
  finemapping_df: Optional[pd.DataFrame] = None,
716
751
  finemapping_cs_col: Optional[str] = "cs",
717
752
  recomb_df: Optional[pd.DataFrame] = None,
718
- show_recombination: bool = True,
719
- snp_labels: bool = True,
720
- label_top_n: int = 3,
721
- pos_col: str = "ps",
722
- p_col: str = "p_wald",
723
- rs_col: str = "rs",
724
- figsize: Tuple[float, Optional[float]] = (12, None),
725
753
  ) -> Any:
726
754
  """Create stacked regional association plots for multiple GWAS.
727
755
 
@@ -731,30 +759,28 @@ class LocusZoomPlotter:
731
759
  Args:
732
760
  gwas_dfs: List of GWAS results DataFrames to stack.
733
761
  chrom: Chromosome number.
734
- start: Start position of the region.
735
- end: End position of the region.
736
- lead_positions: List of lead SNP positions (one per GWAS).
737
- If None, auto-detects from lowest p-value.
738
- panel_labels: Labels for each panel (e.g., phenotype names).
739
- ld_reference_file: Single PLINK fileset for all panels.
762
+ start: Start position in base pairs.
763
+ end: End position in base pairs.
764
+ pos_col: Column name for genomic position.
765
+ p_col: Column name for p-value.
766
+ rs_col: Column name for SNP identifier.
767
+ snp_labels: Whether to show SNP labels on plot.
768
+ label_top_n: Number of top SNPs to label (default 3 for stacked).
769
+ show_recombination: Whether to show recombination rate overlay.
770
+ figsize: Figure size as (width, height) in inches.
771
+ ld_reference_file: Single PLINK fileset (broadcast to all panels).
772
+ ld_col: Column name for pre-computed LD (R^2) values.
773
+ lead_positions: List of lead SNP positions (one per panel).
774
+ panel_labels: List of panel labels (one per panel).
740
775
  ld_reference_files: List of PLINK filesets (one per panel).
741
- ld_col: Column name for pre-computed LD (R²) values in each DataFrame.
742
- Use this if LD was calculated externally.
743
776
  genes_df: Gene annotations for bottom track.
744
777
  exons_df: Exon annotations for gene track.
745
778
  eqtl_df: eQTL data to display as additional panel.
746
779
  eqtl_gene: Filter eQTL data to this target gene.
747
780
  finemapping_df: Fine-mapping/SuSiE results with pos and pip columns.
748
781
  Displayed as PIP line with optional credible set coloring.
749
- finemapping_cs_col: Column name for credible set assignment in finemapping_df.
782
+ finemapping_cs_col: Column name for credible set assignment.
750
783
  recomb_df: Pre-loaded recombination rate data.
751
- show_recombination: Whether to show recombination overlay.
752
- snp_labels: Whether to label top SNPs.
753
- label_top_n: Number of top SNPs to label per panel.
754
- pos_col: Column name for position.
755
- p_col: Column name for p-value.
756
- rs_col: Column name for SNP ID.
757
- figsize: Figure size (width, height). If height is None, auto-calculates.
758
784
 
759
785
  Returns:
760
786
  Figure object (type depends on backend).
@@ -764,9 +790,27 @@ class LocusZoomPlotter:
764
790
  ... [gwas_height, gwas_bmi, gwas_whr],
765
791
  ... chrom=1, start=1000000, end=2000000,
766
792
  ... panel_labels=["Height", "BMI", "WHR"],
767
- ... genes_df=genes_df,
768
793
  ... )
769
794
  """
795
+ # Validate parameters via Pydantic
796
+ StackedPlotConfig.from_kwargs(
797
+ chrom=chrom,
798
+ start=start,
799
+ end=end,
800
+ pos_col=pos_col,
801
+ p_col=p_col,
802
+ rs_col=rs_col,
803
+ snp_labels=snp_labels,
804
+ label_top_n=label_top_n,
805
+ show_recombination=show_recombination,
806
+ figsize=figsize,
807
+ ld_reference_file=ld_reference_file,
808
+ ld_col=ld_col,
809
+ lead_positions=lead_positions,
810
+ panel_labels=panel_labels,
811
+ ld_reference_files=ld_reference_files,
812
+ )
813
+
770
814
  n_gwas = len(gwas_dfs)
771
815
  if n_gwas == 0:
772
816
  raise ValueError("At least one GWAS DataFrame required")
@@ -802,8 +846,16 @@ class LocusZoomPlotter:
802
846
  for df in gwas_dfs:
803
847
  region_df = df[(df[pos_col] >= start) & (df[pos_col] <= end)]
804
848
  if not region_df.empty:
805
- lead_idx = region_df[p_col].idxmin()
806
- lead_positions.append(int(region_df.loc[lead_idx, pos_col]))
849
+ # Filter out NaN p-values for lead SNP detection
850
+ valid_p = region_df[p_col].dropna()
851
+ if valid_p.empty:
852
+ logger.warning(
853
+ "All p-values in region are NaN, cannot determine lead SNP"
854
+ )
855
+ lead_positions.append(None)
856
+ else:
857
+ lead_idx = valid_p.idxmin()
858
+ lead_positions.append(int(region_df.loc[lead_idx, pos_col]))
807
859
  else:
808
860
  lead_positions.append(None)
809
861
 
@@ -912,10 +964,10 @@ class LocusZoomPlotter:
912
964
  zorder=1,
913
965
  )
914
966
 
915
- # Add SNP labels (matplotlib only - interactive backends use hover tooltips)
967
+ # Add SNP labels (capability check - interactive backends use hover tooltips)
916
968
  if snp_labels and rs_col in df.columns and label_top_n > 0 and not df.empty:
917
- if self.backend_name == "matplotlib":
918
- add_snp_labels(
969
+ if self._backend.supports_snp_labels:
970
+ self._backend.add_snp_labels(
919
971
  ax,
920
972
  df,
921
973
  pos_col=pos_col,
@@ -928,10 +980,8 @@ class LocusZoomPlotter:
928
980
 
929
981
  # Add recombination overlay (only on first panel, all backends)
930
982
  if i == 0 and recomb_df is not None and not recomb_df.empty:
931
- if self.backend_name == "matplotlib":
932
- add_recombination_overlay(ax, recomb_df, start, end)
933
- else:
934
- self._add_recombination_overlay_generic(ax, recomb_df, start, end)
983
+ if self._backend.supports_secondary_axis:
984
+ self._add_recombination_overlay(ax, recomb_df, start, end)
935
985
 
936
986
  # Format axes
937
987
  self._backend.set_ylabel(ax, r"$-\log_{10}$ P")
@@ -940,50 +990,11 @@ class LocusZoomPlotter:
940
990
 
941
991
  # Add panel label
942
992
  if panel_labels and i < len(panel_labels):
943
- if self.backend_name == "matplotlib":
944
- ax.annotate(
945
- panel_labels[i],
946
- xy=(0.02, 0.95),
947
- xycoords="axes fraction",
948
- fontsize=11,
949
- fontweight="bold",
950
- va="top",
951
- ha="left",
952
- )
953
- elif self.backend_name == "plotly":
954
- fig, row = ax
955
- fig.add_annotation(
956
- text=f"<b>{panel_labels[i]}</b>",
957
- xref=f"x{row} domain" if row > 1 else "x domain",
958
- yref=f"y{row} domain" if row > 1 else "y domain",
959
- x=0.02,
960
- y=0.95,
961
- showarrow=False,
962
- font=dict(size=11),
963
- xanchor="left",
964
- yanchor="top",
965
- )
966
- elif self.backend_name == "bokeh":
967
- from bokeh.models import Label
968
-
969
- # Get y-axis range for positioning
970
- y_max = ax.y_range.end if ax.y_range.end else 10
971
- x_min = ax.x_range.start if ax.x_range.start else start
972
- label = Label(
973
- x=x_min + (end - start) * 0.02,
974
- y=y_max * 0.95,
975
- text=panel_labels[i],
976
- text_font_size="11pt",
977
- text_font_style="bold",
978
- )
979
- ax.add_layout(label)
993
+ self._backend.add_panel_label(ax, panel_labels[i])
980
994
 
981
995
  # Add LD legend (only on first panel, all backends)
982
996
  if i == 0 and panel_ld_col is not None and panel_ld_col in df.columns:
983
- if self.backend_name == "matplotlib":
984
- self._add_ld_legend(ax)
985
- else:
986
- self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
997
+ self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
987
998
 
988
999
  # Track current panel index
989
1000
  panel_idx = n_gwas
@@ -1050,63 +1061,58 @@ class LocusZoomPlotter:
1050
1061
  eqtl_data["p_value"].clip(lower=1e-300)
1051
1062
  )
1052
1063
 
1053
- def _build_eqtl_hover_data(
1054
- subset_df: pd.DataFrame,
1055
- ) -> Optional[pd.DataFrame]:
1056
- """Build hover data for eQTL interactive backends."""
1057
- hover_cols = {}
1058
- # Position
1059
- if "pos" in subset_df.columns:
1060
- hover_cols["Position"] = subset_df["pos"].values
1061
- # P-value
1062
- if "p_value" in subset_df.columns:
1063
- hover_cols["P-value"] = subset_df["p_value"].values
1064
- # Effect size
1065
- if "effect_size" in subset_df.columns:
1066
- hover_cols["Effect"] = subset_df["effect_size"].values
1067
- # Gene
1068
- if "gene" in subset_df.columns:
1069
- hover_cols["Gene"] = subset_df["gene"].values
1070
- return pd.DataFrame(hover_cols) if hover_cols else None
1064
+ # Build hover data using HoverDataBuilder
1065
+ eqtl_extra_cols = {}
1066
+ if "effect_size" in eqtl_data.columns:
1067
+ eqtl_extra_cols["effect_size"] = "Effect"
1068
+ if "gene" in eqtl_data.columns:
1069
+ eqtl_extra_cols["gene"] = "Gene"
1070
+ eqtl_hover_config = HoverConfig(
1071
+ pos_col="pos" if "pos" in eqtl_data.columns else None,
1072
+ p_col="p_value" if "p_value" in eqtl_data.columns else None,
1073
+ extra_cols=eqtl_extra_cols,
1074
+ )
1075
+ eqtl_hover_builder = HoverDataBuilder(eqtl_hover_config)
1071
1076
 
1072
1077
  # Check if effect_size column exists for directional coloring
1073
1078
  has_effect = "effect_size" in eqtl_data.columns
1074
1079
 
1075
1080
  if has_effect:
1076
- # Plot triangles by effect direction (batch by sign for efficiency)
1081
+ # Vectorized plotting: split by sign, assign colors in bulk
1077
1082
  pos_effects = eqtl_data[eqtl_data["effect_size"] >= 0]
1078
1083
  neg_effects = eqtl_data[eqtl_data["effect_size"] < 0]
1079
1084
 
1080
- # Plot positive effects (up triangles)
1081
- for _, row in pos_effects.iterrows():
1082
- row_df = pd.DataFrame([row])
1085
+ # Vectorized color assignment using apply
1086
+ if not pos_effects.empty:
1087
+ pos_colors = pos_effects["effect_size"].apply(get_eqtl_color)
1083
1088
  self._backend.scatter(
1084
1089
  ax,
1085
- pd.Series([row["pos"]]),
1086
- pd.Series([row["neglog10p"]]),
1087
- colors=get_eqtl_color(row["effect_size"]),
1090
+ pos_effects["pos"],
1091
+ pos_effects["neglog10p"],
1092
+ colors=pos_colors.tolist(),
1088
1093
  sizes=50,
1089
1094
  marker="^",
1090
1095
  edgecolor="black",
1091
1096
  linewidth=0.5,
1092
1097
  zorder=2,
1093
- hover_data=_build_eqtl_hover_data(row_df),
1098
+ hover_data=eqtl_hover_builder.build_dataframe(pos_effects),
1094
1099
  )
1095
- # Plot negative effects (down triangles)
1096
- for _, row in neg_effects.iterrows():
1097
- row_df = pd.DataFrame([row])
1100
+
1101
+ if not neg_effects.empty:
1102
+ neg_colors = neg_effects["effect_size"].apply(get_eqtl_color)
1098
1103
  self._backend.scatter(
1099
1104
  ax,
1100
- pd.Series([row["pos"]]),
1101
- pd.Series([row["neglog10p"]]),
1102
- colors=get_eqtl_color(row["effect_size"]),
1105
+ neg_effects["pos"],
1106
+ neg_effects["neglog10p"],
1107
+ colors=neg_colors.tolist(),
1103
1108
  sizes=50,
1104
1109
  marker="v",
1105
1110
  edgecolor="black",
1106
1111
  linewidth=0.5,
1107
1112
  zorder=2,
1108
- hover_data=_build_eqtl_hover_data(row_df),
1113
+ hover_data=eqtl_hover_builder.build_dataframe(neg_effects),
1109
1114
  )
1115
+
1110
1116
  # Add eQTL effect legend (all backends)
1111
1117
  self._backend.add_eqtl_legend(
1112
1118
  ax, EQTL_POSITIVE_BINS, EQTL_NEGATIVE_BINS
@@ -1125,7 +1131,7 @@ class LocusZoomPlotter:
1125
1131
  linewidth=0.5,
1126
1132
  zorder=2,
1127
1133
  label=label,
1128
- hover_data=_build_eqtl_hover_data(eqtl_data),
1134
+ hover_data=eqtl_hover_builder.build_dataframe(eqtl_data),
1129
1135
  )
1130
1136
  self._backend.add_simple_legend(ax, label, loc="upper right")
1131
1137
 
@@ -1141,15 +1147,12 @@ class LocusZoomPlotter:
1141
1147
  self._backend.hide_spines(ax, ["top", "right"])
1142
1148
  panel_idx += 1
1143
1149
 
1144
- # Plot gene track (all backends)
1150
+ # Plot gene track (all backends use generic function)
1145
1151
  if genes_df is not None:
1146
1152
  gene_ax = axes[panel_idx]
1147
- if self.backend_name == "matplotlib":
1148
- plot_gene_track(gene_ax, genes_df, chrom, start, end, exons_df)
1149
- else:
1150
- plot_gene_track_generic(
1151
- gene_ax, self._backend, genes_df, chrom, start, end, exons_df
1152
- )
1153
+ plot_gene_track_generic(
1154
+ gene_ax, self._backend, genes_df, chrom, start, end, exons_df
1155
+ )
1153
1156
  self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
1154
1157
  self._backend.hide_spines(gene_ax, ["top", "right", "left"])
1155
1158
  else:
@@ -1230,18 +1233,37 @@ class LocusZoomPlotter:
1230
1233
  # Plot points by category
1231
1234
  if categories:
1232
1235
  for cat in categories:
1233
- cat_data = df[df[category_col] == cat]
1236
+ # Handle NaN category: NaN == NaN is False in pandas
1237
+ if pd.isna(cat):
1238
+ cat_data = df[df[category_col].isna()]
1239
+ else:
1240
+ cat_data = df[df[category_col] == cat]
1234
1241
  # Use upward triangles for positive effects, circles otherwise
1235
1242
  if effect_col and effect_col in cat_data.columns:
1236
- for _, row in cat_data.iterrows():
1237
- marker = "^" if row[effect_col] >= 0 else "v"
1243
+ # Vectorized: split by effect sign, 2 scatter calls per category
1244
+ pos_data = cat_data[cat_data[effect_col] >= 0]
1245
+ neg_data = cat_data[cat_data[effect_col] < 0]
1246
+
1247
+ if not pos_data.empty:
1238
1248
  self._backend.scatter(
1239
1249
  ax,
1240
- pd.Series([row["neglog10p"]]),
1241
- pd.Series([row["y_pos"]]),
1250
+ pos_data["neglog10p"],
1251
+ pos_data["y_pos"],
1242
1252
  colors=palette[cat],
1243
1253
  sizes=60,
1244
- marker=marker,
1254
+ marker="^",
1255
+ edgecolor="black",
1256
+ linewidth=0.5,
1257
+ zorder=2,
1258
+ )
1259
+ if not neg_data.empty:
1260
+ self._backend.scatter(
1261
+ ax,
1262
+ neg_data["neglog10p"],
1263
+ neg_data["y_pos"],
1264
+ colors=palette[cat],
1265
+ sizes=60,
1266
+ marker="v",
1245
1267
  edgecolor="black",
1246
1268
  linewidth=0.5,
1247
1269
  zorder=2,
@@ -1281,10 +1303,13 @@ class LocusZoomPlotter:
1281
1303
  self._backend.set_ylabel(ax, "Phenotype")
1282
1304
  self._backend.set_ylim(ax, -0.5, len(df) - 0.5)
1283
1305
 
1284
- # Set y-tick labels to phenotype names (matplotlib only)
1285
- if self.backend_name == "matplotlib":
1286
- ax.set_yticks(df["y_pos"])
1287
- ax.set_yticklabels(df[phenotype_col], fontsize=8)
1306
+ # Set y-tick labels to phenotype names
1307
+ self._backend.set_yticks(
1308
+ ax,
1309
+ positions=df["y_pos"].tolist(),
1310
+ labels=df[phenotype_col].tolist(),
1311
+ fontsize=8,
1312
+ )
1288
1313
 
1289
1314
  self._backend.set_title(ax, f"PheWAS: {variant_id}")
1290
1315
  self._backend.hide_spines(ax, ["top", "right"])
@@ -1399,10 +1424,19 @@ class LocusZoomPlotter:
1399
1424
  self._backend.set_xlabel(ax, effect_label)
1400
1425
  self._backend.set_ylim(ax, -0.5, len(df) - 0.5)
1401
1426
 
1402
- # Set y-tick labels to study names (matplotlib only)
1403
- if self.backend_name == "matplotlib":
1404
- ax.set_yticks(df["y_pos"])
1405
- ax.set_yticklabels(df[study_col], fontsize=10)
1427
+ # Ensure x-axis includes the null value with some padding
1428
+ x_min = min(df[ci_lower_col].min(), null_value)
1429
+ x_max = max(df[ci_upper_col].max(), null_value)
1430
+ x_padding = (x_max - x_min) * 0.1
1431
+ self._backend.set_xlim(ax, x_min - x_padding, x_max + x_padding)
1432
+
1433
+ # Set y-tick labels to study names
1434
+ self._backend.set_yticks(
1435
+ ax,
1436
+ positions=df["y_pos"].tolist(),
1437
+ labels=df[study_col].tolist(),
1438
+ fontsize=10,
1439
+ )
1406
1440
 
1407
1441
  self._backend.set_title(ax, f"Forest Plot: {variant_id}")
1408
1442
  self._backend.hide_spines(ax, ["top", "right"])