pydeflate 2.1.3__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydeflate/core/source.py CHANGED
@@ -2,6 +2,7 @@ from dataclasses import dataclass, field
2
2
 
3
3
  import pandas as pd
4
4
 
5
+ from pydeflate.exceptions import ConfigurationError, DataSourceError
5
6
  from pydeflate.sources.common import AvailableDeflators
6
7
  from pydeflate.sources.dac import read_dac
7
8
  from pydeflate.sources.imf import read_weo
@@ -10,6 +11,12 @@ from pydeflate.sources.world_bank import read_wb, read_wb_lcu_ppp, read_wb_usd_p
10
11
 
11
12
  @dataclass
12
13
  class Source:
14
+ """Base class for data sources implementing SourceProtocol.
15
+
16
+ This class handles loading data from external sources, caching,
17
+ and validation. It implements the SourceProtocol interface.
18
+ """
19
+
13
20
  name: str
14
21
  reader: callable
15
22
  update: bool = False
@@ -17,26 +24,100 @@ class Source:
17
24
  _idx = ["pydeflate_year", "pydeflate_entity_code", "pydeflate_iso3"]
18
25
 
19
26
  def __post_init__(self):
20
- self.data = self.reader(self.update)
27
+ """Load and validate data after initialization."""
28
+ try:
29
+ self.data = self.reader(self.update)
30
+ except Exception as e:
31
+ raise DataSourceError(
32
+ f"Failed to load data: {e}",
33
+ source=self.name,
34
+ ) from e
35
+
21
36
  self.validate()
22
37
 
23
38
  def validate(self):
24
- if self.data.empty:
25
- raise ValueError(f"No data found for {self.name}")
39
+ """Validate that source data is properly formatted.
26
40
 
27
- # check all columns start with pydeflate_
28
- if not all(col.startswith("pydeflate_") for col in self.data.columns):
29
- raise ValueError(f"Invalid data format for {self.name}")
41
+ Raises:
42
+ DataSourceError: If data is empty or improperly formatted
43
+ SchemaValidationError: If data doesn't match expected schema
44
+ """
45
+ if self.data.empty:
46
+ raise DataSourceError(f"No data found", source=self.name)
47
+
48
+ # Check all columns start with pydeflate_
49
+ invalid_cols = [
50
+ col for col in self.data.columns if not col.startswith("pydeflate_")
51
+ ]
52
+ if invalid_cols:
53
+ raise DataSourceError(
54
+ f"Invalid column names (must start with 'pydeflate_'): {invalid_cols}",
55
+ source=self.name,
56
+ )
57
+
58
+ # Validate schema if available and enabled
59
+ # Note: Schema validation is currently experimental
60
+ # Set environment variable PYDEFLATE_ENABLE_VALIDATION=1 to enable
61
+ import os
62
+
63
+ if os.environ.get("PYDEFLATE_ENABLE_VALIDATION") == "1":
64
+ try:
65
+ from pydeflate.schemas import validate_source_data
66
+
67
+ validate_source_data(self.data, self.name)
68
+ except ImportError:
69
+ # Pandera not available, skip schema validation
70
+ pass
71
+ except Exception as e:
72
+ # Schema validation failed, but don't break for now
73
+ # This allows us to roll out schema validation gradually
74
+ import logging
75
+
76
+ logger = logging.getLogger("pydeflate")
77
+ logger.debug(f"Schema validation skipped for {self.name}: {e}")
30
78
 
31
79
  def lcu_usd_exchange(self) -> pd.DataFrame:
80
+ """Return local currency to USD exchange rates.
81
+
82
+ Returns:
83
+ DataFrame with exchange rate data
84
+
85
+ Raises:
86
+ DataSourceError: If exchange rate data is missing
87
+ """
88
+ if "pydeflate_EXCHANGE" not in self.data.columns:
89
+ raise DataSourceError(
90
+ "Exchange rate data (pydeflate_EXCHANGE) not available",
91
+ source=self.name,
92
+ )
32
93
  return self.data.filter(self._idx + ["pydeflate_EXCHANGE"])
33
94
 
34
95
  def price_deflator(self, kind: AvailableDeflators = "NGDP_D") -> pd.DataFrame:
35
-
36
- if f"pydeflate_{kind}" not in self.data.columns:
37
- raise ValueError(f"No deflator data found for {kind} in {self.name} data.")
38
-
39
- return self.data.filter(self._idx + [f"pydeflate_{kind}"])
96
+ """Return price deflator data for specified kind.
97
+
98
+ Args:
99
+ kind: Type of deflator (e.g., 'NGDP_D', 'CPI')
100
+
101
+ Returns:
102
+ DataFrame with deflator data
103
+
104
+ Raises:
105
+ ConfigurationError: If deflator kind not available for this source
106
+ """
107
+ column_name = f"pydeflate_{kind}"
108
+ if column_name not in self.data.columns:
109
+ available = [
110
+ col.replace("pydeflate_", "")
111
+ for col in self.data.columns
112
+ if col.startswith("pydeflate_") and col not in self._idx
113
+ ]
114
+ raise ConfigurationError(
115
+ f"Deflator '{kind}' not available for {self.name}. "
116
+ f"Available deflators: {', '.join(available)}",
117
+ parameter="kind",
118
+ )
119
+
120
+ return self.data.filter(self._idx + [column_name])
40
121
 
41
122
 
42
123
  class IMF(Source):
@@ -3,7 +3,7 @@ from functools import wraps
3
3
  import pandas as pd
4
4
 
5
5
  from pydeflate.core.api import BaseDeflate
6
- from pydeflate.core.source import DAC, WorldBank, IMF
6
+ from pydeflate.core.source import DAC, IMF, WorldBank
7
7
 
8
8
 
9
9
  def _generate_docstring(source_name: str, price_kind: str) -> str:
@@ -4,7 +4,7 @@ import pandas as pd
4
4
  from pandas.util._decorators import deprecate_kwarg
5
5
 
6
6
  from pydeflate.core.api import BaseDeflate
7
- from pydeflate.core.source import DAC, WorldBank, IMF
7
+ from pydeflate.core.source import DAC, IMF, WorldBank
8
8
 
9
9
 
10
10
  @deprecate_kwarg(old_arg_name="method", new_arg_name="deflator_method")
