pylocuszoom 0.8.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylocuszoom/plotter.py CHANGED
@@ -15,7 +15,9 @@ from typing import Any, List, Optional, Tuple
15
15
  import matplotlib.pyplot as plt
16
16
  import numpy as np
17
17
  import pandas as pd
18
+ import requests
18
19
 
20
+ from ._plotter_utils import DEFAULT_GENOMEWIDE_THRESHOLD
19
21
  from .backends import BackendType, get_backend
20
22
  from .backends.hover import HoverConfig, HoverDataBuilder
21
23
  from .colors import (
@@ -28,32 +30,31 @@ from .colors import (
28
30
  get_eqtl_color,
29
31
  get_ld_bin,
30
32
  get_ld_color_palette,
31
- get_phewas_category_palette,
32
33
  )
34
+ from .config import PlotConfig, StackedPlotConfig
33
35
  from .ensembl import get_genes_for_region
34
36
  from .eqtl import validate_eqtl_df
35
37
  from .finemapping import (
36
38
  get_credible_sets,
37
39
  prepare_finemapping_for_plotting,
38
40
  )
39
- from .forest import validate_forest_df
40
41
  from .gene_track import (
41
42
  assign_gene_positions,
42
43
  plot_gene_track_generic,
43
44
  )
44
45
  from .ld import calculate_ld, find_plink
45
46
  from .logging import enable_logging, logger
46
- from .phewas import validate_phewas_df
47
+ from .manhattan_plotter import ManhattanPlotter
47
48
  from .recombination import (
48
49
  RECOMB_COLOR,
49
50
  download_canine_recombination_maps,
50
51
  get_default_data_dir,
51
52
  get_recombination_rate_for_region,
52
53
  )
54
+ from .stats_plotter import StatsPlotter
53
55
  from .utils import normalize_chrom, validate_genes_df, validate_gwas_df
54
56
 
55
- # Default significance threshold: 5e-8 (genome-wide significance)
56
- DEFAULT_GENOMEWIDE_THRESHOLD = 5e-8
57
+ # Precomputed significance line value (used for plotting)
57
58
  DEFAULT_GENOMEWIDE_LINE = -np.log10(DEFAULT_GENOMEWIDE_THRESHOLD)
58
59
 
59
60
 
@@ -138,6 +139,7 @@ class LocusZoomPlotter:
138
139
  genome_build if genome_build else self._default_build(species)
139
140
  )
140
141
  self._backend = get_backend(backend)
142
+ self._backend_name = backend # Store for delegation to child plotters
141
143
  self.plink_path = plink_path or find_plink()
142
144
  self.recomb_data_dir = recomb_data_dir
143
145
  self.genomewide_threshold = genomewide_threshold
@@ -147,6 +149,27 @@ class LocusZoomPlotter:
147
149
  # Cache for loaded data
148
150
  self._recomb_cache = {}
149
151
 
152
+ @property
153
+ def _manhattan_plotter(self) -> ManhattanPlotter:
154
+ """Lazy-load ManhattanPlotter with shared configuration."""
155
+ if not hasattr(self, "_manhattan_plotter_instance"):
156
+ self._manhattan_plotter_instance = ManhattanPlotter(
157
+ species=self.species,
158
+ backend=self._backend_name,
159
+ genomewide_threshold=self.genomewide_threshold,
160
+ )
161
+ return self._manhattan_plotter_instance
162
+
163
+ @property
164
+ def _stats_plotter(self) -> StatsPlotter:
165
+ """Lazy-load StatsPlotter with shared configuration."""
166
+ if not hasattr(self, "_stats_plotter_instance"):
167
+ self._stats_plotter_instance = StatsPlotter(
168
+ backend=self._backend_name,
169
+ genomewide_threshold=self.genomewide_threshold,
170
+ )
171
+ return self._stats_plotter_instance
172
+
150
173
  @staticmethod
151
174
  def _default_build(species: str) -> Optional[str]:
152
175
  """Get default genome build for species."""
@@ -171,9 +194,17 @@ class LocusZoomPlotter:
171
194
  # Download
172
195
  try:
173
196
  return download_canine_recombination_maps()
174
- except Exception as e:
197
+ except (requests.RequestException, OSError, IOError) as e:
198
+ # Expected network/file errors - graceful fallback
175
199
  logger.warning(f"Could not download recombination maps: {e}")
176
200
  return None
201
+ except Exception as e:
202
+ # JUSTIFICATION: Download failure should not prevent plotting.
203
+ # We catch broadly here because graceful degradation is acceptable
204
+ # for optional recombination map downloads. Error-level logging
205
+ # ensures the issue is visible.
206
+ logger.error(f"Unexpected error downloading recombination maps: {e}")
207
+ return None
177
208
  elif self.recomb_data_dir:
178
209
  return Path(self.recomb_data_dir)
179
210
  return None
@@ -204,56 +235,96 @@ class LocusZoomPlotter:
204
235
  except FileNotFoundError:
205
236
  return None
206
237
 
