pydeflate 2.1.3__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydeflate/protocols.py ADDED
@@ -0,0 +1,168 @@
1
+ """Protocol definitions for type safety and extensibility.
2
+
3
+ This module defines the core protocols (interfaces) that pydeflate components
4
+ must implement. Using protocols enables:
5
+ - Type checking without inheritance
6
+ - Duck typing with static verification
7
+ - Clear contracts for extensibility
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Protocol, runtime_checkable
13
+
14
+ import pandas as pd
15
+
16
+ from pydeflate.sources.common import AvailableDeflators
17
+
18
+
19
+ @runtime_checkable
20
+ class SourceProtocol(Protocol):
21
+ """Protocol for data sources (IMF, World Bank, DAC, etc.).
22
+
23
+ All data sources must implement these methods to be compatible
24
+ with pydeflate's core deflation and exchange logic.
25
+ """
26
+
27
+ name: str
28
+ """Human-readable name of the data source (e.g., 'IMF', 'World Bank')"""
29
+
30
+ data: pd.DataFrame
31
+ """The raw data loaded from this source"""
32
+
33
+ _idx: list[str]
34
+ """Standard index columns: ['pydeflate_year', 'pydeflate_entity_code', 'pydeflate_iso3']"""
35
+
36
+ def lcu_usd_exchange(self) -> pd.DataFrame:
37
+ """Return local currency to USD exchange rates.
38
+
39
+ Returns:
40
+ DataFrame with columns: pydeflate_year, pydeflate_entity_code,
41
+ pydeflate_iso3, pydeflate_EXCHANGE
42
+ """
43
+ ...
44
+
45
+ def price_deflator(self, kind: AvailableDeflators = "NGDP_D") -> pd.DataFrame:
46
+ """Return price deflator data for the specified type.
47
+
48
+ Args:
49
+ kind: Type of deflator (e.g., 'NGDP_D', 'CPI', 'PCPI')
50
+
51
+ Returns:
52
+ DataFrame with columns: pydeflate_year, pydeflate_entity_code,
53
+ pydeflate_iso3, pydeflate_{kind}
54
+
55
+ Raises:
56
+ ValueError: If the requested deflator kind is not available
57
+ """
58
+ ...
59
+
60
+ def validate(self) -> None:
61
+ """Validate that the source data is properly formatted.
62
+
63
+ Raises:
64
+ ValueError: If data is empty or improperly formatted
65
+ SchemaValidationError: If data doesn't match expected schema
66
+ """
67
+ ...
68
+
69
+
70
+ @runtime_checkable
71
+ class DeflatorProtocol(Protocol):
72
+ """Protocol for deflator objects (price or exchange deflators)."""
73
+
74
+ source: SourceProtocol
75
+ """The data source providing deflator data"""
76
+
77
+ deflator_type: str
78
+ """Type of deflator: 'price' or 'exchange'"""
79
+
80
+ base_year: int
81
+ """The base year for rebasing deflator values (value = 100 at base year)"""
82
+
83
+ deflator_data: pd.DataFrame
84
+ """The deflator data, rebased to base_year"""
85
+
86
+ def rebase_deflator(self) -> None:
87
+ """Rebase deflator values so that base_year has value 100."""
88
+ ...
89
+
90
+
91
+ @runtime_checkable
92
+ class ExchangeProtocol(Protocol):
93
+ """Protocol for exchange rate objects."""
94
+
95
+ source: SourceProtocol
96
+ """The data source providing exchange rate data"""
97
+
98
+ source_currency: str
99
+ """Source currency code (ISO3 country code or 'LCU')"""
100
+
101
+ target_currency: str
102
+ """Target currency code (ISO3 country code)"""
103
+
104
+ exchange_data: pd.DataFrame
105
+ """Exchange rate data for converting source_currency to target_currency"""
106
+
107
+ def exchange_rate(self, from_currency: str, to_currency: str) -> pd.DataFrame:
108
+ """Calculate exchange rates between two currencies.
109
+
110
+ Args:
111
+ from_currency: Source currency code
112
+ to_currency: Target currency code
113
+
114
+ Returns:
115
+ DataFrame with exchange rates and deflators
116
+ """
117
+ ...
118
+
119
+ def deflator(self) -> pd.DataFrame:
120
+ """Get exchange rate deflator data.
121
+
122
+ Returns:
123
+ DataFrame with columns: pydeflate_year, pydeflate_entity_code,
124
+ pydeflate_iso3, pydeflate_EXCHANGE_D
125
+ """
126
+ ...
127
+
128
+
129
+ @runtime_checkable
130
+ class CacheManagerProtocol(Protocol):
131
+ """Protocol for cache management."""
132
+
133
+ base_dir: pd.DataFrame
134
+ """Base directory for cached data"""
135
+
136
+ def ensure(self, entry, *, refresh: bool = False):
137
+ """Ensure a cache entry exists, downloading if needed.
138
+
139
+ Args:
140
+ entry: CacheEntry describing the dataset
141
+ refresh: If True, force re-download
142
+
143
+ Returns:
144
+ Path to the cached file
145
+ """
146
+ ...
147
+
148
+ def clear(self, key: str | None = None) -> None:
149
+ """Clear cache entries.
150
+
151
+ Args:
152
+ key: If provided, clear only this entry. If None, clear all.
153
+ """
154
+ ...
155
+
156
+
157
+ # Type aliases for common DataFrame schemas
158
+ ExchangeDataFrame = pd.DataFrame
159
+ """DataFrame containing exchange rate data"""
160
+
161
+ DeflatorDataFrame = pd.DataFrame
162
+ """DataFrame containing deflator data"""
163
+
164
+ SourceDataFrame = pd.DataFrame
165
+ """DataFrame containing raw source data"""
166
+
167
+ UserDataFrame = pd.DataFrame
168
+ """DataFrame provided by users for deflation/exchange"""
@@ -1,14 +1,79 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
4
+ import os
5
+ from dataclasses import dataclass
2
6
  from pathlib import Path
3
7
 
4
8
 
5
- class PYDEFLATE_PATHS:
6
- """Class to store the paths to the data and output folders."""
9
+ from platformdirs import user_cache_dir
10
+
11
+
12
+ DATA_DIR_ENV = "PYDEFLATE_DATA_DIR"
13
+
14
+ _PACKAGE_ROOT = Path(__file__).resolve().parent.parent
15
+ _SETTINGS_DIR = _PACKAGE_ROOT / "pydeflate" / "settings"
16
+ _TEST_DATA_DIR = _PACKAGE_ROOT / "tests" / "test_files"
17
+
18
+ _DATA_DIR_OVERRIDE: Path | None = None
19
+
20
+
21
+ def _ensure_dir(path: Path) -> Path:
22
+ path.mkdir(parents=True, exist_ok=True)
23
+ return path
24
+
25
+
26
+ def _default_data_dir() -> Path:
27
+ env_value = os.environ.get(DATA_DIR_ENV)
28
+ if env_value:
29
+ return _ensure_dir(Path(env_value).expanduser().resolve())
30
+ return _ensure_dir(Path(user_cache_dir("pydeflate", "pydeflate")))
31
+
32
+
33
+ def get_data_dir() -> Path:
34
+ """Return the directory where pydeflate caches data files."""
35
+
36
+ if _DATA_DIR_OVERRIDE is not None:
37
+ return _ensure_dir(_DATA_DIR_OVERRIDE)
38
+ return _default_data_dir()
39
+
7
40
 
8
- package = Path(__file__).resolve().parent.parent
9
- data = package / "pydeflate" / ".pydeflate_data"
10
- settings = package / "pydeflate" / "settings"
11
- test_data = package / "tests" / "test_files"
41
+ def set_data_dir(path: str | Path) -> Path:
42
+ """Override the pydeflate data directory for the current process."""
43
+
44
+ global _DATA_DIR_OVERRIDE
45
+ resolved = _ensure_dir(Path(path).expanduser().resolve())
46
+ _DATA_DIR_OVERRIDE = resolved
47
+ return resolved
48
+
49
+
50
+ def reset_data_dir() -> None:
51
+ """Reset any process-level overrides and fall back to defaults."""
52
+
53
+ global _DATA_DIR_OVERRIDE
54
+ _DATA_DIR_OVERRIDE = None
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class _Paths:
59
+ package: Path
60
+ settings: Path
61
+ test_data: Path
62
+
63
+ @property
64
+ def data(self) -> Path:
65
+ return get_data_dir()
66
+
67
+ @data.setter # type: ignore[override]
68
+ def data(self, value: Path | str) -> None: # pragma: no cover - simple proxy
69
+ set_data_dir(value)
70
+
71
+
72
+ PYDEFLATE_PATHS = _Paths(
73
+ package=_PACKAGE_ROOT,
74
+ settings=_SETTINGS_DIR,
75
+ test_data=_TEST_DATA_DIR,
76
+ )
12
77
 
13
78
 
14
79
  def setup_logger(name) -> logging.Logger:
@@ -41,3 +106,9 @@ def setup_logger(name) -> logging.Logger:
41
106
 
42
107
 
43
108
  logger = setup_logger("pydeflate")
109
+
110
+
111
+ def set_pydeflate_path(path: str | Path) -> Path:
112
+ """Set the path to the data folder (public API)."""
113
+
114
+ return set_data_dir(path)
pydeflate/schemas.py ADDED
@@ -0,0 +1,297 @@
1
+ """Pandera schemas for data validation.
2
+
3
+ This module defines validation schemas for all DataFrame structures used
4
+ in pydeflate. This ensures data integrity from external sources and
5
+ catches API changes early.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import pandas as pd
11
+
12
+ try:
13
+ # New import (pandera >= 0.20)
14
+ import pandera.pandas as pa
15
+ from pandera.pandas import Check, Column, DataFrameSchema
16
+ except ImportError:
17
+ # Fallback to old import for older pandera versions
18
+ import pandera as pa
19
+ from pandera import Check, Column, DataFrameSchema
20
+
21
+ # Column definitions for reuse
22
+ YEAR_COLUMN = Column(
23
+ int,
24
+ checks=[
25
+ Check.ge(1960), # No data before 1960
26
+ Check.le(2100), # No projections beyond 2100
27
+ ],
28
+ nullable=False,
29
+ description="Year as integer",
30
+ )
31
+
32
+ ENTITY_CODE_COLUMN = Column(
33
+ str,
34
+ checks=[Check(lambda s: s.str.len() <= 10)],
35
+ nullable=False,
36
+ description="Entity code from source (varies by source)",
37
+ )
38
+
39
+ ISO3_COLUMN = Column(
40
+ str,
41
+ checks=[
42
+ Check(lambda s: (s.str.len() == 3) | s.isna()),
43
+ ],
44
+ nullable=True,
45
+ description="ISO3 country code",
46
+ )
47
+
48
+ EXCHANGE_RATE_COLUMN = Column(
49
+ float,
50
+ checks=[
51
+ Check.gt(0), # Exchange rates must be positive
52
+ ],
53
+ nullable=True,
54
+ description="Exchange rate (LCU per USD)",
55
+ )
56
+
57
+ DEFLATOR_COLUMN = Column(
58
+ float,
59
+ checks=[
60
+ Check.gt(0), # Deflators must be positive
61
+ ],
62
+ nullable=True,
63
+ description="Price deflator index",
64
+ )
65
+
66
+
67
+ class SourceDataSchema(pa.DataFrameModel):
68
+ """Base schema for all data sources.
69
+
70
+ All sources must have these minimum columns after processing.
71
+ """
72
+
73
+ pydeflate_year: int = pa.Field(ge=1960, le=2100)
74
+ pydeflate_entity_code: str = pa.Field(str_length={"max_value": 10})
75
+ pydeflate_iso3: str | None = pa.Field(nullable=True)
76
+
77
+ class Config:
78
+ """Schema configuration."""
79
+
80
+ strict = False # Allow additional columns
81
+ coerce = True # Attempt type coercion
82
+
83
+
84
+ class ExchangeDataSchema(SourceDataSchema):
85
+ """Schema for exchange rate data."""
86
+
87
+ pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
88
+ pydeflate_EXCHANGE_D = Column(
89
+ float,
90
+ checks=[Check.gt(0)],
91
+ nullable=True,
92
+ description="Exchange rate deflator (rebased)",
93
+ )
94
+
95
+ class Config:
96
+ """Schema configuration."""
97
+
98
+ strict = False
99
+ coerce = True
100
+
101
+
102
+ class IMFDataSchema(SourceDataSchema):
103
+ """Schema for IMF WEO data.
104
+
105
+ IMF provides GDP deflators, CPI, and exchange rates.
106
+ """
107
+
108
+ pydeflate_NGDP_D = Column(
109
+ float,
110
+ checks=[Check.gt(0)],
111
+ nullable=True,
112
+ description="GDP deflator",
113
+ )
114
+ pydeflate_PCPI = Column(
115
+ float,
116
+ checks=[Check.gt(0)],
117
+ nullable=True,
118
+ description="CPI (period average)",
119
+ )
120
+ pydeflate_PCPIE = Column(
121
+ float,
122
+ checks=[Check.gt(0)],
123
+ nullable=True,
124
+ description="CPI (end of period)",
125
+ )
126
+ pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
127
+
128
+ class Config:
129
+ """Schema configuration."""
130
+
131
+ strict = False
132
+ coerce = True
133
+
134
+
135
+ class WorldBankDataSchema(SourceDataSchema):
136
+ """Schema for World Bank data."""
137
+
138
+ pydeflate_NGDP_D = Column(
139
+ float,
140
+ checks=[Check.gt(0)],
141
+ nullable=True,
142
+ description="GDP deflator",
143
+ )
144
+ pydeflate_NGDP_DL = Column(
145
+ float,
146
+ checks=[Check.gt(0)],
147
+ nullable=True,
148
+ description="GDP deflator (linked)",
149
+ )
150
+ pydeflate_CPI = Column(
151
+ float,
152
+ checks=[Check.gt(0)],
153
+ nullable=True,
154
+ description="Consumer Price Index",
155
+ )
156
+ pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
157
+
158
+ class Config:
159
+ """Schema configuration."""
160
+
161
+ strict = False
162
+ coerce = True
163
+
164
+
165
+ class DACDataSchema(SourceDataSchema):
166
+ """Schema for OECD DAC data."""
167
+
168
+ pydeflate_DAC_DEFLATOR = Column(
169
+ float,
170
+ checks=[Check.gt(0)],
171
+ nullable=True,
172
+ description="DAC deflator",
173
+ )
174
+ pydeflate_NGDP_D = Column(
175
+ float,
176
+ checks=[Check.gt(0)],
177
+ nullable=True,
178
+ description="GDP deflator (computed)",
179
+ )
180
+ pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
181
+
182
+ class Config:
183
+ """Schema configuration."""
184
+
185
+ strict = False
186
+ coerce = True
187
+
188
+
189
+ class UserInputSchema:
190
+ """Validation for user-provided DataFrames.
191
+
192
+ This is not a Pandera schema but provides methods to validate
193
+ user input with custom column names.
194
+ """
195
+
196
+ @staticmethod
197
+ def validate(
198
+ df,
199
+ id_column: str,
200
+ year_column: str,
201
+ value_column: str,
202
+ ) -> None:
203
+ """Validate user DataFrame has required columns and types.
204
+
205
+ Args:
206
+ df: User's DataFrame
207
+ id_column: Name of column with entity identifiers
208
+ year_column: Name of column with year data
209
+ value_column: Name of column with numeric values
210
+
211
+ Raises:
212
+ ConfigurationError: If required columns are missing
213
+ SchemaValidationError: If column types are invalid
214
+ """
215
+ from pydeflate.exceptions import ConfigurationError, SchemaValidationError
216
+
217
+ # Check required columns exist
218
+ missing_cols = []
219
+ for col_name, col in [
220
+ ("id_column", id_column),
221
+ ("year_column", year_column),
222
+ ("value_column", value_column),
223
+ ]:
224
+ if col not in df.columns:
225
+ missing_cols.append(f"{col_name}='{col}'")
226
+
227
+ if missing_cols:
228
+ raise ConfigurationError(
229
+ f"Required columns missing from DataFrame: {', '.join(missing_cols)}"
230
+ )
231
+
232
+ # Validate value column is numeric
233
+ if not pd.api.types.is_numeric_dtype(df[value_column]):
234
+ raise SchemaValidationError(
235
+ f"Column '{value_column}' must be numeric, got {df[value_column].dtype}"
236
+ )
237
+
238
+ # Validate year column can be converted to datetime
239
+ try:
240
+ pd.to_datetime(df[year_column], errors="coerce")
241
+ except Exception as e:
242
+ raise SchemaValidationError(
243
+ f"Column '{year_column}' cannot be interpreted as dates: {e}"
244
+ )
245
+
246
+
247
+ # Registry of schemas by source name
248
+ SCHEMA_REGISTRY: dict[str, type[DataFrameSchema]] = {
249
+ "IMF": IMFDataSchema,
250
+ "World Bank": WorldBankDataSchema,
251
+ "DAC": DACDataSchema,
252
+ }
253
+
254
+
255
+ def get_schema_for_source(source_name: str) -> type[DataFrameSchema] | None:
256
+ """Get the appropriate schema for a data source.
257
+
258
+ Args:
259
+ source_name: Name of the source (e.g., 'IMF', 'World Bank')
260
+
261
+ Returns:
262
+ Schema class for the source, or None if not found
263
+ """
264
+ return SCHEMA_REGISTRY.get(source_name)
265
+
266
+
267
+ def validate_source_data(df, source_name: str) -> None:
268
+ """Validate that source data matches expected schema.
269
+
270
+ Args:
271
+ df: DataFrame to validate
272
+ source_name: Name of the source
273
+
274
+ Raises:
275
+ SchemaValidationError: If data doesn't match schema
276
+ """
277
+ from pydeflate.exceptions import SchemaValidationError
278
+
279
+ schema_class = get_schema_for_source(source_name)
280
+ if schema_class is None:
281
+ # No schema defined for this source, skip validation
282
+ return
283
+
284
+ try:
285
+ # Instantiate the schema and validate
286
+ schema = schema_class()
287
+ schema.validate(df, lazy=True)
288
+ except pa.errors.SchemaErrors as e:
289
+ # Collect all validation errors
290
+ error_messages = []
291
+ for error in e.failure_cases.itertuples():
292
+ error_messages.append(f" - {error.check}: {error.failure_case}")
293
+
294
+ raise SchemaValidationError(
295
+ f"Data validation failed for {source_name}:\n" + "\n".join(error_messages),
296
+ source=source_name,
297
+ ) from e