@@ -0,0 +1,166 @@
1
+ """Custom exception hierarchy for pydeflate.
2
+
3
+ This module defines specific exception types that allow users to handle
4
+ different failure modes appropriately (e.g., retry on network errors,
5
+ fail fast on validation errors).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+
11
+ class PydeflateError(Exception):
12
+ """Base exception for all pydeflate errors.
13
+
14
+ All exceptions raised by pydeflate inherit from this class,
15
+ making it easy to catch all pydeflate-specific errors.
16
+ """
17
+
18
+ pass
19
+
20
+
21
+ class DataSourceError(PydeflateError):
22
+ """Raised when there's an issue with a data source.
23
+
24
+ This is a base class for all data source related errors.
25
+ """
26
+
27
+ def __init__(self, message: str, source: str | None = None):
28
+ """Initialize DataSourceError.
29
+
30
+ Args:
31
+ message: Description of the error
32
+ source: Name of the data source (e.g., 'IMF', 'World Bank')
33
+ """
34
+ self.source = source
35
+ super().__init__(f"[{source}] {message}" if source else message)
36
+
37
+
38
+ class NetworkError(DataSourceError):
39
+ """Raised when network operations fail.
40
+
41
+ This typically indicates a transient error that might succeed on retry.
42
+ """
43
+
44
+ pass
45
+
46
+
47
+ class SchemaValidationError(DataSourceError):
48
+ """Raised when data doesn't match expected schema.
49
+
50
+ This indicates a problem with the data structure, either from:
51
+ - External API changes
52
+ - Corrupted downloaded data
53
+ - User input with wrong columns/types
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ message: str,
59
+ source: str | None = None,
60
+ expected_schema: dict | None = None,
61
+ actual_schema: dict | None = None,
62
+ ):
63
+ """Initialize SchemaValidationError.
64
+
65
+ Args:
66
+ message: Description of validation failure
67
+ source: Name of the data source
68
+ expected_schema: Expected column types/names
69
+ actual_schema: Actual column types/names found
70
+ """
71
+ self.expected_schema = expected_schema
72
+ self.actual_schema = actual_schema
73
+ super().__init__(message, source)
74
+
75
+
76
+ class CacheError(PydeflateError):
77
+ """Raised when cache operations fail.
78
+
79
+ Examples:
80
+ - Unable to write to cache directory
81
+ - Corrupted cache files
82
+ - Lock file acquisition timeout
83
+ """
84
+
85
+ def __init__(self, message: str, cache_path: str | None = None):
86
+ """Initialize CacheError.
87
+
88
+ Args:
89
+ message: Description of cache error
90
+ cache_path: Path to the cache file/directory involved
91
+ """
92
+ self.cache_path = cache_path
93
+ super().__init__(
94
+ f"Cache error at {cache_path}: {message}" if cache_path else message
95
+ )
96
+
97
+
98
+ class ConfigurationError(PydeflateError):
99
+ """Raised when configuration parameters are invalid.
100
+
101
+ Examples:
102
+ - Invalid currency code
103
+ - Base year out of range
104
+ - Missing required columns in user data
105
+ - Conflicting parameter combinations
106
+ """
107
+
108
+ def __init__(self, message: str, parameter: str | None = None):
109
+ """Initialize ConfigurationError.
110
+
111
+ Args:
112
+ message: Description of configuration issue
113
+ parameter: Name of the problematic parameter
114
+ """
115
+ self.parameter = parameter
116
+ super().__init__(
117
+ f"Invalid configuration for '{parameter}': {message}"
118
+ if parameter
119
+ else message
120
+ )
121
+
122
+
123
+ class MissingDataError(PydeflateError):
124
+ """Raised when required deflator or exchange data is unavailable.
125
+
126
+ This occurs when:
127
+ - Requested country/year combination has no data in the source
128
+ - Data gaps in historical records
129
+ - Future years beyond available estimates
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ message: str,
135
+ missing_entities: dict[str, list[int]] | None = None,
136
+ ):
137
+ """Initialize MissingDataError.
138
+
139
+ Args:
140
+ message: Description of missing data
141
+ missing_entities: Dict mapping entity codes to missing years
142
+ """
143
+ self.missing_entities = missing_entities
144
+ super().__init__(message)
145
+
146
+
147
+ class PluginError(PydeflateError):
148
+ """Raised when plugin registration or loading fails.
149
+
150
+ Examples:
151
+ - Plugin doesn't implement required protocol
152
+ - Plugin name conflicts with existing source
153
+ - Plugin initialization fails
154
+ """
155
+
156
+ def __init__(self, message: str, plugin_name: str | None = None):
157
+ """Initialize PluginError.
158
+
159
+ Args:
160
+ message: Description of plugin error
161
+ plugin_name: Name of the plugin that failed
162
+ """
163
+ self.plugin_name = plugin_name
164
+ super().__init__(
165
+ f"Plugin '{plugin_name}' error: {message}" if plugin_name else message
166
+ )
@@ -3,7 +3,7 @@ from functools import wraps
3
3
  import pandas as pd
4
4
 
5
5
  from pydeflate.core.api import BaseExchange
6
- from pydeflate.core.source import DAC, WorldBank, IMF, WorldBankPPP
6
+ from pydeflate.core.source import DAC, IMF, WorldBank, WorldBankPPP
7
7
 
8
8
 
9
9
  def _generate_docstring(source_name: str) -> str:
pydeflate/plugins.py ADDED
@@ -0,0 +1,289 @@
1
+ """Plugin system for custom data sources.
2
+
3
+ This module provides a registry-based plugin architecture that allows
4
+ users to register custom data sources without modifying pydeflate's code.
5
+
6
+ Example:
7
+ >>> from pydeflate.plugins import register_source
8
+ >>> from pydeflate.protocols import SourceProtocol
9
+ >>>
10
+ >>> @register_source("my_custom_source")
11
+ >>> class MyCustomSource:
12
+ ... def __init__(self, update: bool = False):
13
+ ... self.name = "my_custom_source"
14
+ ... self.data = load_my_data(update)
15
+ ... self._idx = ["pydeflate_year", "pydeflate_entity_code", "pydeflate_iso3"]
16
+ ...
17
+ ... def lcu_usd_exchange(self): ...
18
+ ... def price_deflator(self, kind): ...
19
+ ... def validate(self): ...
20
+ >>>
21
+ >>> # Now use it:
22
+ >>> result = deflate_with_source("my_custom_source", data=df, ...)
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import Any, Callable, TypeVar
28
+
29
+ from pydeflate.exceptions import PluginError
30
+ from pydeflate.protocols import SourceProtocol
31
+
32
+ # Type variable for source classes
33
+ SourceType = TypeVar("SourceType", bound=type)
34
+
35
+
36
+ class SourceRegistry:
37
+ """Registry for data source implementations.
38
+
39
+ This class maintains a mapping of source names to their implementation
40
+ classes. It provides methods to register, retrieve, and list sources.
41
+ """
42
+
43
+ def __init__(self):
44
+ """Initialize an empty registry."""
45
+ self._sources: dict[str, type] = {}
46
+ self._factories: dict[str, Callable[..., SourceProtocol]] = {}
47
+
48
+ def register(
49
+ self,
50
+ name: str,
51
+ source_class: type | None = None,
52
+ factory: Callable[..., SourceProtocol] | None = None,
53
+ *,
54
+ override: bool = False,
55
+ ) -> None:
56
+ """Register a data source.
57
+
58
+ Args:
59
+ name: Unique name for this source (e.g., 'IMF', 'WorldBank')
60
+ source_class: Class that implements SourceProtocol
61
+ factory: Factory function that returns a SourceProtocol instance
62
+ override: If True, allow replacing existing sources
63
+
64
+ Raises:
65
+ PluginError: If name already registered and override=False
66
+ PluginError: If neither source_class nor factory is provided
67
+ PluginError: If source doesn't implement SourceProtocol
68
+ """
69
+ # Validation
70
+ if source_class is None and factory is None:
71
+ raise PluginError(
72
+ "Must provide either source_class or factory",
73
+ plugin_name=name,
74
+ )
75
+
76
+ if not override and name in self._sources:
77
+ raise PluginError(
78
+ f"Source '{name}' already registered. Use override=True to replace.",
79
+ plugin_name=name,
80
+ )
81
+
82
+ # Check protocol conformance if source_class provided
83
+ if source_class is not None:
84
+ if not self._check_protocol_conformance(source_class):
85
+ raise PluginError(
86
+ f"Source class must implement SourceProtocol",
87
+ plugin_name=name,
88
+ )
89
+ self._sources[name] = source_class
90
+
91
+ if factory is not None:
92
+ self._factories[name] = factory
93
+
94
+ def _check_protocol_conformance(self, source_class: type) -> bool:
95
+ """Check if a class implements SourceProtocol.
96
+
97
+ Args:
98
+ source_class: Class to check
99
+
100
+ Returns:
101
+ True if class implements the protocol
102
+ """
103
+ required_methods = ["lcu_usd_exchange", "price_deflator", "validate"]
104
+ required_attrs = ["name", "data", "_idx"]
105
+
106
+ for method in required_methods:
107
+ if not hasattr(source_class, method):
108
+ return False
109
+ if not callable(getattr(source_class, method)):
110
+ return False
111
+
112
+ # Check that class can be instantiated and has required attributes
113
+ # Note: We can't fully check this without instantiation,
114
+ # so we just check the class has these as class attributes or in __init__
115
+ # This is a best-effort check
116
+ return True
117
+
118
+ def get(self, name: str, **kwargs) -> SourceProtocol:
119
+ """Get a source instance by name.
120
+
121
+ Args:
122
+ name: Name of the registered source
123
+ **kwargs: Keyword arguments to pass to source constructor
124
+
125
+ Returns:
126
+ Instance of the source
127
+
128
+ Raises:
129
+ PluginError: If source not found
130
+ """
131
+ # Try factory first
132
+ if name in self._factories:
133
+ try:
134
+ return self._factories[name](**kwargs)
135
+ except Exception as e:
136
+ raise PluginError(
137
+ f"Factory function failed: {e}",
138
+ plugin_name=name,
139
+ ) from e
140
+
141
+ # Try class
142
+ if name in self._sources:
143
+ try:
144
+ return self._sources[name](**kwargs)
145
+ except Exception as e:
146
+ raise PluginError(
147
+ f"Source instantiation failed: {e}",
148
+ plugin_name=name,
149
+ ) from e
150
+
151
+ raise PluginError(
152
+ f"Source '{name}' not found. Available sources: {self.list_sources()}",
153
+ plugin_name=name,
154
+ )
155
+
156
+ def list_sources(self) -> list[str]:
157
+ """List all registered source names.
158
+
159
+ Returns:
160
+ List of source names
161
+ """
162
+ all_names = set(self._sources.keys()) | set(self._factories.keys())
163
+ return sorted(all_names)
164
+
165
+ def is_registered(self, name: str) -> bool:
166
+ """Check if a source is registered.
167
+
168
+ Args:
169
+ name: Source name to check
170
+
171
+ Returns:
172
+ True if registered
173
+ """
174
+ return name in self._sources or name in self._factories
175
+
176
+ def unregister(self, name: str) -> None:
177
+ """Remove a source from the registry.
178
+
179
+ Args:
180
+ name: Source name to remove
181
+
182
+ Raises:
183
+ PluginError: If source not found
184
+ """
185
+ if not self.is_registered(name):
186
+ raise PluginError(
187
+ f"Cannot unregister '{name}': not found",
188
+ plugin_name=name,
189
+ )
190
+
191
+ self._sources.pop(name, None)
192
+ self._factories.pop(name, None)
193
+
194
+
195
+ # Global registry instance
196
+ _global_registry = SourceRegistry()
197
+
198
+
199
+ def register_source(
200
+ name: str,
201
+ *,
202
+ override: bool = False,
203
+ ) -> Callable[[SourceType], SourceType]:
204
+ """Decorator to register a source class.
205
+
206
+ Example:
207
+ >>> @register_source("my_source")
208
+ ... class MySource:
209
+ ... def __init__(self, update: bool = False):
210
+ ... self.name = "my_source"
211
+ ... ...
212
+ ...
213
+ ... def lcu_usd_exchange(self): ...
214
+ ... def price_deflator(self, kind): ...
215
+
216
+ Args:
217
+ name: Unique name for this source
218
+ override: Allow replacing existing sources
219
+
220
+ Returns:
221
+ Decorator function
222
+ """
223
+
224
+ def decorator(cls: SourceType) -> SourceType:
225
+ _global_registry.register(name, source_class=cls, override=override)
226
+ return cls
227
+
228
+ return decorator
229
+
230
+
231
+ def get_source(name: str, **kwargs) -> SourceProtocol:
232
+ """Get a registered source instance.
233
+
234
+ Args:
235
+ name: Name of the source
236
+ **kwargs: Arguments to pass to source constructor
237
+
238
+ Returns:
239
+ Source instance
240
+
241
+ Raises:
242
+ PluginError: If source not found or instantiation fails
243
+ """
244
+ return _global_registry.get(name, **kwargs)
245
+
246
+
247
+ def list_sources() -> list[str]:
248
+ """List all available sources.
249
+
250
+ Returns:
251
+ Sorted list of source names
252
+ """
253
+ return _global_registry.list_sources()
254
+
255
+
256
+ def is_source_registered(name: str) -> bool:
257
+ """Check if a source is registered.
258
+
259
+ Args:
260
+ name: Source name
261
+
262
+ Returns:
263
+ True if registered
264
+ """
265
+ return _global_registry.is_registered(name)
266
+
267
+
268
+ # Pre-register built-in sources
269
+ def _register_builtin_sources():
270
+ """Register pydeflate's built-in sources."""
271
+ from pydeflate.core.source import DAC, IMF, WorldBank, WorldBankPPP
272
+
273
+ _global_registry.register("IMF", source_class=IMF, override=True)
274
+ _global_registry.register("World Bank", source_class=WorldBank, override=True)
275
+ _global_registry.register(
276
+ "World Bank PPP", source_class=WorldBankPPP, override=True
277
+ )
278
+ _global_registry.register("DAC", source_class=DAC, override=True)
279
+
280
+ # Aliases for convenience
281
+ _global_registry.register("imf", source_class=IMF, override=True)
282
+ _global_registry.register("wb", source_class=WorldBank, override=True)
283
+ _global_registry.register("worldbank", source_class=WorldBank, override=True)
284
+ _global_registry.register("dac", source_class=DAC, override=True)
285
+ _global_registry.register("oecd", source_class=DAC, override=True)
286
+
287
+
288
+ # Register built-in sources when module is imported
289
+ _register_builtin_sources()