238
+ def _transform_pvalues(self, df: pd.DataFrame, p_col: str) -> pd.DataFrame:
239
+ """Add neglog10p column with -log10 transformed p-values.
240
+
241
+ Delegates to shared utility function. Assumes df is already a copy.
242
+
243
+ Args:
244
+ df: DataFrame with p-value column (should be a copy).
245
+ p_col: Name of p-value column.
246
+
247
+ Returns:
248
+ DataFrame with neglog10p column added.
249
+ """
250
+ # Use shared utility - note: df should already be a copy at call sites
251
+ df["neglog10p"] = -np.log10(df[p_col].clip(lower=1e-300))
252
+ return df
253
+
207
254
  def plot(
208
255
  self,
209
256
  gwas_df: pd.DataFrame,
257
+ *,
210
258
  chrom: int,
211
259
  start: int,
212
260
  end: int,
261
+ pos_col: str = "ps",
262
+ p_col: str = "p_wald",
263
+ rs_col: str = "rs",
264
+ snp_labels: bool = True,
265
+ label_top_n: int = 5,
266
+ show_recombination: bool = True,
267
+ figsize: Tuple[float, float] = (12.0, 8.0),
213
268
  lead_pos: Optional[int] = None,
214
269
  ld_reference_file: Optional[str] = None,
215
270
  ld_col: Optional[str] = None,
216
271
  genes_df: Optional[pd.DataFrame] = None,
217
272
  exons_df: Optional[pd.DataFrame] = None,
218
273
  recomb_df: Optional[pd.DataFrame] = None,
219
- show_recombination: bool = True,
220
- snp_labels: bool = True,
221
- label_top_n: int = 5,
222
- pos_col: str = "ps",
223
- p_col: str = "p_wald",
224
- rs_col: str = "rs",
225
- figsize: Tuple[int, int] = (12, 8),
226
274
  ) -> Any:
227
275
  """Create a regional association plot.
228
276
 
229
277
  Args:
230
278
  gwas_df: GWAS results DataFrame.
231
279
  chrom: Chromosome number.
232
- start: Start position of the region.
233
- end: End position of the region.
234
- lead_pos: Position of the lead/index SNP to highlight.
235
- ld_reference_file: PLINK binary fileset for LD calculation.
236
- If provided with lead_pos, calculates LD on the fly.
237
- ld_col: Column name for pre-computed LD (R²) values.
238
- Use this if LD was calculated externally.
280
+ start: Start position in base pairs.
281
+ end: End position in base pairs.
282
+ pos_col: Column name for genomic position.
283
+ p_col: Column name for p-value.
284
+ rs_col: Column name for SNP identifier.
285
+ snp_labels: Whether to show SNP labels on plot.
286
+ label_top_n: Number of top SNPs to label.
287
+ show_recombination: Whether to show recombination rate overlay.
288
+ figsize: Figure size as (width, height) in inches.
289
+ lead_pos: Position of lead SNP to highlight. For stacked plots with
290
+ multiple regions, use plot_stacked() with lead_positions (plural).
291
+ ld_reference_file: Path to PLINK binary fileset for LD calculation.
292
+ ld_col: Column name for pre-computed LD (R^2) values.
239
293
  genes_df: Gene annotations with chr, start, end, gene_name.
240
294
  exons_df: Exon annotations with chr, start, end, gene_name.
241
295
  recomb_df: Pre-loaded recombination rate data.
242
296
  If None and show_recombination=True, loads from species default.
243
- show_recombination: Whether to show recombination rate overlay.
244
- snp_labels: Whether to label top SNPs.
245
- label_top_n: Number of top SNPs to label.
246
- pos_col: Column name for position.
247
- p_col: Column name for p-value.
248
- rs_col: Column name for SNP ID.
249
- figsize: Figure size.
250
297
 
251
298
  Returns:
252
- Matplotlib Figure object.
299
+ Figure object (type depends on backend).
253
300
 
254
301
  Raises:
255
- ValidationError: If required DataFrame columns are missing.
302
+ ValidationError: If parameters or DataFrame columns are invalid.
303
+
304
+ Example:
305
+ >>> fig = plotter.plot(
306
+ ... gwas_df,
307
+ ... chrom=1, start=1000000, end=2000000,
308
+ ... lead_pos=1500000, snp_labels=True,
309
+ ... )
256
310
  """
311
+ # Validate parameters via Pydantic
312
+ PlotConfig.from_kwargs(
313
+ chrom=chrom,
314
+ start=start,
315
+ end=end,
316
+ pos_col=pos_col,
317
+ p_col=p_col,
318
+ rs_col=rs_col,
319
+ snp_labels=snp_labels,
320
+ label_top_n=label_top_n,
321
+ show_recombination=show_recombination,
322
+ figsize=figsize,
323
+ lead_pos=lead_pos,
324
+ ld_reference_file=ld_reference_file,
325
+ ld_col=ld_col,
326
+ )
327
+
257
328
  # Validate inputs
258
329
  validate_gwas_df(gwas_df, pos_col=pos_col, p_col=p_col)
259
330
 
@@ -282,7 +353,24 @@ class LocusZoomPlotter:
282
353
 
283
354
  # Prepare data
284
355
  df = gwas_df.copy()
285
- df["neglog10p"] = -np.log10(df[p_col].clip(lower=1e-300))
356
+
357
+ # Validate p-values and warn about issues
358
+ p_values = df[p_col]
359
+ nan_count = p_values.isna().sum()
360
+ if nan_count > 0:
361
+ logger.warning(
362
+ f"GWAS data contains {nan_count} NaN p-values which will be excluded"
363
+ )
364
+ invalid_count = ((p_values < 0) | (p_values > 1)).sum()
365
+ if invalid_count > 0:
366
+ logger.warning(
367
+ f"GWAS data contains {invalid_count} p-values outside [0, 1] range"
368
+ )
369
+ clipped_count = (p_values < 1e-300).sum()
370
+ if clipped_count > 0:
371
+ logger.debug(f"Clipping {clipped_count} p-values below 1e-300 to 1e-300")
372
+
373
+ df = self._transform_pvalues(df, p_col)
286
374
 
287
375
  # Calculate LD if reference file provided
288
376
  if ld_reference_file and lead_pos and ld_col is None:
@@ -351,7 +439,12 @@ class LocusZoomPlotter:
351
439
  # Format axes
352
440
  self._backend.set_ylabel(ax, r"$-\log_{10}$ P")
353
441
  self._backend.set_xlim(ax, start, end)
