pydeflate 2.1.3__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,79 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
4
+ import os
5
+ from dataclasses import dataclass
2
6
  from pathlib import Path
3
7
 
4
8
 
5
- class PYDEFLATE_PATHS:
6
- """Class to store the paths to the data and output folders."""
9
+ from platformdirs import user_cache_dir
10
+
11
+
12
+ DATA_DIR_ENV = "PYDEFLATE_DATA_DIR"
13
+
14
+ _PACKAGE_ROOT = Path(__file__).resolve().parent.parent
15
+ _SETTINGS_DIR = _PACKAGE_ROOT / "pydeflate" / "settings"
16
+ _TEST_DATA_DIR = _PACKAGE_ROOT / "tests" / "test_files"
17
+
18
+ _DATA_DIR_OVERRIDE: Path | None = None
19
+
20
+
21
+ def _ensure_dir(path: Path) -> Path:
22
+ path.mkdir(parents=True, exist_ok=True)
23
+ return path
24
+
25
+
26
+ def _default_data_dir() -> Path:
27
+ env_value = os.environ.get(DATA_DIR_ENV)
28
+ if env_value:
29
+ return _ensure_dir(Path(env_value).expanduser().resolve())
30
+ return _ensure_dir(Path(user_cache_dir("pydeflate", "pydeflate")))
31
+
32
+
33
+ def get_data_dir() -> Path:
34
+ """Return the directory where pydeflate caches data files."""
35
+
36
+ if _DATA_DIR_OVERRIDE is not None:
37
+ return _ensure_dir(_DATA_DIR_OVERRIDE)
38
+ return _default_data_dir()
39
+
7
40
 
8
- package = Path(__file__).resolve().parent.parent
9
- data = package / "pydeflate" / ".pydeflate_data"
10
- settings = package / "pydeflate" / "settings"
11
- test_data = package / "tests" / "test_files"
41
+ def set_data_dir(path: str | Path) -> Path:
42
+ """Override the pydeflate data directory for the current process."""
43
+
44
+ global _DATA_DIR_OVERRIDE
45
+ resolved = _ensure_dir(Path(path).expanduser().resolve())
46
+ _DATA_DIR_OVERRIDE = resolved
47
+ return resolved
48
+
49
+
50
+ def reset_data_dir() -> None:
51
+ """Reset any process-level overrides and fall back to defaults."""
52
+
53
+ global _DATA_DIR_OVERRIDE
54
+ _DATA_DIR_OVERRIDE = None
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class _Paths:
59
+ package: Path
60
+ settings: Path
61
+ test_data: Path
62
+
63
+ @property
64
+ def data(self) -> Path:
65
+ return get_data_dir()
66
+
67
+ @data.setter # type: ignore[override]
68
+ def data(self, value: Path | str) -> None: # pragma: no cover - simple proxy
69
+ set_data_dir(value)
70
+
71
+
72
+ PYDEFLATE_PATHS = _Paths(
73
+ package=_PACKAGE_ROOT,
74
+ settings=_SETTINGS_DIR,
75
+ test_data=_TEST_DATA_DIR,
76
+ )
12
77
 
13
78
 
14
79
  def setup_logger(name) -> logging.Logger:
@@ -41,3 +106,9 @@ def setup_logger(name) -> logging.Logger:
41
106
 
42
107
 
43
108
  logger = setup_logger("pydeflate")
