pylocuszoom 1.2.0__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylocuszoom/coloc.py ADDED
@@ -0,0 +1,82 @@
1
+ """Colocalization data validation and preparation.
2
+
3
+ Validates GWAS and eQTL DataFrames for colocalization scatter plots.
4
+ """
5
+
6
+ from typing import Optional
7
+
8
+ import pandas as pd
9
+
10
+ from .validation import DataFrameValidator
11
+
12
+
13
+ def _validate_coloc_df(
14
+ df: pd.DataFrame,
15
+ df_name: str,
16
+ pos_col: str,
17
+ p_col: str,
18
+ rs_col: Optional[str] = None,
19
+ ) -> None:
20
+ """Validate DataFrame for colocalization plot.
21
+
22
+ Args:
23
+ df: Results DataFrame.
24
+ df_name: Name for error messages (e.g., "GWAS DataFrame").
25
+ pos_col: Column name for genomic positions.
26
+ p_col: Column name for p-values.
27
+ rs_col: Optional column name for SNP IDs.
28
+
29
+ Raises:
30
+ ValidationError: If required columns are missing or have invalid types.
31
+ """
32
+ required_cols = [pos_col, p_col]
33
+ if rs_col is not None:
34
+ required_cols.append(rs_col)
35
+
36
+ (
37
+ DataFrameValidator(df, df_name)
38
+ .require_columns(required_cols)
39
+ .require_numeric([pos_col, p_col])
40
+ .require_range(p_col, min_val=0, max_val=1, exclusive_min=True)
41
+ .validate()
42
+ )
43
+
44
+
45
+ def validate_coloc_gwas_df(
46
+ df: pd.DataFrame,
47
+ pos_col: str,
48
+ p_col: str,
49
+ rs_col: Optional[str] = None,
50
+ ) -> None:
51
+ """Validate GWAS DataFrame for colocalization plot.
52
+
53
+ Args:
54
+ df: GWAS results DataFrame.
55
+ pos_col: Column name for genomic positions.
56
+ p_col: Column name for p-values.
57
+ rs_col: Optional column name for SNP IDs.
58
+
59
+ Raises:
60
+ ValidationError: If required columns are missing or have invalid types.
61
+ """
62
+ _validate_coloc_df(df, "GWAS DataFrame", pos_col, p_col, rs_col)
63
+
64
+
65
+ def validate_coloc_eqtl_df(
66
+ df: pd.DataFrame,
67
+ pos_col: str,
68
+ p_col: str,
69
+ rs_col: Optional[str] = None,
70
+ ) -> None:
71
+ """Validate eQTL DataFrame for colocalization plot.
72
+
73
+ Args:
74
+ df: eQTL results DataFrame.
75
+ pos_col: Column name for genomic positions.
76
+ p_col: Column name for p-values.
77
+ rs_col: Optional column name for SNP IDs.
78
+
79
+ Raises:
80
+ ValidationError: If required columns are missing or have invalid types.
81
+ """
82
+ _validate_coloc_df(df, "eQTL DataFrame", pos_col, p_col, rs_col)
@@ -0,0 +1,390 @@
1
+ """Colocalization scatter plot for GWAS-eQTL visualization.
2
+
3
+ Creates scatter plots comparing GWAS -log10(p) vs eQTL -log10(p)
4
+ with points colored by LD to the lead SNP.
5
+ """
6
+
7
+ from typing import Any, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from scipy import stats
12
+
13
+ from .backends import BackendType, get_backend
14
+ from .coloc import validate_coloc_eqtl_df, validate_coloc_gwas_df
15
+ from .colors import (
16
+ EFFECT_CONGRUENT_COLOR,
17
+ EFFECT_INCONGRUENT_COLOR,
18
+ LD_BINS,
19
+ LD_NA_COLOR,
20
+ LEAD_SNP_COLOR,
21
+ get_ld_color,
22
+ )
23
+
24
+
25
+ def _resolve_merged_column(
26
+ merged: pd.DataFrame,
27
+ col: Optional[str],
28
+ suffix: str,
29
+ ) -> Optional[str]:
30
+ """Resolve column name after DataFrame merge.
31
+
32
+ When merging DataFrames, pandas adds suffixes to duplicate columns.
33
+ This helper finds the actual column name in the merged DataFrame.
34
+
35
+ Args:
36
+ merged: Merged DataFrame to search.
37
+ col: Original column name (or None).
38
+ suffix: Suffix added by merge (e.g., "_gwas" or "_eqtl").
39
+
40
+ Returns:
41
+ Resolved column name, or None if col was None or not found.
42
+ """
43
+ if col is None:
44
+ return None
45
+ suffixed = f"{col}{suffix}"
46
+ if suffixed in merged.columns:
47
+ return suffixed
48
+ if col in merged.columns:
49
+ return col
50
+ return None
51
+
52
+
53
+ def _get_effect_agreement_color(gwas_effect: float, eqtl_effect: float) -> str:
54
+ """Get color based on effect direction agreement.
55
+
56
+ Args:
57
+ gwas_effect: GWAS effect size (beta coefficient).
58
+ eqtl_effect: eQTL effect size (beta coefficient).
59
+
60
+ Returns:
61
+ Hex color code: green for same direction, red for opposite.
62
+ """
63
+ if pd.isna(gwas_effect) or pd.isna(eqtl_effect):
64
+ return LD_NA_COLOR
65
+ same_direction = (gwas_effect > 0) == (eqtl_effect > 0)
66
+ return EFFECT_CONGRUENT_COLOR if same_direction else EFFECT_INCONGRUENT_COLOR
67
+
68
+
69
+ class ColocPlotter:
70
+ """Colocalization scatter plot generator.
71
+
72
+ Creates scatter plots comparing GWAS -log10(p) vs eQTL -log10(p)
73
+ with points colored by LD to the lead SNP.
74
+
75
+ Supports multiple rendering backends:
76
+ - matplotlib (default): Static publication-quality plots
77
+ - plotly: Interactive HTML with hover tooltips
78
+ - bokeh: Interactive HTML for dashboards
79
+
80
+ Args:
81
+ backend: Plotting backend ('matplotlib', 'plotly', or 'bokeh').
82
+
83
+ Example:
84
+ >>> plotter = ColocPlotter()
85
+ >>> fig = plotter.plot_coloc(gwas_df, eqtl_df, lead_snp="rs12345")
86
+ >>> fig.savefig("coloc.png", dpi=150)
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ backend: BackendType = "matplotlib",
92
+ ):
93
+ """Initialize the colocalization plotter."""
94
+ self._backend = get_backend(backend)
95
+ self.backend_name = backend
96
+
97
+ def plot_coloc(
98
+ self,
99
+ gwas_df: pd.DataFrame,
100
+ eqtl_df: pd.DataFrame,
101
+ pos_col: str = "pos",
102
+ gwas_p_col: str = "p_gwas",
103
+ eqtl_p_col: str = "p_eqtl",
104
+ rs_col: Optional[str] = "rs",
105
+ ld_col: Optional[str] = None,
106
+ lead_snp: Optional[str] = None,
107
+ gwas_threshold: float = 5e-8,
108
+ eqtl_threshold: float = 1e-5,
109
+ show_correlation: bool = True,
110
+ color_by_effect: bool = False,
111
+ gwas_effect_col: Optional[str] = None,
112
+ eqtl_effect_col: Optional[str] = None,
113
+ h4_posterior: Optional[float] = None,
114
+ figsize: Tuple[float, float] = (8.0, 8.0),
115
+ title: Optional[str] = None,
116
+ ) -> Any:
117
+ """Create GWAS-eQTL colocalization scatter plot.
118
+
119
+ Args:
120
+ gwas_df: GWAS results DataFrame with positions and p-values.
121
+ eqtl_df: eQTL results DataFrame with positions and p-values.
122
+ pos_col: Column name for genomic positions (must exist in both).
123
+ gwas_p_col: Column name for GWAS p-values.
124
+ eqtl_p_col: Column name for eQTL p-values.
125
+ rs_col: Column name for SNP IDs (optional, for labeling lead SNP).
126
+ ld_col: Column name for LD R² values in GWAS df (optional).
127
+ lead_snp: SNP ID to highlight as lead variant. If None and ld_col
128
+ is provided, auto-selects SNP with highest combined -log10(p).
129
+ gwas_threshold: P-value threshold for GWAS significance line.
130
+ eqtl_threshold: P-value threshold for eQTL significance line.
131
+ show_correlation: Whether to display Pearson correlation.
132
+ color_by_effect: Whether to color points by effect direction agreement.
133
+ gwas_effect_col: Column name for GWAS effect sizes (required if
134
+ color_by_effect=True).
135
+ eqtl_effect_col: Column name for eQTL effect sizes (required if
136
+ color_by_effect=True).
137
+ h4_posterior: Optional COLOC H4 posterior probability to display.
138
+ figsize: Figure size as (width, height).
139
+ title: Plot title.
140
+
141
+ Returns:
142
+ Figure object (type depends on backend).
143
+
144
+ Raises:
145
+ ValidationError: If required columns are missing or invalid.
146
+ ValueError: If no overlapping positions between GWAS and eQTL.
147
+ ValueError: If lead_snp specified but not found in merged data.
148
+ ValueError: If color_by_effect=True but effect columns not provided.
149
+ ValueError: If h4_posterior is not in [0, 1] range.
150
+
151
+ Example:
152
+ >>> fig = plotter.plot_coloc(
153
+ ... gwas_df, eqtl_df,
154
+ ... ld_col="ld", lead_snp="rs12345",
155
+ ... )
156
+ >>> # With effect coloring
157
+ >>> fig = plotter.plot_coloc(
158
+ ... gwas_df, eqtl_df,
159
+ ... color_by_effect=True,
160
+ ... gwas_effect_col="beta_gwas",
161
+ ... eqtl_effect_col="beta_eqtl",
162
+ ... )
163
+ """
164
+ # Validate inputs
165
+ validate_coloc_gwas_df(gwas_df, pos_col, gwas_p_col, rs_col)
166
+ validate_coloc_eqtl_df(eqtl_df, pos_col, eqtl_p_col, rs_col)
167
+
168
+ # Validate effect coloring parameters
169
+ if color_by_effect:
170
+ if gwas_effect_col is None or eqtl_effect_col is None:
171
+ raise ValueError(
172
+ "color_by_effect=True requires gwas_effect_col and eqtl_effect_col"
173
+ )
174
+
175
+ # Validate h4_posterior range
176
+ if h4_posterior is not None:
177
+ if not (0 <= h4_posterior <= 1):
178
+ raise ValueError(f"h4_posterior must be in [0, 1], got {h4_posterior}")
179
+
180
+ # Merge DataFrames on position
181
+ merged = pd.merge(
182
+ gwas_df,
183
+ eqtl_df,
184
+ on=pos_col,
185
+ how="inner",
186
+ suffixes=("_gwas", "_eqtl"),
187
+ )
188
+
189
+ if len(merged) == 0:
190
+ raise ValueError(
191
+ "No overlapping positions between GWAS and eQTL DataFrames"
192
+ )
193
+
194
+ # Resolve column names after merge (pandas adds suffixes to duplicates)
195
+ merged_rs_col = _resolve_merged_column(merged, rs_col, "_gwas")
196
+ ld_col_merged = _resolve_merged_column(merged, ld_col, "_gwas")
197
+ gwas_p_merged = _resolve_merged_column(merged, gwas_p_col, "_gwas")
198
+ eqtl_p_merged = _resolve_merged_column(merged, eqtl_p_col, "_eqtl")
199
+
200
+ # Transform p-values to -log10(p)
201
+ merged["neglog10_gwas"] = -np.log10(merged[gwas_p_merged].clip(lower=1e-300))
202
+ merged["neglog10_eqtl"] = -np.log10(merged[eqtl_p_merged].clip(lower=1e-300))
203
+
204
+ # Drop rows with NaN in either transformed p-value
205
+ merged = merged.dropna(subset=["neglog10_gwas", "neglog10_eqtl"])
206
+
207
+ if len(merged) == 0:
208
+ raise ValueError("No valid data points after removing NaN p-values")
209
+
210
+ # Resolve effect columns if coloring by effect direction
211
+ gwas_effect_merged = None
212
+ eqtl_effect_merged = None
213
+ if color_by_effect:
214
+ gwas_effect_merged = _resolve_merged_column(
215
+ merged, gwas_effect_col, "_gwas"
216
+ )
217
+ if gwas_effect_merged is None:
218
+ raise ValueError(
219
+ f"gwas_effect_col '{gwas_effect_col}' not found in merged data"
220
+ )
221
+ eqtl_effect_merged = _resolve_merged_column(
222
+ merged, eqtl_effect_col, "_eqtl"
223
+ )
224
+ if eqtl_effect_merged is None:
225
+ raise ValueError(
226
+ f"eqtl_effect_col '{eqtl_effect_col}' not found in merged data"
227
+ )
228
+
229
+ # Apply coloring based on mode
230
+ if color_by_effect:
231
+ merged["color"] = merged.apply(
232
+ lambda row: _get_effect_agreement_color(
233
+ row[gwas_effect_merged], row[eqtl_effect_merged]
234
+ ),
235
+ axis=1,
236
+ )
237
+ elif ld_col_merged is not None:
238
+ merged["color"] = merged[ld_col_merged].apply(get_ld_color)
239
+ else:
240
+ merged["color"] = LD_NA_COLOR
241
+
242
+ # Determine lead SNP index
243
+ lead_idx = None
244
+ if lead_snp is not None:
245
+ if merged_rs_col is None:
246
+ raise ValueError(
247
+ f"lead_snp '{lead_snp}' specified but rs_col not found"
248
+ )
249
+ matches = merged[merged[merged_rs_col] == lead_snp]
250
+ if len(matches) == 0:
251
+ raise ValueError(f"lead_snp '{lead_snp}' not found in merged data")
252
+ lead_idx = matches.index[0]
253
+ elif ld_col_merged is not None:
254
+ # Auto-select: highest combined -log10(p)
255
+ merged["combined_score"] = merged["neglog10_gwas"] + merged["neglog10_eqtl"]
256
+ lead_idx = merged["combined_score"].idxmax()
257
+
258
+ # Create figure
259
+ fig, axes = self._backend.create_figure(
260
+ n_panels=1,
261
+ height_ratios=[1.0],
262
+ figsize=figsize,
263
+ )
264
+ ax = axes[0]
265
+
266
+ # Separate lead SNP from other points
267
+ if lead_idx is not None:
268
+ lead_row = merged.loc[[lead_idx]]
269
+ other_rows = merged.drop(lead_idx)
270
+ else:
271
+ lead_row = pd.DataFrame()
272
+ other_rows = merged
273
+
274
+ # Plot non-lead points
275
+ if len(other_rows) > 0:
276
+ self._backend.scatter(
277
+ ax,
278
+ other_rows["neglog10_gwas"],
279
+ other_rows["neglog10_eqtl"],
280
+ colors=other_rows["color"].tolist(),
281
+ sizes=60,
282
+ marker="o",
283
+ edgecolor="black",
284
+ linewidth=0.5,
285
+ zorder=2,
286
+ )
287
+
288
+ # Plot lead SNP as diamond
289
+ if len(lead_row) > 0:
290
+ self._backend.scatter(
291
+ ax,
292
+ lead_row["neglog10_gwas"],
293
+ lead_row["neglog10_eqtl"],
294
+ colors=LEAD_SNP_COLOR,
295
+ sizes=100,
296
+ marker="D",
297
+ edgecolor="black",
298
+ linewidth=0.5,
299
+ zorder=5,
300
+ )
301
+
302
+ # Add lead SNP label
303
+ if merged_rs_col is not None:
304
+ label = lead_row[merged_rs_col].values[0]
305
+ x_pos = lead_row["neglog10_gwas"].values[0]
306
+ y_pos = lead_row["neglog10_eqtl"].values[0]
307
+ self._backend.add_text(
308
+ ax,
309
+ x_pos,
310
+ y_pos + 0.5, # Offset above the point
311
+ label,
312
+ fontsize=9,
313
+ ha="center",
314
+ va="bottom",
315
+ )
316
+
317
+ # Add significance threshold lines
318
+ gwas_sig_line = -np.log10(gwas_threshold)
319
+ eqtl_sig_line = -np.log10(eqtl_threshold)
320
+
321
+ self._backend.axvline(
322
+ ax, x=gwas_sig_line, color="grey", linestyle="--", linewidth=1, alpha=0.7
323
+ )
324
+ self._backend.axhline(
325
+ ax, y=eqtl_sig_line, color="grey", linestyle="--", linewidth=1, alpha=0.7
326
+ )
327
+
328
+ # Calculate data bounds once for text positioning
329
+ x_min, x_max = merged["neglog10_gwas"].min(), merged["neglog10_gwas"].max()
330
+ y_min, y_max = merged["neglog10_eqtl"].min(), merged["neglog10_eqtl"].max()
331
+ x_range = x_max - x_min
332
+ y_range = y_max - y_min
333
+
334
+ # Display correlation in top-left corner
335
+ if show_correlation and len(merged) >= 3:
336
+ r, p = stats.pearsonr(merged["neglog10_gwas"], merged["neglog10_eqtl"])
337
+ p_str = "p < 0.001" if p < 0.001 else f"p = {p:.3f}"
338
+ corr_text = f"r = {r:.3f}\n{p_str}"
339
+ self._backend.add_text(
340
+ ax,
341
+ x_min + 0.05 * x_range,
342
+ y_max - 0.05 * y_range,
343
+ corr_text,
344
+ fontsize=10,
345
+ ha="left",
346
+ va="top",
347
+ )
348
+
349
+ # Display H4 posterior probability in bottom-right corner
350
+ if h4_posterior is not None:
351
+ self._backend.add_text(
352
+ ax,
353
+ x_max - 0.05 * x_range,
354
+ y_min + 0.05 * y_range,
355
+ f"H4 PP = {h4_posterior:.3f}",
356
+ fontsize=10,
357
+ ha="right",
358
+ va="bottom",
359
+ )
360
+
361
+ # Set axis labels
362
+ self._backend.set_xlabel(ax, r"GWAS $-\log_{10}$ P")
363
+ self._backend.set_ylabel(ax, r"eQTL $-\log_{10}$ P")
364
+
365
+ # Set title
366
+ if title:
367
+ self._backend.set_title(ax, title)
368
+
369
+ # Hide top and right spines
370
+ self._backend.hide_spines(ax, ["top", "right"])
371
+
372
+ # Add legend
373
+ if color_by_effect:
374
+ self._add_effect_legend(ax)
375
+ elif ld_col_merged is not None:
376
+ self._backend.add_ld_legend(ax, LD_BINS, LEAD_SNP_COLOR)
377
+
378
+ # Finalize layout
379
+ self._backend.finalize_layout(fig)
380
+
381
+ return fig
382
+
383
+ def _add_effect_legend(self, ax: Any) -> None:
384
+ """Add effect direction legend to plot (all backends)."""
385
+ effect_bins = [
386
+ (0.0, "Same direction", EFFECT_CONGRUENT_COLOR),
387
+ (0.0, "Opposite direction", EFFECT_INCONGRUENT_COLOR),
388
+ (0.0, "Missing effect", LD_NA_COLOR),
389
+ ]
390
+ self._backend.add_effect_legend(ax, effect_bins)
pylocuszoom/colors.py CHANGED
@@ -280,3 +280,29 @@ def get_phewas_category_palette(categories: List[str]) -> dict[str, str]:
280
280
  Dictionary mapping category names to hex colors.
281
281
  """