354
- self._backend.hide_spines(ax, ["top", "right"])
442
+ # When recombination overlay is present, keep right spine for secondary y-axis
443
+ has_recomb = recomb_df is not None and not recomb_df.empty
444
+ if has_recomb and self._backend.supports_secondary_axis:
445
+ self._backend.hide_spines(ax, ["top"])
446
+ else:
447
+ self._backend.hide_spines(ax, ["top", "right"])
355
448
 
356
449
  # Add LD legend (all backends)
357
450
  if ld_col is not None and ld_col in df.columns:
@@ -364,10 +457,12 @@ class LocusZoomPlotter:
364
457
  )
365
458
  self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
366
459
  self._backend.hide_spines(gene_ax, ["top", "right", "left"])
460
+ # Format both axes for interactive backends (they don't share x-axis)
461
+ self._backend.format_xaxis_mb(gene_ax)
367
462
  else:
368
463
  self._backend.set_xlabel(ax, f"Chromosome {chrom} (Mb)")
369
464
 
370
- # Format x-axis with Mb labels
465
+ # Format x-axis with Mb labels (association axis always needs formatting)
371
466
  self._backend.format_xaxis_mb(ax)
372
467
 
373
468
  # Adjust layout
@@ -516,18 +611,29 @@ class LocusZoomPlotter:
516
611
  return
517
612
 
518
613
  # Create secondary y-axis
519
- yaxis_name = self._backend.create_twin_axis(ax)
520
-
521
- # For plotly, yaxis_name is a tuple (fig, row, secondary_y)
522
- # For bokeh, yaxis_name is just a string
523
- if isinstance(yaxis_name, tuple):
524
- _, _, secondary_y = yaxis_name
614
+ twin_result = self._backend.create_twin_axis(ax)
615
+
616
+ # Matplotlib returns the twin Axes object itself - use it for drawing
617
+ # Plotly returns tuple (fig, row, secondary_y_name)
618
+ # Bokeh returns string "secondary"
619
+ from matplotlib.axes import Axes
620
+
621
+ if isinstance(twin_result, Axes):
622
+ # Matplotlib: use the twin axis for all secondary axis operations
623
+ secondary_ax = twin_result
624
+ secondary_y = None # Not used for matplotlib
625
+ elif isinstance(twin_result, tuple):
626
+ # Plotly: use original ax, specify y-axis via yaxis_name
627
+ secondary_ax = ax
628
+ _, _, secondary_y = twin_result
525
629
  else:
526
- secondary_y = yaxis_name
630
+ # Bokeh: use original ax, specify y-axis via yaxis_name
631
+ secondary_ax = ax
632
+ secondary_y = twin_result
527
633
 
528
634
  # Plot fill under curve