109
+
110
+
111
+ def set_pydeflate_path(path: str | Path) -> Path:
112
+ """Set the path to the data folder (public API)."""
113
+
114
+ return set_data_dir(path)
pydeflate/schemas.py ADDED
@@ -0,0 +1,297 @@
1
+ """Pandera schemas for data validation.
2
+
3
+ This module defines validation schemas for all DataFrame structures used
4
+ in pydeflate. This ensures data integrity from external sources and
5
+ catches API changes early.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import pandas as pd
11
+
12
+ try:
13
+ # New import (pandera >= 0.20)
14
+ import pandera.pandas as pa
15
+ from pandera.pandas import Check, Column, DataFrameSchema
16
+ except ImportError:
17
+ # Fallback to old import for older pandera versions
18
+ import pandera as pa
19
+ from pandera import Check, Column, DataFrameSchema
20
+
21
+ # Column definitions for reuse
22
+ YEAR_COLUMN = Column(
23
+ int,
24
+ checks=[
25
+ Check.ge(1960), # No data before 1960
26
+ Check.le(2100), # No projections beyond 2100
27
+ ],
28
+ nullable=False,
29
+ description="Year as integer",
30
+ )
31
+
32
+ ENTITY_CODE_COLUMN = Column(
33
+ str,
34
+ checks=[Check(lambda s: s.str.len() <= 10)],
35
+ nullable=False,
36
+ description="Entity code from source (varies by source)",
37
+ )
38
+
39
+ ISO3_COLUMN = Column(
40
+ str,
41
+ checks=[
42
+ Check(lambda s: (s.str.len() == 3) | s.isna()),
43
+ ],
44
+ nullable=True,
45
+ description="ISO3 country code",
46
+ )
47
+
48
+ EXCHANGE_RATE_COLUMN = Column(
49
+ float,
50
+ checks=[
51
+ Check.gt(0), # Exchange rates must be positive
52
+ ],
53
+ nullable=True,
54
+ description="Exchange rate (LCU per USD)",
55
+ )
56
+
57
+ DEFLATOR_COLUMN = Column(
58
+ float,
59
+ checks=[
60
+ Check.gt(0), # Deflators must be positive
61
+ ],
62
+ nullable=True,
63
+ description="Price deflator index",
64
+ )
65
+
66
+
67
+ class SourceDataSchema(pa.DataFrameModel):
68
+ """Base schema for all data sources.
69
+
70
+ All sources must have these minimum columns after processing.
71
+ """
72
+
73
+ pydeflate_year: int = pa.Field(ge=1960, le=2100)
74
+ pydeflate_entity_code: str = pa.Field(str_length={"max_value": 10})
75
+ pydeflate_iso3: str | None = pa.Field(nullable=True)
76
+
77
+ class Config:
78
+ """Schema configuration."""
79
+
80
+ strict = False # Allow additional columns
81
+ coerce = True # Attempt type coercion
82
+
83
+
84
+ class ExchangeDataSchema(SourceDataSchema):
85
+ """Schema for exchange rate data."""
86
+
87
+ pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
88
+ pydeflate_EXCHANGE_D = Column(
89
+ float,
90
+ checks=[Check.gt(0)],
91
+ nullable=True,
92
+ description="Exchange rate deflator (rebased)",
93
+ )
94
+
95
+ class Config:
96
+ """Schema configuration."""
97
+
98
+ strict = False
99
+ coerce = True
100
+
101
+
102
+ class IMFDataSchema(SourceDataSchema):
103
+ """Schema for IMF WEO data.
104
+
105
+ IMF provides GDP deflators, CPI, and exchange rates.
106
+ """
107
+
108
+ pydeflate_NGDP_D = Column(
109
+ float,
110
+ checks=[Check.gt(0)],
111
+ nullable=True,
112
+ description="GDP deflator",
113
+ )
114
+ pydeflate_PCPI = Column(
115
+ float,
116
+ checks=[Check.gt(0)],
117
+ nullable=True,
118
+ description="CPI (period average)",
119
+ )
120
+ pydeflate_PCPIE = Column(
121
+ float,
122
+ checks=[Check.gt(0)],
123
+ nullable=True,
124
+ description="CPI (end of period)",
125
+ )
126
+ pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
127
+
128
+ class Config:
129
+ """Schema configuration."""
130
+
131
+ strict = False
132
+ coerce = True
133
+
134
+
135
+ class WorldBankDataSchema(SourceDataSchema):
136
+ """Schema for World Bank data."""
137
+
138
+ pydeflate_NGDP_D = Column(
139
+ float,
140
+ checks=[Check.gt(0)],
141
+ nullable=True,
142
+ description="GDP deflator",
143
+ )
144
+ pydeflate_NGDP_DL = Column(
145
+ float,
146
+ checks=[Check.gt(0)],
147
+ nullable=True,
148
+ description="GDP deflator (linked)",
149
+ )
150
+ pydeflate_CPI = Column(
151
+ float,
152
+ checks=[Check.gt(0)],
153
+ nullable=True,
154
+ description="Consumer Price Index",
155
+ )
156
+ pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
157
+
158
+ class Config:
159
+ """Schema configuration."""
160
+
161
+ strict = False
162
+ coerce = True
163
+
164
+
165
+ class DACDataSchema(SourceDataSchema):
166
+ """Schema for OECD DAC data."""
167
+
168
+ pydeflate_DAC_DEFLATOR = Column(
169
+ float,
170
+ checks=[Check.gt(0)],
171
+ nullable=True,
172
+ description="DAC deflator",
173
+ )
174
+ pydeflate_NGDP_D = Column(
175
+ float,
176
+ checks=[Check.gt(0)],
177
+ nullable=True,
178
+ description="GDP deflator (computed)",
179
+ )
180
+ pydeflate_EXCHANGE = EXCHANGE_RATE_COLUMN
181
+
182
+ class Config:
183
+ """Schema configuration."""
184
+
185
+ strict = False
186
+ coerce = True
187
+
188
+
189
+ class UserInputSchema:
190
+ """Validation for user-provided DataFrames.
191
+
192
+ This is not a Pandera schema but provides methods to validate
193
+ user input with custom column names.
194
+ """
195
+
196
+ @staticmethod
197
+ def validate(
198
+ df,
199
+ id_column: str,
200
+ year_column: str,
201
+ value_column: str,
202
+ ) -> None:
203
+ """Validate user DataFrame has required columns and types.
204
+
205
+ Args:
206
+ df: User's DataFrame
207
+ id_column: Name of column with entity identifiers
208
+ year_column: Name of column with year data
209
+ value_column: Name of column with numeric values
210
+
211
+ Raises:
212
+ ConfigurationError: If required columns are missing
213
+ SchemaValidationError: If column types are invalid
214
+ """
215
+ from pydeflate.exceptions import ConfigurationError, SchemaValidationError
216
+
217
+ # Check required columns exist
218
+ missing_cols = []
219
+ for col_name, col in [
220
+ ("id_column", id_column),
221
+ ("year_column", year_column),
222
+ ("value_column", value_column),
223
+ ]:
224
+ if col not in df.columns:
225
+ missing_cols.append(f"{col_name}='{col}'")
226
+
227
+ if missing_cols:
228
+ raise ConfigurationError(
229
+ f"Required columns missing from DataFrame: {', '.join(missing_cols)}"
230
+ )
231
+
232
+ # Validate value column is numeric
233
+ if not pd.api.types.is_numeric_dtype(df[value_column]):
234
+ raise SchemaValidationError(
235
+ f"Column '{value_column}' must be numeric, got {df[value_column].dtype}"
236
+ )
237
+
238
+ # Validate year column can be converted to datetime
239
+ try:
240
+ pd.to_datetime(df[year_column], errors="coerce")
241
+ except Exception as e:
242
+ raise SchemaValidationError(
243
+ f"Column '{year_column}' cannot be interpreted as dates: {e}"
244
+ )
245
+
246
+
247
+ # Registry of schemas by source name
248
+ SCHEMA_REGISTRY: dict[str, type[DataFrameSchema]] = {
249
+ "IMF": IMFDataSchema,
250
+ "World Bank": WorldBankDataSchema,
251
+ "DAC": DACDataSchema,
252
+ }
253
+
254
+
255
+ def get_schema_for_source(source_name: str) -> type[DataFrameSchema] | None:
256
+ """Get the appropriate schema for a data source.
257
+
258
+ Args:
259
+ source_name: Name of the source (e.g., 'IMF', 'World Bank')
260
+
261
+ Returns:
262
+ Schema class for the source, or None if not found
263
+ """
264
+ return SCHEMA_REGISTRY.get(source_name)
265
+
266
+
267
+ def validate_source_data(df, source_name: str) -> None:
268
+ """Validate that source data matches expected schema.
269
+
270
+ Args:
271
+ df: DataFrame to validate
272
+ source_name: Name of the source
273
+
274
+ Raises:
275
+ SchemaValidationError: If data doesn't match schema
276
+ """
277
+ from pydeflate.exceptions import SchemaValidationError
278
+
279
+ schema_class = get_schema_for_source(source_name)
280
+ if schema_class is None:
281
+ # No schema defined for this source, skip validation
282
+ return
283
+
284
+ try:
285
+ # Instantiate the schema and validate
286
+ schema = schema_class()
287
+ schema.validate(df, lazy=True)
288
+ except pa.errors.SchemaErrors as e:
289
+ # Collect all validation errors
290
+ error_messages = []
291
+ for error in e.failure_cases.itertuples():
292
+ error_messages.append(f" - {error.check}: {error.failure_case}")
293
+
294
+ raise SchemaValidationError(
295
+ f"Data validation failed for {source_name}:\n" + "\n".join(error_messages),
296
+ source=source_name,
297
+ ) from e
@@ -1,41 +1,19 @@
1
- from datetime import datetime
2
- from pathlib import Path
1
+ from __future__ import annotations
2
+
3
3
  from typing import Any, Literal
4
4
 
5
5
  import pandas as pd
6
6
  from hdx.location.country import Country
7
7
 
8
- from pydeflate.pydeflate_config import PYDEFLATE_PATHS, logger
8
+ from pydeflate.pydeflate_config import logger
9
9
 
10
10
  AvailableDeflators = Literal["NGDP_D", "NGDP_DL", "CPI", "PCPI", "PCPIE"]
11
11
 
12
12
 
13
- def check_file_age(file: Path) -> int:
14
- """Check the age of a WEO file in days.
15
-
16
- Args:
17
- file (Path): The WEO parquet file to check.
18
-
19
- Returns:
20
- int: The number of days since the file was created.
21
- """
22
- current_date = datetime.today()
23
- # Extract date from the filename (format: weo_YYYY-MM-DD.parquet)
24
- file_date = datetime.strptime(file.stem.split("_")[-1], "%Y-%m-%d")
25
-
26
- # Return the difference in days between today and the file's date
27
- return (current_date - file_date).days
28
-
29
-
30
13
  def enforce_pyarrow_types(df: pd.DataFrame) -> pd.DataFrame:
31
- """Ensures that a DataFrame uses pyarrow dtypes."""
32
- return df.convert_dtypes(dtype_backend="pyarrow")
33
-
34
-
35
- def today() -> str:
36
- from datetime import datetime
14
+ """Ensure that a DataFrame uses pyarrow-backed dtypes."""
37
15
 
38
- return datetime.today().strftime("%Y-%m-%d")
16
+ return df.convert_dtypes(dtype_backend="pyarrow")
39
17
 
40
18
 
41
19
  def _match_regex_to_iso3(
@@ -52,20 +30,17 @@ def _match_regex_to_iso3(
52
30
  if additional_mapping is None:
53
31
  additional_mapping = {}
54
32
 
55
- # Create a Country object
56
33
  country = Country()
57
-
58
- # Match the regex strings to ISO3 country codes
59
- matches = {}
34
+ matches: dict[str, str | None] = {}
60
35
 
61
36
  for match in to_match:
62
37
  try:
63
38
  match_ = country.get_iso3_country_code_fuzzy(match)[0]
64
- except:
39
+ except Exception: # pragma: no cover - defensive logging
65
40
  match_ = None
66
41
  matches[match] = match_
67
42
  if match_ is None and match not in additional_mapping:
68
- logger.debug(f"No ISO3 match found for {match}")
43
+ logger.debug("No ISO3 match found for %s", match)
69
44
 
70
45
  return matches | additional_mapping
71
46
 
@@ -76,7 +51,7 @@ def convert_id(
76
51
  to_type: str = "ISO3",
77
52
  not_found: Any = None,
78
53
  *,
79
- additional_mapping: dict = None,
54
+ additional_mapping: dict | None = None,
80
55
  ) -> pd.Series:
81
56
  """Takes a Pandas' series with country IDs and converts them into the desired type.
