pydeflate 2.1.3__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydeflate/__init__.py +92 -20
- pydeflate/cache.py +139 -0
- pydeflate/constants.py +121 -0
- pydeflate/context.py +211 -0
- pydeflate/core/api.py +33 -11
- pydeflate/core/source.py +92 -11
- pydeflate/deflate/deflators.py +1 -1
- pydeflate/deflate/get_deflators.py +233 -0
- pydeflate/deflate/legacy_deflate.py +1 -1
- pydeflate/exceptions.py +166 -0
- pydeflate/exchange/exchangers.py +1 -1
- pydeflate/exchange/get_rates.py +207 -0
- pydeflate/plugins.py +289 -0
- pydeflate/protocols.py +168 -0
- pydeflate/pydeflate_config.py +77 -6
- pydeflate/schemas.py +297 -0
- pydeflate/sources/common.py +59 -107
- pydeflate/sources/dac.py +39 -52
- pydeflate/sources/imf.py +23 -39
- pydeflate/sources/world_bank.py +44 -117
- pydeflate/utils.py +14 -9
- {pydeflate-2.1.3.dist-info → pydeflate-2.3.0.dist-info}/METADATA +251 -18
- pydeflate-2.3.0.dist-info/RECORD +34 -0
- pydeflate-2.3.0.dist-info/WHEEL +4 -0
- {pydeflate-2.1.3.dist-info → pydeflate-2.3.0.dist-info/licenses}/LICENSE +1 -1
- pydeflate-2.1.3.dist-info/RECORD +0 -25
- pydeflate-2.1.3.dist-info/WHEEL +0 -4
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
from pydeflate.core.api import BaseExchange
|
|
6
|
+
from pydeflate.core.source import DAC, IMF, WorldBank, WorldBankPPP
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _generate_get_rates_docstring(source_name: str) -> str:
|
|
10
|
+
"""Generate docstring for each get exchange rates function."""
|
|
11
|
+
return (
|
|
12
|
+
f"Get exchange rate data from {source_name} without requiring user data.\n\n"
|
|
13
|
+
f"This function returns a DataFrame containing exchange rates for the specified parameters.\n\n"
|
|
14
|
+
"Args:\n"
|
|
15
|
+
" source_currency (str, optional): The source currency code. Defaults to 'USA'.\n"
|
|
16
|
+
" target_currency (str, optional): The target currency code. Defaults to 'USA'.\n"
|
|
17
|
+
" countries (list[str] | None, optional): List of country codes to include. If None, returns all. Defaults to None.\n"
|
|
18
|
+
" years (list[int] | range | None, optional): List or range of years to include. If None, returns all. Defaults to None.\n"
|
|
19
|
+
" use_source_codes (bool, optional): Use source-specific entity codes. Defaults to False.\n"
|
|
20
|
+
" update_rates (bool, optional): Update the exchange rate data before retrieval. Defaults to False.\n\n"
|
|
21
|
+
"Returns:\n"
|
|
22
|
+
" pd.DataFrame: DataFrame with columns:\n"
|
|
23
|
+
" - iso_code (or entity_code if use_source_codes=True): Country/entity identifier\n"
|
|
24
|
+
" - year: Year\n"
|
|
25
|
+
" - exchange_rate: Exchange rate from source to target currency\n"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _get_exchange_rates(exchange_source_cls, **fixed_params):
|
|
30
|
+
"""Decorator to create get_exchange_rates wrappers with specific source."""
|
|
31
|
+
|
|
32
|
+
def decorator(func):
|
|
33
|
+
@wraps(func)
|
|
34
|
+
def wrapper(
|
|
35
|
+
*,
|
|
36
|
+
source_currency: str = "USA",
|
|
37
|
+
target_currency: str = "USA",
|
|
38
|
+
countries: list[str] | None = None,
|
|
39
|
+
years: list[int] | range | None = None,
|
|
40
|
+
use_source_codes: bool = False,
|
|
41
|
+
update_rates: bool = False,
|
|
42
|
+
):
|
|
43
|
+
# Apply fixed parameters - no validation needed since these are internally set
|
|
44
|
+
if "target_currency" in fixed_params:
|
|
45
|
+
target_currency = fixed_params["target_currency"]
|
|
46
|
+
|
|
47
|
+
# Initialize the exchange source
|
|
48
|
+
if exchange_source_cls.__name__ == "WorldBankPPP":
|
|
49
|
+
source = exchange_source_cls(
|
|
50
|
+
update=update_rates,
|
|
51
|
+
from_lcu=False if source_currency == "USA" else True,
|
|
52
|
+
)
|
|
53
|
+
source_currency = "LCU" if source_currency == "USA" else source_currency
|
|
54
|
+
else:
|
|
55
|
+
source = exchange_source_cls(update=update_rates)
|
|
56
|
+
|
|
57
|
+
# Create an exchange object
|
|
58
|
+
exchange = BaseExchange(
|
|
59
|
+
exchange_source=source,
|
|
60
|
+
source_currency=source_currency,
|
|
61
|
+
target_currency=target_currency,
|
|
62
|
+
use_source_codes=use_source_codes,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Get the exchange rate data
|
|
66
|
+
data = exchange.pydeflate_data.copy()
|
|
67
|
+
|
|
68
|
+
# Determine the entity column based on use_source_codes
|
|
69
|
+
entity_col = "pydeflate_entity_code" if use_source_codes else "pydeflate_iso3"
|
|
70
|
+
|
|
71
|
+
# Filter by countries if specified
|
|
72
|
+
if countries is not None:
|
|
73
|
+
data = data[data[entity_col].isin(countries)]
|
|
74
|
+
|
|
75
|
+
# Filter by years if specified
|
|
76
|
+
if years is not None:
|
|
77
|
+
if isinstance(years, range):
|
|
78
|
+
years = list(years)
|
|
79
|
+
data = data[data["pydeflate_year"].isin(years)]
|
|
80
|
+
|
|
81
|
+
# Select and rename columns
|
|
82
|
+
result = data[[entity_col, "pydeflate_year", "pydeflate_EXCHANGE"]].copy()
|
|
83
|
+
result = result.rename(
|
|
84
|
+
columns={
|
|
85
|
+
entity_col: "entity_code" if use_source_codes else "iso_code",
|
|
86
|
+
"pydeflate_year": "year",
|
|
87
|
+
"pydeflate_EXCHANGE": "exchange_rate",
|
|
88
|
+
}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Reset index
|
|
92
|
+
result = result.reset_index(drop=True)
|
|
93
|
+
|
|
94
|
+
return result
|
|
95
|
+
|
|
96
|
+
wrapper.__doc__ = _generate_get_rates_docstring(exchange_source_cls.__name__)
|
|
97
|
+
return wrapper
|
|
98
|
+
|
|
99
|
+
return decorator
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@_get_exchange_rates(DAC)
|
|
103
|
+
def get_oecd_dac_exchange_rates(
|
|
104
|
+
*,
|
|
105
|
+
source_currency: str = "USA",
|
|
106
|
+
target_currency: str = "USA",
|
|
107
|
+
countries: list[str] | None = None,
|
|
108
|
+
years: list[int] | range | None = None,
|
|
109
|
+
use_source_codes: bool = False,
|
|
110
|
+
update_rates: bool = False,
|
|
111
|
+
) -> pd.DataFrame: ...
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@_get_exchange_rates(WorldBank)
|
|
115
|
+
def get_wb_exchange_rates(
|
|
116
|
+
*,
|
|
117
|
+
source_currency: str = "USA",
|
|
118
|
+
target_currency: str = "USA",
|
|
119
|
+
countries: list[str] | None = None,
|
|
120
|
+
years: list[int] | range | None = None,
|
|
121
|
+
use_source_codes: bool = False,
|
|
122
|
+
update_rates: bool = False,
|
|
123
|
+
) -> pd.DataFrame: ...
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_wb_ppp_rates(
|
|
127
|
+
*,
|
|
128
|
+
source_currency: str = "USA",
|
|
129
|
+
countries: list[str] | None = None,
|
|
130
|
+
years: list[int] | range | None = None,
|
|
131
|
+
use_source_codes: bool = False,
|
|
132
|
+
update_rates: bool = False,
|
|
133
|
+
) -> pd.DataFrame:
|
|
134
|
+
"""Get PPP exchange rate data from WorldBankPPP without requiring user data.
|
|
135
|
+
|
|
136
|
+
This function returns a DataFrame containing PPP exchange rates for the specified parameters.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
source_currency (str, optional): The source currency code. Defaults to 'USA'.
|
|
140
|
+
countries (list[str] | None, optional): List of country codes to include. If None, returns all. Defaults to None.
|
|
141
|
+
years (list[int] | range | None, optional): List or range of years to include. If None, returns all. Defaults to None.
|
|
142
|
+
use_source_codes (bool, optional): Use source-specific entity codes. Defaults to False.
|
|
143
|
+
update_rates (bool, optional): Update the exchange rate data before retrieval. Defaults to False.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
pd.DataFrame: DataFrame with columns:
|
|
147
|
+
- iso_code (or entity_code if use_source_codes=True): Country/entity identifier
|
|
148
|
+
- year: Year
|
|
149
|
+
- exchange_rate: PPP exchange rate
|
|
150
|
+
"""
|
|
151
|
+
# Initialize the exchange source
|
|
152
|
+
source = WorldBankPPP(
|
|
153
|
+
update=update_rates,
|
|
154
|
+
from_lcu=False if source_currency == "USA" else True,
|
|
155
|
+
)
|
|
156
|
+
source_currency_internal = "LCU" if source_currency == "USA" else source_currency
|
|
157
|
+
|
|
158
|
+
# Create an exchange object with PPP as target
|
|
159
|
+
exchange = BaseExchange(
|
|
160
|
+
exchange_source=source,
|
|
161
|
+
source_currency=source_currency_internal,
|
|
162
|
+
target_currency="PPP",
|
|
163
|
+
use_source_codes=use_source_codes,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Get the exchange rate data
|
|
167
|
+
data = exchange.pydeflate_data.copy()
|
|
168
|
+
|
|
169
|
+
# Determine the entity column based on use_source_codes
|
|
170
|
+
entity_col = "pydeflate_entity_code" if use_source_codes else "pydeflate_iso3"
|
|
171
|
+
|
|
172
|
+
# Filter by countries if specified
|
|
173
|
+
if countries is not None:
|
|
174
|
+
data = data[data[entity_col].isin(countries)]
|
|
175
|
+
|
|
176
|
+
# Filter by years if specified
|
|
177
|
+
if years is not None:
|
|
178
|
+
if isinstance(years, range):
|
|
179
|
+
years = list(years)
|
|
180
|
+
data = data[data["pydeflate_year"].isin(years)]
|
|
181
|
+
|
|
182
|
+
# Select and rename columns
|
|
183
|
+
result = data[[entity_col, "pydeflate_year", "pydeflate_EXCHANGE"]].copy()
|
|
184
|
+
result = result.rename(
|
|
185
|
+
columns={
|
|
186
|
+
entity_col: "entity_code" if use_source_codes else "iso_code",
|
|
187
|
+
"pydeflate_year": "year",
|
|
188
|
+
"pydeflate_EXCHANGE": "exchange_rate",
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Reset index
|
|
193
|
+
result = result.reset_index(drop=True)
|
|
194
|
+
|
|
195
|
+
return result
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@_get_exchange_rates(IMF)
|
|
199
|
+
def get_imf_exchange_rates(
|
|
200
|
+
*,
|
|
201
|
+
source_currency: str = "USA",
|
|
202
|
+
target_currency: str = "USA",
|
|
203
|
+
countries: list[str] | None = None,
|
|
204
|
+
years: list[int] | range | None = None,
|
|
205
|
+
use_source_codes: bool = False,
|
|
206
|
+
update_rates: bool = False,
|
|
207
|
+
) -> pd.DataFrame: ...
|
pydeflate/plugins.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""Plugin system for custom data sources.
|
|
2
|
+
|
|
3
|
+
This module provides a registry-based plugin architecture that allows
|
|
4
|
+
users to register custom data sources without modifying pydeflate's code.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
>>> from pydeflate.plugins import register_source
|
|
8
|
+
>>> from pydeflate.protocols import SourceProtocol
|
|
9
|
+
>>>
|
|
10
|
+
>>> @register_source("my_custom_source")
|
|
11
|
+
>>> class MyCustomSource:
|
|
12
|
+
... def __init__(self, update: bool = False):
|
|
13
|
+
... self.name = "my_custom_source"
|
|
14
|
+
... self.data = load_my_data(update)
|
|
15
|
+
... self._idx = ["pydeflate_year", "pydeflate_entity_code", "pydeflate_iso3"]
|
|
16
|
+
...
|
|
17
|
+
... def lcu_usd_exchange(self): ...
|
|
18
|
+
... def price_deflator(self, kind): ...
|
|
19
|
+
... def validate(self): ...
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Now use it:
|
|
22
|
+
>>> result = deflate_with_source("my_custom_source", data=df, ...)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from typing import Any, Callable, TypeVar
|
|
28
|
+
|
|
29
|
+
from pydeflate.exceptions import PluginError
|
|
30
|
+
from pydeflate.protocols import SourceProtocol
|
|
31
|
+
|
|
32
|
+
# Type variable for source classes
|
|
33
|
+
SourceType = TypeVar("SourceType", bound=type)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SourceRegistry:
|
|
37
|
+
"""Registry for data source implementations.
|
|
38
|
+
|
|
39
|
+
This class maintains a mapping of source names to their implementation
|
|
40
|
+
classes. It provides methods to register, retrieve, and list sources.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
"""Initialize an empty registry."""
|
|
45
|
+
self._sources: dict[str, type] = {}
|
|
46
|
+
self._factories: dict[str, Callable[..., SourceProtocol]] = {}
|
|
47
|
+
|
|
48
|
+
def register(
|
|
49
|
+
self,
|
|
50
|
+
name: str,
|
|
51
|
+
source_class: type | None = None,
|
|
52
|
+
factory: Callable[..., SourceProtocol] | None = None,
|
|
53
|
+
*,
|
|
54
|
+
override: bool = False,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Register a data source.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
name: Unique name for this source (e.g., 'IMF', 'WorldBank')
|
|
60
|
+
source_class: Class that implements SourceProtocol
|
|
61
|
+
factory: Factory function that returns a SourceProtocol instance
|
|
62
|
+
override: If True, allow replacing existing sources
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
PluginError: If name already registered and override=False
|
|
66
|
+
PluginError: If neither source_class nor factory is provided
|
|
67
|
+
PluginError: If source doesn't implement SourceProtocol
|
|
68
|
+
"""
|
|
69
|
+
# Validation
|
|
70
|
+
if source_class is None and factory is None:
|
|
71
|
+
raise PluginError(
|
|
72
|
+
"Must provide either source_class or factory",
|
|
73
|
+
plugin_name=name,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if not override and name in self._sources:
|
|
77
|
+
raise PluginError(
|
|
78
|
+
f"Source '{name}' already registered. Use override=True to replace.",
|
|
79
|
+
plugin_name=name,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Check protocol conformance if source_class provided
|
|
83
|
+
if source_class is not None:
|
|
84
|
+
if not self._check_protocol_conformance(source_class):
|
|
85
|
+
raise PluginError(
|
|
86
|
+
f"Source class must implement SourceProtocol",
|
|
87
|
+
plugin_name=name,
|
|
88
|
+
)
|
|
89
|
+
self._sources[name] = source_class
|
|
90
|
+
|
|
91
|
+
if factory is not None:
|
|
92
|
+
self._factories[name] = factory
|
|
93
|
+
|
|
94
|
+
def _check_protocol_conformance(self, source_class: type) -> bool:
|
|
95
|
+
"""Check if a class implements SourceProtocol.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
source_class: Class to check
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if class implements the protocol
|
|
102
|
+
"""
|
|
103
|
+
required_methods = ["lcu_usd_exchange", "price_deflator", "validate"]
|
|
104
|
+
required_attrs = ["name", "data", "_idx"]
|
|
105
|
+
|
|
106
|
+
for method in required_methods:
|
|
107
|
+
if not hasattr(source_class, method):
|
|
108
|
+
return False
|
|
109
|
+
if not callable(getattr(source_class, method)):
|
|
110
|
+
return False
|
|
111
|
+
|
|
112
|
+
# Check that class can be instantiated and has required attributes
|
|
113
|
+
# Note: We can't fully check this without instantiation,
|
|
114
|
+
# so we just check the class has these as class attributes or in __init__
|
|
115
|
+
# This is a best-effort check
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
def get(self, name: str, **kwargs) -> SourceProtocol:
|
|
119
|
+
"""Get a source instance by name.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
name: Name of the registered source
|
|
123
|
+
**kwargs: Keyword arguments to pass to source constructor
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Instance of the source
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
PluginError: If source not found
|
|
130
|
+
"""
|
|
131
|
+
# Try factory first
|
|
132
|
+
if name in self._factories:
|
|
133
|
+
try:
|
|
134
|
+
return self._factories[name](**kwargs)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
raise PluginError(
|
|
137
|
+
f"Factory function failed: {e}",
|
|
138
|
+
plugin_name=name,
|
|
139
|
+
) from e
|
|
140
|
+
|
|
141
|
+
# Try class
|
|
142
|
+
if name in self._sources:
|
|
143
|
+
try:
|
|
144
|
+
return self._sources[name](**kwargs)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise PluginError(
|
|
147
|
+
f"Source instantiation failed: {e}",
|
|
148
|
+
plugin_name=name,
|
|
149
|
+
) from e
|
|
150
|
+
|
|
151
|
+
raise PluginError(
|
|
152
|
+
f"Source '{name}' not found. Available sources: {self.list_sources()}",
|
|
153
|
+
plugin_name=name,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def list_sources(self) -> list[str]:
|
|
157
|
+
"""List all registered source names.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
List of source names
|
|
161
|
+
"""
|
|
162
|
+
all_names = set(self._sources.keys()) | set(self._factories.keys())
|
|
163
|
+
return sorted(all_names)
|
|
164
|
+
|
|
165
|
+
def is_registered(self, name: str) -> bool:
|
|
166
|
+
"""Check if a source is registered.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
name: Source name to check
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
True if registered
|
|
173
|
+
"""
|
|
174
|
+
return name in self._sources or name in self._factories
|
|
175
|
+
|
|
176
|
+
def unregister(self, name: str) -> None:
|
|
177
|
+
"""Remove a source from the registry.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
name: Source name to remove
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
PluginError: If source not found
|
|
184
|
+
"""
|
|
185
|
+
if not self.is_registered(name):
|
|
186
|
+
raise PluginError(
|
|
187
|
+
f"Cannot unregister '{name}': not found",
|
|
188
|
+
plugin_name=name,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self._sources.pop(name, None)
|
|
192
|
+
self._factories.pop(name, None)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# Global registry instance
|
|
196
|
+
_global_registry = SourceRegistry()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def register_source(
|
|
200
|
+
name: str,
|
|
201
|
+
*,
|
|
202
|
+
override: bool = False,
|
|
203
|
+
) -> Callable[[SourceType], SourceType]:
|
|
204
|
+
"""Decorator to register a source class.
|
|
205
|
+
|
|
206
|
+
Example:
|
|
207
|
+
>>> @register_source("my_source")
|
|
208
|
+
... class MySource:
|
|
209
|
+
... def __init__(self, update: bool = False):
|
|
210
|
+
... self.name = "my_source"
|
|
211
|
+
... ...
|
|
212
|
+
...
|
|
213
|
+
... def lcu_usd_exchange(self): ...
|
|
214
|
+
... def price_deflator(self, kind): ...
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
name: Unique name for this source
|
|
218
|
+
override: Allow replacing existing sources
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Decorator function
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
def decorator(cls: SourceType) -> SourceType:
|
|
225
|
+
_global_registry.register(name, source_class=cls, override=override)
|
|
226
|
+
return cls
|
|
227
|
+
|
|
228
|
+
return decorator
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def get_source(name: str, **kwargs) -> SourceProtocol:
|
|
232
|
+
"""Get a registered source instance.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
name: Name of the source
|
|
236
|
+
**kwargs: Arguments to pass to source constructor
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Source instance
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
PluginError: If source not found or instantiation fails
|
|
243
|
+
"""
|
|
244
|
+
return _global_registry.get(name, **kwargs)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def list_sources() -> list[str]:
|
|
248
|
+
"""List all available sources.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Sorted list of source names
|
|
252
|
+
"""
|
|
253
|
+
return _global_registry.list_sources()
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def is_source_registered(name: str) -> bool:
|
|
257
|
+
"""Check if a source is registered.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
name: Source name
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
True if registered
|
|
264
|
+
"""
|
|
265
|
+
return _global_registry.is_registered(name)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# Pre-register built-in sources
|
|
269
|
+
def _register_builtin_sources():
|
|
270
|
+
"""Register pydeflate's built-in sources."""
|
|
271
|
+
from pydeflate.core.source import DAC, IMF, WorldBank, WorldBankPPP
|
|
272
|
+
|
|
273
|
+
_global_registry.register("IMF", source_class=IMF, override=True)
|
|
274
|
+
_global_registry.register("World Bank", source_class=WorldBank, override=True)
|
|
275
|
+
_global_registry.register(
|
|
276
|
+
"World Bank PPP", source_class=WorldBankPPP, override=True
|
|
277
|
+
)
|
|
278
|
+
_global_registry.register("DAC", source_class=DAC, override=True)
|
|
279
|
+
|
|
280
|
+
# Aliases for convenience
|
|
281
|
+
_global_registry.register("imf", source_class=IMF, override=True)
|
|
282
|
+
_global_registry.register("wb", source_class=WorldBank, override=True)
|
|
283
|
+
_global_registry.register("worldbank", source_class=WorldBank, override=True)
|
|
284
|
+
_global_registry.register("dac", source_class=DAC, override=True)
|
|
285
|
+
_global_registry.register("oecd", source_class=DAC, override=True)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# Register built-in sources when module is imported
|
|
289
|
+
_register_builtin_sources()
|
pydeflate/protocols.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Protocol definitions for type safety and extensibility.
|
|
2
|
+
|
|
3
|
+
This module defines the core protocols (interfaces) that pydeflate components
|
|
4
|
+
must implement. Using protocols enables:
|
|
5
|
+
- Type checking without inheritance
|
|
6
|
+
- Duck typing with static verification
|
|
7
|
+
- Clear contracts for extensibility
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Protocol, runtime_checkable
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from pydeflate.sources.common import AvailableDeflators
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@runtime_checkable
|
|
20
|
+
class SourceProtocol(Protocol):
|
|
21
|
+
"""Protocol for data sources (IMF, World Bank, DAC, etc.).
|
|
22
|
+
|
|
23
|
+
All data sources must implement these methods to be compatible
|
|
24
|
+
with pydeflate's core deflation and exchange logic.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
name: str
|
|
28
|
+
"""Human-readable name of the data source (e.g., 'IMF', 'World Bank')"""
|
|
29
|
+
|
|
30
|
+
data: pd.DataFrame
|
|
31
|
+
"""The raw data loaded from this source"""
|
|
32
|
+
|
|
33
|
+
_idx: list[str]
|
|
34
|
+
"""Standard index columns: ['pydeflate_year', 'pydeflate_entity_code', 'pydeflate_iso3']"""
|
|
35
|
+
|
|
36
|
+
def lcu_usd_exchange(self) -> pd.DataFrame:
|
|
37
|
+
"""Return local currency to USD exchange rates.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
DataFrame with columns: pydeflate_year, pydeflate_entity_code,
|
|
41
|
+
pydeflate_iso3, pydeflate_EXCHANGE
|
|
42
|
+
"""
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
def price_deflator(self, kind: AvailableDeflators = "NGDP_D") -> pd.DataFrame:
|
|
46
|
+
"""Return price deflator data for the specified type.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
kind: Type of deflator (e.g., 'NGDP_D', 'CPI', 'PCPI')
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
DataFrame with columns: pydeflate_year, pydeflate_entity_code,
|
|
53
|
+
pydeflate_iso3, pydeflate_{kind}
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If the requested deflator kind is not available
|
|
57
|
+
"""
|
|
58
|
+
...
|
|
59
|
+
|
|
60
|
+
def validate(self) -> None:
|
|
61
|
+
"""Validate that the source data is properly formatted.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If data is empty or improperly formatted
|
|
65
|
+
SchemaValidationError: If data doesn't match expected schema
|
|
66
|
+
"""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@runtime_checkable
|
|
71
|
+
class DeflatorProtocol(Protocol):
|
|
72
|
+
"""Protocol for deflator objects (price or exchange deflators)."""
|
|
73
|
+
|
|
74
|
+
source: SourceProtocol
|
|
75
|
+
"""The data source providing deflator data"""
|
|
76
|
+
|
|
77
|
+
deflator_type: str
|
|
78
|
+
"""Type of deflator: 'price' or 'exchange'"""
|
|
79
|
+
|
|
80
|
+
base_year: int
|
|
81
|
+
"""The base year for rebasing deflator values (value = 100 at base year)"""
|
|
82
|
+
|
|
83
|
+
deflator_data: pd.DataFrame
|
|
84
|
+
"""The deflator data, rebased to base_year"""
|
|
85
|
+
|
|
86
|
+
def rebase_deflator(self) -> None:
|
|
87
|
+
"""Rebase deflator values so that base_year has value 100."""
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@runtime_checkable
|
|
92
|
+
class ExchangeProtocol(Protocol):
|
|
93
|
+
"""Protocol for exchange rate objects."""
|
|
94
|
+
|
|
95
|
+
source: SourceProtocol
|
|
96
|
+
"""The data source providing exchange rate data"""
|
|
97
|
+
|
|
98
|
+
source_currency: str
|
|
99
|
+
"""Source currency code (ISO3 country code or 'LCU')"""
|
|
100
|
+
|
|
101
|
+
target_currency: str
|
|
102
|
+
"""Target currency code (ISO3 country code)"""
|
|
103
|
+
|
|
104
|
+
exchange_data: pd.DataFrame
|
|
105
|
+
"""Exchange rate data for converting source_currency to target_currency"""
|
|
106
|
+
|
|
107
|
+
def exchange_rate(self, from_currency: str, to_currency: str) -> pd.DataFrame:
|
|
108
|
+
"""Calculate exchange rates between two currencies.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
from_currency: Source currency code
|
|
112
|
+
to_currency: Target currency code
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
DataFrame with exchange rates and deflators
|
|
116
|
+
"""
|
|
117
|
+
...
|
|
118
|
+
|
|
119
|
+
def deflator(self) -> pd.DataFrame:
|
|
120
|
+
"""Get exchange rate deflator data.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
DataFrame with columns: pydeflate_year, pydeflate_entity_code,
|
|
124
|
+
pydeflate_iso3, pydeflate_EXCHANGE_D
|
|
125
|
+
"""
|
|
126
|
+
...
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@runtime_checkable
|
|
130
|
+
class CacheManagerProtocol(Protocol):
|
|
131
|
+
"""Protocol for cache management."""
|
|
132
|
+
|
|
133
|
+
base_dir: pd.DataFrame
|
|
134
|
+
"""Base directory for cached data"""
|
|
135
|
+
|
|
136
|
+
def ensure(self, entry, *, refresh: bool = False):
|
|
137
|
+
"""Ensure a cache entry exists, downloading if needed.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
entry: CacheEntry describing the dataset
|
|
141
|
+
refresh: If True, force re-download
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Path to the cached file
|
|
145
|
+
"""
|
|
146
|
+
...
|
|
147
|
+
|
|
148
|
+
def clear(self, key: str | None = None) -> None:
|
|
149
|
+
"""Clear cache entries.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
key: If provided, clear only this entry. If None, clear all.
|
|
153
|
+
"""
|
|
154
|
+
...
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# Type aliases for common DataFrame schemas
|
|
158
|
+
ExchangeDataFrame = pd.DataFrame
|
|
159
|
+
"""DataFrame containing exchange rate data"""
|
|
160
|
+
|
|
161
|
+
DeflatorDataFrame = pd.DataFrame
|
|
162
|
+
"""DataFrame containing deflator data"""
|
|
163
|
+
|
|
164
|
+
SourceDataFrame = pd.DataFrame
|
|
165
|
+
"""DataFrame containing raw source data"""
|
|
166
|
+
|
|
167
|
+
UserDataFrame = pd.DataFrame
|
|
168
|
+
"""DataFrame provided by users for deflation/exchange"""
|