529
635
  self._backend.fill_between_secondary(
530
- ax,
636
+ secondary_ax,
531
637
  region_recomb["pos"],
532
638
  0,
533
639
  region_recomb["rate"],
@@ -538,28 +644,32 @@ class LocusZoomPlotter:
538
644
 
539
645
  # Plot recombination rate line
540
646
  self._backend.line_secondary(
541
- ax,
647
+ secondary_ax,
542
648
  region_recomb["pos"],
543
649
  region_recomb["rate"],
544
650
  color=RECOMB_COLOR,
545
- linewidth=1.5,
546
- alpha=0.7,
651
+ linewidth=2.5,
652
+ alpha=0.8,
547
653
  yaxis_name=secondary_y,
548
654
  )
549
655
 
550
- # Set y-axis limits and label
656
+ # Set y-axis limits and label - scale to fit data with headroom
551
657
  max_rate = region_recomb["rate"].max()
552
658
  self._backend.set_secondary_ylim(
553
- ax, 0, max(max_rate * 1.2, 20), yaxis_name=secondary_y
659
+ secondary_ax, 0, max(max_rate * 1.3, 10), yaxis_name=secondary_y
554
660
  )
555
661
  self._backend.set_secondary_ylabel(
556
- ax,
662
+ secondary_ax,
557
663
  "Recombination rate (cM/Mb)",
558
- color=RECOMB_COLOR,
664
+ color="black", # Use black for readability (line/fill color remains light blue)
559
665
  fontsize=9,
560
666
  yaxis_name=secondary_y,
561
667
  )
562
668
 
669
+ # Hide top spine on the secondary axis (matplotlib twin axis has its own frame)
670
+ if isinstance(twin_result, Axes):
671
+ secondary_ax.spines["top"].set_visible(False)
672
+
563
673
  def _plot_finemapping(
564
674
  self,
565
675
  ax: Any,
@@ -664,14 +774,22 @@ class LocusZoomPlotter:
664
774
  def plot_stacked(
665
775
  self,
666
776
  gwas_dfs: List[pd.DataFrame],
777
+ *,
667
778
  chrom: int,
668
779
  start: int,
669
780
  end: int,
781
+ pos_col: str = "ps",
782
+ p_col: str = "p_wald",
783
+ rs_col: str = "rs",
784
+ snp_labels: bool = True,
785
+ label_top_n: int = 3,
786
+ show_recombination: bool = True,
787
+ figsize: Tuple[float, float] = (12.0, 8.0),
788
+ ld_reference_file: Optional[str] = None,
789
+ ld_col: Optional[str] = None,
670
790
  lead_positions: Optional[List[int]] = None,
671
791
  panel_labels: Optional[List[str]] = None,
672
- ld_reference_file: Optional[str] = None,
673
792
  ld_reference_files: Optional[List[str]] = None,
674
- ld_col: Optional[str] = None,
675
793
  genes_df: Optional[pd.DataFrame] = None,
676
794
  exons_df: Optional[pd.DataFrame] = None,
677
795
  eqtl_df: Optional[pd.DataFrame] = None,
@@ -679,13 +797,6 @@ class LocusZoomPlotter:
679
797
  finemapping_df: Optional[pd.DataFrame] = None,
680
798
  finemapping_cs_col: Optional[str] = "cs",
681
799
  recomb_df: Optional[pd.DataFrame] = None,
682
- show_recombination: bool = True,
683
- snp_labels: bool = True,
684
- label_top_n: int = 3,
685
- pos_col: str = "ps",
686
- p_col: str = "p_wald",
687
- rs_col: str = "rs",
688
- figsize: Tuple[float, Optional[float]] = (12, None),
689
800
  ) -> Any:
690
801
  """Create stacked regional association plots for multiple GWAS.
691
802
 
@@ -695,30 +806,29 @@ class LocusZoomPlotter:
695
806
  Args:
696
807
  gwas_dfs: List of GWAS results DataFrames to stack.
697
808
  chrom: Chromosome number.
698
- start: Start position of the region.
699
- end: End position of the region.
700
- lead_positions: List of lead SNP positions (one per GWAS).
701
- If None, auto-detects from lowest p-value.
702
- panel_labels: Labels for each panel (e.g., phenotype names).
703
- ld_reference_file: Single PLINK fileset for all panels.
809
+ start: Start position in base pairs.
810
+ end: End position in base pairs.
811
+ pos_col: Column name for genomic position.
812
+ p_col: Column name for p-value.
813
+ rs_col: Column name for SNP identifier.
814
+ snp_labels: Whether to show SNP labels on plot.
815
+ label_top_n: Number of top SNPs to label (default 3 for stacked).
816
+ show_recombination: Whether to show recombination rate overlay.
817
+ figsize: Figure size as (width, height) in inches.
818
+ ld_reference_file: Single PLINK fileset (broadcast to all panels).
819
+ ld_col: Column name for pre-computed LD (R^2) values.
820
+ lead_positions: List of lead SNP positions, one per region. For single
821
+ region plots, use plot() with lead_pos (singular).
822
+ panel_labels: List of panel labels (one per panel).
704
823
  ld_reference_files: List of PLINK filesets (one per panel).
705
- ld_col: Column name for pre-computed LD (R²) values in each DataFrame.
706
- Use this if LD was calculated externally.
707
824
  genes_df: Gene annotations for bottom track.
708
825
  exons_df: Exon annotations for gene track.
709
826
  eqtl_df: eQTL data to display as additional panel.
710
827
  eqtl_gene: Filter eQTL data to this target gene.
711
828
  finemapping_df: Fine-mapping/SuSiE results with pos and pip columns.
712
829
  Displayed as PIP line with optional credible set coloring.
713
- finemapping_cs_col: Column name for credible set assignment in finemapping_df.
830
+ finemapping_cs_col: Column name for credible set assignment.
714
831
  recomb_df: Pre-loaded recombination rate data.
715
- show_recombination: Whether to show recombination overlay.
716
- snp_labels: Whether to label top SNPs.
717
- label_top_n: Number of top SNPs to label per panel.
718
- pos_col: Column name for position.
719
- p_col: Column name for p-value.
720
- rs_col: Column name for SNP ID.
721
- figsize: Figure size (width, height). If height is None, auto-calculates.
722
832
 
723
833
  Returns:
724
834
  Figure object (type depends on backend).
@@ -728,9 +838,27 @@ class LocusZoomPlotter:
728
838
  ... [gwas_height, gwas_bmi, gwas_whr],
729
839
  ... chrom=1, start=1000000, end=2000000,
730
840
  ... panel_labels=["Height", "BMI", "WHR"],
731
- ... genes_df=genes_df,
732
841
  ... )
733
842
  """
843
+ # Validate parameters via Pydantic
844
+ StackedPlotConfig.from_kwargs(
845
+ chrom=chrom,
846
+ start=start,
847
+ end=end,
848
+ pos_col=pos_col,
849
+ p_col=p_col,
850
+ rs_col=rs_col,
851
+ snp_labels=snp_labels,
852
+ label_top_n=label_top_n,
853
+ show_recombination=show_recombination,
854
+ figsize=figsize,
855
+ ld_reference_file=ld_reference_file,
856
+ ld_col=ld_col,
857
+ lead_positions=lead_positions,
858
+ panel_labels=panel_labels,
859
+ ld_reference_files=ld_reference_files,
860
+ )
861
+
734
862
  n_gwas = len(gwas_dfs)
735
863
  if n_gwas == 0:
736
864
  raise ValueError("At least one GWAS DataFrame required")
@@ -766,8 +894,16 @@ class LocusZoomPlotter:
766
894
  for df in gwas_dfs:
767
895
  region_df = df[(df[pos_col] >= start) & (df[pos_col] <= end)]
768
896
  if not region_df.empty:
769
- lead_idx = region_df[p_col].idxmin()
770
- lead_positions.append(int(region_df.loc[lead_idx, pos_col]))
897
+ # Filter out NaN p-values for lead SNP detection
898
+ valid_p = region_df[p_col].dropna()
899
+ if valid_p.empty:
900
+ logger.warning(
901
+ "All p-values in region are NaN, cannot determine lead SNP"
902
+ )
903
+ lead_positions.append(None)
904
+ else:
905
+ lead_idx = valid_p.idxmin()
906
+ lead_positions.append(int(region_df.loc[lead_idx, pos_col]))
771
907
  else:
772
908
  lead_positions.append(None)
773
909
 
@@ -841,24 +977,34 @@ class LocusZoomPlotter:
841
977
  for i, (gwas_df, lead_pos) in enumerate(zip(gwas_dfs, lead_positions)):
842
978
  ax = axes[i]
843
979
  df = gwas_df.copy()
844
- df["neglog10p"] = -np.log10(df[p_col].clip(lower=1e-300))
980
+ df = self._transform_pvalues(df, p_col)
845
981
 
846
982
  # Use pre-computed LD or calculate from reference
847
983
  panel_ld_col = ld_col
848
984
  if ld_reference_files and ld_reference_files[i] and lead_pos and not ld_col:
849
- lead_snp_row = df[df[pos_col] == lead_pos]
850
- if not lead_snp_row.empty and rs_col in df.columns:
851
- lead_snp_id = lead_snp_row[rs_col].iloc[0]
852
- ld_df = calculate_ld(
853
- bfile_path=ld_reference_files[i],
854
- lead_snp=lead_snp_id,
855
- window_kb=max((end - start) // 1000, 500),
856
- plink_path=self.plink_path,
857
- species=self.species,
985
+ # Check if rs_col exists before attempting LD calculation
986
+ if rs_col not in df.columns:
987
+ logger.warning(
988
+ f"Cannot calculate LD for panel {i + 1}: column '{rs_col}' "
989
+ f"not found in GWAS data. "
990
+ f"Provide rs_col parameter or add SNP IDs to DataFrame."
858
991
  )
859
- if not ld_df.empty:
860
- df = df.merge(ld_df, left_on=rs_col, right_on="SNP", how="left")
861
- panel_ld_col = "R2"
992
+ else:
993
+ lead_snp_row = df[df[pos_col] == lead_pos]
994
+ if not lead_snp_row.empty:
995
+ lead_snp_id = lead_snp_row[rs_col].iloc[0]
996
+ ld_df = calculate_ld(
997
+ bfile_path=ld_reference_files[i],
998
+ lead_snp=lead_snp_id,
999
+ window_kb=max((end - start) // 1000, 500),
1000
+ plink_path=self.plink_path,
1001
+ species=self.species,
1002
+ )
1003
+ if not ld_df.empty:
1004
+ df = df.merge(
1005
+ ld_df, left_on=rs_col, right_on="SNP", how="left"
1006
+ )
1007
+ panel_ld_col = "R2"
862
1008
 
863
1009
  # Plot association
864
1010
  self._plot_association(
@@ -953,8 +1099,16 @@ class LocusZoomPlotter:
953
1099
  eqtl_data = eqtl_df.copy()
954
1100
 
955
1101
  # Filter by gene if specified
956
- if eqtl_gene and "gene" in eqtl_data.columns:
957
- eqtl_data = eqtl_data[eqtl_data["gene"] == eqtl_gene]
1102
+ eqtl_gene_filtered = False
1103
+ if eqtl_gene:
1104
+ if "gene" in eqtl_data.columns:
1105
+ eqtl_data = eqtl_data[eqtl_data["gene"] == eqtl_gene]
1106
+ eqtl_gene_filtered = True
1107
+ else:
1108
+ logger.warning(
1109
+ f"eqtl_gene='{eqtl_gene}' specified but eQTL data has no 'gene' column; "
1110
+ "showing all eQTL data unfiltered"
1111
+ )
958
1112
 
959
1113
  # Filter by region (position and chromosome)
960
1114
  if "pos" in eqtl_data.columns:
@@ -969,9 +1123,7 @@ class LocusZoomPlotter:
969
1123
  eqtl_data = eqtl_data[mask]
970
1124
 
971
1125
  if not eqtl_data.empty:
972
- eqtl_data["neglog10p"] = -np.log10(
973
- eqtl_data["p_value"].clip(lower=1e-300)
974
- )
1126
+ eqtl_data = self._transform_pvalues(eqtl_data, "p_value")
975
1127
 
976
1128
  # Build hover data using HoverDataBuilder
977
1129
  eqtl_extra_cols = {}
@@ -990,47 +1142,49 @@ class LocusZoomPlotter:
990
1142
  has_effect = "effect_size" in eqtl_data.columns
991
1143
 
992
1144
  if has_effect:
993
- # Plot triangles by effect direction (batch by sign for efficiency)
1145
+ # Vectorized plotting: split by sign, assign colors in bulk
994
1146
  pos_effects = eqtl_data[eqtl_data["effect_size"] >= 0]
995
1147
  neg_effects = eqtl_data[eqtl_data["effect_size"] < 0]
996
1148
 
997
- # Plot positive effects (up triangles)
998
- for _, row in pos_effects.iterrows():
999
- row_df = pd.DataFrame([row])
1149
+ # Vectorized color assignment using apply
1150
+ if not pos_effects.empty:
1151
+ pos_colors = pos_effects["effect_size"].apply(get_eqtl_color)
1000
1152
  self._backend.scatter(
1001
1153
  ax,
1002
- pd.Series([row["pos"]]),
1003
- pd.Series([row["neglog10p"]]),
1004
- colors=get_eqtl_color(row["effect_size"]),
1154
+ pos_effects["pos"],
1155
+ pos_effects["neglog10p"],
1156
+ colors=pos_colors.tolist(),
1005
1157
  sizes=50,
1006
1158
  marker="^",
1007
1159
  edgecolor="black",
1008
1160
  linewidth=0.5,
1009
1161
  zorder=2,
1010
- hover_data=eqtl_hover_builder.build_dataframe(row_df),
1162
+ hover_data=eqtl_hover_builder.build_dataframe(pos_effects),
1011
1163
  )
1012
- # Plot negative effects (down triangles)
1013
- for _, row in neg_effects.iterrows():
1014
- row_df = pd.DataFrame([row])
1164
+
1165
+ if not neg_effects.empty:
1166
+ neg_colors = neg_effects["effect_size"].apply(get_eqtl_color)
1015
1167
  self._backend.scatter(
1016
1168
  ax,
1017
- pd.Series([row["pos"]]),
1018
- pd.Series([row["neglog10p"]]),
1019
- colors=get_eqtl_color(row["effect_size"]),
1169
+ neg_effects["pos"],
1170
+ neg_effects["neglog10p"],
1171
+ colors=neg_colors.tolist(),
1020
1172
  sizes=50,
1021
1173
  marker="v",
1022
1174
  edgecolor="black",
1023
1175
  linewidth=0.5,
1024
1176
  zorder=2,
1025
- hover_data=eqtl_hover_builder.build_dataframe(row_df),
1177
+ hover_data=eqtl_hover_builder.build_dataframe(neg_effects),
1026
1178
  )
1179
+
1027
1180
  # Add eQTL effect legend (all backends)
1028
1181
  self._backend.add_eqtl_legend(
1029
1182
  ax, EQTL_POSITIVE_BINS, EQTL_NEGATIVE_BINS
1030
1183
  )
1031
1184
  else:
1032
1185
  # No effect sizes - plot as diamonds
1033
- label = f"eQTL ({eqtl_gene})" if eqtl_gene else "eQTL"
1186
+ # Only show gene in label if filtering was actually applied
1187
+ label = f"eQTL ({eqtl_gene})" if eqtl_gene_filtered else "eQTL"
1034
1188
  self._backend.scatter(
1035
1189
  ax,
1036
1190
  eqtl_data["pos"],
@@ -1090,124 +1244,17 @@ class LocusZoomPlotter:
1090
1244
  significance_threshold: float = 5e-8,
1091
1245
  figsize: Tuple[float, float] = (10, 8),
1092
1246
  ) -> Any:
1093
- """Create a PheWAS (Phenome-Wide Association Study) plot.
1094
-
1095
- Shows associations of a single variant across multiple phenotypes,
1096
- with phenotypes grouped by category and colored accordingly.
1097
-
1098
- Args:
1099
- phewas_df: DataFrame with phenotype associations.
1100
- variant_id: Variant identifier (e.g., "rs12345") for plot title.
1101
- phenotype_col: Column name for phenotype names.
1102
- p_col: Column name for p-values.
1103
- category_col: Column name for phenotype categories.
1104
- effect_col: Optional column name for effect direction (beta/OR).
1105
- significance_threshold: P-value threshold for significance line.
1106
- figsize: Figure size as (width, height).
1107
-
1108
- Returns:
1109
- Figure object (type depends on backend).
1110
-
1111
- Example:
1112
- >>> fig = plotter.plot_phewas(
1113
- ... phewas_df,
1114
- ... variant_id="rs12345",
1115
- ... category_col="category",
1116
- ... )
1117
- """
1118
- validate_phewas_df(phewas_df, phenotype_col, p_col, category_col)
1119
-
1120
- df = phewas_df.copy()
1121
- df["neglog10p"] = -np.log10(df[p_col].clip(lower=1e-300))
1122
-
1123
- # Sort by category then by p-value for consistent ordering
1124
- if category_col in df.columns:
1125
- df = df.sort_values([category_col, p_col])
1126
- categories = df[category_col].unique().tolist()
1127
- palette = get_phewas_category_palette(categories)
1128
- else:
1129
- df = df.sort_values(p_col)
1130
- categories = []
1131
- palette = {}
1132
-
1133
- # Create figure
1134
- fig, axes = self._backend.create_figure(
1135
- n_panels=1,
1136
- height_ratios=[1.0],
1247
+ """Create a PheWAS plot. See StatsPlotter.plot_phewas for docs."""
1248
+ return self._stats_plotter.plot_phewas(
1249
+ phewas_df=phewas_df,
1250
+ variant_id=variant_id,
1251
+ phenotype_col=phenotype_col,
1252
+ p_col=p_col,
1253
+ category_col=category_col,
1254
+ effect_col=effect_col,
1255
+ significance_threshold=significance_threshold,
1137
1256
  figsize=figsize,
1138
1257
  )
1139
- ax = axes[0]
1140
-
1141
- # Assign y-positions (one per phenotype)
1142
- df["y_pos"] = range(len(df))
1143
-
1144
- # Plot points by category
1145
- if categories:
1146
- for cat in categories:
1147
- cat_data = df[df[category_col] == cat]
1148
- # Use upward triangles for positive effects, circles otherwise
1149
- if effect_col and effect_col in cat_data.columns:
1150
- for _, row in cat_data.iterrows():
1151
- marker = "^" if row[effect_col] >= 0 else "v"
1152
- self._backend.scatter(
1153
- ax,
1154
- pd.Series([row["neglog10p"]]),
1155
- pd.Series([row["y_pos"]]),
1156
- colors=palette[cat],
1157
- sizes=60,
1158
- marker=marker,
1159
- edgecolor="black",
1160
- linewidth=0.5,
1161
- zorder=2,
1162
- )
1163
- else:
1164
- self._backend.scatter(
1165
- ax,
1166
- cat_data["neglog10p"],
1167
- cat_data["y_pos"],
1168
- colors=palette[cat],
1169
- sizes=60,
1170
- marker="o",
1171
- edgecolor="black",
1172
- linewidth=0.5,
1173
- zorder=2,
1174
- )
1175
- else:
1176
- self._backend.scatter(
1177
- ax,
1178
- df["neglog10p"],
1179
- df["y_pos"],
1180
- colors="#4169E1",
1181
- sizes=60,
1182
- edgecolor="black",
1183
- linewidth=0.5,
1184
- zorder=2,
1185
- )
1186
-
1187
- # Add significance threshold line
1188
- sig_line = -np.log10(significance_threshold)
1189
- self._backend.axvline(
1190
- ax, x=sig_line, color="red", linestyle="--", linewidth=1, alpha=0.7
1191
- )
1192
-
1193
- # Set axis labels and limits
1194
- self._backend.set_xlabel(ax, r"$-\log_{10}$ P")
1195
- self._backend.set_ylabel(ax, "Phenotype")
1196
- self._backend.set_ylim(ax, -0.5, len(df) - 0.5)
1197
-
1198
- # Set y-tick labels to phenotype names
1199
- self._backend.set_yticks(
1200
- ax,
1201
- positions=df["y_pos"].tolist(),
1202
- labels=df[phenotype_col].tolist(),
1203
- fontsize=8,
1204
- )
1205
-
1206
- self._backend.set_title(ax, f"PheWAS: {variant_id}")
1207
- self._backend.hide_spines(ax, ["top", "right"])
1208
- self._backend.finalize_layout(fig)
1209
-
1210
- return fig
1211
1258
 
1212
1259
  def plot_forest(
1213
1260
  self,
@@ -1222,116 +1269,143 @@ class LocusZoomPlotter:
1222
1269
  effect_label: str = "Effect Size",
1223
1270
  figsize: Tuple[float, float] = (8, 6),
1224
1271
  ) -> Any:
1225
- """Create a forest plot showing effect sizes with confidence intervals.
1226
-
1227
- Args:
1228
- forest_df: DataFrame with effect sizes and confidence intervals.
1229
- variant_id: Variant identifier for plot title.
1230
- study_col: Column name for study/phenotype names.
1231
- effect_col: Column name for effect sizes.
1232
- ci_lower_col: Column name for lower confidence interval.
1233
- ci_upper_col: Column name for upper confidence interval.
1234
- weight_col: Optional column for study weights (affects marker size).
1235
- null_value: Reference value for null effect (0 for beta, 1 for OR).
1236
- effect_label: X-axis label.
1237
- figsize: Figure size as (width, height).
1238
-
1239
- Returns:
1240
- Figure object (type depends on backend).
1241
-
1242
- Example:
1243
- >>> fig = plotter.plot_forest(
1244
- ... forest_df,
1245
- ... variant_id="rs12345",
1246
- ... effect_label="Odds Ratio",
1247
- ... null_value=1.0,
1248
- ... )
1249
- """
1250
- validate_forest_df(forest_df, study_col, effect_col, ci_lower_col, ci_upper_col)
1251
-
1252
- df = forest_df.copy()
1253
-
1254
- # Create figure
1255
- fig, axes = self._backend.create_figure(
1256
- n_panels=1,
1257
- height_ratios=[1.0],
1272
+ """Create a forest plot. See StatsPlotter.plot_forest for docs."""
1273
+ return self._stats_plotter.plot_forest(
1274
+ forest_df=forest_df,
1275
+ variant_id=variant_id,
1276
+ study_col=study_col,
1277
+ effect_col=effect_col,
1278
+ ci_lower_col=ci_lower_col,
1279
+ ci_upper_col=ci_upper_col,
1280
+ weight_col=weight_col,
1281
+ null_value=null_value,
1282
+ effect_label=effect_label,
1258
1283
  figsize=figsize,
1259
1284
  )
1260
- ax = axes[0]
1261
-
1262
- # Assign y-positions (reverse so first study is at top)
1263
- df["y_pos"] = range(len(df) - 1, -1, -1)
1264
-
1265
- # Calculate marker sizes from weights
1266
- if weight_col and weight_col in df.columns:
1267
- # Scale weights to marker sizes (min 40, max 200)
1268
- weights = df[weight_col]
1269
- min_size, max_size = 40, 200
1270
- weight_range = weights.max() - weights.min()
1271
- if weight_range > 0:
1272
- sizes = min_size + (weights - weights.min()) / weight_range * (
1273
- max_size - min_size
1274
- )
1275
- else:
1276
- sizes = (min_size + max_size) / 2
1277
- else:
1278
- sizes = 80
1279
1285
 
1280
- # Calculate error bar extents
1281
- xerr_lower = df[effect_col] - df[ci_lower_col]
1282
- xerr_upper = df[ci_upper_col] - df[effect_col]
1283
-
1284
- # Plot error bars (confidence intervals)
1285
- self._backend.errorbar_h(
1286
- ax,
1287
- x=df[effect_col],
1288
- y=df["y_pos"],
1289
- xerr_lower=xerr_lower,
1290
- xerr_upper=xerr_upper,
1291
- color="black",
1292
- linewidth=1.5,
1293
- capsize=3,
1294
- zorder=2,
1286
+ def plot_manhattan(
1287
+ self,
1288
+ df: pd.DataFrame,
1289
+ chrom_col: str = "chrom",
1290
+ pos_col: str = "pos",
1291
+ p_col: str = "p",
1292
+ custom_chrom_order: Optional[List[str]] = None,
1293
+ category_col: Optional[str] = None,
1294
+ category_order: Optional[List[str]] = None,
1295
+ significance_threshold: Optional[float] = DEFAULT_GENOMEWIDE_THRESHOLD,
1296
+ figsize: Tuple[float, float] = (12, 5),
1297
+ title: Optional[str] = None,
1298
+ ) -> Any:
1299
+ """Create a Manhattan plot. See ManhattanPlotter.plot_manhattan for docs."""
1300
+ return self._manhattan_plotter.plot_manhattan(
1301
+ df=df,
1302
+ chrom_col=chrom_col,
1303
+ pos_col=pos_col,
1304
+ p_col=p_col,
1305
+ custom_chrom_order=custom_chrom_order,
1306
+ category_col=category_col,
1307
+ category_order=category_order,
1308
+ significance_threshold=significance_threshold,
1309
+ figsize=figsize,
1310
+ title=title,
1295
1311
  )
1296
1312
 
1297
- # Plot effect size markers
1298
- self._backend.scatter(
1299
- ax,
1300
- df[effect_col],
1301
- df["y_pos"],
1302
- colors="#4169E1",
1303
- sizes=sizes,
1304
- marker="s", # square markers typical for forest plots
1305
- edgecolor="black",
1306
- linewidth=0.5,
1307
- zorder=3,
1313
+ def plot_qq(
1314
+ self,
1315
+ df: pd.DataFrame,
1316
+ p_col: str = "p",
1317
+ show_confidence_band: bool = True,
1318
+ show_lambda: bool = True,
1319
+ figsize: Tuple[float, float] = (6, 6),
1320
+ title: Optional[str] = None,
1321
+ ) -> Any:
1322
+ """Create a QQ plot. See ManhattanPlotter.plot_qq for docs."""
1323
+ return self._manhattan_plotter.plot_qq(
1324
+ df=df,
1325
+ p_col=p_col,
1326
+ show_confidence_band=show_confidence_band,
1327
+ show_lambda=show_lambda,
1328
+ figsize=figsize,
1329
+ title=title,
1308
1330
  )
1309
1331
 
1310
- # Add null effect line
1311
- self._backend.axvline(
1312
- ax, x=null_value, color="grey", linestyle="--", linewidth=1, alpha=0.7
1332
+ def plot_manhattan_stacked(
1333
+ self,
1334
+ gwas_dfs: List[pd.DataFrame],
1335
+ chrom_col: str = "chrom",
1336
+ pos_col: str = "pos",
1337
+ p_col: str = "p",
1338
+ custom_chrom_order: Optional[List[str]] = None,
1339
+ significance_threshold: Optional[float] = DEFAULT_GENOMEWIDE_THRESHOLD,
1340
+ panel_labels: Optional[List[str]] = None,
1341
+ figsize: Tuple[float, float] = (12, 8),
1342
+ title: Optional[str] = None,
1343
+ ) -> Any:
1344
+ """Create stacked Manhattan plots. See ManhattanPlotter.plot_manhattan_stacked for docs."""
1345
+ return self._manhattan_plotter.plot_manhattan_stacked(
1346
+ gwas_dfs=gwas_dfs,
1347
+ chrom_col=chrom_col,
1348
+ pos_col=pos_col,
1349
+ p_col=p_col,
1350
+ custom_chrom_order=custom_chrom_order,
1351
+ significance_threshold=significance_threshold,
1352
+ panel_labels=panel_labels,
1353
+ figsize=figsize,
1354
+ title=title,
1313
1355
  )
1314
1356
 
1315
- # Set axis labels and limits
1316
- self._backend.set_xlabel(ax, effect_label)
1317
- self._backend.set_ylim(ax, -0.5, len(df) - 0.5)
1318
-
1319
- # Ensure x-axis includes the null value with some padding
1320
- x_min = min(df[ci_lower_col].min(), null_value)
1321
- x_max = max(df[ci_upper_col].max(), null_value)
1322
- x_padding = (x_max - x_min) * 0.1
1323
- self._backend.set_xlim(ax, x_min - x_padding, x_max + x_padding)
1324
-
1325
- # Set y-tick labels to study names
1326
- self._backend.set_yticks(
1327
- ax,
1328
- positions=df["y_pos"].tolist(),
1329
- labels=df[study_col].tolist(),
1330
- fontsize=10,
1357
+ def plot_manhattan_qq(
1358
+ self,
1359
+ df: pd.DataFrame,
1360
+ chrom_col: str = "chrom",
1361
+ pos_col: str = "pos",
1362
+ p_col: str = "p",
1363
+ custom_chrom_order: Optional[List[str]] = None,
1364
+ significance_threshold: Optional[float] = DEFAULT_GENOMEWIDE_THRESHOLD,
1365
+ show_confidence_band: bool = True,
1366
+ show_lambda: bool = True,
1367
+ figsize: Tuple[float, float] = (14, 5),
1368
+ title: Optional[str] = None,
1369
+ ) -> Any:
1370
+ """Create side-by-side Manhattan and QQ plots. See ManhattanPlotter.plot_manhattan_qq for docs."""
1371
+ return self._manhattan_plotter.plot_manhattan_qq(
1372
+ df=df,
1373
+ chrom_col=chrom_col,
1374
+ pos_col=pos_col,
1375
+ p_col=p_col,
1376
+ custom_chrom_order=custom_chrom_order,
1377
+ significance_threshold=significance_threshold,
1378
+ show_confidence_band=show_confidence_band,
1379
+ show_lambda=show_lambda,
1380
+ figsize=figsize,
1381
+ title=title,
1331
1382
  )
1332
1383
 
1333
- self._backend.set_title(ax, f"Forest Plot: {variant_id}")
1334
- self._backend.hide_spines(ax, ["top", "right"])
1335
- self._backend.finalize_layout(fig)
1336
-
1337
- return fig
1384
+ def plot_manhattan_qq_stacked(
1385
+ self,
1386
+ gwas_dfs: List[pd.DataFrame],
1387
+ chrom_col: str = "chrom",
1388
+ pos_col: str = "pos",
1389
+ p_col: str = "p",
1390
+ custom_chrom_order: Optional[List[str]] = None,
1391
+ significance_threshold: Optional[float] = DEFAULT_GENOMEWIDE_THRESHOLD,
1392
+ show_confidence_band: bool = True,
1393
+ show_lambda: bool = True,
1394
+ panel_labels: Optional[List[str]] = None,
1395
+ figsize: Tuple[float, float] = (14, 8),
1396
+ title: Optional[str] = None,
1397
+ ) -> Any:
1398
+ """Create stacked Manhattan+QQ plots. See ManhattanPlotter.plot_manhattan_qq_stacked for docs."""
1399
+ return self._manhattan_plotter.plot_manhattan_qq_stacked(
1400
+ gwas_dfs=gwas_dfs,
1401
+ chrom_col=chrom_col,
1402
+ pos_col=pos_col,
1403
+ p_col=p_col,
1404
+ custom_chrom_order=custom_chrom_order,
1405
+ significance_threshold=significance_threshold,
1406
+ show_confidence_band=show_confidence_band,
1407
+ show_lambda=show_lambda,
1408
+ panel_labels=panel_labels,
1409
+ figsize=figsize,
1410
+ title=title,
1411
+ )