pylocuszoom 1.1.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylocuszoom/plotter.py CHANGED
@@ -15,7 +15,6 @@ from typing import Any, List, Optional, Tuple
15
15
  import matplotlib.pyplot as plt
16
16
  import numpy as np
17
17
  import pandas as pd
18
- import requests
19
18
 
20
19
  from ._plotter_utils import DEFAULT_GENOMEWIDE_THRESHOLD
21
20
  from .backends import BackendType, get_backend
@@ -24,8 +23,9 @@ from .colors import (
24
23
  EQTL_NEGATIVE_BINS,
25
24
  EQTL_POSITIVE_BINS,
26
25
  LD_BINS,
26
+ LD_HEATMAP_COLORS,
27
27
  LEAD_SNP_COLOR,
28
- PIP_LINE_COLOR,
28
+ LEAD_SNP_HIGHLIGHT_COLOR,
29
29
  get_credible_set_color,
30
30
  get_eqtl_color,
31
31
  get_ld_bin,
@@ -36,6 +36,7 @@ from .ensembl import get_genes_for_region
36
36
  from .eqtl import validate_eqtl_df
37
37
  from .finemapping import (
38
38
  get_credible_sets,
39
+ plot_finemapping,
39
40
  prepare_finemapping_for_plotting,
40
41
  )
41
42
  from .gene_track import (
@@ -44,14 +45,11 @@ from .gene_track import (
44
45
  )
45
46
  from .ld import calculate_ld, find_plink
46
47
  from .logging import enable_logging, logger
47
- from .manhattan_plotter import ManhattanPlotter
48
48
  from .recombination import (
49
49
  RECOMB_COLOR,
50
- download_canine_recombination_maps,
51
- get_default_data_dir,
50
+ ensure_recomb_maps,
52
51
  get_recombination_rate_for_region,
53
52
  )
54
- from .stats_plotter import StatsPlotter
55
53
  from .utils import normalize_chrom, validate_genes_df, validate_gwas_df
56
54
 
57
55
  # Precomputed significance line value (used for plotting)
@@ -149,27 +147,6 @@ class LocusZoomPlotter:
149
147
  # Cache for loaded data
150
148
  self._recomb_cache = {}
151
149
 
152
- @property
153
- def _manhattan_plotter(self) -> ManhattanPlotter:
154
- """Lazy-load ManhattanPlotter with shared configuration."""
155
- if not hasattr(self, "_manhattan_plotter_instance"):
156
- self._manhattan_plotter_instance = ManhattanPlotter(
157
- species=self.species,
158
- backend=self._backend_name,
159
- genomewide_threshold=self.genomewide_threshold,
160
- )
161
- return self._manhattan_plotter_instance
162
-
163
- @property
164
- def _stats_plotter(self) -> StatsPlotter:
165
- """Lazy-load StatsPlotter with shared configuration."""
166
- if not hasattr(self, "_stats_plotter_instance"):
167
- self._stats_plotter_instance = StatsPlotter(
168
- backend=self._backend_name,
169
- genomewide_threshold=self.genomewide_threshold,
170
- )
171
- return self._stats_plotter_instance
172
-
173
150
  @staticmethod
174
151
  def _default_build(species: str) -> Optional[str]:
175
152
  """Get default genome build for species."""
@@ -177,37 +154,14 @@ class LocusZoomPlotter:
177
154
  return builds.get(species)
178
155
 
179
156
  def _ensure_recomb_maps(self) -> Optional[Path]:
180
- """Ensure recombination maps are downloaded.
157
+ """Ensure recombination maps are available.
181
158
 
182
- Returns path to recombination map directory, or None if not available.
159
+ Delegates to the recombination module's ensure_recomb_maps function.
160
+
161
+ Returns:
162
+ Path to recombination map directory, or None if not available.
183
163
  """
184
- if self.species == "canine":
185
- if self.recomb_data_dir:
186
- return Path(self.recomb_data_dir)
187
- # Check if already downloaded
188
- default_dir = get_default_data_dir()
189
- if (
190
- default_dir.exists()
191
- and len(list(default_dir.glob("chr*_recomb.tsv"))) >= 39
192
- ): # 38 autosomes + X
193
- return default_dir
194
- # Download
195
- try:
196
- return download_canine_recombination_maps()
197
- except (requests.RequestException, OSError, IOError) as e:
198
- # Expected network/file errors - graceful fallback
199
- logger.warning(f"Could not download recombination maps: {e}")
200
- return None
201
- except Exception as e:
202
- # JUSTIFICATION: Download failure should not prevent plotting.
203
- # We catch broadly here because graceful degradation is acceptable
204
- # for optional recombination map downloads. Error-level logging
205
- # ensures the issue is visible.
206
- logger.error(f"Unexpected error downloading recombination maps: {e}")
207
- return None
208
- elif self.recomb_data_dir:
209
- return Path(self.recomb_data_dir)
210
- return None
164
+ return ensure_recomb_maps(species=self.species, data_dir=self.recomb_data_dir)
211
165
 
212
166
  def _get_recomb_for_region(
213
167
  self, chrom: int, start: int, end: int
@@ -238,7 +192,7 @@ class LocusZoomPlotter:
238
192
  def _transform_pvalues(self, df: pd.DataFrame, p_col: str) -> pd.DataFrame:
239
193
  """Add neglog10p column with -log10 transformed p-values.
240
194
 
241
- Delegates to shared utility function. Assumes df is already a copy.
195
+ Modifies df in place. Callers should pass a copy to avoid side effects.
242
196
 
243
197
  Args:
244
198
  df: DataFrame with p-value column (should be a copy).
@@ -271,6 +225,10 @@ class LocusZoomPlotter:
271
225
  genes_df: Optional[pd.DataFrame] = None,
272
226
  exons_df: Optional[pd.DataFrame] = None,
273
227
  recomb_df: Optional[pd.DataFrame] = None,
228
+ ld_heatmap_df: Optional[pd.DataFrame] = None,
229
+ ld_heatmap_snp_ids: Optional[List[str]] = None,
230
+ ld_heatmap_height: float = 0.25,
231
+ ld_heatmap_metric: str = "r2",
274
232
  ) -> Any:
275
233
  """Create a regional association plot.
276
234
 
@@ -294,12 +252,21 @@ class LocusZoomPlotter:
294
252
  exons_df: Exon annotations with chr, start, end, gene_name.
295
253
  recomb_df: Pre-loaded recombination rate data.
296
254
  If None and show_recombination=True, loads from species default.
255
+ ld_heatmap_df: Pairwise LD matrix (square DataFrame) from
256
+ calculate_pairwise_ld. If provided with ld_heatmap_snp_ids,
257
+ renders heatmap panel below association plot.
258
+ ld_heatmap_snp_ids: List of SNP IDs in matrix order. Required if
259
+ ld_heatmap_df is provided.
260
+ ld_heatmap_height: Height ratio of heatmap panel relative to
261
+ association panel. Default 0.25.
262
+ ld_heatmap_metric: LD metric label for colorbar ("r2" or "dprime").
297
263
 
298
264
  Returns:
299
265
  Figure object (type depends on backend).
300
266
 
301
267
  Raises:
302
268
  ValidationError: If parameters or DataFrame columns are invalid.
269
+ ValueError: If ld_heatmap_df provided without ld_heatmap_snp_ids.
303
270
 
304
271
  Example:
305
272
  >>> fig = plotter.plot(
@@ -328,6 +295,12 @@ class LocusZoomPlotter:
328
295
  # Validate inputs
329
296
  validate_gwas_df(gwas_df, pos_col=pos_col, p_col=p_col)
330
297
 
298
+ # Validate LD heatmap parameters
299
+ if ld_heatmap_df is not None and ld_heatmap_snp_ids is None:
300
+ raise ValueError(
301
+ "ld_heatmap_snp_ids is required when ld_heatmap_df is provided"
302
+ )
303
+
331
304
  # Auto-fetch genes if enabled and not provided
332
305
  if genes_df is None and self._auto_genes:
333
306
  logger.debug(
@@ -400,8 +373,33 @@ class LocusZoomPlotter:
400
373
  if show_recombination and recomb_df is None:
401
374
  recomb_df = self._get_recomb_for_region(chrom, start, end)
402
375
 
376
+ # Transform heatmap to genomic coordinates if provided
377
+ heatmap_data = None
378
+ if ld_heatmap_df is not None and ld_heatmap_snp_ids is not None:
379
+ heatmap_data = self._transform_heatmap_to_genomic_coords(
380
+ ld_matrix=ld_heatmap_df,
381
+ snp_ids=ld_heatmap_snp_ids,
382
+ gwas_df=df,
383
+ start=start,
384
+ end=end,
385
+ rs_col=rs_col,
386
+ pos_col=pos_col,
387
+ )
388
+ if heatmap_data is None:
389
+ logger.warning(
390
+ "No SNPs from LD heatmap overlap with region - heatmap not rendered"
391
+ )
392
+
403
393
  # Create figure layout
404
- fig, ax, gene_ax = self._create_figure(genes_df, chrom, start, end, figsize)
394
+ fig, ax, gene_ax, heatmap_ax = self._create_figure_with_heatmap(
395
+ genes_df=genes_df,
396
+ chrom=chrom,
397
+ start=start,
398
+ end=end,
399
+ figsize=figsize,
400
+ heatmap_data=heatmap_data,
401
+ heatmap_height=ld_heatmap_height,
402
+ )
405
403
 
406
404
  # Plot association data
407
405
  self._plot_association(ax, df, pos_col, ld_col, lead_pos, rs_col, p_col)
@@ -418,9 +416,11 @@ class LocusZoomPlotter:
418
416
  )
419
417
 
420
418
  # Add SNP labels (capability check - interactive backends use hover tooltips)
419
+ # Create labels without adjusting - we'll adjust after axis limits are set
420
+ snp_label_texts: list = []
421
421
  if snp_labels and rs_col in df.columns and label_top_n > 0 and not df.empty:
422
422
  if self._backend.supports_snp_labels:
423
- self._backend.add_snp_labels(
423
+ snp_label_texts = self._backend.add_snp_labels(
424
424
  ax,
425
425
  df,
426
426
  pos_col=pos_col,
@@ -429,6 +429,7 @@ class LocusZoomPlotter:
429
429
  label_top_n=label_top_n,
430
430
  genes_df=genes_df,
431
431
  chrom=chrom,
432
+ adjust=False, # Defer adjustment until after axis limits set
432
433
  )
433
434
 
434
435
  # Add recombination overlay (all backends with secondary axis support)
@@ -455,11 +456,38 @@ class LocusZoomPlotter:
455
456
  plot_gene_track_generic(
456
457
  gene_ax, self._backend, genes_df, chrom, start, end, exons_df
457
458
  )
458
- self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
459
459
  self._backend.hide_spines(gene_ax, ["top", "right", "left"])
460
460
  # Format both axes for interactive backends (they don't share x-axis)
461
461
  self._backend.format_xaxis_mb(gene_ax)
462
- else:
462
+ # Only set x-label on gene track if no heatmap below
463
+ if heatmap_ax is None:
464
+ self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
465
+
466
+ # Render LD heatmap panel if data available
467
+ if heatmap_ax is not None and heatmap_data is not None:
468
+ filtered_matrix, x_positions, filtered_snp_ids = heatmap_data
469
+ # Find lead SNP ID if lead_pos is set
470
+ lead_snp_id = None
471
+ if lead_pos is not None and rs_col in df.columns:
472
+ lead_row = df[df[pos_col] == lead_pos]
473
+ if not lead_row.empty:
474
+ lead_snp_id = lead_row[rs_col].iloc[0]
475
+ self._render_heatmap_panel(
476
+ ax=heatmap_ax,
477
+ fig=fig,
478
+ ld_matrix=filtered_matrix,
479
+ x_positions=x_positions,
480
+ snp_ids=filtered_snp_ids,
481
+ metric=ld_heatmap_metric,
482
+ lead_snp_id=lead_snp_id,
483
+ start=start,
484
+ end=end,
485
+ )
486
+ # Heatmap is at bottom - set x-label on it
487
+ self._backend.set_xlabel(heatmap_ax, f"Chromosome {chrom} (Mb)")
488
+ self._backend.format_xaxis_mb(heatmap_ax)
489
+ elif gene_ax is None and heatmap_ax is None:
490
+ # No gene track and no heatmap - set x-label on association plot
463
491
  self._backend.set_xlabel(ax, f"Chromosome {chrom} (Mb)")
464
492
 
465
493
  # Format x-axis with Mb labels (association axis always needs formatting)
@@ -468,6 +496,11 @@ class LocusZoomPlotter:
468
496
  # Adjust layout
469
497
  self._backend.finalize_layout(fig, hspace=0.1)
470
498
 
499
+ # Adjust SNP labels AFTER all axis limits and layout are finalized
500
+ # adjustText needs final plot bounds to position labels correctly
501
+ if snp_label_texts:
502
+ self._backend.adjust_snp_labels(ax, snp_label_texts)
503
+
471
504
  return fig
472
505
 
473
506
  def _create_figure(
@@ -519,6 +552,304 @@ class LocusZoomPlotter:
519
552
  )
520
553
  return fig, axes[0], None
521
554
 
555
+ def _create_figure_with_heatmap(
556
+ self,
557
+ genes_df: Optional[pd.DataFrame],
558
+ chrom: int,
559
+ start: int,
560
+ end: int,
561
+ figsize: Tuple[float, float],
562
+ heatmap_data: Optional[Tuple[pd.DataFrame, List[int], List[str]]],
563
+ heatmap_height: float = 0.25,
564
+ ) -> Tuple[Any, Any, Optional[Any], Optional[Any]]:
565
+ """Create figure with optional gene track and heatmap panel.
566
+
567
+ Args:
568
+ genes_df: Gene annotations DataFrame.
569
+ chrom: Chromosome number.
570
+ start: Region start position.
571
+ end: Region end position.
572
+ figsize: Base figure size as (width, height).
573
+ heatmap_data: Tuple of (filtered_matrix, x_positions, snp_ids) from
574
+ _transform_heatmap_to_genomic_coords, or None.
575
+ heatmap_height: Height ratio of heatmap relative to association panel.
576
+
577
+ Returns:
578
+ Tuple of (fig, assoc_ax, gene_ax, heatmap_ax). gene_ax and heatmap_ax
579
+ are None if not included.
580
+ """
581
+ # Calculate base heights
582
+ assoc_height = figsize[1] * 0.6
583
+
584
+ # Calculate gene track height if needed
585
+ gene_track_height = 0.0
586
+ if genes_df is not None:
587
+ chrom_str = normalize_chrom(chrom)
588
+ region_genes = genes_df[
589
+ (
590
+ genes_df["chr"].astype(str).str.replace("chr", "", regex=False)
591
+ == chrom_str
592
+ )
593
+ & (genes_df["end"] >= start)
594
+ & (genes_df["start"] <= end)
595
+ ]
596
+ if not region_genes.empty:
597
+ temp_positions = assign_gene_positions(
598
+ region_genes.sort_values("start"), start, end
599
+ )
600
+ n_gene_rows = max(temp_positions) + 1 if temp_positions else 1
601
+ else:
602
+ n_gene_rows = 1
603
+
604
+ base_gene_height = 1.0
605
+ per_row_height = 0.5
606
+ gene_track_height = base_gene_height + (n_gene_rows - 1) * per_row_height
607
+
608
+ # Calculate heatmap height if needed
609
+ actual_heatmap_height = 0.0
610
+ if heatmap_data is not None:
611
+ actual_heatmap_height = assoc_height * heatmap_height
612
+
613
+ # Build panel list (top to bottom): assoc, gene track, heatmap
614
+ n_panels = 1 # Association panel always present
615
+ height_ratios = [assoc_height]
616
+
617
+ if genes_df is not None:
618
+ n_panels += 1
619
+ height_ratios.append(gene_track_height)
620
+
621
+ if heatmap_data is not None:
622
+ n_panels += 1
623
+ height_ratios.append(actual_heatmap_height)
624
+
625
+ total_height = sum(height_ratios)
626
+
627
+ # Create figure
628
+ fig, axes = self._backend.create_figure(
629
+ n_panels=n_panels,
630
+ height_ratios=height_ratios,
631
+ figsize=(figsize[0], total_height),
632
+ sharex=True,
633
+ )
634
+
635
+ # Assign axes
636
+ assoc_ax = axes[0]
637
+ gene_ax = None
638
+ heatmap_ax = None
639
+
640
+ panel_idx = 1
641
+ if genes_df is not None:
642
+ gene_ax = axes[panel_idx]
643
+ panel_idx += 1
644
+ if heatmap_data is not None:
645
+ heatmap_ax = axes[panel_idx]
646
+
647
+ return fig, assoc_ax, gene_ax, heatmap_ax
648
+
649
+ def _transform_heatmap_to_genomic_coords(
650
+ self,
651
+ ld_matrix: pd.DataFrame,
652
+ snp_ids: List[str],
653
+ gwas_df: pd.DataFrame,
654
+ start: int,
655
+ end: int,
656
+ rs_col: str,
657
+ pos_col: str,
658
+ ) -> Optional[Tuple[pd.DataFrame, List[int], List[str]]]:
659
+ """Transform heatmap matrix to genomic coordinates.
660
+
661
+ Args:
662
+ ld_matrix: Square LD matrix from calculate_pairwise_ld.
663
+ snp_ids: SNP IDs in matrix order.
664
+ gwas_df: GWAS DataFrame with position column.
665
+ start: Region start position.
666
+ end: Region end position.
667
+ rs_col: SNP ID column name.
668
+ pos_col: Position column name.
669
+
670
+ Returns:
671
+ Tuple of (filtered_matrix, x_positions, filtered_snp_ids), or None
672
+ if no SNPs overlap with the region.
673
+ """
674
+ # Build SNP-to-position mapping from GWAS data
675
+ if rs_col not in gwas_df.columns:
676
+ logger.warning(
677
+ f"Cannot map heatmap to genomic coords: column '{rs_col}' not in GWAS data"
678
+ )
679
+ return None
680
+
681
+ snp_to_pos = dict(zip(gwas_df[rs_col], gwas_df[pos_col]))
682
+
683
+ # Filter to SNPs present in GWAS and within region
684
+ filtered_indices = []
685
+ filtered_snp_ids = []
686
+ x_positions = []
687
+
688
+ for i, snp_id in enumerate(snp_ids):
689
+ if snp_id in snp_to_pos:
690
+ pos = snp_to_pos[snp_id]
691
+ if start <= pos <= end:
692
+ filtered_indices.append(i)
693
+ filtered_snp_ids.append(snp_id)
694
+ x_positions.append(int(pos))
695
+
696
+ if not filtered_indices:
697
+ return None
698
+
699
+ # Filter matrix to matching rows/columns
700
+ filtered_matrix = ld_matrix.iloc[filtered_indices, filtered_indices].copy()
701
+
702
+ return filtered_matrix, x_positions, filtered_snp_ids
703
+
704
+ def _render_heatmap_panel(
705
+ self,
706
+ ax: Any,
707
+ fig: Any,
708
+ ld_matrix: pd.DataFrame,
709
+ x_positions: List[int],
710
+ snp_ids: List[str],
711
+ metric: str,
712
+ lead_snp_id: Optional[str],
713
+ start: int,
714
+ end: int,
715
+ ) -> None:
716
+ """Render LD heatmap panel with genomic x-coordinates.
717
+
718
+ Args:
719
+ ax: Axes object for heatmap panel.
720
+ fig: Figure object.
721
+ ld_matrix: Filtered LD matrix.
722
+ x_positions: Genomic positions for each SNP (x-axis).
723
+ snp_ids: SNP IDs in filtered order.
724
+ metric: LD metric label ("r2" or "dprime").
725
+ lead_snp_id: Lead SNP ID to highlight (if present in snp_ids).
726
+ start: Region start for x-axis limits.
727
+ end: Region end for x-axis limits.
728
+ """
729
+ data = ld_matrix.values
730
+ n_snps = len(snp_ids)
731
+
732
+ # Skip rendering if only one SNP (can't show pairwise LD)
733
+ if n_snps < 2:
734
+ logger.debug("Skipping heatmap: fewer than 2 SNPs after filtering")
735
+ return
736
+
737
+ # Render triangular heatmap at genomic positions
738
+ mappable = self._backend.add_heatmap(
739
+ ax,
740
+ data=data,
741
+ x_coords=x_positions,
742
+ y_coords=list(range(n_snps)), # Keep y as indices (0, 1, 2, ...)
743
+ cmap_colors=LD_HEATMAP_COLORS,
744
+ vmin=0.0,
745
+ vmax=1.0,
746
+ mask_upper=True,
747
+ )
748
+
749
+ # Add colorbar
750
+ label = "R²" if metric == "r2" else "D'"
751
+ self._backend.add_colorbar(ax, mappable, label=label)
752
+
753
+ # Highlight lead SNP if present
754
+ if lead_snp_id is not None and lead_snp_id in snp_ids:
755
+ lead_idx = snp_ids.index(lead_snp_id)
756
+ self._highlight_heatmap_snp(ax, fig, lead_idx, n_snps)
757
+
758
+ # Set x-axis limits to match regional plot
759
+ self._backend.set_xlim(ax, start, end)
760
+
761
+ # Hide y-axis (SNP indices are not meaningful for viewer)
762
+ self._backend.set_yticks(ax, [], [])
763
+ self._backend.hide_spines(ax, ["top", "right", "left"])
764
+
765
+ def _highlight_heatmap_snp(
766
+ self, ax: Any, fig: Any, snp_idx: int, n_snps: int
767
+ ) -> None:
768
+ """Highlight a SNP's row/column in the heatmap.
769
+
770
+ Args:
771
+ ax: Axes object.
772
+ fig: Figure object.
773
+ snp_idx: Index of SNP to highlight.
774
+ n_snps: Total number of SNPs in matrix.
775
+ """
776
+ if self._backend_name == "matplotlib":
777
+ from matplotlib.patches import Rectangle
778
+
779
+ # Highlight row (cells in row snp_idx, columns 0 to snp_idx)
780
+ for j in range(snp_idx + 1):
781
+ rect = Rectangle(
782
+ (j - 0.5, snp_idx - 0.5),
783
+ 1.0,
784
+ 1.0,
785
+ fill=False,
786
+ edgecolor=LEAD_SNP_HIGHLIGHT_COLOR,
787
+ linewidth=2,
788
+ zorder=10,
789
+ )
790
+ ax.add_patch(rect)
791
+
792
+ # Highlight column (cells in column snp_idx, rows snp_idx to n_snps-1)
793
+ for i in range(snp_idx + 1, n_snps):
794
+ rect = Rectangle(
795
+ (snp_idx - 0.5, i - 0.5),
796
+ 1.0,
797
+ 1.0,
798
+ fill=False,
799
+ edgecolor=LEAD_SNP_HIGHLIGHT_COLOR,
800
+ linewidth=2,
801
+ zorder=10,
802
+ )
803
+ ax.add_patch(rect)
804
+
805
+ elif self._backend_name == "plotly":
806
+ # For plotly, add shapes for row and column highlights
807
+ for j in range(snp_idx + 1):
808
+ fig.add_shape(
809
+ type="rect",
810
+ x0=j - 0.5,
811
+ x1=j + 0.5,
812
+ y0=snp_idx - 0.5,
813
+ y1=snp_idx + 0.5,
814
+ line=dict(color=LEAD_SNP_HIGHLIGHT_COLOR, width=2),
815
+ fillcolor="rgba(0,0,0,0)",
816
+ )
817
+
818
+ for i in range(snp_idx + 1, n_snps):
819
+ fig.add_shape(
820
+ type="rect",
821
+ x0=snp_idx - 0.5,
822
+ x1=snp_idx + 0.5,
823
+ y0=i - 0.5,
824
+ y1=i + 0.5,
825
+ line=dict(color=LEAD_SNP_HIGHLIGHT_COLOR, width=2),
826
+ fillcolor="rgba(0,0,0,0)",
827
+ )
828
+
829
+ elif self._backend_name == "bokeh":
830
+ # For bokeh, add rect glyphs for highlights
831
+ for j in range(snp_idx + 1):
832
+ ax.rect(
833
+ x=j,
834
+ y=snp_idx,
835
+ width=1,
836
+ height=1,
837
+ fill_alpha=0,
838
+ line_color=LEAD_SNP_HIGHLIGHT_COLOR,
839
+ line_width=2,
840
+ )
841
+
842
+ for i in range(snp_idx + 1, n_snps):
843
+ ax.rect(
844
+ x=snp_idx,
845
+ y=i,
846
+ width=1,
847
+ height=1,
848
+ fill_alpha=0,
849
+ line_color=LEAD_SNP_HIGHLIGHT_COLOR,
850
+ line_width=2,
851
+ )
852
+
522
853
  def _plot_association(
523
854
  self,
524
855
  ax: Any,
@@ -670,107 +1001,6 @@ class LocusZoomPlotter:
670
1001
  if isinstance(twin_result, Axes):
671
1002
  secondary_ax.spines["top"].set_visible(False)
672
1003
 
673
- def _plot_finemapping(
674
- self,
675
- ax: Any,
676
- df: pd.DataFrame,
677
- pos_col: str = "pos",
678
- pip_col: str = "pip",
679
- cs_col: Optional[str] = "cs",
680
- show_credible_sets: bool = True,
681
- pip_threshold: float = 0.0,
682
- ) -> None:
683
- """Plot fine-mapping results (PIP line with credible set coloring).
684
-
685
- Args:
686
- ax: Matplotlib axes object.
687
- df: Fine-mapping DataFrame with pos and pip columns.
688
- pos_col: Column name for position.
689
- pip_col: Column name for posterior inclusion probability.
690
- cs_col: Column name for credible set assignment (optional).
691
- show_credible_sets: Whether to color points by credible set.
692
- pip_threshold: Minimum PIP to display as scatter point.
693
- """
694
- # Build hover data using HoverDataBuilder
695
- extra_cols = {pip_col: "PIP"}
696
- if cs_col and cs_col in df.columns:
697
- extra_cols[cs_col] = "Credible Set"
698
- hover_config = HoverConfig(
699
- pos_col=pos_col if pos_col in df.columns else None,
700
- extra_cols=extra_cols,
701
- )
702
- hover_builder = HoverDataBuilder(hover_config)
703
-
704
- # Sort by position for line plotting
705
- df = df.sort_values(pos_col)
706
-
707
- # Plot PIP as line
708
- self._backend.line(
709
- ax,
710
- df[pos_col],
711
- df[pip_col],
712
- color=PIP_LINE_COLOR,
713
- linewidth=1.5,
714
- alpha=0.8,
715
- zorder=1,
716
- )
717
-
718
- # Check if credible sets are available
719
- has_cs = cs_col is not None and cs_col in df.columns and show_credible_sets
720
- credible_sets = get_credible_sets(df, cs_col) if has_cs else []
721
-
722
- if credible_sets:
723
- # Plot points colored by credible set
724
- for cs_id in credible_sets:
725
- cs_data = df[df[cs_col] == cs_id]
726
- color = get_credible_set_color(cs_id)
727
- self._backend.scatter(
728
- ax,
729
- cs_data[pos_col],
730
- cs_data[pip_col],
731
- colors=color,
732
- sizes=50,
733
- marker="o",
734
- edgecolor="black",
735
- linewidth=0.5,
736
- zorder=3,
737
- hover_data=hover_builder.build_dataframe(cs_data),
738
- )
739
- # Plot variants not in any credible set
740
- non_cs_data = df[(df[cs_col].isna()) | (df[cs_col] == 0)]
741
- if not non_cs_data.empty and pip_threshold > 0:
742
- non_cs_data = non_cs_data[non_cs_data[pip_col] >= pip_threshold]
743
- if not non_cs_data.empty:
744
- self._backend.scatter(
745
- ax,
746
- non_cs_data[pos_col],
747
- non_cs_data[pip_col],
748
- colors="#BEBEBE",
749
- sizes=30,
750
- marker="o",
751
- edgecolor="black",
752
- linewidth=0.3,
753
- zorder=2,
754
- hover_data=hover_builder.build_dataframe(non_cs_data),
755
- )
756
- else:
757
- # No credible sets - show all points above threshold
758
- if pip_threshold > 0:
759
- high_pip = df[df[pip_col] >= pip_threshold]
760
- if not high_pip.empty:
761
- self._backend.scatter(
762
- ax,
763
- high_pip[pos_col],
764
- high_pip[pip_col],
765
- colors=PIP_LINE_COLOR,
766
- sizes=50,
767
- marker="o",
768
- edgecolor="black",
769
- linewidth=0.5,
770
- zorder=3,
771
- hover_data=hover_builder.build_dataframe(high_pip),
772
- )
773
-
774
1004
  def plot_stacked(
775
1005
  self,
776
1006
  gwas_dfs: List[pd.DataFrame],
@@ -797,6 +1027,10 @@ class LocusZoomPlotter:
797
1027
  finemapping_df: Optional[pd.DataFrame] = None,
798
1028
  finemapping_cs_col: Optional[str] = "cs",
799
1029
  recomb_df: Optional[pd.DataFrame] = None,
1030
+ ld_heatmap_df: Optional[pd.DataFrame] = None,
1031
+ ld_heatmap_snp_ids: Optional[List[str]] = None,
1032
+ ld_heatmap_height: float = 0.25,
1033
+ ld_heatmap_metric: str = "r2",
800
1034
  ) -> Any:
801
1035
  """Create stacked regional association plots for multiple GWAS.
802
1036
 
@@ -829,10 +1063,21 @@ class LocusZoomPlotter:
829
1063
  Displayed as PIP line with optional credible set coloring.
830
1064
  finemapping_cs_col: Column name for credible set assignment.
831
1065
  recomb_df: Pre-loaded recombination rate data.
1066
+ ld_heatmap_df: Pairwise LD matrix (square DataFrame) from
1067
+ calculate_pairwise_ld. If provided with ld_heatmap_snp_ids,
1068
+ renders heatmap panel at the very bottom of the stack.
1069
+ ld_heatmap_snp_ids: List of SNP IDs in matrix order. Required if
1070
+ ld_heatmap_df is provided.
1071
+ ld_heatmap_height: Height ratio of heatmap panel relative to
1072
+ association panel. Default 0.25.
1073
+ ld_heatmap_metric: LD metric label for colorbar ("r2" or "dprime").
832
1074
 
833
1075
  Returns:
834
1076
  Figure object (type depends on backend).
835
1077
 
1078
+ Raises:
1079
+ ValueError: If ld_heatmap_df provided without ld_heatmap_snp_ids.
1080
+
836
1081
  Example:
837
1082
  >>> fig = plotter.plot_stacked(
838
1083
  ... [gwas_height, gwas_bmi, gwas_whr],
@@ -888,6 +1133,12 @@ class LocusZoomPlotter:
888
1133
  if eqtl_df is not None:
889
1134
  validate_eqtl_df(eqtl_df)
890
1135
 
1136
+ # Validate LD heatmap parameters
1137
+ if ld_heatmap_df is not None and ld_heatmap_snp_ids is None:
1138
+ raise ValueError(
1139
+ "ld_heatmap_snp_ids is required when ld_heatmap_df is provided"
1140
+ )
1141
+
891
1142
  # Handle lead positions
892
1143
  if lead_positions is None:
893
1144
  lead_positions = []
@@ -911,10 +1162,31 @@ class LocusZoomPlotter:
911
1162
  if ld_reference_files is None and ld_reference_file is not None:
912
1163
  ld_reference_files = [ld_reference_file] * n_gwas
913
1164
 
1165
+ # Transform heatmap to genomic coordinates if provided (use first GWAS for mapping)
1166
+ heatmap_data = None
1167
+ if ld_heatmap_df is not None and ld_heatmap_snp_ids is not None:
1168
+ # Use first GWAS DataFrame for SNP-to-position mapping
1169
+ first_gwas = gwas_dfs[0].copy()
1170
+ first_gwas = self._transform_pvalues(first_gwas, p_col)
1171
+ heatmap_data = self._transform_heatmap_to_genomic_coords(
1172
+ ld_matrix=ld_heatmap_df,
1173
+ snp_ids=ld_heatmap_snp_ids,
1174
+ gwas_df=first_gwas,
1175
+ start=start,
1176
+ end=end,
1177
+ rs_col=rs_col,
1178
+ pos_col=pos_col,
1179
+ )
1180
+ if heatmap_data is None:
1181
+ logger.warning(
1182
+ "No SNPs from LD heatmap overlap with region - heatmap not rendered"
1183
+ )
1184
+
914
1185
  # Calculate panel layout
915
1186
  panel_height = 2.5 # inches per GWAS panel
916
1187
  eqtl_height = 2.0 if eqtl_df is not None else 0
917
1188
  finemapping_height = 1.5 if finemapping_df is not None else 0
1189
+ heatmap_height_inches = panel_height * ld_heatmap_height if heatmap_data else 0
918
1190
 
919
1191
  # Gene track height
920
1192
  if genes_df is not None:
@@ -939,11 +1211,13 @@ class LocusZoomPlotter:
939
1211
  gene_track_height = 0
940
1212
 
941
1213
  # Calculate total panels and heights
1214
+ # Order from top to bottom: GWAS, finemapping, eQTL, gene track, heatmap
942
1215
  n_panels = (
943
1216
  n_gwas
944
1217
  + (1 if finemapping_df is not None else 0)
945
1218
  + (1 if eqtl_df is not None else 0)
946
1219
  + (1 if genes_df is not None else 0)
1220
+ + (1 if heatmap_data is not None else 0)
947
1221
  )
948
1222
  height_ratios = [panel_height] * n_gwas
949
1223
  if finemapping_df is not None:
@@ -952,6 +1226,8 @@ class LocusZoomPlotter:
952
1226
  height_ratios.append(eqtl_height)
953
1227
  if genes_df is not None:
954
1228
  height_ratios.append(gene_track_height)
1229
+ if heatmap_data is not None:
1230
+ height_ratios.append(heatmap_height_inches)
955
1231
 
956
1232
  # Calculate figure height
957
1233
  total_height = figsize[1] if figsize[1] else sum(height_ratios)
@@ -973,6 +1249,9 @@ class LocusZoomPlotter:
973
1249
  sharex=True,
974
1250
  )
975
1251
 
1252
+ # Collect label texts for deferred adjustment
1253
+ all_snp_label_texts: list[tuple] = []
1254
+
976
1255
  # Plot each GWAS panel
977
1256
  for i, (gwas_df, lead_pos) in enumerate(zip(gwas_dfs, lead_positions)):
978
1257
  ax = axes[i]
@@ -1023,9 +1302,10 @@ class LocusZoomPlotter:
1023
1302
  )
