pydeflate 2.1.3__py3-none-any.whl → 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydeflate/__init__.py +64 -20
- pydeflate/cache.py +139 -0
- pydeflate/constants.py +121 -0
- pydeflate/context.py +211 -0
- pydeflate/core/api.py +33 -11
- pydeflate/core/source.py +92 -11
- pydeflate/deflate/deflators.py +1 -1
- pydeflate/deflate/legacy_deflate.py +1 -1
- pydeflate/exceptions.py +166 -0
- pydeflate/exchange/exchangers.py +1 -1
- pydeflate/plugins.py +289 -0
- pydeflate/protocols.py +168 -0
- pydeflate/pydeflate_config.py +77 -6
- pydeflate/schemas.py +297 -0
- pydeflate/sources/common.py +59 -107
- pydeflate/sources/dac.py +39 -52
- pydeflate/sources/imf.py +23 -39
- pydeflate/sources/world_bank.py +44 -117
- pydeflate/utils.py +14 -9
- {pydeflate-2.1.3.dist-info → pydeflate-2.2.0.dist-info}/METADATA +119 -18
- pydeflate-2.2.0.dist-info/RECORD +32 -0
- pydeflate-2.2.0.dist-info/WHEEL +4 -0
- {pydeflate-2.1.3.dist-info → pydeflate-2.2.0.dist-info/licenses}/LICENSE +1 -1
- pydeflate-2.1.3.dist-info/RECORD +0 -25
- pydeflate-2.1.3.dist-info/WHEEL +0 -4
pydeflate/core/source.py
CHANGED
|
@@ -2,6 +2,7 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
|
|
5
|
+
from pydeflate.exceptions import ConfigurationError, DataSourceError
|
|
5
6
|
from pydeflate.sources.common import AvailableDeflators
|
|
6
7
|
from pydeflate.sources.dac import read_dac
|
|
7
8
|
from pydeflate.sources.imf import read_weo
|
|
@@ -10,6 +11,12 @@ from pydeflate.sources.world_bank import read_wb, read_wb_lcu_ppp, read_wb_usd_p
|
|
|
10
11
|
|
|
11
12
|
@dataclass
|
|
12
13
|
class Source:
|
|
14
|
+
"""Base class for data sources implementing SourceProtocol.
|
|
15
|
+
|
|
16
|
+
This class handles loading data from external sources, caching,
|
|
17
|
+
and validation. It implements the SourceProtocol interface.
|
|
18
|
+
"""
|
|
19
|
+
|
|
13
20
|
name: str
|
|
14
21
|
reader: callable
|
|
15
22
|
update: bool = False
|
|
@@ -17,26 +24,100 @@ class Source:
|
|
|
17
24
|
_idx = ["pydeflate_year", "pydeflate_entity_code", "pydeflate_iso3"]
|
|
18
25
|
|
|
19
26
|
def __post_init__(self):
|
|
20
|
-
|
|
27
|
+
"""Load and validate data after initialization."""
|
|
28
|
+
try:
|
|
29
|
+
self.data = self.reader(self.update)
|
|
30
|
+
except Exception as e:
|
|
31
|
+
raise DataSourceError(
|
|
32
|
+
f"Failed to load data: {e}",
|
|
33
|
+
source=self.name,
|
|
34
|
+
) from e
|
|
35
|
+
|
|
21
36
|
self.validate()
|
|
22
37
|
|
|
23
38
|
def validate(self):
|
|
24
|
-
|
|
25
|
-
raise ValueError(f"No data found for {self.name}")
|
|
39
|
+
"""Validate that source data is properly formatted.
|
|
26
40
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
41
|
+
Raises:
|
|
42
|
+
DataSourceError: If data is empty or improperly formatted
|
|
43
|
+
SchemaValidationError: If data doesn't match expected schema
|
|
44
|
+
"""
|
|
45
|
+
if self.data.empty:
|
|
46
|
+
raise DataSourceError(f"No data found", source=self.name)
|
|
47
|
+
|
|
48
|
+
# Check all columns start with pydeflate_
|
|
49
|
+
invalid_cols = [
|
|
50
|
+
col for col in self.data.columns if not col.startswith("pydeflate_")
|
|
51
|
+
]
|
|
52
|
+
if invalid_cols:
|
|
53
|
+
raise DataSourceError(
|
|
54
|
+
f"Invalid column names (must start with 'pydeflate_'): {invalid_cols}",
|
|
55
|
+
source=self.name,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Validate schema if available and enabled
|
|
59
|
+
# Note: Schema validation is currently experimental
|
|
60
|
+
# Set environment variable PYDEFLATE_ENABLE_VALIDATION=1 to enable
|
|
61
|
+
import os
|
|
62
|
+
|
|
63
|
+
if os.environ.get("PYDEFLATE_ENABLE_VALIDATION") == "1":
|
|
64
|
+
try:
|
|
65
|
+
from pydeflate.schemas import validate_source_data
|
|
66
|
+
|
|
67
|
+
validate_source_data(self.data, self.name)
|
|
68
|
+
except ImportError:
|
|
69
|
+
# Pandera not available, skip schema validation
|
|
70
|
+
pass
|
|
71
|
+
except Exception as e:
|
|
72
|
+
# Schema validation failed, but don't break for now
|
|
73
|
+
# This allows us to roll out schema validation gradually
|
|
74
|
+
import logging
|
|
75
|
+
|
|
76
|
+
logger = logging.getLogger("pydeflate")
|
|
77
|
+
logger.debug(f"Schema validation skipped for {self.name}: {e}")
|
|
30
78
|
|
|
31
79
|
def lcu_usd_exchange(self) -> pd.DataFrame:
|
|
80
|
+
"""Return local currency to USD exchange rates.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
DataFrame with exchange rate data
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
DataSourceError: If exchange rate data is missing
|
|
87
|
+
"""
|
|
88
|
+
if "pydeflate_EXCHANGE" not in self.data.columns:
|
|
89
|
+
raise DataSourceError(
|
|
90
|
+
"Exchange rate data (pydeflate_EXCHANGE) not available",
|
|
91
|
+
source=self.name,
|
|
92
|
+
)
|
|
32
93
|
return self.data.filter(self._idx + ["pydeflate_EXCHANGE"])
|
|
33
94
|
|
|
34
95
|
def price_deflator(self, kind: AvailableDeflators = "NGDP_D") -> pd.DataFrame:
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
96
|
+
"""Return price deflator data for specified kind.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
kind: Type of deflator (e.g., 'NGDP_D', 'CPI')
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
DataFrame with deflator data
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ConfigurationError: If deflator kind not available for this source
|
|
106
|
+
"""
|
|
107
|
+
column_name = f"pydeflate_{kind}"
|
|
108
|
+
if column_name not in self.data.columns:
|
|
109
|
+
available = [
|
|
110
|
+
col.replace("pydeflate_", "")
|
|
111
|
+
for col in self.data.columns
|
|
112
|
+
if col.startswith("pydeflate_") and col not in self._idx
|
|
113
|
+
]
|
|
114
|
+
raise ConfigurationError(
|
|
115
|
+
f"Deflator '{kind}' not available for {self.name}. "
|
|
116
|
+
f"Available deflators: {', '.join(available)}",
|
|
117
|
+
parameter="kind",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return self.data.filter(self._idx + [column_name])
|
|
40
121
|
|
|
41
122
|
|
|
42
123
|
class IMF(Source):
|
pydeflate/deflate/deflators.py
CHANGED
|
@@ -3,7 +3,7 @@ from functools import wraps
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
|
|
5
5
|
from pydeflate.core.api import BaseDeflate
|
|
6
|
-
from pydeflate.core.source import DAC,
|
|
6
|
+
from pydeflate.core.source import DAC, IMF, WorldBank
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def _generate_docstring(source_name: str, price_kind: str) -> str:
|
|
@@ -4,7 +4,7 @@ import pandas as pd
|
|
|
4
4
|
from pandas.util._decorators import deprecate_kwarg
|
|
5
5
|
|
|
6
6
|
from pydeflate.core.api import BaseDeflate
|
|
7
|
-
from pydeflate.core.source import DAC,
|
|
7
|
+
from pydeflate.core.source import DAC, IMF, WorldBank
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
@deprecate_kwarg(old_arg_name="method", new_arg_name="deflator_method")
|
pydeflate/exceptions.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Custom exception hierarchy for pydeflate.
|
|
2
|
+
|
|
3
|
+
This module defines specific exception types that allow users to handle
|
|
4
|
+
different failure modes appropriately (e.g., retry on network errors,
|
|
5
|
+
fail fast on validation errors).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PydeflateError(Exception):
|
|
12
|
+
"""Base exception for all pydeflate errors.
|
|
13
|
+
|
|
14
|
+
All exceptions raised by pydeflate inherit from this class,
|
|
15
|
+
making it easy to catch all pydeflate-specific errors.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DataSourceError(PydeflateError):
|
|
22
|
+
"""Raised when there's an issue with a data source.
|
|
23
|
+
|
|
24
|
+
This is a base class for all data source related errors.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, message: str, source: str | None = None):
|
|
28
|
+
"""Initialize DataSourceError.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
message: Description of the error
|
|
32
|
+
source: Name of the data source (e.g., 'IMF', 'World Bank')
|
|
33
|
+
"""
|
|
34
|
+
self.source = source
|
|
35
|
+
super().__init__(f"[{source}] {message}" if source else message)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class NetworkError(DataSourceError):
|
|
39
|
+
"""Raised when network operations fail.
|
|
40
|
+
|
|
41
|
+
This typically indicates a transient error that might succeed on retry.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class SchemaValidationError(DataSourceError):
|
|
48
|
+
"""Raised when data doesn't match expected schema.
|
|
49
|
+
|
|
50
|
+
This indicates a problem with the data structure, either from:
|
|
51
|
+
- External API changes
|
|
52
|
+
- Corrupted downloaded data
|
|
53
|
+
- User input with wrong columns/types
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
message: str,
|
|
59
|
+
source: str | None = None,
|
|
60
|
+
expected_schema: dict | None = None,
|
|
61
|
+
actual_schema: dict | None = None,
|
|
62
|
+
):
|
|
63
|
+
"""Initialize SchemaValidationError.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
message: Description of validation failure
|
|
67
|
+
source: Name of the data source
|
|
68
|
+
expected_schema: Expected column types/names
|
|
69
|
+
actual_schema: Actual column types/names found
|
|
70
|
+
"""
|
|
71
|
+
self.expected_schema = expected_schema
|
|
72
|
+
self.actual_schema = actual_schema
|
|
73
|
+
super().__init__(message, source)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class CacheError(PydeflateError):
|
|
77
|
+
"""Raised when cache operations fail.
|
|
78
|
+
|
|
79
|
+
Examples:
|
|
80
|
+
- Unable to write to cache directory
|
|
81
|
+
- Corrupted cache files
|
|
82
|
+
- Lock file acquisition timeout
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, message: str, cache_path: str | None = None):
|
|
86
|
+
"""Initialize CacheError.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
message: Description of cache error
|
|
90
|
+
cache_path: Path to the cache file/directory involved
|
|
91
|
+
"""
|
|
92
|
+
self.cache_path = cache_path
|
|
93
|
+
super().__init__(
|
|
94
|
+
f"Cache error at {cache_path}: {message}" if cache_path else message
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ConfigurationError(PydeflateError):
|
|
99
|
+
"""Raised when configuration parameters are invalid.
|
|
100
|
+
|
|
101
|
+
Examples:
|
|
102
|
+
- Invalid currency code
|
|
103
|
+
- Base year out of range
|
|
104
|
+
- Missing required columns in user data
|
|
105
|
+
- Conflicting parameter combinations
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, message: str, parameter: str | None = None):
|
|
109
|
+
"""Initialize ConfigurationError.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
message: Description of configuration issue
|
|
113
|
+
parameter: Name of the problematic parameter
|
|
114
|
+
"""
|
|
115
|
+
self.parameter = parameter
|
|
116
|
+
super().__init__(
|
|
117
|
+
f"Invalid configuration for '{parameter}': {message}"
|
|
118
|
+
if parameter
|
|
119
|
+
else message
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class MissingDataError(PydeflateError):
|
|
124
|
+
"""Raised when required deflator or exchange data is unavailable.
|
|
125
|
+
|
|
126
|
+
This occurs when:
|
|
127
|
+
- Requested country/year combination has no data in the source
|
|
128
|
+
- Data gaps in historical records
|
|
129
|
+
- Future years beyond available estimates
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __init__(
|
|
133
|
+
self,
|
|
134
|
+
message: str,
|
|
135
|
+
missing_entities: dict[str, list[int]] | None = None,
|
|
136
|
+
):
|
|
137
|
+
"""Initialize MissingDataError.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
message: Description of missing data
|
|
141
|
+
missing_entities: Dict mapping entity codes to missing years
|
|
142
|
+
"""
|
|
143
|
+
self.missing_entities = missing_entities
|
|
144
|
+
super().__init__(message)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class PluginError(PydeflateError):
|
|
148
|
+
"""Raised when plugin registration or loading fails.
|
|
149
|
+
|
|
150
|
+
Examples:
|
|
151
|
+
- Plugin doesn't implement required protocol
|
|
152
|
+
- Plugin name conflicts with existing source
|
|
153
|
+
- Plugin initialization fails
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, message: str, plugin_name: str | None = None):
|
|
157
|
+
"""Initialize PluginError.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
message: Description of plugin error
|
|
161
|
+
plugin_name: Name of the plugin that failed
|
|
162
|
+
"""
|
|
163
|
+
self.plugin_name = plugin_name
|
|
164
|
+
super().__init__(
|
|
165
|
+
f"Plugin '{plugin_name}' error: {message}" if plugin_name else message
|
|
166
|
+
)
|
pydeflate/exchange/exchangers.py
CHANGED
|
@@ -3,7 +3,7 @@ from functools import wraps
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
|
|
5
5
|
from pydeflate.core.api import BaseExchange
|
|
6
|
-
from pydeflate.core.source import DAC,
|
|
6
|
+
from pydeflate.core.source import DAC, IMF, WorldBank, WorldBankPPP
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def _generate_docstring(source_name: str) -> str:
|
pydeflate/plugins.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""Plugin system for custom data sources.
|
|
2
|
+
|
|
3
|
+
This module provides a registry-based plugin architecture that allows
|
|
4
|
+
users to register custom data sources without modifying pydeflate's code.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
>>> from pydeflate.plugins import register_source
|
|
8
|
+
>>> from pydeflate.protocols import SourceProtocol
|
|
9
|
+
>>>
|
|
10
|
+
>>> @register_source("my_custom_source")
|
|
11
|
+
>>> class MyCustomSource:
|
|
12
|
+
... def __init__(self, update: bool = False):
|
|
13
|
+
... self.name = "my_custom_source"
|
|
14
|
+
... self.data = load_my_data(update)
|
|
15
|
+
... self._idx = ["pydeflate_year", "pydeflate_entity_code", "pydeflate_iso3"]
|
|
16
|
+
...
|
|
17
|
+
... def lcu_usd_exchange(self): ...
|
|
18
|
+
... def price_deflator(self, kind): ...
|
|
19
|
+
... def validate(self): ...
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Now use it:
|
|
22
|
+
>>> result = deflate_with_source("my_custom_source", data=df, ...)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from typing import Any, Callable, TypeVar
|
|
28
|
+
|
|
29
|
+
from pydeflate.exceptions import PluginError
|
|
30
|
+
from pydeflate.protocols import SourceProtocol
|
|
31
|
+
|
|
32
|
+
# Type variable for source classes
|
|
33
|
+
SourceType = TypeVar("SourceType", bound=type)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SourceRegistry:
|
|
37
|
+
"""Registry for data source implementations.
|
|
38
|
+
|
|
39
|
+
This class maintains a mapping of source names to their implementation
|
|
40
|
+
classes. It provides methods to register, retrieve, and list sources.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
"""Initialize an empty registry."""
|
|
45
|
+
self._sources: dict[str, type] = {}
|
|
46
|
+
self._factories: dict[str, Callable[..., SourceProtocol]] = {}
|
|
47
|
+
|
|
48
|
+
def register(
|
|
49
|
+
self,
|
|
50
|
+
name: str,
|
|
51
|
+
source_class: type | None = None,
|
|
52
|
+
factory: Callable[..., SourceProtocol] | None = None,
|
|
53
|
+
*,
|
|
54
|
+
override: bool = False,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Register a data source.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
name: Unique name for this source (e.g., 'IMF', 'WorldBank')
|
|
60
|
+
source_class: Class that implements SourceProtocol
|
|
61
|
+
factory: Factory function that returns a SourceProtocol instance
|
|
62
|
+
override: If True, allow replacing existing sources
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
PluginError: If name already registered and override=False
|
|
66
|
+
PluginError: If neither source_class nor factory is provided
|
|
67
|
+
PluginError: If source doesn't implement SourceProtocol
|
|
68
|
+
"""
|
|
69
|
+
# Validation
|
|
70
|
+
if source_class is None and factory is None:
|
|
71
|
+
raise PluginError(
|
|
72
|
+
"Must provide either source_class or factory",
|
|
73
|
+
plugin_name=name,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if not override and name in self._sources:
|
|
77
|
+
raise PluginError(
|
|
78
|
+
f"Source '{name}' already registered. Use override=True to replace.",
|
|
79
|
+
plugin_name=name,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Check protocol conformance if source_class provided
|
|
83
|
+
if source_class is not None:
|
|
84
|
+
if not self._check_protocol_conformance(source_class):
|
|
85
|
+
raise PluginError(
|
|
86
|
+
f"Source class must implement SourceProtocol",
|
|
87
|
+
plugin_name=name,
|
|
88
|
+
)
|
|
89
|
+
self._sources[name] = source_class
|
|
90
|
+
|
|
91
|
+
if factory is not None:
|
|
92
|
+
self._factories[name] = factory
|
|
93
|
+
|
|
94
|
+
def _check_protocol_conformance(self, source_class: type) -> bool:
|
|
95
|
+
"""Check if a class implements SourceProtocol.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
source_class: Class to check
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if class implements the protocol
|
|
102
|
+
"""
|
|
103
|
+
required_methods = ["lcu_usd_exchange", "price_deflator", "validate"]
|
|
104
|
+
required_attrs = ["name", "data", "_idx"]
|
|
105
|
+
|
|
106
|
+
for method in required_methods:
|
|
107
|
+
if not hasattr(source_class, method):
|
|
108
|
+
return False
|
|
109
|
+
if not callable(getattr(source_class, method)):
|
|
110
|
+
return False
|
|
111
|
+
|
|
112
|
+
# Check that class can be instantiated and has required attributes
|
|
113
|
+
# Note: We can't fully check this without instantiation,
|
|
114
|
+
# so we just check the class has these as class attributes or in __init__
|
|
115
|
+
# This is a best-effort check
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
def get(self, name: str, **kwargs) -> SourceProtocol:
|
|
119
|
+
"""Get a source instance by name.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
name: Name of the registered source
|
|
123
|
+
**kwargs: Keyword arguments to pass to source constructor
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Instance of the source
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
PluginError: If source not found
|
|
130
|
+
"""
|
|
131
|
+
# Try factory first
|
|
132
|
+
if name in self._factories:
|
|
133
|
+
try:
|
|
134
|
+
return self._factories[name](**kwargs)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
raise PluginError(
|
|
137
|
+
f"Factory function failed: {e}",
|
|
138
|
+
plugin_name=name,
|
|
139
|
+
) from e
|
|
140
|
+
|
|
141
|
+
# Try class
|
|
142
|
+
if name in self._sources:
|
|
143
|
+
try:
|
|
144
|
+
return self._sources[name](**kwargs)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise PluginError(
|
|
147
|
+
f"Source instantiation failed: {e}",
|
|
148
|
+
plugin_name=name,
|
|
149
|
+
) from e
|
|
150
|
+
|
|
151
|
+
raise PluginError(
|
|
152
|
+
f"Source '{name}' not found. Available sources: {self.list_sources()}",
|
|
153
|
+
plugin_name=name,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def list_sources(self) -> list[str]:
|
|
157
|
+
"""List all registered source names.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
List of source names
|
|
161
|
+
"""
|
|
162
|
+
all_names = set(self._sources.keys()) | set(self._factories.keys())
|
|
163
|
+
return sorted(all_names)
|
|
164
|
+
|
|
165
|
+
def is_registered(self, name: str) -> bool:
|
|
166
|
+
"""Check if a source is registered.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
name: Source name to check
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
True if registered
|
|
173
|
+
"""
|
|
174
|
+
return name in self._sources or name in self._factories
|
|
175
|
+
|
|
176
|
+
def unregister(self, name: str) -> None:
|
|
177
|
+
"""Remove a source from the registry.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
name: Source name to remove
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
PluginError: If source not found
|
|
184
|
+
"""
|
|
185
|
+
if not self.is_registered(name):
|
|
186
|
+
raise PluginError(
|
|
187
|
+
f"Cannot unregister '{name}': not found",
|
|
188
|
+
plugin_name=name,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self._sources.pop(name, None)
|
|
192
|
+
self._factories.pop(name, None)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# Global registry instance
|
|
196
|
+
_global_registry = SourceRegistry()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def register_source(
|
|
200
|
+
name: str,
|
|
201
|
+
*,
|
|
202
|
+
override: bool = False,
|
|
203
|
+
) -> Callable[[SourceType], SourceType]:
|
|
204
|
+
"""Decorator to register a source class.
|
|
205
|
+
|
|
206
|
+
Example:
|
|
207
|
+
>>> @register_source("my_source")
|
|
208
|
+
... class MySource:
|
|
209
|
+
... def __init__(self, update: bool = False):
|
|
210
|
+
... self.name = "my_source"
|
|
211
|
+
... ...
|
|
212
|
+
...
|
|
213
|
+
... def lcu_usd_exchange(self): ...
|
|
214
|
+
... def price_deflator(self, kind): ...
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
name: Unique name for this source
|
|
218
|
+
override: Allow replacing existing sources
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Decorator function
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
def decorator(cls: SourceType) -> SourceType:
|
|
225
|
+
_global_registry.register(name, source_class=cls, override=override)
|
|
226
|
+
return cls
|
|
227
|
+
|
|
228
|
+
return decorator
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def get_source(name: str, **kwargs) -> SourceProtocol:
|
|
232
|
+
"""Get a registered source instance.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
name: Name of the source
|
|
236
|
+
**kwargs: Arguments to pass to source constructor
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Source instance
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
PluginError: If source not found or instantiation fails
|
|
243
|
+
"""
|
|
244
|
+
return _global_registry.get(name, **kwargs)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def list_sources() -> list[str]:
|
|
248
|
+
"""List all available sources.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Sorted list of source names
|
|
252
|
+
"""
|
|
253
|
+
return _global_registry.list_sources()
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def is_source_registered(name: str) -> bool:
|
|
257
|
+
"""Check if a source is registered.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
name: Source name
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
True if registered
|
|
264
|
+
"""
|
|
265
|
+
return _global_registry.is_registered(name)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# Pre-register built-in sources
|
|
269
|
+
def _register_builtin_sources():
|
|
270
|
+
"""Register pydeflate's built-in sources."""
|
|
271
|
+
from pydeflate.core.source import DAC, IMF, WorldBank, WorldBankPPP
|
|
272
|
+
|
|
273
|
+
_global_registry.register("IMF", source_class=IMF, override=True)
|
|
274
|
+
_global_registry.register("World Bank", source_class=WorldBank, override=True)
|
|
275
|
+
_global_registry.register(
|
|
276
|
+
"World Bank PPP", source_class=WorldBankPPP, override=True
|
|
277
|
+
)
|
|
278
|
+
_global_registry.register("DAC", source_class=DAC, override=True)
|
|
279
|
+
|
|
280
|
+
# Aliases for convenience
|
|
281
|
+
_global_registry.register("imf", source_class=IMF, override=True)
|
|
282
|
+
_global_registry.register("wb", source_class=WorldBank, override=True)
|
|
283
|
+
_global_registry.register("worldbank", source_class=WorldBank, override=True)
|
|
284
|
+
_global_registry.register("dac", source_class=DAC, override=True)
|
|
285
|
+
_global_registry.register("oecd", source_class=DAC, override=True)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# Register built-in sources when module is imported
|
|
289
|
+
_register_builtin_sources()
|