pylocuszoom 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylocuszoom/plotter.py CHANGED
@@ -15,12 +15,9 @@ from typing import Any, List, Optional, Tuple
15
15
  import matplotlib.pyplot as plt
16
16
  import numpy as np
17
17
  import pandas as pd
18
- from matplotlib.axes import Axes
19
- from matplotlib.figure import Figure
20
- from matplotlib.lines import Line2D
21
- from matplotlib.patches import Patch
22
18
 
23
19
  from .backends import BackendType, get_backend
20
+ from .backends.hover import HoverConfig, HoverDataBuilder
24
21
  from .colors import (
25
22
  EQTL_NEGATIVE_BINS,
26
23
  EQTL_POSITIVE_BINS,
@@ -31,23 +28,24 @@ from .colors import (
31
28
  get_eqtl_color,
32
29
  get_ld_bin,
33
30
  get_ld_color_palette,
31
+ get_phewas_category_palette,
34
32
  )
33
+ from .ensembl import get_genes_for_region
35
34
  from .eqtl import validate_eqtl_df
36
35
  from .finemapping import (
37
36
  get_credible_sets,
38
37
  prepare_finemapping_for_plotting,
39
38
  )
39
+ from .forest import validate_forest_df
40
40
  from .gene_track import (
41
41
  assign_gene_positions,
42
- plot_gene_track,
43
42
  plot_gene_track_generic,
44
43
  )
45
- from .labels import add_snp_labels
46
44
  from .ld import calculate_ld, find_plink
47
45
  from .logging import enable_logging, logger
46
+ from .phewas import validate_phewas_df
48
47
  from .recombination import (
49
48
  RECOMB_COLOR,
50
- add_recombination_overlay,
51
49
  download_canine_recombination_maps,
52
50
  get_default_data_dir,
53
51
  get_recombination_rate_for_region,
@@ -116,8 +114,21 @@ class LocusZoomPlotter:
116
114
  recomb_data_dir: Optional[str] = None,
117
115
  genomewide_threshold: float = DEFAULT_GENOMEWIDE_THRESHOLD,
118
116
  log_level: Optional[str] = "INFO",
117
+ auto_genes: bool = False,
119
118
  ):
120
- """Initialize the plotter."""
119
+ """Initialize the plotter.
120
+
121
+ Args:
122
+ species: Species name ('canine', 'feline', or None for custom).
123
+ genome_build: Genome build for coordinate system.
124
+ backend: Plotting backend ('matplotlib', 'plotly', or 'bokeh').
125
+ plink_path: Path to PLINK executable for LD calculation.
126
+ recomb_data_dir: Directory containing recombination maps.
127
+ genomewide_threshold: P-value threshold for significance line.
128
+ log_level: Logging level.
129
+ auto_genes: If True, automatically fetch genes from Ensembl when
130
+ genes_df is not provided. Default False for backward compatibility.
131
+ """
121
132
  # Configure logging
122
133
  if log_level is not None:
123
134
  enable_logging(log_level)
@@ -126,12 +137,12 @@ class LocusZoomPlotter:
126
137
  self.genome_build = (
127
138
  genome_build if genome_build else self._default_build(species)
128
139
  )
129
- self.backend_name = backend
130
140
  self._backend = get_backend(backend)
131
141
  self.plink_path = plink_path or find_plink()
132
142
  self.recomb_data_dir = recomb_data_dir
133
143
  self.genomewide_threshold = genomewide_threshold
134
144
  self._genomewide_line = -np.log10(genomewide_threshold)
145
+ self._auto_genes = auto_genes
135
146
 
136
147
  # Cache for loaded data
137
148
  self._recomb_cache = {}
@@ -245,6 +256,22 @@ class LocusZoomPlotter:
245
256
  """
246
257
  # Validate inputs
247
258
  validate_gwas_df(gwas_df, pos_col=pos_col, p_col=p_col)
259
+
260
+ # Auto-fetch genes if enabled and not provided
261
+ if genes_df is None and self._auto_genes:
262
+ logger.debug(
263
+ f"auto_genes enabled, fetching genes for chr{chrom}:{start}-{end}"
264
+ )
265
+ genes_df = get_genes_for_region(
266
+ species=self.species,
267
+ chrom=chrom,
268
+ start=start,
269
+ end=end,
270
+ )
271
+ if genes_df.empty:
272
+ logger.debug("No genes found in region from Ensembl")
273
+ genes_df = None
274
+
248
275
  if genes_df is not None:
249
276
  validate_genes_df(genes_df)
250
277
 
@@ -302,10 +329,10 @@ class LocusZoomPlotter:
302
329
  zorder=1,
303
330
  )
304
331
 
305
- # Add SNP labels (matplotlib only - interactive backends use hover tooltips)
332
+ # Add SNP labels (capability check - interactive backends use hover tooltips)
306
333
  if snp_labels and rs_col in df.columns and label_top_n > 0 and not df.empty:
307
- if self.backend_name == "matplotlib":
308
- add_snp_labels(
334
+ if self._backend.supports_snp_labels:
335
+ self._backend.add_snp_labels(
309
336
  ax,
310
337
  df,
311
338
  pos_col=pos_col,
@@ -316,12 +343,10 @@ class LocusZoomPlotter:
316
343
  chrom=chrom,
317
344
  )
318
345
 
319
- # Add recombination overlay (all backends)
346
+ # Add recombination overlay (all backends with secondary axis support)
320
347
  if recomb_df is not None and not recomb_df.empty:
321
- if self.backend_name == "matplotlib":
322
- add_recombination_overlay(ax, recomb_df, start, end)
323
- else:
324
- self._add_recombination_overlay_generic(ax, recomb_df, start, end)
348
+ if self._backend.supports_secondary_axis:
349
+ self._add_recombination_overlay(ax, recomb_df, start, end)
325
350
 
326
351
  # Format axes
327
352
  self._backend.set_ylabel(ax, r"$-\log_{10}$ P")
@@ -330,19 +355,13 @@ class LocusZoomPlotter:
330
355
 
331
356
  # Add LD legend (all backends)
332
357
  if ld_col is not None and ld_col in df.columns:
333
- if self.backend_name == "matplotlib":
334
- self._add_ld_legend(ax)
335
- else:
336
- self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
358
+ self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
337
359
 
338
- # Plot gene track (all backends)
360
+ # Plot gene track (all backends use generic function)
339
361
  if genes_df is not None and gene_ax is not None:
340
- if self.backend_name == "matplotlib":
341
- plot_gene_track(gene_ax, genes_df, chrom, start, end, exons_df)
342
- else:
343
- plot_gene_track_generic(
344
- gene_ax, self._backend, genes_df, chrom, start, end, exons_df
345
- )
362
+ plot_gene_track_generic(
363
+ gene_ax, self._backend, genes_df, chrom, start, end, exons_df
364
+ )
346
365
  self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
347
366
  self._backend.hide_spines(gene_ax, ["top", "right", "left"])
348
367
  else:
@@ -363,7 +382,7 @@ class LocusZoomPlotter:
363
382
  start: int,
364
383
  end: int,
365
384
  figsize: Tuple[int, int],
366
- ) -> Tuple[Figure, Axes, Optional[Axes]]:
385
+ ) -> Tuple[Any, Any, Optional[Any]]:
367
386
  """Create figure with optional gene track."""
368
387
  if genes_df is not None:
369
388
  # Calculate dynamic height based on gene rows
@@ -407,7 +426,7 @@ class LocusZoomPlotter:
407
426
 
408
427
  def _plot_association(
409
428
  self,
410
- ax: Axes,
429
+ ax: Any,
411
430
  df: pd.DataFrame,
412
431
  pos_col: str,
413
432
  ld_col: Optional[str],
@@ -416,23 +435,14 @@ class LocusZoomPlotter:
416
435
  p_col: Optional[str] = None,
417
436
  ) -> None:
418
437
  """Plot association scatter with LD coloring."""
419
-
420
- def _build_hover_data(subset_df: pd.DataFrame) -> Optional[pd.DataFrame]:
421
- """Build hover data for interactive backends."""
422
- hover_cols = {}
423
- # RS ID first (will be bold in hover)
424
- if rs_col and rs_col in subset_df.columns:
425
- hover_cols["SNP"] = subset_df[rs_col].values
426
- # Position
427
- if pos_col in subset_df.columns:
428
- hover_cols["Position"] = subset_df[pos_col].values
429
- # P-value
430
- if p_col and p_col in subset_df.columns:
431
- hover_cols["P-value"] = subset_df[p_col].values
432
- # LD
433
- if ld_col and ld_col in subset_df.columns:
434
- hover_cols["R²"] = subset_df[ld_col].values
435
- return pd.DataFrame(hover_cols) if hover_cols else None
438
+ # Build hover data using HoverDataBuilder
439
+ hover_config = HoverConfig(
440
+ snp_col=rs_col if rs_col and rs_col in df.columns else None,
441
+ pos_col=pos_col if pos_col in df.columns else None,
442
+ p_col=p_col if p_col and p_col in df.columns else None,
443
+ ld_col=ld_col if ld_col and ld_col in df.columns else None,
444
+ )
445
+ hover_builder = HoverDataBuilder(hover_config)
436
446
 
437
447
  # LD-based coloring
438
448
  if ld_col is not None and ld_col in df.columns:
@@ -451,7 +461,7 @@ class LocusZoomPlotter:
451
461
  edgecolor="black",
452
462
  linewidth=0.5,
453
463
  zorder=2,
454
- hover_data=_build_hover_data(bin_data),
464
+ hover_data=hover_builder.build_dataframe(bin_data),
455
465
  )
456
466
  else:
457
467
  # Default: grey points
@@ -464,7 +474,7 @@ class LocusZoomPlotter:
464
474
  edgecolor="black",
465
475
  linewidth=0.5,
466
476
  zorder=2,
467
- hover_data=_build_hover_data(df),
477
+ hover_data=hover_builder.build_dataframe(df),
468
478
  )
469
479
 
470
480
  # Highlight lead SNP with larger, more prominent marker
@@ -481,57 +491,21 @@ class LocusZoomPlotter:
481
491
  edgecolor="black",
482
492
  linewidth=1.5,
483
493
  zorder=10,
484
- hover_data=_build_hover_data(lead_snp),
494
+ hover_data=hover_builder.build_dataframe(lead_snp),
485
495
  )
486
496
 
487
- def _add_ld_legend(self, ax: Axes) -> None:
488
- """Add LD color legend to plot."""
489
- palette = get_ld_color_palette()
490
- legend_elements = [
491
- Line2D(
492
- [0],
493
- [0],
494
- marker="D",
495
- color="w",
496
- markerfacecolor=LEAD_SNP_COLOR,
497
- markeredgecolor="black",
498
- markersize=6,
499
- label="Lead SNP",
500
- ),
501
- ]
502
-
503
- for threshold, label, _ in LD_BINS:
504
- legend_elements.append(
505
- Patch(
506
- facecolor=palette[label],
507
- edgecolor="black",
508
- label=label,
509
- )
510
- )
511
-
512
- ax.legend(
513
- handles=legend_elements,
514
- loc="upper right",
515
- fontsize=9,
516
- frameon=True,
517
- framealpha=0.9,
518
- title=r"$r^2$",
519
- title_fontsize=10,
520
- handlelength=1.5,
521
- handleheight=1.0,
522
- labelspacing=0.4,
523
- )
524
-
525
- def _add_recombination_overlay_generic(
497
+ def _add_recombination_overlay(
526
498
  self,
527
499
  ax: Any,
528
500
  recomb_df: pd.DataFrame,
529
501
  start: int,
530
502
  end: int,
531
503
  ) -> None:
532
- """Add recombination overlay for interactive backends (plotly/bokeh).
504
+ """Add recombination overlay for all backends.
533
505
 
534
506
  Creates a secondary y-axis with recombination rate line and fill.
507
+ Uses backend-agnostic secondary axis methods that work across
508
+ matplotlib, plotly, and bokeh.
535
509
  """
536
510
  # Filter to region
537
511
  region_recomb = recomb_df[
@@ -588,7 +562,7 @@ class LocusZoomPlotter:
588
562
 
589
563
  def _plot_finemapping(
590
564
  self,
591
- ax: Axes,
565
+ ax: Any,
592
566
  df: pd.DataFrame,
593
567
  pos_col: str = "pos",
594
568
  pip_col: str = "pip",
@@ -607,22 +581,15 @@ class LocusZoomPlotter:
607
581
  show_credible_sets: Whether to color points by credible set.
608
582
  pip_threshold: Minimum PIP to display as scatter point.
609
583
  """
610
-
611
- def _build_finemapping_hover_data(
612
- subset_df: pd.DataFrame,
613
- ) -> Optional[pd.DataFrame]:
614
- """Build hover data for interactive backends."""
615
- hover_cols = {}
616
- # Position
617
- if pos_col in subset_df.columns:
618
- hover_cols["Position"] = subset_df[pos_col].values
619
- # PIP
620
- if pip_col in subset_df.columns:
621
- hover_cols["PIP"] = subset_df[pip_col].values
622
- # Credible set
623
- if cs_col and cs_col in subset_df.columns:
624
- hover_cols["Credible Set"] = subset_df[cs_col].values
625
- return pd.DataFrame(hover_cols) if hover_cols else None
584
+ # Build hover data using HoverDataBuilder
585
+ extra_cols = {pip_col: "PIP"}
586
+ if cs_col and cs_col in df.columns:
587
+ extra_cols[cs_col] = "Credible Set"
588
+ hover_config = HoverConfig(
589
+ pos_col=pos_col if pos_col in df.columns else None,
590
+ extra_cols=extra_cols,
591
+ )
592
+ hover_builder = HoverDataBuilder(hover_config)
626
593
 
627
594
  # Sort by position for line plotting
628
595
  df = df.sort_values(pos_col)
@@ -657,7 +624,7 @@ class LocusZoomPlotter:
657
624
  edgecolor="black",
658
625
  linewidth=0.5,
659
626
  zorder=3,
660
- hover_data=_build_finemapping_hover_data(cs_data),
627
+ hover_data=hover_builder.build_dataframe(cs_data),
661
628
  )
662
629
  # Plot variants not in any credible set
663
630
  non_cs_data = df[(df[cs_col].isna()) | (df[cs_col] == 0)]
@@ -674,7 +641,7 @@ class LocusZoomPlotter:
674
641
  edgecolor="black",
675
642
  linewidth=0.3,
676
643
  zorder=2,
677
- hover_data=_build_finemapping_hover_data(non_cs_data),
644
+ hover_data=hover_builder.build_dataframe(non_cs_data),
678
645
  )
679
646
  else:
680
647
  # No credible sets - show all points above threshold
@@ -691,7 +658,7 @@ class LocusZoomPlotter:
691
658
  edgecolor="black",
692
659
  linewidth=0.5,
693
660
  zorder=3,
694
- hover_data=_build_finemapping_hover_data(high_pip),
661
+ hover_data=hover_builder.build_dataframe(high_pip),
695
662
  )
696
663
 
697
664
  def plot_stacked(
@@ -909,10 +876,10 @@ class LocusZoomPlotter:
909
876
  zorder=1,
910
877
  )
911
878
 
912
- # Add SNP labels (matplotlib only - interactive backends use hover tooltips)
879
+ # Add SNP labels (capability check - interactive backends use hover tooltips)
913
880
  if snp_labels and rs_col in df.columns and label_top_n > 0 and not df.empty:
914
- if self.backend_name == "matplotlib":
915
- add_snp_labels(
881
+ if self._backend.supports_snp_labels:
882
+ self._backend.add_snp_labels(
916
883
  ax,
917
884
  df,
918
885
  pos_col=pos_col,
@@ -925,10 +892,8 @@ class LocusZoomPlotter:
925
892
 
926
893
  # Add recombination overlay (only on first panel, all backends)
927
894
  if i == 0 and recomb_df is not None and not recomb_df.empty:
928
- if self.backend_name == "matplotlib":
929
- add_recombination_overlay(ax, recomb_df, start, end)
930
- else:
931
- self._add_recombination_overlay_generic(ax, recomb_df, start, end)
895
+ if self._backend.supports_secondary_axis:
896
+ self._add_recombination_overlay(ax, recomb_df, start, end)
932
897
 
933
898
  # Format axes
934
899
  self._backend.set_ylabel(ax, r"$-\log_{10}$ P")
@@ -937,50 +902,11 @@ class LocusZoomPlotter:
937
902
 
938
903
  # Add panel label
939
904
  if panel_labels and i < len(panel_labels):
940
- if self.backend_name == "matplotlib":
941
- ax.annotate(
942
- panel_labels[i],
943
- xy=(0.02, 0.95),
944
- xycoords="axes fraction",
945
- fontsize=11,
946
- fontweight="bold",
947
- va="top",
948
- ha="left",
949
- )
950
- elif self.backend_name == "plotly":
951
- fig, row = ax
952
- fig.add_annotation(
953
- text=f"<b>{panel_labels[i]}</b>",
954
- xref=f"x{row} domain" if row > 1 else "x domain",
955
- yref=f"y{row} domain" if row > 1 else "y domain",
956
- x=0.02,
957
- y=0.95,
958
- showarrow=False,
959
- font=dict(size=11),
960
- xanchor="left",
961
- yanchor="top",
962
- )
963
- elif self.backend_name == "bokeh":
964
- from bokeh.models import Label
965
-
966
- # Get y-axis range for positioning
967
- y_max = ax.y_range.end if ax.y_range.end else 10
968
- x_min = ax.x_range.start if ax.x_range.start else start
969
- label = Label(
970
- x=x_min + (end - start) * 0.02,
971
- y=y_max * 0.95,
972
- text=panel_labels[i],
973
- text_font_size="11pt",
974
- text_font_style="bold",
975
- )
976
- ax.add_layout(label)
905
+ self._backend.add_panel_label(ax, panel_labels[i])
977
906
 
978
907
  # Add LD legend (only on first panel, all backends)
979
908
  if i == 0 and panel_ld_col is not None and panel_ld_col in df.columns:
980
- if self.backend_name == "matplotlib":
981
- self._add_ld_legend(ax)
982
- else:
983
- self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
909
+ self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
984
910
 
985
911
  # Track current panel index
986
912
  panel_idx = n_gwas
@@ -1030,35 +956,35 @@ class LocusZoomPlotter:
1030
956
  if eqtl_gene and "gene" in eqtl_data.columns:
1031
957
  eqtl_data = eqtl_data[eqtl_data["gene"] == eqtl_gene]
1032
958
 
1033
- # Filter by region
959
+ # Filter by region (position and chromosome)
1034
960
  if "pos" in eqtl_data.columns:
1035
- eqtl_data = eqtl_data[
1036
- (eqtl_data["pos"] >= start) & (eqtl_data["pos"] <= end)
1037
- ]
961
+ mask = (eqtl_data["pos"] >= start) & (eqtl_data["pos"] <= end)
962
+ # Also filter by chromosome if column exists
963
+ if "chr" in eqtl_data.columns:
964
+ chrom_str = str(chrom).replace("chr", "")
965
+ eqtl_chrom = (
966
+ eqtl_data["chr"].astype(str).str.replace("chr", "", regex=False)
967
+ )
968
+ mask = mask & (eqtl_chrom == chrom_str)
969
+ eqtl_data = eqtl_data[mask]
1038
970
 
1039
971
  if not eqtl_data.empty:
1040
972
  eqtl_data["neglog10p"] = -np.log10(
1041
973
  eqtl_data["p_value"].clip(lower=1e-300)
1042
974
  )
1043
975
 
1044
- def _build_eqtl_hover_data(
1045
- subset_df: pd.DataFrame,
1046
- ) -> Optional[pd.DataFrame]:
1047
- """Build hover data for eQTL interactive backends."""
1048
- hover_cols = {}
1049
- # Position
1050
- if "pos" in subset_df.columns:
1051
- hover_cols["Position"] = subset_df["pos"].values
1052
- # P-value
1053
- if "p_value" in subset_df.columns:
1054
- hover_cols["P-value"] = subset_df["p_value"].values
1055
- # Effect size
1056
- if "effect_size" in subset_df.columns:
1057
- hover_cols["Effect"] = subset_df["effect_size"].values
1058
- # Gene
1059
- if "gene" in subset_df.columns:
1060
- hover_cols["Gene"] = subset_df["gene"].values
1061
- return pd.DataFrame(hover_cols) if hover_cols else None
976
+ # Build hover data using HoverDataBuilder
977
+ eqtl_extra_cols = {}
978
+ if "effect_size" in eqtl_data.columns:
979
+ eqtl_extra_cols["effect_size"] = "Effect"
980
+ if "gene" in eqtl_data.columns:
981
+ eqtl_extra_cols["gene"] = "Gene"
982
+ eqtl_hover_config = HoverConfig(
983
+ pos_col="pos" if "pos" in eqtl_data.columns else None,
984
+ p_col="p_value" if "p_value" in eqtl_data.columns else None,
985
+ extra_cols=eqtl_extra_cols,
986
+ )
987
+ eqtl_hover_builder = HoverDataBuilder(eqtl_hover_config)
1062
988
 
1063
989
  # Check if effect_size column exists for directional coloring
1064
990
  has_effect = "effect_size" in eqtl_data.columns
@@ -1081,7 +1007,7 @@ class LocusZoomPlotter:
1081
1007
  edgecolor="black",
1082
1008
  linewidth=0.5,
1083
1009
  zorder=2,
1084
- hover_data=_build_eqtl_hover_data(row_df),
1010
+ hover_data=eqtl_hover_builder.build_dataframe(row_df),
1085
1011
  )
1086
1012
  # Plot negative effects (down triangles)
1087
1013
  for _, row in neg_effects.iterrows():
@@ -1096,7 +1022,7 @@ class LocusZoomPlotter:
1096
1022
  edgecolor="black",
1097
1023
  linewidth=0.5,
1098
1024
  zorder=2,
1099
- hover_data=_build_eqtl_hover_data(row_df),
1025
+ hover_data=eqtl_hover_builder.build_dataframe(row_df),
1100
1026
  )
1101
1027
  # Add eQTL effect legend (all backends)
1102
1028
  self._backend.add_eqtl_legend(
@@ -1116,7 +1042,7 @@ class LocusZoomPlotter:
1116
1042
  linewidth=0.5,
1117
1043
  zorder=2,
1118
1044
  label=label,
1119
- hover_data=_build_eqtl_hover_data(eqtl_data),
1045
+ hover_data=eqtl_hover_builder.build_dataframe(eqtl_data),
1120
1046
  )
1121
1047
  self._backend.add_simple_legend(ax, label, loc="upper right")
1122
1048
 
@@ -1132,15 +1058,12 @@ class LocusZoomPlotter:
1132
1058
  self._backend.hide_spines(ax, ["top", "right"])
1133
1059
  panel_idx += 1
1134
1060
 
1135
- # Plot gene track (all backends)
1061
+ # Plot gene track (all backends use generic function)
1136
1062
  if genes_df is not None:
1137
1063
  gene_ax = axes[panel_idx]
1138
- if self.backend_name == "matplotlib":
1139
- plot_gene_track(gene_ax, genes_df, chrom, start, end, exons_df)
1140
- else:
1141
- plot_gene_track_generic(
1142
- gene_ax, self._backend, genes_df, chrom, start, end, exons_df
1143
- )
1064
+ plot_gene_track_generic(
1065
+ gene_ax, self._backend, genes_df, chrom, start, end, exons_df
1066
+ )
1144
1067
  self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
1145
1068
  self._backend.hide_spines(gene_ax, ["top", "right", "left"])
1146
1069
  else:
@@ -1155,3 +1078,260 @@ class LocusZoomPlotter:
1155
1078
  self._backend.finalize_layout(fig, hspace=0.1)
1156
1079
 
1157
1080
  return fig
1081
+
1082
+ def plot_phewas(
1083
+ self,
1084
+ phewas_df: pd.DataFrame,
1085
+ variant_id: str,
1086
+ phenotype_col: str = "phenotype",
1087
+ p_col: str = "p_value",
1088
+ category_col: str = "category",
1089
+ effect_col: Optional[str] = None,
1090
+ significance_threshold: float = 5e-8,
1091
+ figsize: Tuple[float, float] = (10, 8),
1092
+ ) -> Any:
1093
+ """Create a PheWAS (Phenome-Wide Association Study) plot.
1094
+
1095
+ Shows associations of a single variant across multiple phenotypes,
1096
+ with phenotypes grouped by category and colored accordingly.
1097
+
1098
+ Args:
1099
+ phewas_df: DataFrame with phenotype associations.
1100
+ variant_id: Variant identifier (e.g., "rs12345") for plot title.
1101
+ phenotype_col: Column name for phenotype names.
1102
+ p_col: Column name for p-values.
1103
+ category_col: Column name for phenotype categories.
1104
+ effect_col: Optional column name for effect direction (beta/OR).
1105
+ significance_threshold: P-value threshold for significance line.
1106
+ figsize: Figure size as (width, height).
1107
+
1108
+ Returns:
1109
+ Figure object (type depends on backend).
1110
+
1111
+ Example:
1112
+ >>> fig = plotter.plot_phewas(
1113
+ ... phewas_df,
1114
+ ... variant_id="rs12345",
1115
+ ... category_col="category",
1116
+ ... )
1117
+ """
1118
+ validate_phewas_df(phewas_df, phenotype_col, p_col, category_col)
1119
+
1120
+ df = phewas_df.copy()
1121
+ df["neglog10p"] = -np.log10(df[p_col].clip(lower=1e-300))
1122
+
1123
+ # Sort by category then by p-value for consistent ordering
1124
+ if category_col in df.columns:
1125
+ df = df.sort_values([category_col, p_col])
1126
+ categories = df[category_col].unique().tolist()
1127
+ palette = get_phewas_category_palette(categories)
1128
+ else:
1129
+ df = df.sort_values(p_col)
1130
+ categories = []
1131
+ palette = {}
1132
+
1133
+ # Create figure
1134
+ fig, axes = self._backend.create_figure(
1135
+ n_panels=1,
1136
+ height_ratios=[1.0],
1137
+ figsize=figsize,
1138
+ )
1139
+ ax = axes[0]
1140
+
1141
+ # Assign y-positions (one per phenotype)
1142
+ df["y_pos"] = range(len(df))
1143
+
1144
+ # Plot points by category
1145
+ if categories:
1146
+ for cat in categories:
1147
+ cat_data = df[df[category_col] == cat]
1148
+ # Use upward triangles for positive effects, circles otherwise
1149
+ if effect_col and effect_col in cat_data.columns:
1150
+ for _, row in cat_data.iterrows():
1151
+ marker = "^" if row[effect_col] >= 0 else "v"
1152
+ self._backend.scatter(
1153
+ ax,
1154
+ pd.Series([row["neglog10p"]]),
1155
+ pd.Series([row["y_pos"]]),
1156
+ colors=palette[cat],
1157
+ sizes=60,
1158
+ marker=marker,
1159
+ edgecolor="black",
1160
+ linewidth=0.5,
1161
+ zorder=2,
1162
+ )
1163
+ else:
1164
+ self._backend.scatter(
1165
+ ax,
1166
+ cat_data["neglog10p"],
1167
+ cat_data["y_pos"],
1168
+ colors=palette[cat],
1169
+ sizes=60,
1170
+ marker="o",
1171
+ edgecolor="black",
1172
+ linewidth=0.5,
1173
+ zorder=2,
1174
+ )
1175
+ else:
1176
+ self._backend.scatter(
1177
+ ax,
1178
+ df["neglog10p"],
1179
+ df["y_pos"],
1180
+ colors="#4169E1",
1181
+ sizes=60,
1182
+ edgecolor="black",
1183
+ linewidth=0.5,
1184
+ zorder=2,
1185
+ )
1186
+
1187
+ # Add significance threshold line
1188
+ sig_line = -np.log10(significance_threshold)
1189
+ self._backend.axvline(
1190
+ ax, x=sig_line, color="red", linestyle="--", linewidth=1, alpha=0.7
1191
+ )
1192
+
1193
+ # Set axis labels and limits
1194
+ self._backend.set_xlabel(ax, r"$-\log_{10}$ P")
1195
+ self._backend.set_ylabel(ax, "Phenotype")
1196
+ self._backend.set_ylim(ax, -0.5, len(df) - 0.5)
1197
+
1198
+ # Set y-tick labels to phenotype names
1199
+ self._backend.set_yticks(
1200
+ ax,
1201
+ positions=df["y_pos"].tolist(),
1202
+ labels=df[phenotype_col].tolist(),
1203
+ fontsize=8,
1204
+ )
1205
+
1206
+ self._backend.set_title(ax, f"PheWAS: {variant_id}")
1207
+ self._backend.hide_spines(ax, ["top", "right"])
1208
+ self._backend.finalize_layout(fig)
1209
+
1210
+ return fig
1211
+
1212
+ def plot_forest(
1213
+ self,
1214
+ forest_df: pd.DataFrame,
1215
+ variant_id: str,
1216
+ study_col: str = "study",
1217
+ effect_col: str = "effect",
1218
+ ci_lower_col: str = "ci_lower",
1219
+ ci_upper_col: str = "ci_upper",
1220
+ weight_col: Optional[str] = None,
1221
+ null_value: float = 0.0,
1222
+ effect_label: str = "Effect Size",
1223
+ figsize: Tuple[float, float] = (8, 6),
1224
+ ) -> Any:
1225
+ """Create a forest plot showing effect sizes with confidence intervals.
1226
+
1227
+ Args:
1228
+ forest_df: DataFrame with effect sizes and confidence intervals.
1229
+ variant_id: Variant identifier for plot title.
1230
+ study_col: Column name for study/phenotype names.
1231
+ effect_col: Column name for effect sizes.
1232
+ ci_lower_col: Column name for lower confidence interval.
1233
+ ci_upper_col: Column name for upper confidence interval.
1234
+ weight_col: Optional column for study weights (affects marker size).
1235
+ null_value: Reference value for null effect (0 for beta, 1 for OR).
1236
+ effect_label: X-axis label.
1237
+ figsize: Figure size as (width, height).
1238
+
1239
+ Returns:
1240
+ Figure object (type depends on backend).
1241
+
1242
+ Example:
1243
+ >>> fig = plotter.plot_forest(
1244
+ ... forest_df,
1245
+ ... variant_id="rs12345",
1246
+ ... effect_label="Odds Ratio",
1247
+ ... null_value=1.0,
1248
+ ... )
1249
+ """
1250
+ validate_forest_df(forest_df, study_col, effect_col, ci_lower_col, ci_upper_col)
1251
+
1252
+ df = forest_df.copy()
1253
+
1254
+ # Create figure
1255
+ fig, axes = self._backend.create_figure(
1256
+ n_panels=1,
1257
+ height_ratios=[1.0],
1258
+ figsize=figsize,
1259
+ )
1260
+ ax = axes[0]
1261
+
1262
+ # Assign y-positions (reverse so first study is at top)
1263
+ df["y_pos"] = range(len(df) - 1, -1, -1)
1264
+
1265
+ # Calculate marker sizes from weights
1266
+ if weight_col and weight_col in df.columns:
1267
+ # Scale weights to marker sizes (min 40, max 200)
1268
+ weights = df[weight_col]
1269
+ min_size, max_size = 40, 200
1270
+ weight_range = weights.max() - weights.min()
1271
+ if weight_range > 0:
1272
+ sizes = min_size + (weights - weights.min()) / weight_range * (
1273
+ max_size - min_size
1274
+ )
1275
+ else:
1276
+ sizes = (min_size + max_size) / 2
1277
+ else:
1278
+ sizes = 80
1279
+
1280
+ # Calculate error bar extents
1281
+ xerr_lower = df[effect_col] - df[ci_lower_col]
1282
+ xerr_upper = df[ci_upper_col] - df[effect_col]
1283
+
1284
+ # Plot error bars (confidence intervals)
1285
+ self._backend.errorbar_h(
1286
+ ax,
1287
+ x=df[effect_col],
1288
+ y=df["y_pos"],
1289
+ xerr_lower=xerr_lower,
1290
+ xerr_upper=xerr_upper,
1291
+ color="black",
1292
+ linewidth=1.5,
1293
+ capsize=3,
1294
+ zorder=2,
1295
+ )
1296
+
1297
+ # Plot effect size markers
1298
+ self._backend.scatter(
1299
+ ax,
1300
+ df[effect_col],
1301
+ df["y_pos"],
1302
+ colors="#4169E1",
1303
+ sizes=sizes,
1304
+ marker="s", # square markers typical for forest plots
1305
+ edgecolor="black",
1306
+ linewidth=0.5,
1307
+ zorder=3,
1308
+ )
1309
+
1310
+ # Add null effect line
1311
+ self._backend.axvline(
1312
+ ax, x=null_value, color="grey", linestyle="--", linewidth=1, alpha=0.7
1313
+ )
1314
+
1315
+ # Set axis labels and limits
1316
+ self._backend.set_xlabel(ax, effect_label)
1317
+ self._backend.set_ylim(ax, -0.5, len(df) - 0.5)
1318
+
1319
+ # Ensure x-axis includes the null value with some padding
1320
+ x_min = min(df[ci_lower_col].min(), null_value)
1321
+ x_max = max(df[ci_upper_col].max(), null_value)
1322
+ x_padding = (x_max - x_min) * 0.1
1323
+ self._backend.set_xlim(ax, x_min - x_padding, x_max + x_padding)
1324
+
1325
+ # Set y-tick labels to study names
1326
+ self._backend.set_yticks(
1327
+ ax,
1328
+ positions=df["y_pos"].tolist(),
1329
+ labels=df[study_col].tolist(),
1330
+ fontsize=10,
1331
+ )
1332
+
1333
+ self._backend.set_title(ax, f"Forest Plot: {variant_id}")
1334
+ self._backend.hide_spines(ax, ["top", "right"])
1335
+ self._backend.finalize_layout(fig)
1336
+
1337
+ return fig