1024
1303
 
1025
1304
  # Add SNP labels (capability check - interactive backends use hover tooltips)
1305
+ # Create labels without adjusting - we'll adjust after axis limits are set
1026
1306
  if snp_labels and rs_col in df.columns and label_top_n > 0 and not df.empty:
1027
1307
  if self._backend.supports_snp_labels:
1028
- self._backend.add_snp_labels(
1308
+ texts = self._backend.add_snp_labels(
1029
1309
  ax,
1030
1310
  df,
1031
1311
  pos_col=pos_col,
@@ -1034,7 +1314,10 @@ class LocusZoomPlotter:
1034
1314
  label_top_n=label_top_n,
1035
1315
  genes_df=genes_df,
1036
1316
  chrom=chrom,
1317
+ adjust=False, # Defer adjustment until after axis limits set
1037
1318
  )
1319
+ if texts:
1320
+ all_snp_label_texts.append((ax, texts))
1038
1321
 
1039
1322
  # Add recombination overlay (only on first panel, all backends)
1040
1323
  if i == 0 and recomb_df is not None and not recomb_df.empty:
@@ -1070,7 +1353,8 @@ class LocusZoomPlotter:
1070
1353
  )
1071
1354
 
1072
1355
  if not fm_data.empty:
1073
- self._plot_finemapping(
1356
+ plot_finemapping(
1357
+ self._backend,
1074
1358
  ax,
1075
1359
  fm_data,
1076
1360
  pos_col="pos",
@@ -1218,8 +1502,37 @@ class LocusZoomPlotter:
1218
1502
  plot_gene_track_generic(
1219
1503
  gene_ax, self._backend, genes_df, chrom, start, end, exons_df
1220
1504
  )
1221
- self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
1222
1505
  self._backend.hide_spines(gene_ax, ["top", "right", "left"])
1506
+ panel_idx += 1
1507
+
1508
+ # Plot LD heatmap panel if provided (at very bottom)
1509
+ if heatmap_data is not None:
1510
+ heatmap_ax = axes[panel_idx]
1511
+ filtered_matrix, x_positions, filtered_snp_ids = heatmap_data
1512
+ # Find lead SNP ID from first GWAS panel if lead_positions set
1513
+ lead_snp_id = None
1514
+ if lead_positions and lead_positions[0] is not None:
1515
+ first_gwas = gwas_dfs[0]
1516
+ if rs_col in first_gwas.columns:
1517
+ lead_row = first_gwas[first_gwas[pos_col] == lead_positions[0]]
1518
+ if not lead_row.empty:
1519
+ lead_snp_id = lead_row[rs_col].iloc[0]
1520
+ self._render_heatmap_panel(
1521
+ ax=heatmap_ax,
1522
+ fig=fig,
1523
+ ld_matrix=filtered_matrix,
1524
+ x_positions=x_positions,
1525
+ snp_ids=filtered_snp_ids,
1526
+ metric=ld_heatmap_metric,
1527
+ lead_snp_id=lead_snp_id,
1528
+ start=start,
1529
+ end=end,
1530
+ )
1531
+ # Heatmap is at very bottom - set x-label here
1532
+ self._backend.set_xlabel(heatmap_ax, f"Chromosome {chrom} (Mb)")
1533
+ elif genes_df is not None:
1534
+ # Gene track is at bottom (no heatmap) - set x-label on gene track
1535
+ self._backend.set_xlabel(gene_ax, f"Chromosome {chrom} (Mb)")
1223
1536
  else:
1224
1537
  # Set x-label on bottom panel
1225
1538
  self._backend.set_xlabel(axes[-1], f"Chromosome {chrom} (Mb)")
@@ -1231,181 +1544,9 @@ class LocusZoomPlotter:
1231
1544
  # Adjust layout
1232
1545
  self._backend.finalize_layout(fig, hspace=0.1)
1233
1546
 
1234
- return fig
1235
-
1236
- def plot_phewas(
1237
- self,
1238
- phewas_df: pd.DataFrame,
1239
- variant_id: str,
1240
- phenotype_col: str = "phenotype",
1241
- p_col: str = "p_value",
1242
- category_col: str = "category",
1243
- effect_col: Optional[str] = None,
1244
- significance_threshold: float = 5e-8,
1245
- figsize: Tuple[float, float] = (10, 8),
1246
- ) -> Any:
1247
- """Create a PheWAS plot. See StatsPlotter.plot_phewas for docs."""
1248
- return self._stats_plotter.plot_phewas(
1249
- phewas_df=phewas_df,
1250
- variant_id=variant_id,
1251
- phenotype_col=phenotype_col,
1252
- p_col=p_col,
1253
- category_col=category_col,
1254
- effect_col=effect_col,
1255
- significance_threshold=significance_threshold,
1256
- figsize=figsize,
1257
- )
1547
+ # Adjust SNP labels AFTER all axis limits and layout are finalized
1548
+ # adjustText needs final plot bounds to position labels correctly
1549
+ for ax, texts in all_snp_label_texts:
1550
+ self._backend.adjust_snp_labels(ax, texts)
1258
1551
 
1259
- def plot_forest(
1260
- self,
1261
- forest_df: pd.DataFrame,
1262
- variant_id: str,
1263
- study_col: str = "study",
1264
- effect_col: str = "effect",
1265
- ci_lower_col: str = "ci_lower",
1266
- ci_upper_col: str = "ci_upper",
1267
- weight_col: Optional[str] = None,
1268
- null_value: float = 0.0,
1269
- effect_label: str = "Effect Size",
1270
- figsize: Tuple[float, float] = (8, 6),
1271
- ) -> Any:
1272
- """Create a forest plot. See StatsPlotter.plot_forest for docs."""
1273
- return self._stats_plotter.plot_forest(
1274
- forest_df=forest_df,
1275
- variant_id=variant_id,
1276
- study_col=study_col,
1277
- effect_col=effect_col,
1278
- ci_lower_col=ci_lower_col,
1279
- ci_upper_col=ci_upper_col,
1280
- weight_col=weight_col,
1281
- null_value=null_value,
1282
- effect_label=effect_label,
1283
- figsize=figsize,
1284
- )
1285
-
1286
- def plot_manhattan(
1287
- self,
1288
- df: pd.DataFrame,
1289
- chrom_col: str = "chrom",
1290
- pos_col: str = "pos",
1291
- p_col: str = "p",
1292
- custom_chrom_order: Optional[List[str]] = None,
1293
- category_col: Optional[str] = None,
1294
- category_order: Optional[List[str]] = None,
1295
- significance_threshold: Optional[float] = DEFAULT_GENOMEWIDE_THRESHOLD,
1296
- figsize: Tuple[float, float] = (12, 5),
1297
- title: Optional[str] = None,
1298
- ) -> Any:
1299
- """Create a Manhattan plot. See ManhattanPlotter.plot_manhattan for docs."""
1300
- return self._manhattan_plotter.plot_manhattan(
1301
- df=df,
1302
- chrom_col=chrom_col,
1303
- pos_col=pos_col,
1304
- p_col=p_col,
1305
- custom_chrom_order=custom_chrom_order,
1306
- category_col=category_col,
1307
- category_order=category_order,
1308
- significance_threshold=significance_threshold,
1309
- figsize=figsize,
1310
- title=title,
1311
- )
1312
-
1313
- def plot_qq(
1314
- self,
1315
- df: pd.DataFrame,
1316
- p_col: str = "p",
1317
- show_confidence_band: bool = True,
1318
- show_lambda: bool = True,
1319
- figsize: Tuple[float, float] = (6, 6),
1320
- title: Optional[str] = None,
1321
- ) -> Any:
1322
- """Create a QQ plot. See ManhattanPlotter.plot_qq for docs."""
1323
- return self._manhattan_plotter.plot_qq(
1324
- df=df,
1325
- p_col=p_col,
1326
- show_confidence_band=show_confidence_band,
1327
- show_lambda=show_lambda,
1328
- figsize=figsize,
1329
- title=title,
1330
- )
1331
-
1332
- def plot_manhattan_stacked(
1333
- self,
1334
- gwas_dfs: List[pd.DataFrame],
1335
- chrom_col: str = "chrom",
1336
- pos_col: str = "pos",
1337
- p_col: str = "p",
1338
- custom_chrom_order: Optional[List[str]] = None,
1339
- significance_threshold: Optional[float] = DEFAULT_GENOMEWIDE_THRESHOLD,
1340
- panel_labels: Optional[List[str]] = None,
1341
- figsize: Tuple[float, float] = (12, 8),
1342
- title: Optional[str] = None,
1343
- ) -> Any:
1344
- """Create stacked Manhattan plots. See ManhattanPlotter.plot_manhattan_stacked for docs."""
1345
- return self._manhattan_plotter.plot_manhattan_stacked(
1346
- gwas_dfs=gwas_dfs,
1347
- chrom_col=chrom_col,
1348
- pos_col=pos_col,
1349
- p_col=p_col,
1350
- custom_chrom_order=custom_chrom_order,
1351
- significance_threshold=significance_threshold,
1352
- panel_labels=panel_labels,
1353
- figsize=figsize,
1354
- title=title,
1355
- )
1356
-
1357
- def plot_manhattan_qq(
1358
- self,
1359
- df: pd.DataFrame,
1360
- chrom_col: str = "chrom",
1361
- pos_col: str = "pos",
1362
- p_col: str = "p",
1363
- custom_chrom_order: Optional[List[str]] = None,
1364
- significance_threshold: Optional[float] = DEFAULT_GENOMEWIDE_THRESHOLD,
1365
- show_confidence_band: bool = True,
1366
- show_lambda: bool = True,
1367
- figsize: Tuple[float, float] = (14, 5),
1368
- title: Optional[str] = None,
1369
- ) -> Any:
1370
- """Create side-by-side Manhattan and QQ plots. See ManhattanPlotter.plot_manhattan_qq for docs."""
1371
- return self._manhattan_plotter.plot_manhattan_qq(
1372
- df=df,
1373
- chrom_col=chrom_col,
1374
- pos_col=pos_col,
1375
- p_col=p_col,
1376
- custom_chrom_order=custom_chrom_order,
1377
- significance_threshold=significance_threshold,
1378
- show_confidence_band=show_confidence_band,
1379
- show_lambda=show_lambda,
1380
- figsize=figsize,
1381
- title=title,
1382
- )
1383
-
1384
- def plot_manhattan_qq_stacked(
1385
- self,
1386
- gwas_dfs: List[pd.DataFrame],
1387
- chrom_col: str = "chrom",
1388
- pos_col: str = "pos",
1389
- p_col: str = "p",
1390
- custom_chrom_order: Optional[List[str]] = None,
1391
- significance_threshold: Optional[float] = DEFAULT_GENOMEWIDE_THRESHOLD,
1392
- show_confidence_band: bool = True,
1393
- show_lambda: bool = True,
1394
- panel_labels: Optional[List[str]] = None,
1395
- figsize: Tuple[float, float] = (14, 8),
1396
- title: Optional[str] = None,
1397
- ) -> Any:
1398
- """Create stacked Manhattan+QQ plots. See ManhattanPlotter.plot_manhattan_qq_stacked for docs."""
1399
- return self._manhattan_plotter.plot_manhattan_qq_stacked(
1400
- gwas_dfs=gwas_dfs,
1401
- chrom_col=chrom_col,
1402
- pos_col=pos_col,
1403
- p_col=p_col,
1404
- custom_chrom_order=custom_chrom_order,
1405
- significance_threshold=significance_threshold,
1406
- show_confidence_band=show_confidence_band,
1407
- show_lambda=show_lambda,
1408
- panel_labels=panel_labels,
1409
- figsize=figsize,
1410
- title=title,
1411
- )
1552
+ return fig