pydeflate 2.1.3__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,207 @@
1
+ from functools import wraps
2
+
3
+ import pandas as pd
4
+
5
+ from pydeflate.core.api import BaseExchange
6
+ from pydeflate.core.source import DAC, IMF, WorldBank, WorldBankPPP
7
+
8
+
9
+ def _generate_get_rates_docstring(source_name: str) -> str:
10
+ """Generate docstring for each get exchange rates function."""
11
+ return (
12
+ f"Get exchange rate data from {source_name} without requiring user data.\n\n"
13
+ f"This function returns a DataFrame containing exchange rates for the specified parameters.\n\n"
14
+ "Args:\n"
15
+ " source_currency (str, optional): The source currency code. Defaults to 'USA'.\n"
16
+ " target_currency (str, optional): The target currency code. Defaults to 'USA'.\n"
17
+ " countries (list[str] | None, optional): List of country codes to include. If None, returns all. Defaults to None.\n"
18
+ " years (list[int] | range | None, optional): List or range of years to include. If None, returns all. Defaults to None.\n"
19
+ " use_source_codes (bool, optional): Use source-specific entity codes. Defaults to False.\n"
20
+ " update_rates (bool, optional): Update the exchange rate data before retrieval. Defaults to False.\n\n"
21
+ "Returns:\n"
22
+ " pd.DataFrame: DataFrame with columns:\n"
23
+ " - iso_code (or entity_code if use_source_codes=True): Country/entity identifier\n"
24
+ " - year: Year\n"
25
+ " - exchange_rate: Exchange rate from source to target currency\n"
26
+ )
27
+
28
+
29
+ def _get_exchange_rates(exchange_source_cls, **fixed_params):
30
+ """Decorator to create get_exchange_rates wrappers with specific source."""
31
+
32
+ def decorator(func):
33
+ @wraps(func)
34
+ def wrapper(
35
+ *,
36
+ source_currency: str = "USA",
37
+ target_currency: str = "USA",
38
+ countries: list[str] | None = None,
39
+ years: list[int] | range | None = None,
40
+ use_source_codes: bool = False,
41
+ update_rates: bool = False,
42
+ ):
43
+ # Apply fixed parameters - no validation needed since these are internally set
44
+ if "target_currency" in fixed_params:
45
+ target_currency = fixed_params["target_currency"]
46
+
47
+ # Initialize the exchange source
48
+ if exchange_source_cls.__name__ == "WorldBankPPP":
49
+ source = exchange_source_cls(
50
+ update=update_rates,
51
+ from_lcu=False if source_currency == "USA" else True,
52
+ )
53
+ source_currency = "LCU" if source_currency == "USA" else source_currency
54
+ else:
55
+ source = exchange_source_cls(update=update_rates)
56
+
57
+ # Create an exchange object
58
+ exchange = BaseExchange(
59
+ exchange_source=source,
60
+ source_currency=source_currency,
61
+ target_currency=target_currency,
62
+ use_source_codes=use_source_codes,
63
+ )
64
+
65
+ # Get the exchange rate data
66
+ data = exchange.pydeflate_data.copy()
67
+
68
+ # Determine the entity column based on use_source_codes
69
+ entity_col = "pydeflate_entity_code" if use_source_codes else "pydeflate_iso3"
70
+
71
+ # Filter by countries if specified
72
+ if countries is not None:
73
+ data = data[data[entity_col].isin(countries)]
74
+
75
+ # Filter by years if specified
76
+ if years is not None:
77
+ if isinstance(years, range):
78
+ years = list(years)
79
+ data = data[data["pydeflate_year"].isin(years)]
80
+
81
+ # Select and rename columns
82
+ result = data[[entity_col, "pydeflate_year", "pydeflate_EXCHANGE"]].copy()
83
+ result = result.rename(
84
+ columns={
85
+ entity_col: "entity_code" if use_source_codes else "iso_code",
86
+ "pydeflate_year": "year",
87
+ "pydeflate_EXCHANGE": "exchange_rate",
88
+ }
89
+ )
90
+
91
+ # Reset index
92
+ result = result.reset_index(drop=True)
93
+
94
+ return result
95
+
96
+ wrapper.__doc__ = _generate_get_rates_docstring(exchange_source_cls.__name__)
97
+ return wrapper
98
+
99
+ return decorator
100
+
101
+
102
+ @_get_exchange_rates(DAC)
103
+ def get_oecd_dac_exchange_rates(
104
+ *,
105
+ source_currency: str = "USA",
106
+ target_currency: str = "USA",
107
+ countries: list[str] | None = None,
108
+ years: list[int] | range | None = None,
109
+ use_source_codes: bool = False,
110
+ update_rates: bool = False,
111
+ ) -> pd.DataFrame: ...
112
+
113
+
114
+ @_get_exchange_rates(WorldBank)
115
+ def get_wb_exchange_rates(
116
+ *,
117
+ source_currency: str = "USA",
118
+ target_currency: str = "USA",
119
+ countries: list[str] | None = None,
120
+ years: list[int] | range | None = None,
121
+ use_source_codes: bool = False,
122
+ update_rates: bool = False,
123
+ ) -> pd.DataFrame: ...
124
+
125
+
126
+ def get_wb_ppp_rates(
127
+ *,
128
+ source_currency: str = "USA",
129
+ countries: list[str] | None = None,
130
+ years: list[int] | range | None = None,
131
+ use_source_codes: bool = False,
132
+ update_rates: bool = False,
133
+ ) -> pd.DataFrame:
134
+ """Get PPP exchange rate data from WorldBankPPP without requiring user data.
135
+
136
+ This function returns a DataFrame containing PPP exchange rates for the specified parameters.
137
+
138
+ Args:
139
+ source_currency (str, optional): The source currency code. Defaults to 'USA'.
140
+ countries (list[str] | None, optional): List of country codes to include. If None, returns all. Defaults to None.
141
+ years (list[int] | range | None, optional): List or range of years to include. If None, returns all. Defaults to None.
142
+ use_source_codes (bool, optional): Use source-specific entity codes. Defaults to False.
143
+ update_rates (bool, optional): Update the exchange rate data before retrieval. Defaults to False.
144
+
145
+ Returns:
146
+ pd.DataFrame: DataFrame with columns:
147
+ - iso_code (or entity_code if use_source_codes=True): Country/entity identifier
148
+ - year: Year
149
+ - exchange_rate: PPP exchange rate
150
+ """
151
+ # Initialize the exchange source
152
+ source = WorldBankPPP(
153
+ update=update_rates,
154
+ from_lcu=False if source_currency == "USA" else True,
155
+ )
156
+ source_currency_internal = "LCU" if source_currency == "USA" else source_currency
157
+
158
+ # Create an exchange object with PPP as target
159
+ exchange = BaseExchange(
160
+ exchange_source=source,
161
+ source_currency=source_currency_internal,
162
+ target_currency="PPP",
163
+ use_source_codes=use_source_codes,
164
+ )
165
+
166
+ # Get the exchange rate data
167
+ data = exchange.pydeflate_data.copy()
168
+
169
+ # Determine the entity column based on use_source_codes
170
+ entity_col = "pydeflate_entity_code" if use_source_codes else "pydeflate_iso3"
171
+
172
+ # Filter by countries if specified
173
+ if countries is not None:
174
+ data = data[data[entity_col].isin(countries)]
175
+
176
+ # Filter by years if specified
177
+ if years is not None:
178
+ if isinstance(years, range):
179
+ years = list(years)
180
+ data = data[data["pydeflate_year"].isin(years)]
181
+
182
+ # Select and rename columns
183
+ result = data[[entity_col, "pydeflate_year", "pydeflate_EXCHANGE"]].copy()
184
+ result = result.rename(
185
+ columns={
186
+ entity_col: "entity_code" if use_source_codes else "iso_code",
187
+ "pydeflate_year": "year",
188
+ "pydeflate_EXCHANGE": "exchange_rate",
189
+ }
190
+ )
191
+
192
+ # Reset index
193
+ result = result.reset_index(drop=True)
194
+
195
+ return result
196
+
197
+
198
+ @_get_exchange_rates(IMF)
199
+ def get_imf_exchange_rates(
200
+ *,
201
+ source_currency: str = "USA",
202
+ target_currency: str = "USA",
203
+ countries: list[str] | None = None,
204
+ years: list[int] | range | None = None,
205
+ use_source_codes: bool = False,
206
+ update_rates: bool = False,
207
+ ) -> pd.DataFrame: ...
pydeflate/plugins.py ADDED
@@ -0,0 +1,289 @@
1
+ """Plugin system for custom data sources.
2
+
3
+ This module provides a registry-based plugin architecture that allows
4
+ users to register custom data sources without modifying pydeflate's code.
5
+
6
+ Example:
7
+ >>> from pydeflate.plugins import register_source
8
+ >>> from pydeflate.protocols import SourceProtocol
9
+ >>>
10
+ >>> @register_source("my_custom_source")
11
+ >>> class MyCustomSource:
12
+ ... def __init__(self, update: bool = False):
13
+ ... self.name = "my_custom_source"
14
+ ... self.data = load_my_data(update)
15
+ ... self._idx = ["pydeflate_year", "pydeflate_entity_code", "pydeflate_iso3"]
16
+ ...
17
+ ... def lcu_usd_exchange(self): ...
18
+ ... def price_deflator(self, kind): ...
19
+ ... def validate(self): ...
20
+ >>>
21
+ >>> # Now use it:
22
+ >>> result = deflate_with_source("my_custom_source", data=df, ...)
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import Any, Callable, TypeVar
28
+
29
+ from pydeflate.exceptions import PluginError
30
+ from pydeflate.protocols import SourceProtocol
31
+
32
+ # Type variable for source classes
33
+ SourceType = TypeVar("SourceType", bound=type)
34
+
35
+
36
+ class SourceRegistry:
37
+ """Registry for data source implementations.
38
+
39
+ This class maintains a mapping of source names to their implementation
40
+ classes. It provides methods to register, retrieve, and list sources.
41
+ """
42
+
43
+ def __init__(self):
44
+ """Initialize an empty registry."""
45
+ self._sources: dict[str, type] = {}
46
+ self._factories: dict[str, Callable[..., SourceProtocol]] = {}
47
+
48
+ def register(
49
+ self,
50
+ name: str,
51
+ source_class: type | None = None,
52
+ factory: Callable[..., SourceProtocol] | None = None,
53
+ *,
54
+ override: bool = False,
55
+ ) -> None:
56
+ """Register a data source.
57
+
58
+ Args:
59
+ name: Unique name for this source (e.g., 'IMF', 'WorldBank')
60
+ source_class: Class that implements SourceProtocol
61
+ factory: Factory function that returns a SourceProtocol instance
62
+ override: If True, allow replacing existing sources
63
+
64
+ Raises:
65
+ PluginError: If name already registered and override=False
66
+ PluginError: If neither source_class nor factory is provided
67
+ PluginError: If source doesn't implement SourceProtocol
68
+ """
69
+ # Validation
70
+ if source_class is None and factory is None:
71
+ raise PluginError(
72
+ "Must provide either source_class or factory",
73
+ plugin_name=name,
74
+ )
75
+
76
+ if not override and name in self._sources:
77
+ raise PluginError(
78
+ f"Source '{name}' already registered. Use override=True to replace.",
79
+ plugin_name=name,
80
+ )
81
+
82
+ # Check protocol conformance if source_class provided
83
+ if source_class is not None:
84
+ if not self._check_protocol_conformance(source_class):
85
+ raise PluginError(
86
+ f"Source class must implement SourceProtocol",
87
+ plugin_name=name,
88
+ )
89
+ self._sources[name] = source_class
90
+
91
+ if factory is not None:
92
+ self._factories[name] = factory
93
+
94
+ def _check_protocol_conformance(self, source_class: type) -> bool:
95
+ """Check if a class implements SourceProtocol.
96
+
97
+ Args:
98
+ source_class: Class to check
99
+
100
+ Returns:
101
+ True if class implements the protocol
102
+ """
103
+ required_methods = ["lcu_usd_exchange", "price_deflator", "validate"]
104
+ required_attrs = ["name", "data", "_idx"]
105
+
106
+ for method in required_methods:
107
+ if not hasattr(source_class, method):
108
+ return False
109
+ if not callable(getattr(source_class, method)):
110
+ return False
111
+
112
+ # Check that class can be instantiated and has required attributes
113
+ # Note: We can't fully check this without instantiation,
114
+ # so we just check the class has these as class attributes or in __init__
115
+ # This is a best-effort check
116
+ return True
117
+
118
+ def get(self, name: str, **kwargs) -> SourceProtocol:
119
+ """Get a source instance by name.
120
+
121
+ Args:
122
+ name: Name of the registered source
123
+ **kwargs: Keyword arguments to pass to source constructor
124
+
125
+ Returns:
126
+ Instance of the source
127
+
128
+ Raises:
129
+ PluginError: If source not found
130
+ """
131
+ # Try factory first
132
+ if name in self._factories:
133
+ try:
134
+ return self._factories[name](**kwargs)
135
+ except Exception as e:
136
+ raise PluginError(
137
+ f"Factory function failed: {e}",
138
+ plugin_name=name,
139
+ ) from e
140
+
141
+ # Try class
142
+ if name in self._sources:
143
+ try:
144
+ return self._sources[name](**kwargs)
145
+ except Exception as e:
146
+ raise PluginError(
147
+ f"Source instantiation failed: {e}",
148
+ plugin_name=name,
149
+ ) from e
150
+
151
+ raise PluginError(
152
+ f"Source '{name}' not found. Available sources: {self.list_sources()}",
153
+ plugin_name=name,
154
+ )
155
+
156
+ def list_sources(self) -> list[str]:
157
+ """List all registered source names.
158
+
159
+ Returns:
160
+ List of source names
161
+ """
162
+ all_names = set(self._sources.keys()) | set(self._factories.keys())
163
+ return sorted(all_names)
164
+
165
+ def is_registered(self, name: str) -> bool:
166
+ """Check if a source is registered.
167
+
168
+ Args:
169
+ name: Source name to check
170
+
171
+ Returns:
172
+ True if registered
173
+ """
174
+ return name in self._sources or name in self._factories
175
+
176
+ def unregister(self, name: str) -> None:
177
+ """Remove a source from the registry.
178
+
179
+ Args:
180
+ name: Source name to remove
181
+
182
+ Raises:
183
+ PluginError: If source not found
184
+ """
185
+ if not self.is_registered(name):
186
+ raise PluginError(
187
+ f"Cannot unregister '{name}': not found",
188
+ plugin_name=name,
189
+ )
190
+
191
+ self._sources.pop(name, None)
192
+ self._factories.pop(name, None)
193
+
194
+
195
+ # Global registry instance
196
+ _global_registry = SourceRegistry()
197
+
198
+
199
+ def register_source(
200
+ name: str,
201
+ *,
202
+ override: bool = False,
203
+ ) -> Callable[[SourceType], SourceType]:
204
+ """Decorator to register a source class.
205
+
206
+ Example:
207
+ >>> @register_source("my_source")
208
+ ... class MySource:
209
+ ... def __init__(self, update: bool = False):
210
+ ... self.name = "my_source"
211
+ ... ...
212
+ ...
213
+ ... def lcu_usd_exchange(self): ...
214
+ ... def price_deflator(self, kind): ...
215
+
216
+ Args:
217
+ name: Unique name for this source
218
+ override: Allow replacing existing sources
219
+
220
+ Returns:
221
+ Decorator function
222
+ """
223
+
224
+ def decorator(cls: SourceType) -> SourceType:
225
+ _global_registry.register(name, source_class=cls, override=override)
226
+ return cls
227
+
228
+ return decorator
229
+
230
+
231
+ def get_source(name: str, **kwargs) -> SourceProtocol:
232
+ """Get a registered source instance.
233
+
234
+ Args:
235
+ name: Name of the source
236
+ **kwargs: Arguments to pass to source constructor
237
+
238
+ Returns:
239
+ Source instance
240
+
241
+ Raises:
242
+ PluginError: If source not found or instantiation fails
243
+ """
244
+ return _global_registry.get(name, **kwargs)
245
+
246
+
247
+ def list_sources() -> list[str]:
248
+ """List all available sources.
249
+
250
+ Returns:
251
+ Sorted list of source names
252
+ """
253
+ return _global_registry.list_sources()
254
+
255
+
256
+ def is_source_registered(name: str) -> bool:
257
+ """Check if a source is registered.
258
+
259
+ Args:
260
+ name: Source name
261
+
262
+ Returns:
263
+ True if registered
264
+ """
265
+ return _global_registry.is_registered(name)
266
+
267
+
268
+ # Pre-register built-in sources
269
+ def _register_builtin_sources():
270
+ """Register pydeflate's built-in sources."""
271
+ from pydeflate.core.source import DAC, IMF, WorldBank, WorldBankPPP
272
+
273
+ _global_registry.register("IMF", source_class=IMF, override=True)
274
+ _global_registry.register("World Bank", source_class=WorldBank, override=True)
275
+ _global_registry.register(
276
+ "World Bank PPP", source_class=WorldBankPPP, override=True
277
+ )
278
+ _global_registry.register("DAC", source_class=DAC, override=True)
279
+
280
+ # Aliases for convenience
281
+ _global_registry.register("imf", source_class=IMF, override=True)
282
+ _global_registry.register("wb", source_class=WorldBank, override=True)
283
+ _global_registry.register("worldbank", source_class=WorldBank, override=True)
284
+ _global_registry.register("dac", source_class=DAC, override=True)
285
+ _global_registry.register("oecd", source_class=DAC, override=True)
286
+
287
+
288
+ # Register built-in sources when module is imported
289
+ _register_builtin_sources()
pydeflate/protocols.py ADDED
@@ -0,0 +1,168 @@
1
+ """Protocol definitions for type safety and extensibility.
2
+
3
+ This module defines the core protocols (interfaces) that pydeflate components
4
+ must implement. Using protocols enables:
5
+ - Type checking without inheritance
6
+ - Duck typing with static verification
7
+ - Clear contracts for extensibility
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Protocol, runtime_checkable
13
+
14
+ import pandas as pd
15
+
16
+ from pydeflate.sources.common import AvailableDeflators
17
+
18
+
19
+ @runtime_checkable
20
+ class SourceProtocol(Protocol):
21
+ """Protocol for data sources (IMF, World Bank, DAC, etc.).
22
+
23
+ All data sources must implement these methods to be compatible
24
+ with pydeflate's core deflation and exchange logic.
25
+ """
26
+
27
+ name: str
28
+ """Human-readable name of the data source (e.g., 'IMF', 'World Bank')"""
29
+
30
+ data: pd.DataFrame
31
+ """The raw data loaded from this source"""
32
+
33
+ _idx: list[str]
34
+ """Standard index columns: ['pydeflate_year', 'pydeflate_entity_code', 'pydeflate_iso3']"""
35
+
36
+ def lcu_usd_exchange(self) -> pd.DataFrame:
37
+ """Return local currency to USD exchange rates.
38
+
39
+ Returns:
40
+ DataFrame with columns: pydeflate_year, pydeflate_entity_code,
41
+ pydeflate_iso3, pydeflate_EXCHANGE
42
+ """
43
+ ...
44
+
45
+ def price_deflator(self, kind: AvailableDeflators = "NGDP_D") -> pd.DataFrame:
46
+ """Return price deflator data for the specified type.
47
+
48
+ Args:
49
+ kind: Type of deflator (e.g., 'NGDP_D', 'CPI', 'PCPI')
50
+
51
+ Returns:
52
+ DataFrame with columns: pydeflate_year, pydeflate_entity_code,
53
+ pydeflate_iso3, pydeflate_{kind}
54
+
55
+ Raises:
56
+ ValueError: If the requested deflator kind is not available
57
+ """
58
+ ...
59
+
60
+ def validate(self) -> None:
61
+ """Validate that the source data is properly formatted.
62
+
63
+ Raises:
64
+ ValueError: If data is empty or improperly formatted
65
+ SchemaValidationError: If data doesn't match expected schema
66
+ """
67
+ ...
68
+
69
+
70
+ @runtime_checkable
71
+ class DeflatorProtocol(Protocol):
72
+ """Protocol for deflator objects (price or exchange deflators)."""
73
+
74
+ source: SourceProtocol
75
+ """The data source providing deflator data"""
76
+
77
+ deflator_type: str
78
+ """Type of deflator: 'price' or 'exchange'"""
79
+
80
+ base_year: int
81
+ """The base year for rebasing deflator values (value = 100 at base year)"""
82
+
83
+ deflator_data: pd.DataFrame
84
+ """The deflator data, rebased to base_year"""
85
+
86
+ def rebase_deflator(self) -> None:
87
+ """Rebase deflator values so that base_year has value 100."""
88
+ ...
89
+
90
+
91
+ @runtime_checkable
92
+ class ExchangeProtocol(Protocol):
93
+ """Protocol for exchange rate objects."""
94
+
95
+ source: SourceProtocol
96
+ """The data source providing exchange rate data"""
97
+
98
+ source_currency: str
99
+ """Source currency code (ISO3 country code or 'LCU')"""
100
+
101
+ target_currency: str
102
+ """Target currency code (ISO3 country code)"""
103
+
104
+ exchange_data: pd.DataFrame
105
+ """Exchange rate data for converting source_currency to target_currency"""
106
+
107
+ def exchange_rate(self, from_currency: str, to_currency: str) -> pd.DataFrame:
108
+ """Calculate exchange rates between two currencies.
109
+
110
+ Args:
111
+ from_currency: Source currency code
112
+ to_currency: Target currency code
113
+
114
+ Returns:
115
+ DataFrame with exchange rates and deflators
116
+ """
117
+ ...
118
+
119
+ def deflator(self) -> pd.DataFrame:
120
+ """Get exchange rate deflator data.
121
+
122
+ Returns:
123
+ DataFrame with columns: pydeflate_year, pydeflate_entity_code,
124
+ pydeflate_iso3, pydeflate_EXCHANGE_D
125
+ """
126
+ ...
127
+
128
+
129
+ @runtime_checkable
130
+ class CacheManagerProtocol(Protocol):
131
+ """Protocol for cache management."""
132
+
133
+ base_dir: pd.DataFrame
134
+ """Base directory for cached data"""
135
+
136
+ def ensure(self, entry, *, refresh: bool = False):
137
+ """Ensure a cache entry exists, downloading if needed.
138
+
139
+ Args:
140
+ entry: CacheEntry describing the dataset
141
+ refresh: If True, force re-download
142
+
143
+ Returns:
144
+ Path to the cached file
145
+ """
146
+ ...
147
+
148
+ def clear(self, key: str | None = None) -> None:
149
+ """Clear cache entries.
150
+
151
+ Args:
152
+ key: If provided, clear only this entry. If None, clear all.
153
+ """
154
+ ...
155
+
156
+
157
+ # Type aliases for common DataFrame schemas
158
+ ExchangeDataFrame = pd.DataFrame
159
+ """DataFrame containing exchange rate data"""
160
+
161
+ DeflatorDataFrame = pd.DataFrame
162
+ """DataFrame containing deflator data"""
163
+
164
+ SourceDataFrame = pd.DataFrame
165
+ """DataFrame containing raw source data"""
166
+
167
+ UserDataFrame = pd.DataFrame
168
+ """DataFrame provided by users for deflation/exchange"""