282
282
  return {cat: get_phewas_category_color(i) for i, cat in enumerate(categories)}
283
+
284
+
285
+ # =============================================================================
286
+ # LD Heatmap Colors
287
+ # =============================================================================
288
+
289
+ # Custom colormap name for LD heatmaps
290
+ LD_HEATMAP_CMAP_NAME = "ld_heatmap"
291
+
292
+ # White-to-red gradient for R² heatmaps (0 = white, 1 = red)
293
+ LD_HEATMAP_COLORS: List[str] = ["#FFFFFF", "#FF0000"]
294
+
295
+ # Color for missing/NaN LD values in heatmaps
296
+ LD_HEATMAP_MISSING_COLOR = "#808080" # grey
297
+
298
+ # Highlight colors for lead and secondary SNPs in heatmaps
299
+ LEAD_SNP_HIGHLIGHT_COLOR = "#FF0000" # red border/outline for lead SNP
300
+ SECONDARY_HIGHLIGHT_COLOR = "#0000FF" # blue for other highlighted SNPs
301
+
302
+ # =============================================================================
303
+ # Colocalization Effect Direction Colors
304
+ # =============================================================================
305
+
306
+ # Colors for effect direction agreement in colocalization plots
307
+ EFFECT_CONGRUENT_COLOR = "#4DAF4A" # green - same direction
308
+ EFFECT_INCONGRUENT_COLOR = "#E41A1C" # red - opposite direction
pylocuszoom/config.py CHANGED
@@ -355,6 +355,66 @@ class StackedPlotConfig(BaseModel):
355
355
  )