82
57
 
@@ -93,7 +68,6 @@ def convert_id(
93
68
  the same datatype as the target type.
94
69
  """
95
70
 
96
- # if from and to are the same, return without changing anything
97
71
  if from_type == to_type:
98
72
  return series
99
73
 
@@ -107,7 +81,6 @@ def convert_id(
107
81
  mapping = mapping_functions[from_type](
108
82
  to_match=s_unique, additional_mapping=additional_mapping
109
83
  )
110
-
111
84
  return series.map(mapping).fillna(series if not_found is None else not_found)
112
85
 
113
86
 
@@ -141,7 +114,6 @@ def add_pydeflate_iso3(
141
114
  "Sub-Sahara Africa": "SSA",
142
115
  },
143
116
  )
144
-
145
117
  return df
146
118
 
147
119
 
@@ -160,7 +132,6 @@ def prefix_pydeflate_to_columns(
160
132
  df.columns = [
161
133
  f"{prefix}{col}" if not col.startswith(prefix) else col for col in df.columns
162
134
  ]
163
-
164
135
  return df
165
136
 
166
137
 
@@ -187,7 +158,7 @@ def compute_exchange_deflator(
187
158
  base_year_measure: str | None = None,
188
159
  exchange: str = "EXCHANGE",
189
160
  year: str = "year",
190
- grouper: list[str] = None,
161
+ grouper: list[str] | None = None,
191
162
  ) -> pd.DataFrame:
192
163
  """Compute the exchange rate deflator for each group of entities.
193
164
 
@@ -205,87 +176,68 @@ def compute_exchange_deflator(
205
176
  pd.DataFrame: DataFrame with an additional column for the exchange rate deflator.
206
177
  """
207
178
 
208
- def _add_deflator(
179
+ def _compute_deflator_for_group(
209
180
  group: pd.DataFrame,
210
- measure: str | None = "NGDPD_D",
211
- exchange: str = "EXCHANGE",
212
- year: str = "year",
181
+ measure: str | None,
182
+ exchange_col: str,
183
+ year_col: str,
184
+ deflator_col: str,
213
185
  ) -> pd.DataFrame:
214
-
215
- # if needed, clean exchange name
216
- if exchange.endswith("_to") or exchange.endswith("_from"):
217
- exchange_name = exchange.rsplit("_", 1)[0]
218
- else:
219
- exchange_name = exchange
220
-
221
- # Identify the base year for the deflator
186
+ """Compute deflator for a single group and add it as a column."""
187
+ # Identify base year
222
188
  if measure is not None:
223
- base_year = identify_base_year(group, measure=measure, year=year)
189
+ base_year = identify_base_year(group, measure=measure, year=year_col)
224
190
  else:
225
- base_year = group.dropna(subset=exchange)[year].max()
191
+ valid_rows = group.dropna(subset=[exchange_col])
192
+ base_year = valid_rows[year_col].max() if not valid_rows.empty else None
226
193
 
227
- # If no base year is found, return the group unchanged
194
+ # If no base year found, return group without deflator column
228
195
  if base_year is None or pd.isna(base_year):
229
196
  return group
230
197
 
231
198
  # Extract the exchange rate value for the base year
232
- base_value = group.loc[group[year] == base_year, exchange].values
199
+ base_value_rows = group.loc[group[year_col] == base_year, exchange_col]
233
200
 
234
- # If base value is found and valid, calculate the deflator
235
- if base_value.size > 0 and pd.notna(base_value[0]):
236
- group[f"{exchange_name}_D"] = round(
237
- 100 * group[exchange] / base_value[0], 6
238
- )
201
+ # If no valid base value, return group without deflator column
202
+ if base_value_rows.empty or pd.isna(base_value_rows.iloc[0]):
203
+ return group
204
+
205
+ # Calculate and add deflator column
206
+ base_value = base_value_rows.iloc[0]
207
+ group = group.copy()
208
+ group[deflator_col] = round(100 * group[exchange_col] / base_value, 6)
239
209
 
240
210
  return group
241
211
 
242
212
  if grouper is None:
243
213
  grouper = ["entity", "entity_code"]
244
214
 
245
- # Apply the deflator computation for each group of 'entity' and 'entity_code'
246
- return df.groupby(grouper, group_keys=False).apply(
247
- _add_deflator, measure=base_year_measure, exchange=exchange, year=year
248
- )
249
-
250
-
251
- def read_data(
252
- file_finder_func: callable,
253
- download_func: callable,
254
- data_name: str,
255
- update: bool = False,
256
- ) -> pd.DataFrame:
257
- """Generic function to read data from parquet files or download fresh data.
258
-
259
- Args:
260
- file_finder_func (function): Function to find existing data files in the path.
261
- download_func (function): Function to download fresh data if no files are
262
- found or an update is needed.
263
- data_name (str): Name of the dataset for logging purposes (e.g., "WEO", "DAC").
264
- update (bool): If True, forces downloading of new data even if files exist.
265
-
266
- Returns:
267
- pd.DataFrame: The latest available data.
268
- """
269
- # Find existing files using the provided file finder function
270
- files = file_finder_func(PYDEFLATE_PATHS.data)
271
-
272
- # If no files are found or update is requested, download new data
273
- if len(files) == 0 or update:
274
- download_func()
275
- files = file_finder_func(PYDEFLATE_PATHS.data)
276
-
277
- # If files are found, sort them by age and load the most recent one
278
- if len(files) > 0:
279
- files = sorted(files, key=check_file_age)
280
- latest_file = files[0]
281
-
282
- # Check if the latest file is older than 120 days and log a warning
283
- if check_file_age(latest_file) > 120:
284
- logger.warn(
285
- f"The latest {data_name} data is more than 120 days old.\n"
286
- f"Consider updating by setting update=True in the function call."
287
- )
288
-
289
- # Read and return the latest parquet file as a DataFrame
290
- logger.info(f"Reading {data_name} data from {latest_file}")
291
- return pd.read_parquet(latest_file)
215
+ # Determine the exchange column name for the deflator
216
+ if exchange.endswith("_to") or exchange.endswith("_from"):
217
+ exchange_name = exchange.rsplit("_", 1)[0]
218
+ else:
219
+ exchange_name = exchange
220
+
221
+ deflator_col = f"{exchange_name}_D"
222
+
223
+ # Process each group and concatenate results
224
+ # This approach avoids the FutureWarning from groupby().apply() operating on grouping columns
225
+ processed_groups = []
226
+ for name, group in df.groupby(grouper, sort=False):
227
+ processed_group = _compute_deflator_for_group(
228
+ group=group,
229
+ measure=base_year_measure,
230
+ exchange_col=exchange,
231
+ year_col=year,
232
+ deflator_col=deflator_col,
233
+ )
234
+ processed_groups.append(processed_group)
235
+
236
+ # Concatenate all processed groups and restore original row order
237
+ result = pd.concat(processed_groups, ignore_index=False)
238
+
239
+ # Sort by index to restore original row order
240
+ # (groupby may have changed the order when grouping rows together)
241
+ result = result.sort_index()
242
+
243
+ return result