pylocuszoom 0.6.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylocuszoom/ensembl.py ADDED
@@ -0,0 +1,476 @@
1
+ # src/pylocuszoom/ensembl.py
2
+ """Ensembl REST API integration for reference data fetching.
3
+
4
+ Provides functions to fetch gene and exon annotations from the Ensembl REST API
5
+ (https://rest.ensembl.org) for any species.
6
+
7
+ Note: Recombination rates are NOT available from Ensembl for most species.
8
+ Use species-specific recombination maps instead (see recombination.py).
9
+ """
10
+
11
+ import hashlib
12
+ import os
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+
17
+ import pandas as pd
18
+ import requests
19
+
20
+ from .logging import logger
21
+ from .utils import ValidationError
22
+
23
+ # Ensembl API limits regions to 5Mb
24
+ ENSEMBL_MAX_REGION_SIZE = 5_000_000
25
+
26
+ # Species name aliases -> Ensembl species names
27
+ SPECIES_ALIASES: dict[str, str] = {
28
+ # Canine
29
+ "canine": "canis_lupus_familiaris",
30
+ "dog": "canis_lupus_familiaris",
31
+ "canis_familiaris": "canis_lupus_familiaris",
32
+ # Feline
33
+ "feline": "felis_catus",
34
+ "cat": "felis_catus",
35
+ # Human
36
+ "human": "homo_sapiens",
37
+ # Mouse
38
+ "mouse": "mus_musculus",
39
+ # Rat
40
+ "rat": "rattus_norvegicus",
41
+ }
42
+
43
+
44
+ ENSEMBL_REST_URL = "https://rest.ensembl.org"
45
+ ENSEMBL_REQUEST_TIMEOUT = 30 # seconds
46
+ ENSEMBL_MAX_RETRIES = 3
47
+ ENSEMBL_RETRY_DELAY = 1.0 # seconds, doubles on each retry
48
+
49
+
50
+ def _normalize_chrom(chrom: str | int) -> str:
51
+ """Normalize chromosome name by removing 'chr' prefix."""
52
+ return str(chrom).replace("chr", "")
53
+
54
+
55
+ def _validate_region_size(start: int, end: int, context: str) -> None:
56
+ """Validate region size is within Ensembl API limits.
57
+
58
+ Args:
59
+ start: Region start position.
60
+ end: Region end position.
61
+ context: Context for error message (e.g., "genes_df", "exons_df").
62
+
63
+ Raises:
64
+ ValidationError: If region exceeds 5Mb limit.
65
+ """
66
+ region_size = end - start
67
+ if region_size > ENSEMBL_MAX_REGION_SIZE:
68
+ raise ValidationError(
69
+ f"Region size {region_size:,} bp exceeds Ensembl API limit of 5Mb. "
70
+ f"Please use a smaller region or provide {context} directly."
71
+ )
72
+
73
+
74
+ def get_ensembl_species_name(species: str) -> str:
75
+ """Convert species alias to Ensembl species name.
76
+
77
+ Args:
78
+ species: Species name or alias (e.g., "canine", "dog", "human").
79
+
80
+ Returns:
81
+ Ensembl-compatible species name (e.g., "canis_lupus_familiaris").
82
+ """
83
+ return SPECIES_ALIASES.get(species.lower(), species.lower())
84
+
85
+
86
+ def get_ensembl_cache_dir() -> Path:
87
+ """Get the cache directory for Ensembl data.
88
+
89
+ Uses same base location as recombination maps: ~/.cache/snp-scope-plot/ensembl
90
+
91
+ Returns:
92
+ Path to cache directory (created if doesn't exist).
93
+ """
94
+ if sys.platform == "darwin":
95
+ base = Path.home() / ".cache"
96
+ elif sys.platform == "win32":
97
+ base = Path(os.environ.get("LOCALAPPDATA", Path.home() / "AppData" / "Local"))
98
+ else:
99
+ base = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache"))
100
+
101
+ cache_dir = base / "snp-scope-plot" / "ensembl"
102
+ cache_dir.mkdir(parents=True, exist_ok=True)
103
+ return cache_dir
104
+
105
+
106
+ def _cache_key(species: str, chrom: str, start: int, end: int) -> str:
107
+ """Generate cache key for a region."""
108
+ key_str = f"{species}_{chrom}_{start}_{end}"
109
+ return hashlib.md5(key_str.encode()).hexdigest()[:16]
110
+
111
+
112
+ def get_cached_genes(
113
+ cache_dir: Path,
114
+ species: str,
115
+ chrom: str | int,
116
+ start: int,
117
+ end: int,
118
+ ) -> pd.DataFrame | None:
119
+ """Load cached genes if available.
120
+
121
+ Args:
122
+ cache_dir: Cache directory path.
123
+ species: Species name or alias.
124
+ chrom: Chromosome name or number.
125
+ start: Region start position.
126
+ end: Region end position.
127
+
128
+ Returns:
129
+ DataFrame if cache hit, None if cache miss.
130
+ """
131
+ ensembl_species = get_ensembl_species_name(species)
132
+ chrom_str = _normalize_chrom(chrom)
133
+ cache_key = _cache_key(ensembl_species, chrom_str, start, end)
134
+
135
+ species_dir = cache_dir / ensembl_species
136
+ cache_file = species_dir / f"genes_{cache_key}.csv"
137
+
138
+ if not cache_file.exists():
139
+ return None
140
+
141
+ logger.debug(f"Cache hit: {cache_file}")
142
+ return pd.read_csv(cache_file)
143
+
144
+
145
+ def save_cached_genes(
146
+ df: pd.DataFrame,
147
+ cache_dir: Path,
148
+ species: str,
149
+ chrom: str | int,
150
+ start: int,
151
+ end: int,
152
+ ) -> None:
153
+ """Save genes to cache as CSV.
154
+
155
+ Args:
156
+ df: DataFrame with gene annotations to cache.
157
+ cache_dir: Cache directory path.
158
+ species: Species name or alias.
159
+ chrom: Chromosome name or number.
160
+ start: Region start position.
161
+ end: Region end position.
162
+ """
163
+ ensembl_species = get_ensembl_species_name(species)
164
+ chrom_str = _normalize_chrom(chrom)
165
+ cache_key = _cache_key(ensembl_species, chrom_str, start, end)
166
+
167
+ species_dir = cache_dir / ensembl_species
168
+ species_dir.mkdir(parents=True, exist_ok=True)
169
+
170
+ cache_file = species_dir / f"genes_{cache_key}.csv"
171
+ df.to_csv(cache_file, index=False)
172
+ logger.debug(f"Cached genes to: {cache_file}")
173
+
174
+
175
+ def _make_ensembl_request(
176
+ url: str,
177
+ params: dict,
178
+ max_retries: int = ENSEMBL_MAX_RETRIES,
179
+ raise_on_error: bool = False,
180
+ ) -> list | None:
181
+ """Make request to Ensembl API with retry logic.
182
+
183
+ Args:
184
+ url: API endpoint URL.
185
+ params: Query parameters.
186
+ max_retries: Maximum retry attempts for retryable errors.
187
+ raise_on_error: If True, raise exception on error instead of returning None.
188
+
189
+ Returns:
190
+ JSON response as list, or None on non-retryable error.
191
+
192
+ Raises:
193
+ ValidationError: If raise_on_error=True and request fails.
194
+ """
195
+ delay = ENSEMBL_RETRY_DELAY
196
+
197
+ for attempt in range(max_retries):
198
+ try:
199
+ response = requests.get(
200
+ url,
201
+ params=params,
202
+ headers={"Content-Type": "application/json"},
203
+ timeout=ENSEMBL_REQUEST_TIMEOUT,
204
+ )
205
+ except requests.RequestException as e:
206
+ logger.warning(f"Ensembl API request failed (attempt {attempt + 1}): {e}")
207
+ if attempt < max_retries - 1:
208
+ time.sleep(delay)
209
+ delay *= 2
210
+ continue
211
+ if raise_on_error:
212
+ raise ValidationError(
213
+ f"Ensembl API request failed after {max_retries} attempts: {e}"
214
+ )
215
+ return None
216
+
217
+ # Success
218
+ if response.ok:
219
+ return response.json()
220
+
221
+ # Retryable errors (429 rate limit, 503 service unavailable)
222
+ if response.status_code in (429, 503) and attempt < max_retries - 1:
223
+ logger.warning(
224
+ f"Ensembl API returned {response.status_code} "
225
+ f"(attempt {attempt + 1}), retrying..."
226
+ )
227
+ time.sleep(delay)
228
+ delay *= 2
229
+ continue
230
+
231
+ # Non-retryable error
232
+ error_msg = f"Ensembl API error {response.status_code}: {response.text[:200]}"
233
+ logger.warning(error_msg)
234
+ if raise_on_error:
235
+ raise ValidationError(error_msg)
236
+ return None
237
+
238
+ return None
239
+
240
+
241
+ def fetch_genes_from_ensembl(
242
+ species: str,
243
+ chrom: str | int,
244
+ start: int,
245
+ end: int,
246
+ biotype: str = "protein_coding",
247
+ raise_on_error: bool = False,
248
+ ) -> pd.DataFrame:
249
+ """Fetch gene annotations from Ensembl REST API.
250
+
251
+ Args:
252
+ species: Species name or alias.
253
+ chrom: Chromosome name or number.
254
+ start: Region start position (1-based).
255
+ end: Region end position (1-based).
256
+ biotype: Gene biotype filter (default: protein_coding).
257
+ raise_on_error: If True, raise ValidationError on API errors.
258
+
259
+ Returns:
260
+ DataFrame with columns: chr, start, end, gene_name, strand, gene_id, biotype.
261
+ Returns empty DataFrame on API error (unless raise_on_error=True).
262
+
263
+ Raises:
264
+ ValidationError: If region > 5Mb or if raise_on_error=True and API fails.
265
+ """
266
+ _validate_region_size(start, end, "genes_df")
267
+
268
+ ensembl_species = get_ensembl_species_name(species)
269
+ chrom_str = _normalize_chrom(chrom)
270
+
271
+ # Build region string
272
+ region = f"{chrom_str}:{start}-{end}"
273
+
274
+ # Build API URL
275
+ url = f"{ENSEMBL_REST_URL}/overlap/region/{ensembl_species}/{region}"
276
+ params = {"feature": "gene", "biotype": biotype}
277
+
278
+ logger.debug(f"Fetching genes from Ensembl: {url}")
279
+
280
+ data = _make_ensembl_request(url, params, raise_on_error=raise_on_error)
281
+
282
+ if data is None:
283
+ return pd.DataFrame()
284
+
285
+ if not data:
286
+ logger.debug(f"No genes found in region {region}")
287
+ return pd.DataFrame()
288
+
289
+ # Convert to DataFrame
290
+ records = []
291
+ for gene in data:
292
+ if gene.get("feature_type") != "gene":
293
+ continue
294
+ records.append(
295
+ {
296
+ "chr": str(gene.get("seq_region_name", chrom_str)),
297
+ "start": gene.get("start"),
298
+ "end": gene.get("end"),
299
+ "gene_name": gene.get("external_name", gene.get("id", "")),
300
+ "strand": "+" if gene.get("strand", 1) == 1 else "-",
301
+ "gene_id": gene.get("id", ""),
302
+ "biotype": gene.get("biotype", ""),
303
+ }
304
+ )
305
+
306
+ df = pd.DataFrame(records)
307
+ logger.debug(f"Fetched {len(df)} genes from Ensembl")
308
+ return df
309
+
310
+
311
+ def fetch_exons_from_ensembl(
312
+ species: str,
313
+ chrom: str | int,
314
+ start: int,
315
+ end: int,
316
+ raise_on_error: bool = False,
317
+ ) -> pd.DataFrame:
318
+ """Fetch exon annotations from Ensembl REST API.
319
+
320
+ Args:
321
+ species: Species name or alias.
322
+ chrom: Chromosome name or number.
323
+ start: Region start position (1-based).
324
+ end: Region end position (1-based).
325
+ raise_on_error: If True, raise ValidationError on API errors.
326
+
327
+ Returns:
328
+ DataFrame with columns: chr, start, end, gene_name, exon_id, transcript_id.
329
+ Returns empty DataFrame on API error (unless raise_on_error=True).
330
+
331
+ Raises:
332
+ ValidationError: If region > 5Mb or if raise_on_error=True and API fails.
333
+ """
334
+ _validate_region_size(start, end, "exons_df")
335
+
336
+ ensembl_species = get_ensembl_species_name(species)
337
+ chrom_str = _normalize_chrom(chrom)
338
+ region = f"{chrom_str}:{start}-{end}"
339
+
340
+ url = f"{ENSEMBL_REST_URL}/overlap/region/{ensembl_species}/{region}"
341
+ params = {"feature": "exon"}
342
+
343
+ logger.debug(f"Fetching exons from Ensembl: {url}")
344
+
345
+ data = _make_ensembl_request(url, params, raise_on_error=raise_on_error)
346
+
347
+ if data is None:
348
+ return pd.DataFrame()
349
+
350
+ if not data:
351
+ return pd.DataFrame()
352
+
353
+ records = []
354
+ for exon in data:
355
+ if exon.get("feature_type") != "exon":
356
+ continue
357
+ records.append(
358
+ {
359
+ "chr": str(exon.get("seq_region_name", chrom_str)),
360
+ "start": exon.get("start"),
361
+ "end": exon.get("end"),
362
+ "gene_name": "", # Exon endpoint doesn't include gene name
363
+ "exon_id": exon.get("id", ""),
364
+ "transcript_id": exon.get("Parent", ""),
365
+ }
366
+ )
367
+
368
+ df = pd.DataFrame(records)
369
+ logger.debug(f"Fetched {len(df)} exons from Ensembl")
370
+ return df
371
+
372
+
373
+ def get_genes_for_region(
374
+ species: str,
375
+ chrom: str | int,
376
+ start: int,
377
+ end: int,
378
+ cache_dir: Path | None = None,
379
+ use_cache: bool = True,
380
+ include_exons: bool = False,
381
+ raise_on_error: bool = False,
382
+ ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
383
+ """Get gene annotations for a genomic region.
384
+
385
+ Checks cache first, fetches from Ensembl API if not cached.
386
+
387
+ Args:
388
+ species: Species name or alias.
389
+ chrom: Chromosome name or number.
390
+ start: Region start position (1-based).
391
+ end: Region end position (1-based).
392
+ cache_dir: Cache directory (uses default if None).
393
+ use_cache: Whether to use disk cache.
394
+ include_exons: If True, also fetch exons and return tuple (genes_df, exons_df).
395
+ raise_on_error: If True, raise ValidationError on API errors.
396
+
397
+ Returns:
398
+ If include_exons=False: DataFrame with gene annotations.
399
+ If include_exons=True: Tuple of (genes_df, exons_df).
400
+
401
+ Raises:
402
+ ValidationError: If region > 5Mb or if raise_on_error=True and API fails.
403
+
404
+ Note:
405
+ Gene annotations are cached to disk. Exons are fetched from the API
406
+ on each call when include_exons=True (not cached separately).
407
+ """
408
+ if cache_dir is None:
409
+ cache_dir = get_ensembl_cache_dir()
410
+
411
+ chrom_str = _normalize_chrom(chrom)
412
+
413
+ # Check cache first
414
+ if use_cache:
415
+ cached = get_cached_genes(cache_dir, species, chrom_str, start, end)
416
+ if cached is not None:
417
+ if include_exons:
418
+ # Exons not cached separately (yet)
419
+ exons_df = fetch_exons_from_ensembl(
420
+ species, chrom_str, start, end, raise_on_error=raise_on_error
421
+ )
422
+ return cached, exons_df
423
+ return cached
424
+
425
+ # Fetch from Ensembl API
426
+ genes_df = fetch_genes_from_ensembl(
427
+ species, chrom_str, start, end, raise_on_error=raise_on_error
428
+ )
429
+
430
+ # Cache the result (even if empty, to avoid repeated API calls for gene-sparse regions)
431
+ if use_cache:
432
+ save_cached_genes(genes_df, cache_dir, species, chrom_str, start, end)
433
+
434
+ if include_exons:
435
+ exons_df = fetch_exons_from_ensembl(
436
+ species, chrom_str, start, end, raise_on_error=raise_on_error
437
+ )
438
+ return genes_df, exons_df
439
+
440
+ return genes_df
441
+
442
+
443
+ def clear_ensembl_cache(
444
+ cache_dir: Path | None = None,
445
+ species: str | None = None,
446
+ ) -> int:
447
+ """Clear cached Ensembl data.
448
+
449
+ Args:
450
+ cache_dir: Cache directory (uses default if None).
451
+ species: If provided, only clear cache for this species.
452
+
453
+ Returns:
454
+ Number of files deleted.
455
+ """
456
+ if cache_dir is None:
457
+ cache_dir = get_ensembl_cache_dir()
458
+
459
+ deleted = 0
460
+
461
+ if species:
462
+ # Clear only specific species
463
+ ensembl_species = get_ensembl_species_name(species)
464
+ species_dir = cache_dir / ensembl_species
465
+ if species_dir.exists():
466
+ for cache_file in species_dir.glob("*.csv"):
467
+ cache_file.unlink()
468
+ deleted += 1
469
+ else:
470
+ # Clear all species
471
+ for cache_file in cache_dir.glob("**/*.csv"):
472
+ cache_file.unlink()
473
+ deleted += 1
474
+
475
+ logger.info(f"Cleared {deleted} cached Ensembl files from {cache_dir}")
476
+ return deleted
pylocuszoom/eqtl.py CHANGED
@@ -9,18 +9,15 @@ from typing import List, Optional
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
 
12
+ from .exceptions import EQTLValidationError, ValidationError
12
13
  from .logging import logger
14
+ from .utils import filter_by_region
15
+ from .validation import DataFrameValidator
13
16
 
14
17
  REQUIRED_EQTL_COLS = ["pos", "p_value"]
15
18
  OPTIONAL_EQTL_COLS = ["gene", "effect_size", "rs", "se"]
16
19
 
17
20
 
18
- class EQTLValidationError(ValueError):
19
- """Raised when eQTL DataFrame validation fails."""
20
-
21
- pass
22
-
23
-
24
21
  def validate_eqtl_df(
25
22
  df: pd.DataFrame,
26
23
  pos_col: str = "pos",
@@ -36,17 +33,15 @@ def validate_eqtl_df(
36
33
  Raises:
37
34
  EQTLValidationError: If required columns are missing.
38
35
  """
39
- missing = []
40
- if pos_col not in df.columns:
41
- missing.append(pos_col)
42
- if p_col not in df.columns:
43
- missing.append(p_col)
44
-
45
- if missing:
46
- raise EQTLValidationError(
47
- f"eQTL DataFrame missing required columns: {missing}. "
48
- f"Required: {pos_col} (position), {p_col} (p-value)"
36
+ try:
37
+ (
38
+ DataFrameValidator(df, "eQTL DataFrame")
39
+ .require_columns([pos_col, p_col])
40
+ .require_numeric([p_col])
41
+ .validate()
49
42
  )
43
+ except ValidationError as e:
44
+ raise EQTLValidationError(str(e)) from e
50
45
 
51
46
 
52
47
  def filter_eqtl_by_gene(
@@ -99,15 +94,12 @@ def filter_eqtl_by_region(
99
94
  Returns:
100
95
  Filtered DataFrame containing only eQTLs in the region.
101
96
  """
102
- mask = (df[pos_col] >= start) & (df[pos_col] <= end)
103
-
104
- # Filter by chromosome if column exists
105
- if chrom_col and chrom_col in df.columns:
106
- chrom_str = str(chrom).replace("chr", "")
107
- df_chrom = df[chrom_col].astype(str).str.replace("chr", "", regex=False)
108
- mask = mask & (df_chrom == chrom_str)
109
-
110
- filtered = df[mask].copy()
97
+ filtered = filter_by_region(
98
+ df,
99
+ region=(chrom, start, end),
100
+ chrom_col=chrom_col or "",
101
+ pos_col=pos_col,
102
+ )
111
103
  logger.debug(
112
104
  f"Filtered eQTL data to {len(filtered)} variants in region chr{chrom}:{start}-{end}"
113
105
  )
@@ -0,0 +1,33 @@
1
+ """Exception hierarchy for pyLocusZoom.
2
+
3
+ All pyLocusZoom exceptions inherit from PyLocusZoomError, enabling users to
4
+ catch all library errors with `except PyLocusZoomError`.
5
+ """
6
+
7
+
8
+ class PyLocusZoomError(Exception):
9
+ """Base exception for all pyLocusZoom errors."""
10
+
11
+
12
+ class ValidationError(PyLocusZoomError, ValueError):
13
+ """Raised when input validation fails. Inherits ValueError for backward compat."""
14
+
15
+
16
+ class EQTLValidationError(ValidationError):
17
+ """Raised when eQTL DataFrame validation fails."""
18
+
19
+
20
+ class FinemappingValidationError(ValidationError):
21
+ """Raised when fine-mapping DataFrame validation fails."""
22
+
23
+
24
+ class LoaderValidationError(ValidationError):
25
+ """Raised when loaded data fails validation."""
26
+
27
+
28
+ class BackendError(PyLocusZoomError):
29
+ """Raised when backend operations fail."""
30
+
31
+
32
+ class DataDownloadError(PyLocusZoomError, RuntimeError):
33
+ """Raised when data download operations fail."""
@@ -8,19 +8,16 @@ from typing import List, Optional
8
8
 
9
9
  import pandas as pd
10
10
 
11
+ from .exceptions import FinemappingValidationError, ValidationError
11
12
  from .logging import logger
13
+ from .utils import filter_by_region
14
+ from .validation import DataFrameValidator
12
15
 
13
16
  # Required columns for fine-mapping data
14
17
  REQUIRED_FINEMAPPING_COLS = ["pos", "pip"]
15
18
  OPTIONAL_FINEMAPPING_COLS = ["rs", "cs", "cs_id", "effect", "se"]
16
19
 
17
20
 
18
- class FinemappingValidationError(ValueError):
19
- """Raised when fine-mapping DataFrame validation fails."""
20
-
21
- pass
22
-
23
-
24
21
  def validate_finemapping_df(
25
22
  df: pd.DataFrame,
26
23
  pos_col: str = "pos",
@@ -36,24 +33,16 @@ def validate_finemapping_df(
36
33
  Raises:
37
34
  FinemappingValidationError: If required columns are missing.
38
35
  """
39
- missing = []
40
- if pos_col not in df.columns:
41
- missing.append(pos_col)
42
- if pip_col not in df.columns:
43
- missing.append(pip_col)
44
-
45
- if missing:
46
- raise FinemappingValidationError(
47
- f"Fine-mapping DataFrame missing required columns: {missing}. "
48
- f"Required: {pos_col} (position), {pip_col} (posterior inclusion probability)"
49
- )
50
-
51
- # Validate PIP values are in [0, 1]
52
- if not df[pip_col].between(0, 1).all():
53
- invalid_count = (~df[pip_col].between(0, 1)).sum()
54
- raise FinemappingValidationError(
55
- f"PIP values must be between 0 and 1. Found {invalid_count} invalid values."
36
+ try:
37
+ (
38
+ DataFrameValidator(df, "Fine-mapping DataFrame")
39
+ .require_columns([pos_col, pip_col])
40
+ .require_numeric([pip_col])
41
+ .require_range(pip_col, min_val=0, max_val=1)
42
+ .validate()
56
43
  )
44
+ except ValidationError as e:
45
+ raise FinemappingValidationError(str(e)) from e
57
46
 
58
47
 
59
48
  def filter_finemapping_by_region(
@@ -77,15 +66,12 @@ def filter_finemapping_by_region(
77
66
  Returns:
78
67
  Filtered DataFrame containing only variants in the region.
79
68
  """
80
- mask = (df[pos_col] >= start) & (df[pos_col] <= end)
81
-
82
- # Filter by chromosome if column exists
83
- if chrom_col and chrom_col in df.columns:
84
- chrom_str = str(chrom).replace("chr", "")
85
- df_chrom = df[chrom_col].astype(str).str.replace("chr", "", regex=False)
86
- mask = mask & (df_chrom == chrom_str)
87
-
88
- filtered = df[mask].copy()
69
+ filtered = filter_by_region(
70
+ df,
71
+ region=(chrom, start, end),
72
+ chrom_col=chrom_col or "",
73
+ pos_col=pos_col,
74
+ )
89
75
  logger.debug(
90
76
  f"Filtered fine-mapping data to {len(filtered)} variants in region "
91
77
  f"chr{chrom}:{start}-{end}"
pylocuszoom/forest.py CHANGED
@@ -5,7 +5,7 @@ Validates and prepares meta-analysis/forest plot data for visualization.
5
5
 
6
6
  import pandas as pd
7
7
 
8
- from .utils import ValidationError
8
+ from .validation import DataFrameValidator
9
9
 
10
10
 
11
11
  def validate_forest_df(
@@ -15,7 +15,7 @@ def validate_forest_df(
15
15
  ci_lower_col: str = "ci_lower",
16
16
  ci_upper_col: str = "ci_upper",
17
17
  ) -> None:
18
- """Validate forest plot DataFrame has required columns.
18
+ """Validate forest plot DataFrame has required columns and types.
19
19
 
20
20
  Args:
21
21
  df: Forest plot data DataFrame.
@@ -25,13 +25,12 @@ def validate_forest_df(
25
25
  ci_upper_col: Column name for upper confidence interval.
26
26
 
27
27
  Raises:
28
- ValidationError: If required columns are missing.
28
+ ValidationError: If required columns are missing or have invalid types.
29
29
  """
30
- required = [study_col, effect_col, ci_lower_col, ci_upper_col]
31
- missing = [col for col in required if col not in df.columns]
32
-
33
- if missing:
34
- raise ValidationError(
35
- f"Forest plot DataFrame missing required columns: {missing}. "
36
- f"Required: {required}. Found: {list(df.columns)}"
37
- )
30
+ (
31
+ DataFrameValidator(df, "Forest plot DataFrame")
32
+ .require_columns([study_col, effect_col, ci_lower_col, ci_upper_col])
33
+ .require_numeric([effect_col, ci_lower_col, ci_upper_col])
34
+ .require_ci_ordering(ci_lower_col, effect_col, ci_upper_col)
35
+ .validate()
36
+ )