pylocuszoom 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylocuszoom/colors.py CHANGED
@@ -239,3 +239,44 @@ def get_credible_set_color_palette(n_sets: int = 10) -> dict[int, str]:
239
239
  return {
240
240
  i + 1: CREDIBLE_SET_COLORS[i % len(CREDIBLE_SET_COLORS)] for i in range(n_sets)
241
241
  }
242
+
243
+
244
+ # PheWAS category colors - distinct colors for phenotype categories
245
+ PHEWAS_CATEGORY_COLORS: List[str] = [
246
+ "#E41A1C", # red
247
+ "#377EB8", # blue
248
+ "#4DAF4A", # green
249
+ "#984EA3", # purple
250
+ "#FF7F00", # orange
251
+ "#FFFF33", # yellow
252
+ "#A65628", # brown
253
+ "#F781BF", # pink
254
+ "#999999", # grey
255
+ "#66C2A5", # teal
256
+ "#FC8D62", # salmon
257
+ "#8DA0CB", # periwinkle
258
+ ]
259
+
260
+
261
+ def get_phewas_category_color(category_idx: int) -> str:
262
+ """Get color for a PheWAS category by index.
263
+
264
+ Args:
265
+ category_idx: Zero-indexed category number.
266
+
267
+ Returns:
268
+ Hex color code string.
269
+ """
270
+ return PHEWAS_CATEGORY_COLORS[category_idx % len(PHEWAS_CATEGORY_COLORS)]
271
+
272
+
273
+ def get_phewas_category_palette(categories: List[str]) -> dict[str, str]:
274
+ """Get color palette mapping category names to colors.
275
+
276
+ Args:
277
+ categories: List of unique category names.
278
+
279
+ Returns:
280
+ Dictionary mapping category names to hex colors.
281
+ """
282
+ return {cat: get_phewas_category_color(i) for i, cat in enumerate(categories)}
pylocuszoom/ensembl.py ADDED
@@ -0,0 +1,476 @@
1
+ # src/pylocuszoom/ensembl.py
2
+ """Ensembl REST API integration for reference data fetching.
3
+
4
+ Provides functions to fetch gene and exon annotations from the Ensembl REST API
5
+ (https://rest.ensembl.org) for any species.
6
+
7
+ Note: Recombination rates are NOT available from Ensembl for most species.
8
+ Use species-specific recombination maps instead (see recombination.py).
9
+ """
10
+
11
+ import hashlib
12
+ import os
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+
17
+ import pandas as pd
18
+ import requests
19
+
20
+ from .logging import logger
21
+ from .utils import ValidationError
22
+
23
+ # Ensembl API limits regions to 5Mb
24
+ ENSEMBL_MAX_REGION_SIZE = 5_000_000
25
+
26
+ # Species name aliases -> Ensembl species names
27
+ SPECIES_ALIASES: dict[str, str] = {
28
+ # Canine
29
+ "canine": "canis_lupus_familiaris",
30
+ "dog": "canis_lupus_familiaris",
31
+ "canis_familiaris": "canis_lupus_familiaris",
32
+ # Feline
33
+ "feline": "felis_catus",
34
+ "cat": "felis_catus",
35
+ # Human
36
+ "human": "homo_sapiens",
37
+ # Mouse
38
+ "mouse": "mus_musculus",
39
+ # Rat
40
+ "rat": "rattus_norvegicus",
41
+ }
42
+
43
+
44
+ ENSEMBL_REST_URL = "https://rest.ensembl.org"
45
+ ENSEMBL_REQUEST_TIMEOUT = 30 # seconds
46
+ ENSEMBL_MAX_RETRIES = 3
47
+ ENSEMBL_RETRY_DELAY = 1.0 # seconds, doubles on each retry
48
+
49
+
50
+ def _normalize_chrom(chrom: str | int) -> str:
51
+ """Normalize chromosome name by removing 'chr' prefix."""
52
+ return str(chrom).replace("chr", "")
53
+
54
+
55
+ def _validate_region_size(start: int, end: int, context: str) -> None:
56
+ """Validate region size is within Ensembl API limits.
57
+
58
+ Args:
59
+ start: Region start position.
60
+ end: Region end position.
61
+ context: Context for error message (e.g., "genes_df", "exons_df").
62
+
63
+ Raises:
64
+ ValidationError: If region exceeds 5Mb limit.
65
+ """
66
+ region_size = end - start
67
+ if region_size > ENSEMBL_MAX_REGION_SIZE:
68
+ raise ValidationError(
69
+ f"Region size {region_size:,} bp exceeds Ensembl API limit of 5Mb. "
70
+ f"Please use a smaller region or provide {context} directly."
71
+ )
72
+
73
+
74
+ def get_ensembl_species_name(species: str) -> str:
75
+ """Convert species alias to Ensembl species name.
76
+
77
+ Args:
78
+ species: Species name or alias (e.g., "canine", "dog", "human").
79
+
80
+ Returns:
81
+ Ensembl-compatible species name (e.g., "canis_lupus_familiaris").
82
+ """
83
+ return SPECIES_ALIASES.get(species.lower(), species.lower())
84
+
85
+
86
+ def get_ensembl_cache_dir() -> Path:
87
+ """Get the cache directory for Ensembl data.
88
+
89
+ Uses same base location as recombination maps: ~/.cache/snp-scope-plot/ensembl
90
+
91
+ Returns:
92
+ Path to cache directory (created if doesn't exist).
93
+ """
94
+ if sys.platform == "darwin":
95
+ base = Path.home() / ".cache"
96
+ elif sys.platform == "win32":
97
+ base = Path(os.environ.get("LOCALAPPDATA", Path.home() / "AppData" / "Local"))
98
+ else:
99
+ base = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache"))
100
+
101
+ cache_dir = base / "snp-scope-plot" / "ensembl"
102
+ cache_dir.mkdir(parents=True, exist_ok=True)
103
+ return cache_dir
104
+
105
+
106
+ def _cache_key(species: str, chrom: str, start: int, end: int) -> str:
107
+ """Generate cache key for a region."""
108
+ key_str = f"{species}_{chrom}_{start}_{end}"
109
+ return hashlib.md5(key_str.encode()).hexdigest()[:16]
110
+
111
+
112
+ def get_cached_genes(
113
+ cache_dir: Path,
114
+ species: str,
115
+ chrom: str | int,
116
+ start: int,
117
+ end: int,
118
+ ) -> pd.DataFrame | None:
119
+ """Load cached genes if available.
120
+
121
+ Args:
122
+ cache_dir: Cache directory path.
123
+ species: Species name or alias.
124
+ chrom: Chromosome name or number.
125
+ start: Region start position.
126
+ end: Region end position.
127
+
128
+ Returns:
129
+ DataFrame if cache hit, None if cache miss.
130
+ """
131
+ ensembl_species = get_ensembl_species_name(species)
132
+ chrom_str = _normalize_chrom(chrom)
133
+ cache_key = _cache_key(ensembl_species, chrom_str, start, end)
134
+
135
+ species_dir = cache_dir / ensembl_species
136
+ cache_file = species_dir / f"genes_{cache_key}.csv"
137
+
138
+ if not cache_file.exists():
139
+ return None
140
+
141
+ logger.debug(f"Cache hit: {cache_file}")
142
+ return pd.read_csv(cache_file)
143
+
144
+
145
+ def save_cached_genes(
146
+ df: pd.DataFrame,
147
+ cache_dir: Path,
148
+ species: str,
149
+ chrom: str | int,
150
+ start: int,
151
+ end: int,
152
+ ) -> None:
153
+ """Save genes to cache as CSV.
154
+
155
+ Args:
156
+ df: DataFrame with gene annotations to cache.
157
+ cache_dir: Cache directory path.
158
+ species: Species name or alias.
159
+ chrom: Chromosome name or number.
160
+ start: Region start position.
161
+ end: Region end position.
162
+ """
163
+ ensembl_species = get_ensembl_species_name(species)
164
+ chrom_str = _normalize_chrom(chrom)
165
+ cache_key = _cache_key(ensembl_species, chrom_str, start, end)
166
+
167
+ species_dir = cache_dir / ensembl_species
168
+ species_dir.mkdir(parents=True, exist_ok=True)
169
+
170
+ cache_file = species_dir / f"genes_{cache_key}.csv"
171
+ df.to_csv(cache_file, index=False)
172
+ logger.debug(f"Cached genes to: {cache_file}")
173
+
174
+
175
+ def _make_ensembl_request(
176
+ url: str,
177
+ params: dict,
178
+ max_retries: int = ENSEMBL_MAX_RETRIES,
179
+ raise_on_error: bool = False,
180
+ ) -> list | None:
181
+ """Make request to Ensembl API with retry logic.
182
+
183
+ Args:
184
+ url: API endpoint URL.
185
+ params: Query parameters.
186
+ max_retries: Maximum retry attempts for retryable errors.
187
+ raise_on_error: If True, raise exception on error instead of returning None.
188
+
189
+ Returns:
190
+ JSON response as list, or None on non-retryable error.
191
+
192
+ Raises:
193
+ ValidationError: If raise_on_error=True and request fails.
194
+ """
195
+ delay = ENSEMBL_RETRY_DELAY
196
+
197
+ for attempt in range(max_retries):
198
+ try:
199
+ response = requests.get(
200
+ url,
201
+ params=params,
202
+ headers={"Content-Type": "application/json"},
203
+ timeout=ENSEMBL_REQUEST_TIMEOUT,
204
+ )
205
+ except requests.RequestException as e:
206
+ logger.warning(f"Ensembl API request failed (attempt {attempt + 1}): {e}")
207
+ if attempt < max_retries - 1:
208
+ time.sleep(delay)
209
+ delay *= 2
210
+ continue
211
+ if raise_on_error:
212
+ raise ValidationError(
213
+ f"Ensembl API request failed after {max_retries} attempts: {e}"
214
+ )
215
+ return None
216
+
217
+ # Success
218
+ if response.ok:
219
+ return response.json()
220
+
221
+ # Retryable errors (429 rate limit, 503 service unavailable)
222
+ if response.status_code in (429, 503) and attempt < max_retries - 1:
223
+ logger.warning(
224
+ f"Ensembl API returned {response.status_code} "
225
+ f"(attempt {attempt + 1}), retrying..."
226
+ )
227
+ time.sleep(delay)
228
+ delay *= 2
229
+ continue
230
+
231
+ # Non-retryable error
232
+ error_msg = f"Ensembl API error {response.status_code}: {response.text[:200]}"
233
+ logger.warning(error_msg)
234
+ if raise_on_error:
235
+ raise ValidationError(error_msg)
236
+ return None
237
+
238
+ return None
239
+
240
+
241
+ def fetch_genes_from_ensembl(
242
+ species: str,
243
+ chrom: str | int,
244
+ start: int,
245
+ end: int,
246
+ biotype: str = "protein_coding",
247
+ raise_on_error: bool = False,
248
+ ) -> pd.DataFrame:
249
+ """Fetch gene annotations from Ensembl REST API.
250
+
251
+ Args:
252
+ species: Species name or alias.
253
+ chrom: Chromosome name or number.
254
+ start: Region start position (1-based).
255
+ end: Region end position (1-based).
256
+ biotype: Gene biotype filter (default: protein_coding).
257
+ raise_on_error: If True, raise ValidationError on API errors.
258
+
259
+ Returns:
260
+ DataFrame with columns: chr, start, end, gene_name, strand, gene_id, biotype.
261
+ Returns empty DataFrame on API error (unless raise_on_error=True).
262
+
263
+ Raises:
264
+ ValidationError: If region > 5Mb or if raise_on_error=True and API fails.
265
+ """
266
+ _validate_region_size(start, end, "genes_df")
267
+
268
+ ensembl_species = get_ensembl_species_name(species)
269
+ chrom_str = _normalize_chrom(chrom)
270
+
271
+ # Build region string
272
+ region = f"{chrom_str}:{start}-{end}"
273
+
274
+ # Build API URL
275
+ url = f"{ENSEMBL_REST_URL}/overlap/region/{ensembl_species}/{region}"
276
+ params = {"feature": "gene", "biotype": biotype}
277
+
278
+ logger.debug(f"Fetching genes from Ensembl: {url}")
279
+
280
+ data = _make_ensembl_request(url, params, raise_on_error=raise_on_error)
281
+
282
+ if data is None:
283
+ return pd.DataFrame()
284
+
285
+ if not data:
286
+ logger.debug(f"No genes found in region {region}")
287
+ return pd.DataFrame()
288
+
289
+ # Convert to DataFrame
290
+ records = []
291
+ for gene in data:
292
+ if gene.get("feature_type") != "gene":
293
+ continue
294
+ records.append(
295
+ {
296
+ "chr": str(gene.get("seq_region_name", chrom_str)),
297
+ "start": gene.get("start"),
298
+ "end": gene.get("end"),
299
+ "gene_name": gene.get("external_name", gene.get("id", "")),
300
+ "strand": "+" if gene.get("strand", 1) == 1 else "-",
301
+ "gene_id": gene.get("id", ""),
302
+ "biotype": gene.get("biotype", ""),
303
+ }
304
+ )
305
+
306
+ df = pd.DataFrame(records)
307
+ logger.debug(f"Fetched {len(df)} genes from Ensembl")
308
+ return df
309
+
310
+
311
+ def fetch_exons_from_ensembl(
312
+ species: str,
313
+ chrom: str | int,
314
+ start: int,
315
+ end: int,
316
+ raise_on_error: bool = False,
317
+ ) -> pd.DataFrame:
318
+ """Fetch exon annotations from Ensembl REST API.
319
+
320
+ Args:
321
+ species: Species name or alias.
322
+ chrom: Chromosome name or number.
323
+ start: Region start position (1-based).
324
+ end: Region end position (1-based).
325
+ raise_on_error: If True, raise ValidationError on API errors.
326
+
327
+ Returns:
328
+ DataFrame with columns: chr, start, end, gene_name, exon_id, transcript_id.
329
+ Returns empty DataFrame on API error (unless raise_on_error=True).
330
+
331
+ Raises:
332
+ ValidationError: If region > 5Mb or if raise_on_error=True and API fails.
333
+ """
334
+ _validate_region_size(start, end, "exons_df")
335
+
336
+ ensembl_species = get_ensembl_species_name(species)
337
+ chrom_str = _normalize_chrom(chrom)
338
+ region = f"{chrom_str}:{start}-{end}"
339
+
340
+ url = f"{ENSEMBL_REST_URL}/overlap/region/{ensembl_species}/{region}"
341
+ params = {"feature": "exon"}
342
+
343
+ logger.debug(f"Fetching exons from Ensembl: {url}")
344
+
345
+ data = _make_ensembl_request(url, params, raise_on_error=raise_on_error)
346
+
347
+ if data is None:
348
+ return pd.DataFrame()
349
+
350
+ if not data:
351
+ return pd.DataFrame()
352
+
353
+ records = []
354
+ for exon in data:
355
+ if exon.get("feature_type") != "exon":
356
+ continue
357
+ records.append(
358
+ {
359
+ "chr": str(exon.get("seq_region_name", chrom_str)),
360
+ "start": exon.get("start"),
361
+ "end": exon.get("end"),
362
+ "gene_name": "", # Exon endpoint doesn't include gene name
363
+ "exon_id": exon.get("id", ""),
364
+ "transcript_id": exon.get("Parent", ""),
365
+ }
366
+ )
367
+
368
+ df = pd.DataFrame(records)
369
+ logger.debug(f"Fetched {len(df)} exons from Ensembl")
370
+ return df
371
+
372
+
373
+ def get_genes_for_region(
374
+ species: str,
375
+ chrom: str | int,
376
+ start: int,
377
+ end: int,
378
+ cache_dir: Path | None = None,
379
+ use_cache: bool = True,
380
+ include_exons: bool = False,
381
+ raise_on_error: bool = False,
382
+ ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
383
+ """Get gene annotations for a genomic region.
384
+
385
+ Checks cache first, fetches from Ensembl API if not cached.
386
+
387
+ Args:
388
+ species: Species name or alias.
389
+ chrom: Chromosome name or number.
390
+ start: Region start position (1-based).
391
+ end: Region end position (1-based).
392
+ cache_dir: Cache directory (uses default if None).
393
+ use_cache: Whether to use disk cache.
394
+ include_exons: If True, also fetch exons and return tuple (genes_df, exons_df).
395
+ raise_on_error: If True, raise ValidationError on API errors.
396
+
397
+ Returns:
398
+ If include_exons=False: DataFrame with gene annotations.
399
+ If include_exons=True: Tuple of (genes_df, exons_df).
400
+
401
+ Raises:
402
+ ValidationError: If region > 5Mb or if raise_on_error=True and API fails.
403
+
404
+ Note:
405
+ Gene annotations are cached to disk. Exons are fetched from the API
406
+ on each call when include_exons=True (not cached separately).
407
+ """
408
+ if cache_dir is None:
409
+ cache_dir = get_ensembl_cache_dir()
410
+
411
+ chrom_str = _normalize_chrom(chrom)
412
+
413
+ # Check cache first
414
+ if use_cache:
415
+ cached = get_cached_genes(cache_dir, species, chrom_str, start, end)
416
+ if cached is not None:
417
+ if include_exons:
418
+ # Exons not cached separately (yet)
419
+ exons_df = fetch_exons_from_ensembl(
420
+ species, chrom_str, start, end, raise_on_error=raise_on_error
421
+ )
422
+ return cached, exons_df
423
+ return cached
424
+
425
+ # Fetch from Ensembl API
426
+ genes_df = fetch_genes_from_ensembl(
427
+ species, chrom_str, start, end, raise_on_error=raise_on_error
428
+ )
429
+
430
+ # Cache the result (even if empty, to avoid repeated API calls for gene-sparse regions)
431
+ if use_cache:
432
+ save_cached_genes(genes_df, cache_dir, species, chrom_str, start, end)
433
+
434
+ if include_exons:
435
+ exons_df = fetch_exons_from_ensembl(
436
+ species, chrom_str, start, end, raise_on_error=raise_on_error
437
+ )
438
+ return genes_df, exons_df
439
+
440
+ return genes_df
441
+
442
+
443
+ def clear_ensembl_cache(
444
+ cache_dir: Path | None = None,
445
+ species: str | None = None,
446
+ ) -> int:
447
+ """Clear cached Ensembl data.
448
+
449
+ Args:
450
+ cache_dir: Cache directory (uses default if None).
451
+ species: If provided, only clear cache for this species.
452
+
453
+ Returns:
454
+ Number of files deleted.
455
+ """
456
+ if cache_dir is None:
457
+ cache_dir = get_ensembl_cache_dir()
458
+
459
+ deleted = 0
460
+
461
+ if species:
462
+ # Clear only specific species
463
+ ensembl_species = get_ensembl_species_name(species)
464
+ species_dir = cache_dir / ensembl_species
465
+ if species_dir.exists():
466
+ for cache_file in species_dir.glob("*.csv"):
467
+ cache_file.unlink()
468
+ deleted += 1
469
+ else:
470
+ # Clear all species
471
+ for cache_file in cache_dir.glob("**/*.csv"):
472
+ cache_file.unlink()
473
+ deleted += 1
474
+
475
+ logger.info(f"Cleared {deleted} cached Ensembl files from {cache_dir}")
476
+ return deleted
pylocuszoom/eqtl.py CHANGED
@@ -10,6 +10,8 @@ import numpy as np
10
10
  import pandas as pd
11
11
 
12
12
  from .logging import logger
13
+ from .utils import ValidationError, filter_by_region
14
+ from .validation import DataFrameValidator
13
15
 
14
16
  REQUIRED_EQTL_COLS = ["pos", "p_value"]
15
17
  OPTIONAL_EQTL_COLS = ["gene", "effect_size", "rs", "se"]
@@ -36,17 +38,14 @@ def validate_eqtl_df(
36
38
  Raises:
37
39
  EQTLValidationError: If required columns are missing.
38
40
  """
39
- missing = []
40
- if pos_col not in df.columns:
41
- missing.append(pos_col)
42
- if p_col not in df.columns:
43
- missing.append(p_col)
44
-
45
- if missing:
46
- raise EQTLValidationError(
47
- f"eQTL DataFrame missing required columns: {missing}. "
48
- f"Required: {pos_col} (position), {p_col} (p-value)"
41
+ try:
42
+ (
43
+ DataFrameValidator(df, "eQTL DataFrame")
44
+ .require_columns([pos_col, p_col])
45
+ .validate()
49
46
  )
47
+ except ValidationError as e:
48
+ raise EQTLValidationError(str(e)) from e
50
49
 
51
50
 
52
51
  def filter_eqtl_by_gene(
@@ -99,15 +98,12 @@ def filter_eqtl_by_region(
99
98
  Returns:
100
99
  Filtered DataFrame containing only eQTLs in the region.
101
100
  """
102
- mask = (df[pos_col] >= start) & (df[pos_col] <= end)
103
-
104
- # Filter by chromosome if column exists
105
- if chrom_col and chrom_col in df.columns:
106
- chrom_str = str(chrom).replace("chr", "")
107
- df_chrom = df[chrom_col].astype(str).str.replace("chr", "", regex=False)
108
- mask = mask & (df_chrom == chrom_str)
109
-
110
- filtered = df[mask].copy()
101
+ filtered = filter_by_region(
102
+ df,
103
+ region=(chrom, start, end),
104
+ chrom_col=chrom_col or "",
105
+ pos_col=pos_col,
106
+ )
111
107
  logger.debug(
112
108
  f"Filtered eQTL data to {len(filtered)} variants in region chr{chrom}:{start}-{end}"
113
109
  )
@@ -9,6 +9,8 @@ from typing import List, Optional
9
9
  import pandas as pd
10
10
 
11
11
  from .logging import logger
12
+ from .utils import ValidationError, filter_by_region
13
+ from .validation import DataFrameValidator
12
14
 
13
15
  # Required columns for fine-mapping data
14
16
  REQUIRED_FINEMAPPING_COLS = ["pos", "pip"]
@@ -36,24 +38,16 @@ def validate_finemapping_df(
36
38
  Raises:
37
39
  FinemappingValidationError: If required columns are missing.
38
40
  """
39
- missing = []
40
- if pos_col not in df.columns:
41
- missing.append(pos_col)
42
- if pip_col not in df.columns:
43
- missing.append(pip_col)
44
-
45
- if missing:
46
- raise FinemappingValidationError(
47
- f"Fine-mapping DataFrame missing required columns: {missing}. "
48
- f"Required: {pos_col} (position), {pip_col} (posterior inclusion probability)"
49
- )
50
-
51
- # Validate PIP values are in [0, 1]
52
- if not df[pip_col].between(0, 1).all():
53
- invalid_count = (~df[pip_col].between(0, 1)).sum()
54
- raise FinemappingValidationError(
55
- f"PIP values must be between 0 and 1. Found {invalid_count} invalid values."
41
+ try:
42
+ (
43
+ DataFrameValidator(df, "Fine-mapping DataFrame")
44
+ .require_columns([pos_col, pip_col])
45
+ .require_numeric([pip_col])
46
+ .require_range(pip_col, min_val=0, max_val=1)
47
+ .validate()
56
48
  )
49
+ except ValidationError as e:
50
+ raise FinemappingValidationError(str(e)) from e
57
51
 
58
52
 
59
53
  def filter_finemapping_by_region(
@@ -77,15 +71,12 @@ def filter_finemapping_by_region(
77
71
  Returns:
78
72
  Filtered DataFrame containing only variants in the region.
79
73
  """
80
- mask = (df[pos_col] >= start) & (df[pos_col] <= end)
81
-
82
- # Filter by chromosome if column exists
83
- if chrom_col and chrom_col in df.columns:
84
- chrom_str = str(chrom).replace("chr", "")
85
- df_chrom = df[chrom_col].astype(str).str.replace("chr", "", regex=False)
86
- mask = mask & (df_chrom == chrom_str)
87
-
88
- filtered = df[mask].copy()
74
+ filtered = filter_by_region(
75
+ df,
76
+ region=(chrom, start, end),
77
+ chrom_col=chrom_col or "",
78
+ pos_col=pos_col,
79
+ )
89
80
  logger.debug(
90
81
  f"Filtered fine-mapping data to {len(filtered)} variants in region "
91
82
  f"chr{chrom}:{start}-{end}"
pylocuszoom/forest.py ADDED
@@ -0,0 +1,35 @@
1
+ """Forest plot data validation and preparation.
2
+
3
+ Validates and prepares meta-analysis/forest plot data for visualization.
4
+ """
5
+
6
+ import pandas as pd
7
+
8
+ from .validation import DataFrameValidator
9
+
10
+
11
+ def validate_forest_df(
12
+ df: pd.DataFrame,
13
+ study_col: str = "study",
14
+ effect_col: str = "effect",
15
+ ci_lower_col: str = "ci_lower",
16
+ ci_upper_col: str = "ci_upper",
17
+ ) -> None:
18
+ """Validate forest plot DataFrame has required columns and types.
19
+
20
+ Args:
21
+ df: Forest plot data DataFrame.
22
+ study_col: Column name for study/phenotype names.
23
+ effect_col: Column name for effect sizes (beta, OR, HR).
24
+ ci_lower_col: Column name for lower confidence interval.
25
+ ci_upper_col: Column name for upper confidence interval.
26
+
27
+ Raises:
28
+ ValidationError: If required columns are missing or have invalid types.
29
+ """
30
+ (
31
+ DataFrameValidator(df, "Forest plot DataFrame")
32
+ .require_columns([study_col, effect_col, ci_lower_col, ci_upper_col])
33
+ .require_numeric([effect_col, ci_lower_col, ci_upper_col])
34
+ .validate()
35
+ )