356
356
 
357
357
 
358
+ class ColocConfig(BaseModel):
359
+ """Configuration for colocalization plot.
360
+
361
+ Attributes:
362
+ gwas_p_col: Column name for GWAS p-values.
363
+ eqtl_p_col: Column name for eQTL p-values.
364
+ pos_col: Column name for genomic position.
365
+ rs_col: Optional column name for SNP identifiers.
366
+ ld_col: Optional column name for pre-computed LD values.
367
+ lead_snp: Optional lead SNP identifier for highlighting.
368
+ gwas_threshold: GWAS significance threshold (default 5e-8).
369
+ eqtl_threshold: eQTL significance threshold (default 1e-5).
370
+ show_correlation: Whether to display Pearson correlation.
371
+ color_by_effect: Whether to color by effect direction agreement.
372
+ gwas_effect_col: Column name for GWAS effect sizes.
373
+ eqtl_effect_col: Column name for eQTL effect sizes.
374
+ h4_posterior: Optional COLOC H4 posterior probability to display.
375
+ figsize: Figure size as (width, height).
376
+ """
377
+
378
+ model_config = ConfigDict(frozen=True)
379
+
380
+ gwas_p_col: str = Field(default="p_gwas", description="GWAS p-value column")
381
+ eqtl_p_col: str = Field(default="p_eqtl", description="eQTL p-value column")
382
+ pos_col: str = Field(default="pos", description="Position column")
383
+ rs_col: Optional[str] = Field(default="rs", description="SNP ID column")
384
+ ld_col: Optional[str] = Field(default=None, description="Pre-computed LD column")
385
+ lead_snp: Optional[str] = Field(default=None, description="Lead SNP ID")
386
+ gwas_threshold: float = Field(
387
+ default=5e-8, gt=0, le=1, description="GWAS significance"
388
+ )
389
+ eqtl_threshold: float = Field(
390
+ default=1e-5, gt=0, le=1, description="eQTL significance"
391
+ )
392
+ show_correlation: bool = Field(default=True, description="Show Pearson correlation")
393
+ color_by_effect: bool = Field(
394
+ default=False, description="Color by effect agreement"
395
+ )
396
+ gwas_effect_col: Optional[str] = Field(
397
+ default=None, description="GWAS effect column"
398
+ )
399
+ eqtl_effect_col: Optional[str] = Field(
400
+ default=None, description="eQTL effect column"
401
+ )
402
+ h4_posterior: Optional[float] = Field(
403
+ default=None, ge=0, le=1, description="COLOC H4 PP"
404
+ )
405
+ figsize: Tuple[float, float] = Field(default=(8.0, 8.0), description="Figure size")
406
+
407
+ @model_validator(mode="after")
408
+ def validate_effect_coloring(self) -> "ColocConfig":
409
+ """Validate that effect coloring has required columns."""
410
+ if self.color_by_effect:
411
+ if self.gwas_effect_col is None or self.eqtl_effect_col is None:
412
+ raise ValueError(
413
+ "color_by_effect=True requires gwas_effect_col and eqtl_effect_col"
414
+ )
415
+ return self
416
+
417
+
358
418
  __all__ = [
359
419
  "RegionConfig",
360
420
  "ColumnConfig",
@@ -362,4 +422,5 @@ __all__ = [
362
422
  "LDConfig",
363
423
  "PlotConfig",
364
424
  "StackedPlotConfig",
425
+ "ColocConfig",
365
426
  ]
pylocuszoom/labels.py CHANGED
@@ -24,6 +24,7 @@ def add_snp_labels(
24
24
  genes_df: Optional[pd.DataFrame] = None,
25
25
  chrom: Optional[Union[int, str]] = None,
26
26
  max_label_length: int = 15,
27
+ adjust: bool = True,
27
28
  **kwargs: Any,
28
29
  ) -> List[Annotation]:
29
30
  """Add text labels to top SNPs in the regional plot.
@@ -41,6 +42,8 @@ def add_snp_labels(
41
42
  genes_df: Unused, kept for backward compatibility.
42
43
  chrom: Unused, kept for backward compatibility.
43
44
  max_label_length: Maximum label length before truncation.
45
+ adjust: If True, run adjustText immediately. If False, caller must
46
+ call adjust_snp_labels() after setting axis limits.
44
47
 
45
48
  Returns:
46
49
  List of matplotlib text annotation objects.
@@ -101,21 +104,43 @@ def add_snp_labels(
101
104
  )
102
105
  texts.append(text)
103
106
 
104
- # Only use adjustText when there are multiple labels to avoid overlap
105
- if len(texts) > 1:
106
- try:
107
- from adjustText import adjust_text
108
-
109
- adjust_text(
110
- texts,
111
- ax=ax,
112
- arrowprops=dict(arrowstyle="-", color="gray", lw=0.5),
113
- expand_points=(1.5, 1.5),
114
- )
115
- except ImportError:
116
- logger.warning(
117
- "adjustText not installed - SNP labels may overlap. "
118
- "Install with: pip install adjustText"
119
- )
107
+ if adjust:
108
+ adjust_snp_labels(ax, texts)
120
109
 
121
110
  return texts
111
+
112
+
113
+ def adjust_snp_labels(ax: Axes, texts: List[Annotation]) -> None:
114
+ """Adjust SNP label positions to avoid overlaps.
115
+
116
+ This function should be called AFTER all axis limits have been set,
117
+ as adjustText needs to know the final plot bounds to position labels
118
+ correctly within the visible area.
119
+
120
+ Args:
121
+ ax: Matplotlib axes object.
122
+ texts: List of text annotation objects from add_snp_labels().
123
+
124
+ Example:
125
+ >>> texts = add_snp_labels(ax, df, adjust=False)
126
+ >>> ax.set_xlim(start, end)
127
+ >>> ax.set_ylim(0, max_y)
128
+ >>> adjust_snp_labels(ax, texts)
129
+ """
130
+ if len(texts) <= 1:
131
+ return
132
+
133
+ try:
134
+ from adjustText import adjust_text
135
+
136
+ adjust_text(
137
+ texts,
138
+ ax=ax,
139
+ arrowprops=dict(arrowstyle="-", color="gray", lw=0.5),
140
+ expand_points=(1.5, 1.5),
141
+ )
142
+ except ImportError:
143
+ logger.warning(
144
+ "adjustText not installed - SNP labels may overlap. "
145
+ "Install with: pip install adjustText"
146
+ )