pydeflate 2.1.3__py3-none-any.whl → 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydeflate/__init__.py +64 -20
- pydeflate/cache.py +139 -0
- pydeflate/constants.py +121 -0
- pydeflate/context.py +211 -0
- pydeflate/core/api.py +33 -11
- pydeflate/core/source.py +92 -11
- pydeflate/deflate/deflators.py +1 -1
- pydeflate/deflate/legacy_deflate.py +1 -1
- pydeflate/exceptions.py +166 -0
- pydeflate/exchange/exchangers.py +1 -1
- pydeflate/plugins.py +289 -0
- pydeflate/protocols.py +168 -0
- pydeflate/pydeflate_config.py +77 -6
- pydeflate/schemas.py +297 -0
- pydeflate/sources/common.py +59 -107
- pydeflate/sources/dac.py +39 -52
- pydeflate/sources/imf.py +23 -39
- pydeflate/sources/world_bank.py +44 -117
- pydeflate/utils.py +14 -9
- {pydeflate-2.1.3.dist-info → pydeflate-2.2.0.dist-info}/METADATA +119 -18
- pydeflate-2.2.0.dist-info/RECORD +32 -0
- pydeflate-2.2.0.dist-info/WHEEL +4 -0
- {pydeflate-2.1.3.dist-info → pydeflate-2.2.0.dist-info/licenses}/LICENSE +1 -1
- pydeflate-2.1.3.dist-info/RECORD +0 -25
- pydeflate-2.1.3.dist-info/WHEEL +0 -4
pydeflate/protocols.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Protocol definitions for type safety and extensibility.
|
|
2
|
+
|
|
3
|
+
This module defines the core protocols (interfaces) that pydeflate components
|
|
4
|
+
must implement. Using protocols enables:
|
|
5
|
+
- Type checking without inheritance
|
|
6
|
+
- Duck typing with static verification
|
|
7
|
+
- Clear contracts for extensibility
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Protocol, runtime_checkable
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from pydeflate.sources.common import AvailableDeflators
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@runtime_checkable
|
|
20
|
+
class SourceProtocol(Protocol):
|
|
21
|
+
"""Protocol for data sources (IMF, World Bank, DAC, etc.).
|
|
22
|
+
|
|
23
|
+
All data sources must implement these methods to be compatible
|
|
24
|
+
with pydeflate's core deflation and exchange logic.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
name: str
|
|
28
|
+
"""Human-readable name of the data source (e.g., 'IMF', 'World Bank')"""
|
|
29
|
+
|
|
30
|
+
data: pd.DataFrame
|
|
31
|
+
"""The raw data loaded from this source"""
|
|
32
|
+
|
|
33
|
+
_idx: list[str]
|
|
34
|
+
"""Standard index columns: ['pydeflate_year', 'pydeflate_entity_code', 'pydeflate_iso3']"""
|
|
35
|
+
|
|
36
|
+
def lcu_usd_exchange(self) -> pd.DataFrame:
|
|
37
|
+
"""Return local currency to USD exchange rates.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
DataFrame with columns: pydeflate_year, pydeflate_entity_code,
|
|
41
|
+
pydeflate_iso3, pydeflate_EXCHANGE
|
|
42
|
+
"""
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
def price_deflator(self, kind: AvailableDeflators = "NGDP_D") -> pd.DataFrame:
|
|
46
|
+
"""Return price deflator data for the specified type.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
kind: Type of deflator (e.g., 'NGDP_D', 'CPI', 'PCPI')
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
DataFrame with columns: pydeflate_year, pydeflate_entity_code,
|
|
53
|
+
pydeflate_iso3, pydeflate_{kind}
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If the requested deflator kind is not available
|
|
57
|
+
"""
|
|
58
|
+
...
|
|
59
|
+
|
|
60
|
+
def validate(self) -> None:
|
|
61
|
+
"""Validate that the source data is properly formatted.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If data is empty or improperly formatted
|
|
65
|
+
SchemaValidationError: If data doesn't match expected schema
|
|
66
|
+
"""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@runtime_checkable
|
|
71
|
+
class DeflatorProtocol(Protocol):
|
|
72
|
+
"""Protocol for deflator objects (price or exchange deflators)."""
|
|
73
|
+
|
|
74
|
+
source: SourceProtocol
|
|
75
|
+
"""The data source providing deflator data"""
|
|
76
|
+
|
|
77
|
+
deflator_type: str
|
|
78
|
+
"""Type of deflator: 'price' or 'exchange'"""
|
|
79
|
+
|
|
80
|
+
base_year: int
|
|
81
|
+
"""The base year for rebasing deflator values (value = 100 at base year)"""
|
|
82
|
+
|
|
83
|
+
deflator_data: pd.DataFrame
|
|
84
|
+
"""The deflator data, rebased to base_year"""
|
|
85
|
+
|
|
86
|
+
def rebase_deflator(self) -> None:
|
|
87
|
+
"""Rebase deflator values so that base_year has value 100."""
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@runtime_checkable
|
|
92
|
+
class ExchangeProtocol(Protocol):
|
|
93
|
+
"""Protocol for exchange rate objects."""
|
|
94
|
+
|
|
95
|
+
source: SourceProtocol
|
|
96
|
+
"""The data source providing exchange rate data"""
|
|
97
|
+
|
|
98
|
+
source_currency: str
|
|
99
|
+
"""Source currency code (ISO3 country code or 'LCU')"""
|
|
100
|
+
|
|
101
|
+
target_currency: str
|
|
102
|
+
"""Target currency code (ISO3 country code)"""
|
|
103
|
+
|
|
104
|
+
exchange_data: pd.DataFrame
|
|
105
|
+
"""Exchange rate data for converting source_currency to target_currency"""
|
|
106
|
+
|
|
107
|
+
def exchange_rate(self, from_currency: str, to_currency: str) -> pd.DataFrame:
|
|
108
|
+
"""Calculate exchange rates between two currencies.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
from_currency: Source currency code
|
|
112
|
+
to_currency: Target currency code
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
DataFrame with exchange rates and deflators
|
|
116
|
+
"""
|
|
117
|
+
...
|
|
118
|
+
|
|
119
|
+
def deflator(self) -> pd.DataFrame:
|
|
120
|
+
"""Get exchange rate deflator data.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
DataFrame with columns: pydeflate_year, pydeflate_entity_code,
|
|
124
|
+
pydeflate_iso3, pydeflate_EXCHANGE_D
|
|
125
|
+
"""
|
|
126
|
+
...
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@runtime_checkable
|
|
130
|
+
class CacheManagerProtocol(Protocol):
|
|
131
|
+
"""Protocol for cache management."""
|
|
132
|
+
|
|
133
|
+
base_dir: pd.DataFrame
|
|
134
|
+
"""Base directory for cached data"""
|
|
135
|
+
|
|
136
|
+
def ensure(self, entry, *, refresh: bool = False):
|
|
137
|
+
"""Ensure a cache entry exists, downloading if needed.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
entry: CacheEntry describing the dataset
|
|
141
|
+
refresh: If True, force re-download
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Path to the cached file
|
|
145
|
+
"""
|
|
146
|
+
...
|
|
147
|
+
|
|
148
|
+
def clear(self, key: str | None = None) -> None:
|
|
149
|
+
"""Clear cache entries.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
key: If provided, clear only this entry. If None, clear all.
|
|
153
|
+
"""
|
|
154
|
+
...
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# Type aliases for common DataFrame schemas
|
|
158
|
+
ExchangeDataFrame = pd.DataFrame
|
|
159
|
+
"""DataFrame containing exchange rate data"""
|
|
160
|
+
|
|
161
|
+
DeflatorDataFrame = pd.DataFrame
|
|
162
|
+
"""DataFrame containing deflator data"""
|
|
163
|
+
|
|
164
|
+
SourceDataFrame = pd.DataFrame
|
|
165
|
+
"""DataFrame containing raw source data"""
|
|
166
|
+
|
|
167
|
+
UserDataFrame = pd.DataFrame
|
|
168
|
+
"""DataFrame provided by users for deflation/exchange"""
|
pydeflate/pydeflate_config.py
CHANGED
|
@@ -1,14 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass
|
|
2
6
|
from pathlib import Path
|
|
3
7
|
|
|
4
8
|
|
|
5
|
-
|
|
6
|
-
|
|
9
|
+
from platformdirs import user_cache_dir
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
DATA_DIR_ENV = "PYDEFLATE_DATA_DIR"
|
|
13
|
+
|
|
14
|
+
_PACKAGE_ROOT = Path(__file__).resolve().parent.parent
|
|
15
|
+
_SETTINGS_DIR = _PACKAGE_ROOT / "pydeflate" / "settings"
|
|
16
|
+
_TEST_DATA_DIR = _PACKAGE_ROOT / "tests" / "test_files"
|
|
17
|
+
|
|
18
|
+
_DATA_DIR_OVERRIDE: Path | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _ensure_dir(path: Path) -> Path:
|
|
22
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
23
|
+
return path
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _default_data_dir() -> Path:
|
|
27
|
+
env_value = os.environ.get(DATA_DIR_ENV)
|
|
28
|
+
if env_value:
|
|
29
|
+
return _ensure_dir(Path(env_value).expanduser().resolve())
|
|
30
|
+
return _ensure_dir(Path(user_cache_dir("pydeflate", "pydeflate")))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_data_dir() -> Path:
|
|
34
|
+
"""Return the directory where pydeflate caches data files."""
|
|
35
|
+
|
|
36
|
+
if _DATA_DIR_OVERRIDE is not None:
|
|
37
|
+
return _ensure_dir(_DATA_DIR_OVERRIDE)
|
|
38
|
+
return _default_data_dir()
|
|
39
|
+
|
|
7
40
|
|
|
8
|
-
|
|
9
|
-
data
|
|
10
|
-
|
|
11
|
-
|
|
41
|
+
def set_data_dir(path: str | Path) -> Path:
|
|
42
|
+
"""Override the pydeflate data directory for the current process."""
|
|
43
|
+
|
|
44
|
+
global _DATA_DIR_OVERRIDE
|
|
45
|
+
resolved = _ensure_dir(Path(path).expanduser().resolve())
|
|
46
|
+
_DATA_DIR_OVERRIDE = resolved
|
|
47
|
+
return resolved
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def reset_data_dir() -> None:
|
|
51
|
+
"""Reset any process-level overrides and fall back to defaults."""
|
|
52
|
+
|
|
53
|
+
global _DATA_DIR_OVERRIDE
|
|
54
|
+
_DATA_DIR_OVERRIDE = None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True)
|
|
58
|
+
class _Paths:
|
|
59
|
+
package: Path
|
|
60
|
+
settings: Path
|
|
61
|
+
test_data: Path
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def data(self) -> Path:
|
|
65
|
+
return get_data_dir()
|
|
66
|
+
|
|
67
|
+
@data.setter # type: ignore[override]
|
|
68
|
+
def data(self, value: Path | str) -> None: # pragma: no cover - simple proxy
|
|
69
|
+
set_data_dir(value)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
PYDEFLATE_PATHS = _Paths(
|
|
73
|
+
package=_PACKAGE_ROOT,
|
|
74
|
+
settings=_SETTINGS_DIR,
|
|
75
|
+
test_data=_TEST_DATA_DIR,
|
|
76
|
+
)
|
|
12
77
|
|
|
13
78
|
|
|
14
79
|
def setup_logger(name) -> logging.Logger:
|
|
@@ -41,3 +106,9 @@ def setup_logger(name) -> logging.Logger:
|
|
|
41
106
|
|
|
42
107
|
|
|
43
108
|
logger = setup_logger("pydeflate")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def set_pydeflate_path(path: str | Path) -> Path:
|
|
112
|
+
"""Set the path to the data folder (public API)."""
|
|
113
|
+
|
|
114
|
+
return set_data_dir(path)
|
pydeflate/schemas.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""Pandera schemas for data validation.
|
|
2
|
+
|
|
3
|
+
This module defines validation schemas for all DataFrame structures used
|
|
4
|
+
in pydeflate. This ensures data integrity from external sources and
|
|
5
|
+
catches API changes early.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
# New import (pandera >= 0.20)
|
|
14
|
+
import pandera.pandas as pa
|
|
15
|
+
from pandera.pandas import Check, Column, DataFrameSchema
|
|
16
|
+
except ImportError:
|
|
17
|
+
# Fallback to old import for older pandera versions
|
|
18
|
+
import pandera as pa
|
|
19
|
+
from pandera import Check, Column, DataFrameSchema
|
|
20
|
+
|
|
21
|
+
# Column definitions for reuse
|
|
22
|
+
YEAR_COLUMN = Column(
|
|
23
|
+
int,
|
|
24
|
+
checks=[
|
|
25
|
+
Check.ge(1960), # No data before 1960
|
|
26
|
+
Check.le(2100), # No projections beyond 2100
|
|
27
|
+
],
|
|
28
|
+
nullable=False,
|
|
29
|
+
description="Year as integer",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
ENTITY_CODE_COLUMN = Column(
|
|
33
|
+
str,
|
|
34
|
+
checks=[Check(lambda s: s.str.len() <= 10)],
|
|
35
|
+
nullable=False,
|
|
36
|
+
description="Entity code from source (varies by source)",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
ISO3_COLUMN = Column(
|
|
40
|
+
str,
|
|
41
|
+
checks=[
|
|
42
|
+
Check(lambda s: (s.str.len() == 3) | s.isna()),
|
|
43
|
+
],
|
|
44
|
+
nullable=True,
|
|
45
|
+
description="ISO3 country code",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
EXCHANGE_RATE_COLUMN = Column(
|
|
49
|
+
float,
|
|
50
|
+
checks=[
|
|
51
|
+
Check.gt(0), # Exchange rates must be positive
|
|
52
|
+
],
|
|
53
|
+
nullable=True,
|
|
54
|
+
description="Exchange rate (LCU per USD)",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
DEFLATOR_COLUMN = Column(
|
|
58
|
+
float,
|
|
59
|
+
checks=[
|
|
60
|
+
Check.gt(0), # Deflators must be positive
|
|
61
|
+
],
|
|
62
|
+
nullable=True,
|
|
63
|
+
description="Price deflator index",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class SourceDataSchema(pa.DataFrameModel):
|
|
68
|
+
"""Base schema for all data sources.
|
|
69
|
+
|
|
70
|
+
All sources must have these minimum columns after processing.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
pydeflate_year: int = pa.Field(ge=1960, le=2100)
|
|
74
|
+
pydeflate_entity_code: str = pa.Field(str_length={"max_value": 10})
|
|
75
|
+
pydeflate_iso3: str | None = pa.Field(nullable=True)
|
|
76
|
+
|
|
77
|
+
class Config:
|
|
78
|
+
"""Schema configuration."""
|
|
79
|
+
|
|
80
|
+
strict = False # Allow additional columns
|
|
81
|
+
coerce = True # Attempt type coercion
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ExchangeDataSchema(SourceDataSchema):
|
|
85
|
+
"""Schema for exchange rate data."""
|
|
86
|
+
|
|
87
|
+
pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
|
|
88
|
+
pydeflate_EXCHANGE_D = Column(
|
|
89
|
+
float,
|
|
90
|
+
checks=[Check.gt(0)],
|
|
91
|
+
nullable=True,
|
|
92
|
+
description="Exchange rate deflator (rebased)",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
class Config:
|
|
96
|
+
"""Schema configuration."""
|
|
97
|
+
|
|
98
|
+
strict = False
|
|
99
|
+
coerce = True
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class IMFDataSchema(SourceDataSchema):
|
|
103
|
+
"""Schema for IMF WEO data.
|
|
104
|
+
|
|
105
|
+
IMF provides GDP deflators, CPI, and exchange rates.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
pydeflate_NGDP_D = Column(
|
|
109
|
+
float,
|
|
110
|
+
checks=[Check.gt(0)],
|
|
111
|
+
nullable=True,
|
|
112
|
+
description="GDP deflator",
|
|
113
|
+
)
|
|
114
|
+
pydeflate_PCPI = Column(
|
|
115
|
+
float,
|
|
116
|
+
checks=[Check.gt(0)],
|
|
117
|
+
nullable=True,
|
|
118
|
+
description="CPI (period average)",
|
|
119
|
+
)
|
|
120
|
+
pydeflate_PCPIE = Column(
|
|
121
|
+
float,
|
|
122
|
+
checks=[Check.gt(0)],
|
|
123
|
+
nullable=True,
|
|
124
|
+
description="CPI (end of period)",
|
|
125
|
+
)
|
|
126
|
+
pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
|
|
127
|
+
|
|
128
|
+
class Config:
|
|
129
|
+
"""Schema configuration."""
|
|
130
|
+
|
|
131
|
+
strict = False
|
|
132
|
+
coerce = True
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class WorldBankDataSchema(SourceDataSchema):
|
|
136
|
+
"""Schema for World Bank data."""
|
|
137
|
+
|
|
138
|
+
pydeflate_NGDP_D = Column(
|
|
139
|
+
float,
|
|
140
|
+
checks=[Check.gt(0)],
|
|
141
|
+
nullable=True,
|
|
142
|
+
description="GDP deflator",
|
|
143
|
+
)
|
|
144
|
+
pydeflate_NGDP_DL = Column(
|
|
145
|
+
float,
|
|
146
|
+
checks=[Check.gt(0)],
|
|
147
|
+
nullable=True,
|
|
148
|
+
description="GDP deflator (linked)",
|
|
149
|
+
)
|
|
150
|
+
pydeflate_CPI = Column(
|
|
151
|
+
float,
|
|
152
|
+
checks=[Check.gt(0)],
|
|
153
|
+
nullable=True,
|
|
154
|
+
description="Consumer Price Index",
|
|
155
|
+
)
|
|
156
|
+
pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
|
|
157
|
+
|
|
158
|
+
class Config:
|
|
159
|
+
"""Schema configuration."""
|
|
160
|
+
|
|
161
|
+
strict = False
|
|
162
|
+
coerce = True
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class DACDataSchema(SourceDataSchema):
|
|
166
|
+
"""Schema for OECD DAC data."""
|
|
167
|
+
|
|
168
|
+
pydeflate_DAC_DEFLATOR = Column(
|
|
169
|
+
float,
|
|
170
|
+
checks=[Check.gt(0)],
|
|
171
|
+
nullable=True,
|
|
172
|
+
description="DAC deflator",
|
|
173
|
+
)
|
|
174
|
+
pydeflate_NGDP_D = Column(
|
|
175
|
+
float,
|
|
176
|
+
checks=[Check.gt(0)],
|
|
177
|
+
nullable=True,
|
|
178
|
+
description="GDP deflator (computed)",
|
|
179
|
+
)
|
|
180
|
+
pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
|
|
181
|
+
|
|
182
|
+
class Config:
|
|
183
|
+
"""Schema configuration."""
|
|
184
|
+
|
|
185
|
+
strict = False
|
|
186
|
+
coerce = True
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class UserInputSchema:
|
|
190
|
+
"""Validation for user-provided DataFrames.
|
|
191
|
+
|
|
192
|
+
This is not a Pandera schema but provides methods to validate
|
|
193
|
+
user input with custom column names.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def validate(
|
|
198
|
+
df,
|
|
199
|
+
id_column: str,
|
|
200
|
+
year_column: str,
|
|
201
|
+
value_column: str,
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Validate user DataFrame has required columns and types.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
df: User's DataFrame
|
|
207
|
+
id_column: Name of column with entity identifiers
|
|
208
|
+
year_column: Name of column with year data
|
|
209
|
+
value_column: Name of column with numeric values
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
ConfigurationError: If required columns are missing
|
|
213
|
+
SchemaValidationError: If column types are invalid
|
|
214
|
+
"""
|
|
215
|
+
from pydeflate.exceptions import ConfigurationError, SchemaValidationError
|
|
216
|
+
|
|
217
|
+
# Check required columns exist
|
|
218
|
+
missing_cols = []
|
|
219
|
+
for col_name, col in [
|
|
220
|
+
("id_column", id_column),
|
|
221
|
+
("year_column", year_column),
|
|
222
|
+
("value_column", value_column),
|
|
223
|
+
]:
|
|
224
|
+
if col not in df.columns:
|
|
225
|
+
missing_cols.append(f"{col_name}='{col}'")
|
|
226
|
+
|
|
227
|
+
if missing_cols:
|
|
228
|
+
raise ConfigurationError(
|
|
229
|
+
f"Required columns missing from DataFrame: {', '.join(missing_cols)}"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Validate value column is numeric
|
|
233
|
+
if not pd.api.types.is_numeric_dtype(df[value_column]):
|
|
234
|
+
raise SchemaValidationError(
|
|
235
|
+
f"Column '{value_column}' must be numeric, got {df[value_column].dtype}"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Validate year column can be converted to datetime
|
|
239
|
+
try:
|
|
240
|
+
pd.to_datetime(df[year_column], errors="coerce")
|
|
241
|
+
except Exception as e:
|
|
242
|
+
raise SchemaValidationError(
|
|
243
|
+
f"Column '{year_column}' cannot be interpreted as dates: {e}"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# Registry of schemas by source name
|
|
248
|
+
SCHEMA_REGISTRY: dict[str, type[DataFrameSchema]] = {
|
|
249
|
+
"IMF": IMFDataSchema,
|
|
250
|
+
"World Bank": WorldBankDataSchema,
|
|
251
|
+
"DAC": DACDataSchema,
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def get_schema_for_source(source_name: str) -> type[DataFrameSchema] | None:
|
|
256
|
+
"""Get the appropriate schema for a data source.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
source_name: Name of the source (e.g., 'IMF', 'World Bank')
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Schema class for the source, or None if not found
|
|
263
|
+
"""
|
|
264
|
+
return SCHEMA_REGISTRY.get(source_name)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def validate_source_data(df, source_name: str) -> None:
|
|
268
|
+
"""Validate that source data matches expected schema.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
df: DataFrame to validate
|
|
272
|
+
source_name: Name of the source
|
|
273
|
+
|
|
274
|
+
Raises:
|
|
275
|
+
SchemaValidationError: If data doesn't match schema
|
|
276
|
+
"""
|
|
277
|
+
from pydeflate.exceptions import SchemaValidationError
|
|
278
|
+
|
|
279
|
+
schema_class = get_schema_for_source(source_name)
|
|
280
|
+
if schema_class is None:
|
|
281
|
+
# No schema defined for this source, skip validation
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
# Instantiate the schema and validate
|
|
286
|
+
schema = schema_class()
|
|
287
|
+
schema.validate(df, lazy=True)
|
|
288
|
+
except pa.errors.SchemaErrors as e:
|
|
289
|
+
# Collect all validation errors
|
|
290
|
+
error_messages = []
|
|
291
|
+
for error in e.failure_cases.itertuples():
|
|
292
|
+
error_messages.append(f" - {error.check}: {error.failure_case}")
|
|
293
|
+
|
|
294
|
+
raise SchemaValidationError(
|
|
295
|
+
f"Data validation failed for {source_name}:\n" + "\n".join(error_messages),
|
|
296
|
+
source=source_name,
|
|
297
|
+
) from e
|