pylocuszoom 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylocuszoom/plotter.py CHANGED
@@ -10,7 +10,7 @@ Supports multiple backends:
10
10
  """
11
11
 
12
12
  from pathlib import Path
13
- from typing import Any, List, Optional, Tuple, Union
13
+ from typing import Any, List, Optional, Tuple
14
14
 
15
15
  import matplotlib.pyplot as plt
16
16
  import numpy as np
@@ -19,30 +19,43 @@ from matplotlib.axes import Axes
19
19
  from matplotlib.figure import Figure
20
20
  from matplotlib.lines import Line2D
21
21
  from matplotlib.patches import Patch
22
- from matplotlib.ticker import FuncFormatter, MaxNLocator
23
-
24
- from .backends import BackendType, PlotBackend, get_backend
25
22
 
23
+ from .backends import BackendType, get_backend
26
24
  from .colors import (
25
+ EQTL_NEGATIVE_BINS,
26
+ EQTL_POSITIVE_BINS,
27
27
  LD_BINS,
28
28
  LEAD_SNP_COLOR,
29
+ PIP_LINE_COLOR,
30
+ get_credible_set_color,
31
+ get_eqtl_color,
29
32
  get_ld_bin,
30
33
  get_ld_color_palette,
31
34
  )
32
- from .gene_track import assign_gene_positions, plot_gene_track
35
+ from .eqtl import validate_eqtl_df
36
+ from .finemapping import (
37
+ get_credible_sets,
38
+ prepare_finemapping_for_plotting,
39
+ )
40
+ from .gene_track import (
41
+ assign_gene_positions,
42
+ plot_gene_track,
43
+ plot_gene_track_generic,
44
+ )
33
45
  from .labels import add_snp_labels
34
46
  from .ld import calculate_ld, find_plink
35
47
  from .logging import enable_logging, logger
36
48
  from .recombination import (
49
+ RECOMB_COLOR,
37
50
  add_recombination_overlay,
38
- download_dog_recombination_maps,
51
+ download_canine_recombination_maps,
39
52
  get_default_data_dir,
40
53
  get_recombination_rate_for_region,
41
54
  )
42
55
  from .utils import normalize_chrom, validate_genes_df, validate_gwas_df
43
56
 
44
- # Default significance threshold: 5e-8 for human, 5e-7 for dog
45
- DEFAULT_GENOMEWIDE_THRESHOLD = 5e-7
57
+ # Default significance threshold: 5e-8 (genome-wide significance)
58
+ DEFAULT_GENOMEWIDE_THRESHOLD = 5e-8
46
59
  DEFAULT_GENOMEWIDE_LINE = -np.log10(DEFAULT_GENOMEWIDE_THRESHOLD)
47
60
 
48
61
 
@@ -52,7 +65,7 @@ class LocusZoomPlotter:
52
65
  Creates LocusZoom-style regional plots with:
53
66
  - LD coloring based on R² with lead variant
54
67
  - Gene and exon tracks
55
- - Recombination rate overlays (dog built-in, or user-provided)
68
+ - Recombination rate overlays (canine built-in, or user-provided)
56
69
  - Automatic SNP labeling
57
70
 
58
71
  Supports multiple rendering backends:
@@ -61,9 +74,9 @@ class LocusZoomPlotter:
61
74
  - bokeh: Interactive HTML for dashboards
62
75
 
63
76
  Args:
64
- species: Species name ('dog', 'cat', or None for custom).
65
- Dog has built-in recombination maps.
66
- genome_build: Genome build for coordinate system. For dog:
77
+ species: Species name ('canine', 'feline', or None for custom).
78
+ Canine has built-in recombination maps.
79
+ genome_build: Genome build for coordinate system. For canine:
67
80
  "canfam3.1" (default) or "canfam4". If "canfam4", recombination
68
81
  maps are automatically lifted over from CanFam3.1.
69
82
  backend: Plotting backend ('matplotlib', 'plotly', or 'bokeh').
@@ -78,10 +91,10 @@ class LocusZoomPlotter:
78
91
 
79
92
  Example:
80
93
  >>> # Static plot (default)
81
- >>> plotter = LocusZoomPlotter(species="dog")
94
+ >>> plotter = LocusZoomPlotter(species="canine")
82
95
  >>>
83
96
  >>> # Interactive plot with plotly
84
- >>> plotter = LocusZoomPlotter(species="dog", backend="plotly")
97
+ >>> plotter = LocusZoomPlotter(species="canine", backend="plotly")
85
98
  >>>
86
99
  >>> fig = plotter.plot(
87
100
  ... gwas_df,
@@ -96,7 +109,7 @@ class LocusZoomPlotter:
96
109
 
97
110
  def __init__(
98
111
  self,
99
- species: str = "dog",
112
+ species: str = "canine",
100
113
  genome_build: Optional[str] = None,
101
114
  backend: BackendType = "matplotlib",
102
115
  plink_path: Optional[str] = None,
@@ -126,9 +139,9 @@ class LocusZoomPlotter:
126
139
  @staticmethod
127
140
  def _default_build(species: str) -> Optional[str]:
128
141
  """Get default genome build for species."""
129
- if species == "dog":
142
+ if species == "canine":
130
143
  return "canfam3.1"
131
- if species == "cat":
144
+ if species == "feline":
132
145
  return "felCat9"
133
146
  return None
134
147
 
@@ -137,7 +150,7 @@ class LocusZoomPlotter:
137
150
 
138
151
  Returns path to recombination map directory, or None if not available.
139
152
  """
140
- if self.species == "dog":
153
+ if self.species == "canine":
141
154
  if self.recomb_data_dir:
142
155
  return Path(self.recomb_data_dir)
143
156
  # Check if already downloaded
@@ -149,7 +162,7 @@ class LocusZoomPlotter:
149
162
  return default_dir
