pylocuszoom 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,223 @@
1
+ """Fine-mapping/SuSiE data handling for pyLocusZoom.
2
+
3
+ Provides utilities for loading, validating, and preparing statistical
4
+ fine-mapping results (SuSiE, FINEMAP, etc.) for visualization.
5
+ """
6
+
7
+ from typing import List, Optional
8
+
9
+ import pandas as pd
10
+
11
+ from .logging import logger
12
+
13
+ # Required columns for fine-mapping data
14
+ REQUIRED_FINEMAPPING_COLS = ["pos", "pip"]
15
+ OPTIONAL_FINEMAPPING_COLS = ["rs", "cs", "cs_id", "effect", "se"]
16
+
17
+
18
+ class FinemappingValidationError(ValueError):
19
+ """Raised when fine-mapping DataFrame validation fails."""
20
+
21
+ pass
22
+
23
+
24
+ def validate_finemapping_df(
25
+ df: pd.DataFrame,
26
+ pos_col: str = "pos",
27
+ pip_col: str = "pip",
28
+ ) -> None:
29
+ """Validate fine-mapping DataFrame has required columns.
30
+
31
+ Args:
32
+ df: Fine-mapping DataFrame to validate.
33
+ pos_col: Column name for genomic position.
34
+ pip_col: Column name for posterior inclusion probability.
35
+
36
+ Raises:
37
+ FinemappingValidationError: If required columns are missing.
38
+ """
39
+ missing = []
40
+ if pos_col not in df.columns:
41
+ missing.append(pos_col)
42
+ if pip_col not in df.columns:
43
+ missing.append(pip_col)
44
+
45
+ if missing:
46
+ raise FinemappingValidationError(
47
+ f"Fine-mapping DataFrame missing required columns: {missing}. "
48
+ f"Required: {pos_col} (position), {pip_col} (posterior inclusion probability)"
49
+ )
50
+
51
+ # Validate PIP values are in [0, 1]
52
+ if not df[pip_col].between(0, 1).all():
53
+ invalid_count = (~df[pip_col].between(0, 1)).sum()
54
+ raise FinemappingValidationError(
55
+ f"PIP values must be between 0 and 1. Found {invalid_count} invalid values."
56
+ )
57
+
58
+
59
+ def filter_finemapping_by_region(
60
+ df: pd.DataFrame,
61
+ chrom: int,
62
+ start: int,
63
+ end: int,
64
+ pos_col: str = "pos",
65
+ chrom_col: Optional[str] = "chr",
66
+ ) -> pd.DataFrame:
67
+ """Filter fine-mapping data to a genomic region.
68
+
69
+ Args:
70
+ df: Fine-mapping DataFrame.
71
+ chrom: Chromosome number.
72
+ start: Start position.
73
+ end: End position.
74
+ pos_col: Column name for position.
75
+ chrom_col: Column name for chromosome (if present).
76
+
77
+ Returns:
78
+ Filtered DataFrame containing only variants in the region.
79
+ """
80
+ mask = (df[pos_col] >= start) & (df[pos_col] <= end)
81
+
82
+ # Filter by chromosome if column exists
83
+ if chrom_col and chrom_col in df.columns:
84
+ chrom_str = str(chrom).replace("chr", "")
85
+ df_chrom = df[chrom_col].astype(str).str.replace("chr", "", regex=False)
86
+ mask = mask & (df_chrom == chrom_str)
87
+
88
+ filtered = df[mask].copy()
89
+ logger.debug(
90
+ f"Filtered fine-mapping data to {len(filtered)} variants in region "
91
+ f"chr{chrom}:{start}-{end}"
92
+ )
93
+ return filtered
94
+
95
+
96
+ def get_credible_sets(
97
+ df: pd.DataFrame,
98
+ cs_col: str = "cs",
99
+ ) -> List[int]:
100
+ """Get list of unique credible set IDs.
101
+
102
+ Args:
103
+ df: Fine-mapping DataFrame.
104
+ cs_col: Column containing credible set assignments.
105
+
106
+ Returns:
107
+ Sorted list of unique credible set IDs (excluding 0/NA).
108
+ """
109
+ if cs_col not in df.columns:
110
+ return []
111
+ # Filter out variants not in a credible set (typically cs=0 or NA)
112
+ cs_values = df[cs_col].dropna()
113
+ cs_values = cs_values[cs_values != 0]
114
+ return sorted(cs_values.unique().tolist())
115
+
116
+
117
+ def filter_by_credible_set(
118
+ df: pd.DataFrame,
119
+ cs_id: int,
120
+ cs_col: str = "cs",
121
+ ) -> pd.DataFrame:
122
+ """Filter to variants in a specific credible set.
123
+
124
+ Args:
125
+ df: Fine-mapping DataFrame.
126
+ cs_id: Credible set ID to filter for.
127
+ cs_col: Column containing credible set assignments.
128
+
129
+ Returns:
130
+ Filtered DataFrame containing only variants in the credible set.
131
+ """
132
+ if cs_col not in df.columns:
133
+ raise FinemappingValidationError(
134
+ f"Cannot filter by credible set: column '{cs_col}' not found. "
135
+ f"Available columns: {list(df.columns)}"
136
+ )
137
+ return df[df[cs_col] == cs_id].copy()
138
+
139
+
140
+ def prepare_finemapping_for_plotting(
141
+ df: pd.DataFrame,
142
+ pos_col: str = "pos",
143
+ pip_col: str = "pip",
144
+ chrom: Optional[int] = None,
145
+ start: Optional[int] = None,
146
+ end: Optional[int] = None,
147
+ ) -> pd.DataFrame:
148
+ """Prepare fine-mapping data for plotting.
149
+
150
+ Validates, filters, and sorts data for plotting as a line or scatter.
151
+
152
+ Args:
153
+ df: Raw fine-mapping DataFrame.
154
+ pos_col: Column name for position.
155
+ pip_col: Column name for PIP.
156
+ chrom: Optional chromosome for region filtering.
157
+ start: Optional start position for region filtering.
158
+ end: Optional end position for region filtering.
159
+
160
+ Returns:
161
+ Prepared DataFrame sorted by position.
162
+ """
163
+ validate_finemapping_df(df, pos_col=pos_col, pip_col=pip_col)
164
+
165
+ result = df.copy()
166
+
167
+ # Filter by region if specified
168
+ if chrom is not None and start is not None and end is not None:
169
+ result = filter_finemapping_by_region(
170
+ result, chrom, start, end, pos_col=pos_col
171
+ )
172
+
173
+ # Sort by position for line plotting
174
+ result = result.sort_values(pos_col)
175
+
176
+ return result
177
+
178
+
179
+ def get_top_pip_variants(
180
+ df: pd.DataFrame,
181
+ n: int = 5,
182
+ pip_col: str = "pip",
183
+ pip_threshold: float = 0.0,
184
+ ) -> pd.DataFrame:
185
+ """Get top variants by posterior inclusion probability.
186
+
187
+ Args:
188
+ df: Fine-mapping DataFrame.
189
+ n: Number of top variants to return.
190
+ pip_col: Column containing PIP values.
191
+ pip_threshold: Minimum PIP threshold.
192
+
193
+ Returns:
194
+ DataFrame with top N variants by PIP.
195
+ """
196
+ filtered = df[df[pip_col] >= pip_threshold]
197
+ return filtered.nlargest(n, pip_col)
198
+
199
+
200
+ def calculate_credible_set_coverage(
201
+ df: pd.DataFrame,
202
+ cs_col: str = "cs",
203
+ pip_col: str = "pip",
204
+ ) -> dict:
205
+ """Calculate cumulative PIP for each credible set.
206
+
207
+ Args:
208
+ df: Fine-mapping DataFrame.
209
+ cs_col: Column containing credible set assignments.
210
+ pip_col: Column containing PIP values.
211
+
212
+ Returns:
213
+ Dictionary mapping credible set ID to cumulative PIP.
214
+ """
215
+ if cs_col not in df.columns:
216
+ return {}
217
+
218
+ coverage = {}
219
+ for cs_id in get_credible_sets(df, cs_col):
220
+ cs_data = filter_by_credible_set(df, cs_id, cs_col)
221
+ coverage[cs_id] = cs_data[pip_col].sum()
222
+
223
+ return coverage
pylocuszoom/gene_track.py CHANGED
@@ -7,7 +7,7 @@ Provides LocusZoom-style gene track plotting with:
7
7
  - Gene name labels