150
163
  # Download
151
164
  try:
152
- return download_dog_recombination_maps()
165
+ return download_canine_recombination_maps()
153
166
  except Exception as e:
154
167
  logger.warning(f"Could not download recombination maps: {e}")
155
168
  return None
@@ -249,20 +262,27 @@ class LocusZoomPlotter:
249
262
 
250
263
  # Calculate LD if reference file provided
251
264
  if ld_reference_file and lead_pos and ld_col is None:
252
- lead_snp_row = df[df[pos_col] == lead_pos]
253
- if not lead_snp_row.empty:
254
- lead_snp_id = lead_snp_row[rs_col].iloc[0]
255
- logger.debug(f"Calculating LD for lead SNP {lead_snp_id}")
256
- ld_df = calculate_ld(
257
- bfile_path=ld_reference_file,
258
- lead_snp=lead_snp_id,
259
- window_kb=max((end - start) // 1000, 500),
260
- plink_path=self.plink_path,
261
- species=self.species,
265
+ # Check if rs_col exists before attempting LD calculation
266
+ if rs_col not in df.columns:
267
+ logger.warning(
268
+ f"Cannot calculate LD: column '{rs_col}' not found in GWAS data. "
269
+ f"Provide rs_col parameter or add SNP IDs to DataFrame."
262
270
  )
263
- if not ld_df.empty:
264
- df = df.merge(ld_df, left_on=rs_col, right_on="SNP", how="left")
265
- ld_col = "R2"
271
+ else:
272
+ lead_snp_row = df[df[pos_col] == lead_pos]
273
+ if not lead_snp_row.empty:
274
+ lead_snp_id = lead_snp_row[rs_col].iloc[0]
275
+ logger.debug(f"Calculating LD for lead SNP {lead_snp_id}")
276
+ ld_df = calculate_ld(
277
+ bfile_path=ld_reference_file,
278
+ lead_snp=lead_snp_id,
279
+ window_kb=max((end - start) // 1000, 500),
280
+ plink_path=self.plink_path,
281
+ species=self.species,
282
+ )
283
+ if not ld_df.empty:
284
+ df = df.merge(ld_df, left_on=rs_col, right_on="SNP", how="left")
285
+ ld_col = "R2"
266
286
 
267
287
  # Load recombination data if needed
268
288
  if show_recombination and recomb_df is None:
@@ -272,61 +292,70 @@ class LocusZoomPlotter:
272
292
  fig, ax, gene_ax = self._create_figure(genes_df, chrom, start, end, figsize)
273
293
 
274
294
  # Plot association data
275
- self._plot_association(ax, df, pos_col, ld_col, lead_pos)
295
+ self._plot_association(ax, df, pos_col, ld_col, lead_pos, rs_col, p_col)
276
296
 
277
297
  # Add significance line
278
- ax.axhline(
298
+ self._backend.axhline(
299
+ ax,
279
300
  y=self._genomewide_line,
280
- color="grey",
301
+ color="red",
281
302
  linestyle="--",
282
303
  linewidth=1,
304
+ alpha=0.65,
283
305
  zorder=1,
284
306
  )
285
307
 
286
- # Add SNP labels
308
+ # Add SNP labels (matplotlib only - interactive backends use hover tooltips)
287
309
  if snp_labels and rs_col in df.columns and label_top_n > 0 and not df.empty:
288
- add_snp_labels(
289
- ax,
290
- df,
291
- pos_col=pos_col,
292
- neglog10p_col="neglog10p",
293
- rs_col=rs_col,
294
- label_top_n=label_top_n,
295
- genes_df=genes_df,
296
- chrom=chrom,
297
- )
310
+ if self.backend_name == "matplotlib":
311
+ add_snp_labels(
312
+ ax,
313
+ df,
314
+ pos_col=pos_col,
315
+ neglog10p_col="neglog10p",
316
+ rs_col=rs_col,
317
+ label_top_n=label_top_n,
318
+ genes_df=genes_df,
319
+ chrom=chrom,
320
+ )
298
321
 
299
- # Add recombination overlay
322
+ # Add recombination overlay (all backends)
300
323
  if recomb_df is not None and not recomb_df.empty:
301
- add_recombination_overlay(ax, recomb_df, start, end)
324
+ if self.backend_name == "matplotlib":
325
+ add_recombination_overlay(ax, recomb_df, start, end)
326
+ else:
327
+ self._add_recombination_overlay_generic(ax, recomb_df, start, end)
302
328
 
303
329
  # Format axes
304
- ax.set_ylabel(r"$-\log_{10}$ P")
305
- ax.set_xlim(start, end)
306
- ax.spines["top"].set_visible(False)
307
- ax.spines["right"].set_visible(False)
330
+ self._backend.set_ylabel(ax, r"$-\log_{10}$ P")
331
+ self._backend.set_xlim(ax, start, end)
332
+ self._backend.hide_spines(ax, ["top", "right"])
308
333
 
309
- # Add LD legend
334
+ # Add LD legend (all backends)
310
335
  if ld_col is not None and ld_col in df.columns:
311
- self._add_ld_legend(ax)
336
+ if self.backend_name == "matplotlib":
337
+ self._add_ld_legend(ax)
338
+ else:
339
+ self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
312
340
 
313
- # Plot gene track
341
+ # Plot gene track (all backends)
314
342
  if genes_df is not None and gene_ax is not None:
315
- plot_gene_track(gene_ax, genes_df, chrom, start, end, exons_df)
316
- gene_ax.set_xlabel(f"Chromosome {chrom} (Mb)")
317
- gene_ax.spines["top"].set_visible(False)
318
- gene_ax.spines["right"].set_visible(False)
319
- gene_ax.spines["left"].set_visible(False)
343
+ if self.backend_name == "matplotlib":
344
+ plot_gene_track(gene_ax, genes_df, chrom, start, end, exons_df)
345
+ else:
346
+ plot_gene_track_generic(
347
+ gene_ax, self._backend, genes_df, chrom, start, end, exons_df
348
+ )
349
+ self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
350
+ self._backend.hide_spines(gene_ax, ["top", "right", "left"])
320
351
  else:
321
- ax.set_xlabel(f"Chromosome {chrom} (Mb)")
352
+ self._backend.set_xlabel(ax, f"Chromosome {chrom} (Mb)")
322
353
 
323
354
  # Format x-axis with Mb labels
324
- ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x / 1e6:.2f}"))
325
- ax.xaxis.set_major_locator(MaxNLocator(nbins=6))
355
+ self._backend.format_xaxis_mb(ax)
326
356
 
327
357
  # Adjust layout
328
- fig.subplots_adjust(left=0.08, right=0.95, top=0.95, bottom=0.1, hspace=0.08)
329
- plt.ion()
358
+ self._backend.finalize_layout(fig, hspace=0.1)
330
359
 
331
360
  return fig
332
361
 
@@ -364,18 +393,20 @@ class LocusZoomPlotter:
364
393
  assoc_height = figsize[1] * 0.6
365
394
  total_height = assoc_height + gene_track_height
366
395
 
367
- fig, axes = plt.subplots(
368
- 2,
369
- 1,
370
- figsize=(figsize[0], total_height),
396
+ fig, axes = self._backend.create_figure(
397
+ n_panels=2,
371
398
  height_ratios=[assoc_height, gene_track_height],
399
+ figsize=(figsize[0], total_height),
372
400
  sharex=True,
373
- gridspec_kw={"hspace": 0},
374
401
  )
375
402
  return fig, axes[0], axes[1]
376
403
  else:
377
- fig, ax = plt.subplots(figsize=(figsize[0], figsize[1] * 0.75))
378
- return fig, ax, None
404
+ fig, axes = self._backend.create_figure(
405
+ n_panels=1,
406
+ height_ratios=[1.0],
407
+ figsize=(figsize[0], figsize[1] * 0.75),
408
+ )
409
+ return fig, axes[0], None
379
410
 
380
411
  def _plot_association(
381
412
  self,
@@ -384,8 +415,28 @@ class LocusZoomPlotter:
384
415
  pos_col: str,
385
416
  ld_col: Optional[str],
386
417
  lead_pos: Optional[int],
418
+ rs_col: Optional[str] = None,
419
+ p_col: Optional[str] = None,
387
420
  ) -> None:
388
421
  """Plot association scatter with LD coloring."""
422
+
423
+ def _build_hover_data(subset_df: pd.DataFrame) -> Optional[pd.DataFrame]:
424
+ """Build hover data for interactive backends."""
425
+ hover_cols = {}
426
+ # RS ID first (will be bold in hover)
427
+ if rs_col and rs_col in subset_df.columns:
428
+ hover_cols["SNP"] = subset_df[rs_col].values
429
+ # Position
430
+ if pos_col in subset_df.columns:
431
+ hover_cols["Position"] = subset_df[pos_col].values
432
+ # P-value
433
+ if p_col and p_col in subset_df.columns:
434
+ hover_cols["P-value"] = subset_df[p_col].values
435
+ # LD
436
+ if ld_col and ld_col in subset_df.columns:
437
+ hover_cols["R²"] = subset_df[ld_col].values
438
+ return pd.DataFrame(hover_cols) if hover_cols else None
439
+
389
440
  # LD-based coloring
390
441
  if ld_col is not None and ld_col in df.columns:
391
442
  df["ld_bin"] = df[ld_col].apply(get_ld_bin)
@@ -394,40 +445,46 @@ class LocusZoomPlotter:
394
445
  palette = get_ld_color_palette()
395
446
  for bin_label in df["ld_bin"].unique():
396
447
  bin_data = df[df["ld_bin"] == bin_label]
397
- ax.scatter(
448
+ self._backend.scatter(
449
+ ax,
398
450
  bin_data[pos_col],
399
451
  bin_data["neglog10p"],
400
- c=palette.get(bin_label, "#BEBEBE"),
401
- s=60,
452
+ colors=palette.get(bin_label, "#BEBEBE"),
453
+ sizes=60,
402
454
  edgecolor="black",
403
455
  linewidth=0.5,
404
456
  zorder=2,
457
+ hover_data=_build_hover_data(bin_data),
405
458
  )
406
459
  else:
407
460
  # Default: grey points
408
- ax.scatter(
461
+ self._backend.scatter(
462
+ ax,
409
463
  df[pos_col],
410
464
  df["neglog10p"],
411
- c="#BEBEBE",
412
- s=60,
465
+ colors="#BEBEBE",
466
+ sizes=60,
413
467
  edgecolor="black",
414
468
  linewidth=0.5,
415
469
  zorder=2,
470
+ hover_data=_build_hover_data(df),
416
471
  )
417
472
 
418
- # Highlight lead SNP
473
+ # Highlight lead SNP with larger, more prominent marker
419
474
  if lead_pos is not None:
420
475
  lead_snp = df[df[pos_col] == lead_pos]
421
476
  if not lead_snp.empty:
422
- ax.scatter(
477
+ self._backend.scatter(
478
+ ax,
423
479
  lead_snp[pos_col],
424
480
  lead_snp["neglog10p"],
425
- c=LEAD_SNP_COLOR,
426
- s=120,
481
+ colors=LEAD_SNP_COLOR,
482
+ sizes=120, # Larger than regular points for visibility
427
483
  marker="D",
428
- edgecolors="black",
429
- linewidths=1,
484
+ edgecolor="black",
485
+ linewidth=1.5,
430
486
  zorder=10,
487
+ hover_data=_build_hover_data(lead_snp),
431
488
  )
432
489
 
433
490
  def _add_ld_legend(self, ax: Axes) -> None:
@@ -441,8 +498,8 @@ class LocusZoomPlotter:
441
498
  color="w",
442
499
  markerfacecolor=LEAD_SNP_COLOR,
443
500
  markeredgecolor="black",
444
- markersize=8,
445
- label="Index SNP",
501
+ markersize=6,
502
+ label="Lead SNP",
446
503
  ),
447
504
  ]
448
505
 
@@ -457,7 +514,7 @@ class LocusZoomPlotter:
457
514
 
458
515
  ax.legend(
459
516
  handles=legend_elements,
460
- loc="upper left",
517
+ loc="upper right",
461
518
  fontsize=9,
462
519
  frameon=True,
463
520
  framealpha=0.9,
@@ -468,6 +525,249 @@ class LocusZoomPlotter:
468
525
  labelspacing=0.4,
469
526
  )
470
527
 
528
+ def _add_recombination_overlay_generic(
529
+ self,
530
+ ax: Any,
531
+ recomb_df: pd.DataFrame,
532
+ start: int,
533
+ end: int,
534
+ ) -> None:
535
+ """Add recombination overlay for interactive backends (plotly/bokeh).
536
+
537
+ Creates a secondary y-axis with recombination rate line and fill.
538
+ """
539
+ # Filter to region
540
+ region_recomb = recomb_df[
541
+ (recomb_df["pos"] >= start) & (recomb_df["pos"] <= end)
542
+ ].copy()
543
+
544
+ if region_recomb.empty:
545
+ return
546
+
547
+ # Create secondary y-axis
548
+ yaxis_name = self._backend.create_twin_axis(ax)
549
+
550
+ # For plotly, yaxis_name is a tuple (fig, row, secondary_y)
551
+ # For bokeh, yaxis_name is just a string
552
+ if isinstance(yaxis_name, tuple):
553
+ _, _, secondary_y = yaxis_name
554
+ else:
555
+ secondary_y = yaxis_name
556
+
557
+ # Plot fill under curve
558
+ self._backend.fill_between_secondary(
559
+ ax,
560
+ region_recomb["pos"],
561
+ 0,
562
+ region_recomb["rate"],
563
+ color=RECOMB_COLOR,
564
+ alpha=0.15,
565
+ yaxis_name=secondary_y,
566
+ )
567
+
568
+ # Plot recombination rate line
569
+ self._backend.line_secondary(
570
+ ax,
571
+ region_recomb["pos"],
572
+ region_recomb["rate"],
573
+ color=RECOMB_COLOR,
574
+ linewidth=1.5,
575
+ alpha=0.7,
576
+ yaxis_name=secondary_y,
577
+ )
578
+
579
+ # Set y-axis limits and label
580
+ max_rate = region_recomb["rate"].max()
581
+ self._backend.set_secondary_ylim(
582
+ ax, 0, max(max_rate * 1.2, 20), yaxis_name=secondary_y
583
+ )
584
+ self._backend.set_secondary_ylabel(
585
+ ax,
586
+ "Recombination rate (cM/Mb)",
587
+ color=RECOMB_COLOR,
588
+ fontsize=9,
589
+ yaxis_name=secondary_y,
590
+ )
591
+
592
+ def _add_eqtl_legend(self, ax: Axes) -> None:
593
+ """Add eQTL effect size legend to plot."""
594
+ legend_elements = []
595
+
596
+ # Positive effects (upward triangles)
597
+ for _, _, label, color in EQTL_POSITIVE_BINS:
598
+ legend_elements.append(
599
+ Line2D(
600
+ [0],
601
+ [0],
602
+ marker="^",
603
+ color="w",
604
+ markerfacecolor=color,
605
+ markeredgecolor="black",
606
+ markersize=7,
607
+ label=label,
608
+ )
609
+ )
610
+
611
+ # Negative effects (downward triangles)
612
+ for _, _, label, color in EQTL_NEGATIVE_BINS:
613
+ legend_elements.append(
614
+ Line2D(
615
+ [0],
616
+ [0],
617
+ marker="v",
618
+ color="w",
619
+ markerfacecolor=color,
620
+ markeredgecolor="black",
621
+ markersize=7,
622
+ label=label,
623
+ )
624
+ )
625
+
626
+ ax.legend(
627
+ handles=legend_elements,
628
+ loc="upper right",
629
+ fontsize=8,
630
+ frameon=True,
631
+ framealpha=0.9,
632
+ title="eQTL effect",
633
+ title_fontsize=9,
634
+ handlelength=1.2,
635
+ handleheight=1.0,
636
+ labelspacing=0.3,
637
+ )
638
+
639
+ def _plot_finemapping(
640
+ self,
641
+ ax: Axes,
642
+ df: pd.DataFrame,
643
+ pos_col: str = "pos",
644
+ pip_col: str = "pip",
645
+ cs_col: Optional[str] = "cs",
646
+ show_credible_sets: bool = True,
647
+ pip_threshold: float = 0.0,
648
+ ) -> None:
649
+ """Plot fine-mapping results (PIP line with credible set coloring).
650
+
651
+ Args:
652
+ ax: Matplotlib axes object.
653
+ df: Fine-mapping DataFrame with pos and pip columns.
654
+ pos_col: Column name for position.
655
+ pip_col: Column name for posterior inclusion probability.
656
+ cs_col: Column name for credible set assignment (optional).
657
+ show_credible_sets: Whether to color points by credible set.
658
+ pip_threshold: Minimum PIP to display as scatter point.
659
+ """
660
+ # Sort by position for line plotting
661
+ df = df.sort_values(pos_col)
662
+
663
+ # Plot PIP as line
664
+ self._backend.line(
665
+ ax,
666
+ df[pos_col],
667
+ df[pip_col],
668
+ color=PIP_LINE_COLOR,
669
+ linewidth=1.5,
670
+ alpha=0.8,
671
+ zorder=1,
672
+ )
673
+
674
+ # Check if credible sets are available
675
+ has_cs = cs_col is not None and cs_col in df.columns and show_credible_sets
676
+ credible_sets = get_credible_sets(df, cs_col) if has_cs else []
677
+
678
+ if credible_sets:
679
+ # Plot points colored by credible set
680
+ for cs_id in credible_sets:
681
+ cs_data = df[df[cs_col] == cs_id]
682
+ color = get_credible_set_color(cs_id)
683
+ self._backend.scatter(
684
+ ax,
685
+ cs_data[pos_col],
686
+ cs_data[pip_col],
687
+ colors=color,
688
+ sizes=50,
689
+ marker="o",
690
+ edgecolor="black",
691
+ linewidth=0.5,
692
+ zorder=3,
693
+ label=f"CS{cs_id}",
694
+ )
695
+ # Plot variants not in any credible set
696
+ non_cs_data = df[(df[cs_col].isna()) | (df[cs_col] == 0)]
697
+ if not non_cs_data.empty and pip_threshold > 0:
698
+ non_cs_data = non_cs_data[non_cs_data[pip_col] >= pip_threshold]
699
+ if not non_cs_data.empty:
700
+ self._backend.scatter(
701
+ ax,
702
+ non_cs_data[pos_col],
703
+ non_cs_data[pip_col],
704
+ colors="#BEBEBE",
705
+ sizes=30,
706
+ marker="o",
707
+ edgecolor="black",
708
+ linewidth=0.3,
709
+ zorder=2,
710
+ )
711
+ else:
712
+ # No credible sets - show all points above threshold
713
+ if pip_threshold > 0:
714
+ high_pip = df[df[pip_col] >= pip_threshold]
715
+ if not high_pip.empty:
716
+ self._backend.scatter(
717
+ ax,
718
+ high_pip[pos_col],
719
+ high_pip[pip_col],
720
+ colors=PIP_LINE_COLOR,
721
+ sizes=50,
722
+ marker="o",
723
+ edgecolor="black",
724
+ linewidth=0.5,
725
+ zorder=3,
726
+ )
727
+
728
+ def _add_finemapping_legend(
729
+ self,
730
+ ax: Axes,
731
+ credible_sets: List[int],
732
+ ) -> None:
733
+ """Add fine-mapping legend showing credible sets.
734
+
735
+ Args:
736
+ ax: Matplotlib axes object.
737
+ credible_sets: List of credible set IDs to include.
738
+ """
739
+ if not credible_sets:
740
+ return
741
+
742
+ legend_elements = []
743
+ for cs_id in credible_sets:
744
+ color = get_credible_set_color(cs_id)
745
+ legend_elements.append(
746
+ Line2D(
747
+ [0],
748
+ [0],
749
+ marker="o",
750
+ color="w",
751
+ markerfacecolor=color,
752
+ markeredgecolor="black",
753
+ markersize=7,
754
+ label=f"CS{cs_id}",
755
+ )
756
+ )
757
+
758
+ ax.legend(
759
+ handles=legend_elements,
760
+ loc="upper right",
761
+ fontsize=8,
762
+ frameon=True,
763
+ framealpha=0.9,
764
+ title="Credible sets",
765
+ title_fontsize=9,
766
+ handlelength=1.2,
767
+ handleheight=1.0,
768
+ labelspacing=0.3,
769
+ )
770
+
471
771
  def plot_stacked(
472
772
  self,
473
773
  gwas_dfs: List[pd.DataFrame],
@@ -478,10 +778,13 @@ class LocusZoomPlotter:
478
778
  panel_labels: Optional[List[str]] = None,
479
779
  ld_reference_file: Optional[str] = None,
480
780
  ld_reference_files: Optional[List[str]] = None,
781
+ ld_col: Optional[str] = None,
481
782
  genes_df: Optional[pd.DataFrame] = None,
482
783
  exons_df: Optional[pd.DataFrame] = None,
483
784
  eqtl_df: Optional[pd.DataFrame] = None,
484
785
  eqtl_gene: Optional[str] = None,
786
+ finemapping_df: Optional[pd.DataFrame] = None,
787
+ finemapping_cs_col: Optional[str] = "cs",
485
788
  recomb_df: Optional[pd.DataFrame] = None,
486
789
  show_recombination: bool = True,
487
790
  snp_labels: bool = True,
@@ -506,10 +809,15 @@ class LocusZoomPlotter:
506
809
  panel_labels: Labels for each panel (e.g., phenotype names).
507
810
  ld_reference_file: Single PLINK fileset for all panels.
508
811
  ld_reference_files: List of PLINK filesets (one per panel).
812
+ ld_col: Column name for pre-computed LD (R²) values in each DataFrame.
813
+ Use this if LD was calculated externally.
509
814
  genes_df: Gene annotations for bottom track.
510
815
  exons_df: Exon annotations for gene track.
511
816
  eqtl_df: eQTL data to display as additional panel.
512
817
  eqtl_gene: Filter eQTL data to this target gene.
818
+ finemapping_df: Fine-mapping/SuSiE results with pos and pip columns.
819
+ Displayed as PIP line with optional credible set coloring.
820
+ finemapping_cs_col: Column name for credible set assignment in finemapping_df.
513
821
  recomb_df: Pre-loaded recombination rate data.
514
822
  show_recombination: Whether to show recombination overlay.
515
823
  snp_labels: Whether to label top SNPs.
@@ -534,11 +842,30 @@ class LocusZoomPlotter:
534
842
  if n_gwas == 0:
535
843
  raise ValueError("At least one GWAS DataFrame required")
536
844
 
845
+ # Validate list lengths match
846
+ if lead_positions is not None and len(lead_positions) != n_gwas:
847
+ raise ValueError(
848
+ f"lead_positions length ({len(lead_positions)}) must match "
849
+ f"number of GWAS DataFrames ({n_gwas})"
850
+ )
851
+ if panel_labels is not None and len(panel_labels) != n_gwas:
852
+ raise ValueError(
853
+ f"panel_labels length ({len(panel_labels)}) must match "
854
+ f"number of GWAS DataFrames ({n_gwas})"
855
+ )
856
+ if ld_reference_files is not None and len(ld_reference_files) != n_gwas:
857
+ raise ValueError(
858
+ f"ld_reference_files length ({len(ld_reference_files)}) must match "
859
+ f"number of GWAS DataFrames ({n_gwas})"
860
+ )
861
+
537
862
  # Validate inputs
538
863
  for i, df in enumerate(gwas_dfs):
539
864
  validate_gwas_df(df, pos_col=pos_col, p_col=p_col)
540
865
  if genes_df is not None:
541
866
  validate_genes_df(genes_df)
867
+ if eqtl_df is not None:
868
+ validate_eqtl_df(eqtl_df)
542
869
 
543
870
  # Handle lead positions
544
871
  if lead_positions is None:
@@ -558,12 +885,16 @@ class LocusZoomPlotter:
558
885
  # Calculate panel layout
559
886
  panel_height = 2.5 # inches per GWAS panel
560
887
  eqtl_height = 2.0 if eqtl_df is not None else 0
888
+ finemapping_height = 1.5 if finemapping_df is not None else 0
561
889
 
562
890
  # Gene track height
563
891
  if genes_df is not None:
564
892
  chrom_str = normalize_chrom(chrom)
565
893
  region_genes = genes_df[
566
- (genes_df["chr"].astype(str).str.replace("chr", "", regex=False) == chrom_str)
894
+ (
895
+ genes_df["chr"].astype(str).str.replace("chr", "", regex=False)
896
+ == chrom_str
897
+ )
567
898
  & (genes_df["end"] >= start)
568
899
  & (genes_df["start"] <= end)
569
900
  ]
@@ -579,8 +910,15 @@ class LocusZoomPlotter:
579
910
  gene_track_height = 0
580
911
 
581
912
  # Calculate total panels and heights
582
- n_panels = n_gwas + (1 if eqtl_df is not None else 0) + (1 if genes_df is not None else 0)
913
+ n_panels = (
914
+ n_gwas
915
+ + (1 if finemapping_df is not None else 0)
916
+ + (1 if eqtl_df is not None else 0)
917
+ + (1 if genes_df is not None else 0)
918
+ )
583
919
  height_ratios = [panel_height] * n_gwas
920
+ if finemapping_df is not None:
921
+ height_ratios.append(finemapping_height)
584
922
  if eqtl_df is not None:
585
923
  height_ratios.append(eqtl_height)
586
924
  if genes_df is not None:
@@ -590,26 +928,21 @@ class LocusZoomPlotter:
590
928
  total_height = figsize[1] if figsize[1] else sum(height_ratios)
591
929
  actual_figsize = (figsize[0], total_height)
592
930
 
593
- logger.debug(f"Creating stacked plot with {n_panels} panels for chr{chrom}:{start}-{end}")
594
-
595
- # Prevent auto-display in interactive environments
596
- plt.ioff()
931
+ logger.debug(
932
+ f"Creating stacked plot with {n_panels} panels for chr{chrom}:{start}-{end}"
933
+ )
597
934
 
598
935
  # Load recombination data if needed
599
936
  if show_recombination and recomb_df is None:
600
937
  recomb_df = self._get_recomb_for_region(chrom, start, end)
601
938
 
602
- # Create figure
603
- fig, axes = plt.subplots(
604
- n_panels,
605
- 1,
606
- figsize=actual_figsize,
939
+ # Create figure using backend
940
+ fig, axes = self._backend.create_figure(
941
+ n_panels=n_panels,
607
942
  height_ratios=height_ratios,
943
+ figsize=actual_figsize,
608
944
  sharex=True,
609
- gridspec_kw={"hspace": 0.05},
610
945
  )
611
- if n_panels == 1:
612
- axes = [axes]
613
946
 
614
947
  # Plot each GWAS panel
615
948
  for i, (gwas_df, lead_pos) in enumerate(zip(gwas_dfs, lead_positions)):
@@ -617,9 +950,9 @@ class LocusZoomPlotter:
617
950
  df = gwas_df.copy()
618
951
  df["neglog10p"] = -np.log10(df[p_col].clip(lower=1e-300))
619
952
 
620
- # Calculate LD if reference provided
621
- ld_col = None
622
- if ld_reference_files and ld_reference_files[i] and lead_pos:
953
+ # Use pre-computed LD or calculate from reference
954
+ panel_ld_col = ld_col
955
+ if ld_reference_files and ld_reference_files[i] and lead_pos and not ld_col:
623
956
  lead_snp_row = df[df[pos_col] == lead_pos]
624
957
  if not lead_snp_row.empty and rs_col in df.columns:
625
958
  lead_snp_id = lead_snp_row[rs_col].iloc[0]
@@ -632,51 +965,135 @@ class LocusZoomPlotter:
632
965
  )
633
966
  if not ld_df.empty:
634
967
  df = df.merge(ld_df, left_on=rs_col, right_on="SNP", how="left")
635
- ld_col = "R2"
968
+ panel_ld_col = "R2"
636
969
 
637
970
  # Plot association
638
- self._plot_association(ax, df, pos_col, ld_col, lead_pos)
971
+ self._plot_association(ax, df, pos_col, panel_ld_col, lead_pos, rs_col, p_col)
639
972
 
640
973
  # Add significance line
641
- ax.axhline(y=self._genomewide_line, color="grey", linestyle="--", linewidth=1, zorder=1)
974
+ self._backend.axhline(
975
+ ax,
976
+ y=self._genomewide_line,
977
+ color="red",
978
+ linestyle="--",
979
+ linewidth=1,
980
+ alpha=0.65,
981
+ zorder=1,
982
+ )
642
983
 
643
- # Add SNP labels
984
+ # Add SNP labels (matplotlib only - interactive backends use hover tooltips)
644
985
  if snp_labels and rs_col in df.columns and label_top_n > 0 and not df.empty:
645
- add_snp_labels(
646
- ax, df, pos_col=pos_col, neglog10p_col="neglog10p",
647
- rs_col=rs_col, label_top_n=label_top_n, genes_df=genes_df, chrom=chrom,
648
- )
986
+ if self.backend_name == "matplotlib":
987
+ add_snp_labels(
988
+ ax,
989
+ df,
990
+ pos_col=pos_col,
991
+ neglog10p_col="neglog10p",
992
+ rs_col=rs_col,
993
+ label_top_n=label_top_n,
994
+ genes_df=genes_df,
995
+ chrom=chrom,
996
+ )
649
997
 
650
- # Add recombination overlay (only on first panel)
998
+ # Add recombination overlay (only on first panel, all backends)
651
999
  if i == 0 and recomb_df is not None and not recomb_df.empty:
652
- add_recombination_overlay(ax, recomb_df, start, end)
1000
+ if self.backend_name == "matplotlib":
1001
+ add_recombination_overlay(ax, recomb_df, start, end)
1002
+ else:
1003
+ self._add_recombination_overlay_generic(ax, recomb_df, start, end)
653
1004
 
654
1005
  # Format axes
655
- ax.set_ylabel(r"$-\log_{10}$ P")
656
- ax.set_xlim(start, end)
657
- ax.spines["top"].set_visible(False)
658
- ax.spines["right"].set_visible(False)
1006
+ self._backend.set_ylabel(ax, r"$-\log_{10}$ P")
1007
+ self._backend.set_xlim(ax, start, end)
1008
+ self._backend.hide_spines(ax, ["top", "right"])
659
1009
 
660
1010
  # Add panel label
661
1011
  if panel_labels and i < len(panel_labels):
662
- ax.annotate(
663
- panel_labels[i],
664
- xy=(0.02, 0.95),
665
- xycoords="axes fraction",
666
- fontsize=11,
667
- fontweight="bold",
668
- va="top",
669
- ha="left",
1012
+ if self.backend_name == "matplotlib":
1013
+ ax.annotate(
1014
+ panel_labels[i],
1015
+ xy=(0.02, 0.95),
1016
+ xycoords="axes fraction",
1017
+ fontsize=11,
1018
+ fontweight="bold",
1019
+ va="top",
1020
+ ha="left",
1021
+ )
1022
+ elif self.backend_name == "plotly":
1023
+ fig, row = ax
1024
+ fig.add_annotation(
1025
+ text=f"<b>{panel_labels[i]}</b>",
1026
+ xref=f"x{row} domain" if row > 1 else "x domain",
1027
+ yref=f"y{row} domain" if row > 1 else "y domain",
1028
+ x=0.02,
1029
+ y=0.95,
1030
+ showarrow=False,
1031
+ font=dict(size=11),
1032
+ xanchor="left",
1033
+ yanchor="top",
1034
+ )
1035
+ elif self.backend_name == "bokeh":
1036
+ from bokeh.models import Label
1037
+
1038
+ # Get y-axis range for positioning
1039
+ y_max = ax.y_range.end if ax.y_range.end else 10
1040
+ x_min = ax.x_range.start if ax.x_range.start else start
1041
+ label = Label(
1042
+ x=x_min + (end - start) * 0.02,
1043
+ y=y_max * 0.95,
1044
+ text=panel_labels[i],
1045
+ text_font_size="11pt",
1046
+ text_font_style="bold",
1047
+ )
1048
+ ax.add_layout(label)
1049
+
1050
+ # Add LD legend (only on first panel, all backends)
1051
+ if i == 0 and panel_ld_col is not None and panel_ld_col in df.columns:
1052
+ if self.backend_name == "matplotlib":
1053
+ self._add_ld_legend(ax)
1054
+ else:
1055
+ self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
1056
+
1057
+ # Track current panel index
1058
+ panel_idx = n_gwas
1059
+
1060
+ # Plot fine-mapping panel if provided
1061
+ if finemapping_df is not None:
1062
+ ax = axes[panel_idx]
1063
+ fm_data = prepare_finemapping_for_plotting(
1064
+ finemapping_df,
1065
+ pos_col="pos",
1066
+ pip_col="pip",
1067
+ chrom=chrom,
1068
+ start=start,
1069
+ end=end,
1070
+ )
1071
+
1072
+ if not fm_data.empty:
1073
+ self._plot_finemapping(
1074
+ ax,
1075
+ fm_data,
1076
+ pos_col="pos",
1077
+ pip_col="pip",
1078
+ cs_col=finemapping_cs_col,
1079
+ show_credible_sets=True,
1080
+ pip_threshold=0.01,
670
1081
  )
671
1082
 
672
- # Add LD legend (only on first panel)
673
- if i == 0 and ld_col is not None and ld_col in df.columns:
674
- self._add_ld_legend(ax)
1083
+ # Add legend for credible sets
1084
+ credible_sets = get_credible_sets(fm_data, finemapping_cs_col)
1085
+ if credible_sets:
1086
+ self._add_finemapping_legend(ax, credible_sets)
1087
+
1088
+ self._backend.set_ylabel(ax, "PIP")
1089
+ self._backend.set_ylim(ax, -0.05, 1.05)
1090
+ self._backend.hide_spines(ax, ["top", "right"])
1091
+ panel_idx += 1
675
1092
 
676
1093
  # Plot eQTL panel if provided
677
- panel_idx = n_gwas
1094
+ eqtl_panel_idx = panel_idx
678
1095
  if eqtl_df is not None:
679
- ax = axes[panel_idx]
1096
+ ax = axes[eqtl_panel_idx]
680
1097
  eqtl_data = eqtl_df.copy()
681
1098
 
682
1099
  # Filter by gene if specified
@@ -685,49 +1102,85 @@ class LocusZoomPlotter:
685
1102
 
686
1103
  # Filter by region
687
1104
  if "pos" in eqtl_data.columns:
688
- eqtl_data = eqtl_data[(eqtl_data["pos"] >= start) & (eqtl_data["pos"] <= end)]
1105
+ eqtl_data = eqtl_data[
1106
+ (eqtl_data["pos"] >= start) & (eqtl_data["pos"] <= end)
1107
+ ]
689
1108
 
690
1109
  if not eqtl_data.empty:
691
- eqtl_data["neglog10p"] = -np.log10(eqtl_data["p_value"].clip(lower=1e-300))
692
-
693
- # Plot as diamonds (different from GWAS circles)
694
- ax.scatter(
695
- eqtl_data["pos"],
696
- eqtl_data["neglog10p"],
697
- c="#FF6B6B",
698
- s=60,
699
- marker="D",
700
- edgecolor="black",
701
- linewidth=0.5,
702
- zorder=2,
703
- label=f"eQTL ({eqtl_gene})" if eqtl_gene else "eQTL",
1110
+ eqtl_data["neglog10p"] = -np.log10(
1111
+ eqtl_data["p_value"].clip(lower=1e-300)
704
1112
  )
705
- ax.legend(loc="upper left", fontsize=9)
706
1113
 
707
- ax.set_ylabel(r"$-\log_{10}$ P (eQTL)")
708
- ax.axhline(y=self._genomewide_line, color="grey", linestyle="--", linewidth=1)
709
- ax.spines["top"].set_visible(False)
710
- ax.spines["right"].set_visible(False)
1114
+ # Check if effect_size column exists for directional coloring
1115
+ has_effect = "effect_size" in eqtl_data.columns
1116
+
1117
+ if has_effect:
1118
+ # Plot triangles by effect direction with color by magnitude
1119
+ for _, row in eqtl_data.iterrows():
1120
+ effect = row["effect_size"]
1121
+ color = get_eqtl_color(effect)
1122
+ marker = "^" if effect >= 0 else "v"
1123
+ self._backend.scatter(
1124
+ ax,
1125
+ pd.Series([row["pos"]]),
1126
+ pd.Series([row["neglog10p"]]),
1127
+ colors=color,
1128
+ sizes=50,
1129
+ marker=marker,
1130
+ edgecolor="black",
1131
+ linewidth=0.5,
1132
+ zorder=2,
1133
+ )
1134
+ # Add eQTL effect legend
1135
+ self._add_eqtl_legend(ax)
1136
+ else:
1137
+ # No effect sizes - plot as diamonds
1138
+ self._backend.scatter(
1139
+ ax,
1140
+ eqtl_data["pos"],
1141
+ eqtl_data["neglog10p"],
1142
+ colors="#FF6B6B",
1143
+ sizes=60,
1144
+ marker="D",
1145
+ edgecolor="black",
1146
+ linewidth=0.5,
1147
+ zorder=2,
1148
+ label=f"eQTL ({eqtl_gene})" if eqtl_gene else "eQTL",
1149
+ )
1150
+ ax.legend(loc="upper right", fontsize=9)
1151
+
1152
+ self._backend.set_ylabel(ax, r"$-\log_{10}$ P (eQTL)")
1153
+ self._backend.axhline(
1154
+ ax,
1155
+ y=self._genomewide_line,
1156
+ color="red",
1157
+ linestyle="--",
1158
+ linewidth=1,
1159
+ alpha=0.65,
1160
+ )
1161
+ self._backend.hide_spines(ax, ["top", "right"])
711
1162
  panel_idx += 1
712
1163
 
713
- # Plot gene track
1164
+ # Plot gene track (all backends)
714
1165
  if genes_df is not None:
715
1166
  gene_ax = axes[panel_idx]
716
- plot_gene_track(gene_ax, genes_df, chrom, start, end, exons_df)
717
- gene_ax.set_xlabel(f"Chromosome {chrom} (Mb)")
718
- gene_ax.spines["top"].set_visible(False)
719
- gene_ax.spines["right"].set_visible(False)
720
- gene_ax.spines["left"].set_visible(False)
1167
+ if self.backend_name == "matplotlib":
1168
+ plot_gene_track(gene_ax, genes_df, chrom, start, end, exons_df)
1169
+ else:
1170
+ plot_gene_track_generic(
1171
+ gene_ax, self._backend, genes_df, chrom, start, end, exons_df
1172
+ )
1173
+ self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
1174
+ self._backend.hide_spines(gene_ax, ["top", "right", "left"])
721
1175
  else:
722
1176
  # Set x-label on bottom panel
723
- axes[-1].set_xlabel(f"Chromosome {chrom} (Mb)")
1177
+ self._backend.set_xlabel(axes[-1], f"Chromosome {chrom} (Mb)")
724
1178
 
725
- # Format x-axis
726
- axes[0].xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x / 1e6:.2f}"))
727
- axes[0].xaxis.set_major_locator(MaxNLocator(nbins=6))
1179
+ # Format x-axis (call for all axes - Plotly needs each subplot formatted)
1180
+ for ax in axes:
1181
+ self._backend.format_xaxis_mb(ax)
728
1182
 
729
1183
  # Adjust layout
730
- fig.subplots_adjust(left=0.08, right=0.95, top=0.95, bottom=0.08, hspace=0.05)
731
- plt.ion()
1184
+ self._backend.finalize_layout(fig, hspace=0.1)
732
1185
 
733
1186
  return fig