8
8
  """
9
9
 
10
- from typing import List, Optional, Union
10
+ from typing import Any, List, Optional, Union
11
11
 
12
12
  import pandas as pd
13
13
  from matplotlib.axes import Axes
@@ -15,17 +15,17 @@ from matplotlib.patches import Polygon, Rectangle
15
15
 
16
16
  from .utils import normalize_chrom
17
17
 
18
- # Strand-specific colors (bold, distinct)
18
+ # Strand-specific colors (distinct from LD palette)
19
19
  STRAND_COLORS: dict[Optional[str], str] = {
20
- "+": "#6A3D9A", # Bold purple for forward strand
21
- "-": "#1F78B4", # Bold teal/blue for reverse strand
22
- None: "#666666", # Grey if no strand info
20
+ "+": "#DAA520", # Goldenrod for forward strand
21
+ "-": "#6BB3FF", # Light blue for reverse strand
22
+ None: "#999999", # Light grey if no strand info
23
23
  }
24
24
 
25
25
  # Layout constants
26
- ROW_HEIGHT = 0.40 # Total height per row
27
- GENE_AREA = 0.28 # Bottom portion for gene drawing
28
- EXON_HEIGHT = 0.22 # Exon rectangle height
26
+ ROW_HEIGHT = 0.35 # Total height per row (reduced for tighter spacing)
27
+ GENE_AREA = 0.25 # Bottom portion for gene drawing
28
+ EXON_HEIGHT = 0.20 # Exon rectangle height
29
29
  INTRON_HEIGHT = 0.02 # Thin intron line
30
30
 
31
31
 
@@ -145,7 +145,7 @@ def plot_gene_track(
145
145
  ].copy()
146
146
 
147
147
  ax.set_xlim(start, end)
148
- ax.set_ylabel("Genes", fontsize=10)
148
+ ax.set_ylabel("")
149
149
  ax.set_yticks([])
150
150
 
151
151
  # theme_classic: only bottom spine
@@ -175,7 +175,7 @@ def plot_gene_track(
175
175
  # Set y-axis limits - small bottom margin for gene body, tight top
176
176
  max_row = max(positions) if positions else 0
177
177
  bottom_margin = EXON_HEIGHT / 2 + 0.02 # Room for bottom gene
178
- top_margin = 0.15 # Small space above top label
178
+ top_margin = 0.05 # Minimal space above top label
179
179
  ax.set_ylim(
180
180
  -bottom_margin,
181
181
  (max_row + 1) * ROW_HEIGHT - ROW_HEIGHT + GENE_AREA + top_margin,
@@ -193,6 +193,8 @@ def plot_gene_track(
193
193
  & (exons_df["start"] <= end)
194
194
  ].copy()
195
195
 
196
+ region_width = end - start
197
+
196
198
  for idx, (_, gene) in enumerate(region_genes.iterrows()):
197
199
  gene_start = max(int(gene["start"]), start)
198
200
  gene_end = min(int(gene["end"]), end)
@@ -255,43 +257,59 @@ def plot_gene_track(
255
257
  )
256
258
  )
257
259
 
258
- # Add strand direction triangle at gene tip
260
+ # Add strand direction triangles (tip, center, tail)
259
261
  if "strand" in gene.index:
260
262
  strand = gene["strand"]
261
- region_width = end - start
262
263
  arrow_dir = 1 if strand == "+" else -1
263
264
 
264
- # Triangle dimensions - whole arrow past gene end
265
+ # Triangle dimensions
265
266
  tri_height = EXON_HEIGHT * 0.35
266
267
  tri_width = region_width * 0.006
267
268
 
268
- # Triangle entirely past gene tip
269
- if arrow_dir == 1: # Forward strand: arrow starts at gene end
270
- base_x = gene_end
271
- tip_x = base_x + tri_width
272
- tri_points = [
273
- [tip_x, y_gene], # Tip pointing right
274
- [base_x, y_gene + tri_height],
275
- [base_x, y_gene - tri_height],
269
+ # Arrow positions: front, middle, back (tip positions)
270
+ tip_offset = tri_width / 2 # Tiny offset to keep tip inside gene
271
+ tail_offset = tri_width * 1.5 # Offset for tail arrow from gene start/end
272
+ gene_center = (gene_start + gene_end) / 2
273
+ if arrow_dir == 1: # Forward strand
274
+ arrow_tip_positions = [
275
+ gene_start + tail_offset, # Tail (tip inside gene)
276
+ gene_center + tri_width / 2, # Middle (arrow center at gene center)
277
+ gene_end - tip_offset, # Tip (near gene end)
276
278
  ]
277
- else: # Reverse strand: arrow starts at gene start
278
- base_x = gene_start
279
- tip_x = base_x - tri_width
280
- tri_points = [
281
- [tip_x, y_gene], # Tip pointing left
282
- [base_x, y_gene + tri_height],
283
- [base_x, y_gene - tri_height],
279
+ arrow_color = "#000000" # Black for forward
280
+ else: # Reverse strand
281
+ arrow_tip_positions = [
282
+ gene_end - tail_offset, # Tail (tip inside gene)
283
+ gene_center - tri_width / 2, # Middle (arrow center at gene center)
284
+ gene_start + tip_offset, # Tip (near gene start)
284
285
  ]
285
-
286
- triangle = Polygon(
287
- tri_points,
288
- closed=True,
289
- facecolor="black",
290
- edgecolor="black",
291
- linewidth=0.5,
292
- zorder=5,
293
- )
294
- ax.add_patch(triangle)
286
+ arrow_color = "#333333" # Dark grey for reverse
287
+
288
+ for tip_x in arrow_tip_positions:
289
+ if arrow_dir == 1:
290
+ base_x = tip_x - tri_width
291
+ tri_points = [
292
+ [tip_x, y_gene], # Tip pointing right
293
+ [base_x, y_gene + tri_height],
294
+ [base_x, y_gene - tri_height],
295
+ ]
296
+ else:
297
+ base_x = tip_x + tri_width
298
+ tri_points = [
299
+ [tip_x, y_gene], # Tip pointing left
300
+ [base_x, y_gene + tri_height],
301
+ [base_x, y_gene - tri_height],
302
+ ]
303
+
304
+ triangle = Polygon(
305
+ tri_points,
306
+ closed=True,
307
+ facecolor=arrow_color,
308
+ edgecolor=arrow_color,
309
+ linewidth=0.5,
310
+ zorder=5,
311
+ )
312
+ ax.add_patch(triangle)
295
313
 
296
314
  # Add gene name label in the gap above gene
297
315
  if gene_name:
@@ -309,3 +327,206 @@ def plot_gene_track(
309
327
  zorder=4,
310
328
  clip_on=True,
311
329
  )
330
+
331
+
332
+ def plot_gene_track_generic(
333
+ ax: Any,
334
+ backend: Any,
335
+ genes_df: pd.DataFrame,
336
+ chrom: Union[int, str],
337
+ start: int,
338
+ end: int,
339
+ exons_df: Optional[pd.DataFrame] = None,
340
+ ) -> None:
341
+ """Plot gene annotations using a backend-agnostic approach.
342
+
343
+ This function works with matplotlib, plotly, and bokeh backends.
344
+
345
+ Args:
346
+ ax: Axes object (format depends on backend).
347
+ backend: Backend instance with drawing methods.
348
+ genes_df: Gene annotations with chr, start, end, gene_name,
349
+ and optionally strand (+/-) column.
350
+ chrom: Chromosome number or string.
351
+ start: Region start position.
352
+ end: Region end position.
353
+ exons_df: Exon annotations with chr, start, end, gene_name
354
+ columns for drawing exon structure. Optional.
355
+ """
356
+ chrom_str = normalize_chrom(chrom)
357
+ region_genes = genes_df[
358
+ (genes_df["chr"].astype(str).str.replace("chr", "", regex=False) == chrom_str)
359
+ & (genes_df["end"] >= start)
360
+ & (genes_df["start"] <= end)
361
+ ].copy()
362
+
363
+ backend.set_xlim(ax, start, end)
364
+ backend.set_ylabel(ax, "", fontsize=10)
365
+
366
+ if region_genes.empty:
367
+ backend.set_ylim(ax, 0, 1)
368
+ backend.add_text(
369
+ ax,
370
+ (start + end) / 2,
371
+ 0.5,
372
+ "No genes",
373
+ fontsize=9,
374
+ ha="center",
375
+ va="center",
376
+ color="grey",
377
+ )
378
+ return
379
+
380
+ # Assign vertical positions to avoid overlap
381
+ region_genes = region_genes.sort_values("start")
382
+ positions = assign_gene_positions(region_genes, start, end)
383
+
384
+ # Set y-axis limits - small bottom margin for gene body, tight top
385
+ max_row = max(positions) if positions else 0
386
+ bottom_margin = EXON_HEIGHT / 2 + 0.02 # Room for bottom gene
387
+ top_margin = 0.05 # Minimal space above top label
388
+ backend.set_ylim(
389
+ ax,
390
+ -bottom_margin,
391
+ (max_row + 1) * ROW_HEIGHT - ROW_HEIGHT + GENE_AREA + top_margin,
392
+ )
393
+
394
+ # Filter exons for this region if available
395
+ region_exons = None
396
+ if exons_df is not None and not exons_df.empty:
397
+ region_exons = exons_df[
398
+ (
399
+ exons_df["chr"].astype(str).str.replace("chr", "", regex=False)
400
+ == chrom_str
401
+ )
402
+ & (exons_df["end"] >= start)
403
+ & (exons_df["start"] <= end)
404
+ ].copy()
405
+
406
+ region_width = end - start
407
+
408
+ for idx, (_, gene) in enumerate(region_genes.iterrows()):
409
+ gene_start = max(int(gene["start"]), start)
410
+ gene_end = min(int(gene["end"]), end)
411
+ row = positions[idx]
412
+ gene_name = gene.get("gene_name", "")
413
+
414
+ # Get strand-specific color
415
+ strand = gene.get("strand") if "strand" in gene.index else None
416
+ gene_col = STRAND_COLORS.get(strand, STRAND_COLORS[None])
417
+
418
+ # Y position: bottom of row + offset for gene area
419
+ y_gene = row * ROW_HEIGHT + 0.05
420
+ y_label = y_gene + EXON_HEIGHT / 2 + 0.01 # Just above gene top
421
+
422
+ # Check if we have exon data for this gene
423
+ gene_exons = None
424
+ if region_exons is not None and not region_exons.empty and gene_name:
425
+ gene_exons = region_exons[region_exons["gene_name"] == gene_name].copy()
426
+
427
+ if gene_exons is not None and not gene_exons.empty:
428
+ # Draw intron line (thin horizontal line spanning gene)
429
+ backend.add_rectangle(
430
+ ax,
431
+ (gene_start, y_gene - INTRON_HEIGHT / 2),
432
+ gene_end - gene_start,
433
+ INTRON_HEIGHT,
434
+ facecolor=gene_col,
435
+ edgecolor=gene_col,
436
+ linewidth=0.5,
437
+ zorder=1,
438
+ )
439
+
440
+ # Draw exons (thick rectangles)
441
+ for _, exon in gene_exons.iterrows():
442
+ exon_start = max(int(exon["start"]), start)
443
+ exon_end = min(int(exon["end"]), end)
444
+ backend.add_rectangle(
445
+ ax,
446
+ (exon_start, y_gene - EXON_HEIGHT / 2),
447
+ exon_end - exon_start,
448
+ EXON_HEIGHT,
449
+ facecolor=gene_col,
450
+ edgecolor=gene_col,
451
+ linewidth=0.5,
452
+ zorder=2,
453
+ )
454
+ else:
455
+ # No exon data - draw full gene body as rectangle (fallback)
456
+ backend.add_rectangle(
457
+ ax,
458
+ (gene_start, y_gene - EXON_HEIGHT / 2),
459
+ gene_end - gene_start,
460
+ EXON_HEIGHT,
461
+ facecolor=gene_col,
462
+ edgecolor=gene_col,
463
+ linewidth=0.5,
464
+ zorder=2,
465
+ )
466
+
467
+ # Add strand direction triangles (tip, center, tail)
468
+ if "strand" in gene.index:
469
+ strand = gene["strand"]
470
+ arrow_dir = 1 if strand == "+" else -1
471
+
472
+ # Triangle dimensions
473
+ tri_height = EXON_HEIGHT * 0.35
474
+ tri_width = region_width * 0.006
475
+
476
+ # Arrow positions: front, middle, back (tip positions)
477
+ tip_offset = tri_width / 2 # Tiny offset to keep tip inside gene
478
+ tail_offset = tri_width * 1.5 # Offset for tail arrow from gene start/end
479
+ gene_center = (gene_start + gene_end) / 2
480
+ if arrow_dir == 1: # Forward strand
481
+ arrow_tip_positions = [
482
+ gene_start + tail_offset, # Tail (tip inside gene)
483
+ gene_center + tri_width / 2, # Middle (arrow center at gene center)
484
+ gene_end - tip_offset, # Tip (near gene end)
485
+ ]
486
+ arrow_color = "#000000" # Black for forward
487
+ else: # Reverse strand
488
+ arrow_tip_positions = [
489
+ gene_end - tail_offset, # Tail (tip inside gene)
490
+ gene_center - tri_width / 2, # Middle (arrow center at gene center)
491
+ gene_start + tip_offset, # Tip (near gene start)
492
+ ]
493
+ arrow_color = "#333333" # Dark grey for reverse
494
+
495
+ for tip_x in arrow_tip_positions:
496
+ if arrow_dir == 1:
497
+ base_x = tip_x - tri_width
498
+ tri_points = [
499
+ [tip_x, y_gene], # Tip pointing right
500
+ [base_x, y_gene + tri_height],
501
+ [base_x, y_gene - tri_height],
502
+ ]
503
+ else:
504
+ base_x = tip_x + tri_width
505
+ tri_points = [
506
+ [tip_x, y_gene], # Tip pointing left
507
+ [base_x, y_gene + tri_height],
508
+ [base_x, y_gene - tri_height],
509
+ ]
510
+
511
+ backend.add_polygon(
512
+ ax,
513
+ tri_points,
514
+ facecolor=arrow_color,
515
+ edgecolor=arrow_color,
516
+ linewidth=0.5,
517
+ zorder=5,
518
+ )
519
+
520
+ # Add gene name label in the gap above gene
521
+ if gene_name:
522
+ label_pos = (gene_start + gene_end) / 2
523
+ backend.add_text(
524
+ ax,
525
+ label_pos,
526
+ y_label,
527
+ gene_name,
528
+ fontsize=6,
529
+ ha="center",
530
+ va="bottom",
531
+ color="#000000",
532
+ )