microimpute 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- microimpute-0.1.0/PKG-INFO +53 -0
- microimpute-0.1.0/README.md +6 -0
- microimpute-0.1.0/microimpute/__init__.py +48 -0
- microimpute-0.1.0/microimpute/comparisons/__init__.py +20 -0
- microimpute-0.1.0/microimpute/comparisons/autoimpute.py +427 -0
- microimpute-0.1.0/microimpute/comparisons/data.py +481 -0
- microimpute-0.1.0/microimpute/comparisons/imputations.py +169 -0
- microimpute-0.1.0/microimpute/comparisons/quantile_loss.py +211 -0
- microimpute-0.1.0/microimpute/config.py +55 -0
- microimpute-0.1.0/microimpute/evaluations/__init__.py +6 -0
- microimpute-0.1.0/microimpute/evaluations/cross_validation.py +374 -0
- microimpute-0.1.0/microimpute/main.py +20 -0
- microimpute-0.1.0/microimpute/models/__init__.py +17 -0
- microimpute-0.1.0/microimpute/models/imputer.py +247 -0
- microimpute-0.1.0/microimpute/models/matching.py +419 -0
- microimpute-0.1.0/microimpute/models/ols.py +231 -0
- microimpute-0.1.0/microimpute/models/qrf.py +273 -0
- microimpute-0.1.0/microimpute/models/quantreg.py +248 -0
- microimpute-0.1.0/microimpute/tests/README.md +40 -0
- microimpute-0.1.0/microimpute/tests/__init__.py +1 -0
- microimpute-0.1.0/microimpute/tests/test_autoimpute.py +69 -0
- microimpute-0.1.0/microimpute/tests/test_basic.py +8 -0
- microimpute-0.1.0/microimpute/tests/test_models/README.md +126 -0
- microimpute-0.1.0/microimpute/tests/test_models/__init__.py +1 -0
- microimpute-0.1.0/microimpute/tests/test_models/test_imputers.py +152 -0
- microimpute-0.1.0/microimpute/tests/test_models/test_matching.py +252 -0
- microimpute-0.1.0/microimpute/tests/test_models/test_ols.py +131 -0
- microimpute-0.1.0/microimpute/tests/test_models/test_qrf.py +260 -0
- microimpute-0.1.0/microimpute/tests/test_models/test_quantreg.py +125 -0
- microimpute-0.1.0/microimpute/tests/test_quantile_comparison.py +105 -0
- microimpute-0.1.0/microimpute/utils/logging_utils.py +55 -0
- microimpute-0.1.0/microimpute/utils/qrf.py +275 -0
- microimpute-0.1.0/microimpute/utils/statmatch_hotdeck.py +264 -0
- microimpute-0.1.0/microimpute/visualizations/__init__.py +6 -0
- microimpute-0.1.0/microimpute/visualizations/plotting.py +627 -0
- microimpute-0.1.0/microimpute.egg-info/PKG-INFO +53 -0
- microimpute-0.1.0/microimpute.egg-info/SOURCES.txt +40 -0
- microimpute-0.1.0/microimpute.egg-info/dependency_links.txt +1 -0
- microimpute-0.1.0/microimpute.egg-info/requires.txt +43 -0
- microimpute-0.1.0/microimpute.egg-info/top_level.txt +1 -0
- microimpute-0.1.0/pyproject.toml +78 -0
- microimpute-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: microimpute
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Benchmarking imputation methods for microdata
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: numpy<2.0.0,>=1.26.0
|
|
8
|
+
Requires-Dist: pandas<3.0.0,>=2.2.0
|
|
9
|
+
Requires-Dist: plotly<6.0.0,>=5.24.0
|
|
10
|
+
Requires-Dist: kaleido<0.3.0,>=0.2.1
|
|
11
|
+
Requires-Dist: scikit-learn<2.0.0,>=1.6.1
|
|
12
|
+
Requires-Dist: scipy<2.0.0,>=1.11.0
|
|
13
|
+
Requires-Dist: requests<3.0.0,>=2.32.0
|
|
14
|
+
Requires-Dist: tqdm<5.0.0,>=4.65.0
|
|
15
|
+
Requires-Dist: statsmodels<0.15.0,>=0.14.0
|
|
16
|
+
Requires-Dist: quantile-forest<1.5.0,>=1.4.0
|
|
17
|
+
Requires-Dist: pydantic<3.0.0,>=2.8.0
|
|
18
|
+
Requires-Dist: optuna==4.3.0
|
|
19
|
+
Requires-Dist: joblib<2.0.0,>=1.2.0
|
|
20
|
+
Provides-Extra: dev
|
|
21
|
+
Requires-Dist: pytest<9.0.0,>=8.0.0; extra == "dev"
|
|
22
|
+
Requires-Dist: pytest-cov<7.0.0,>=6.0.0; extra == "dev"
|
|
23
|
+
Requires-Dist: flake8<7.0.0,>=6.0.0; extra == "dev"
|
|
24
|
+
Requires-Dist: black>=23.0.0; extra == "dev"
|
|
25
|
+
Requires-Dist: isort<6.0.0,>=5.9.0; extra == "dev"
|
|
26
|
+
Requires-Dist: mypy<2.0.0,>=1.0.0; extra == "dev"
|
|
27
|
+
Requires-Dist: build<2.0.0,>=1.0.0; extra == "dev"
|
|
28
|
+
Requires-Dist: linecheck<0.2.0,>=0.1.0; extra == "dev"
|
|
29
|
+
Provides-Extra: matching
|
|
30
|
+
Requires-Dist: rpy2<4.0.0,>=3.5.0; extra == "matching"
|
|
31
|
+
Provides-Extra: docs
|
|
32
|
+
Requires-Dist: sphinx<6.0.0,>=5.0.0; extra == "docs"
|
|
33
|
+
Requires-Dist: docutils<0.18.0,>=0.17.0; extra == "docs"
|
|
34
|
+
Requires-Dist: jupyter-book>=0.15.0; extra == "docs"
|
|
35
|
+
Requires-Dist: sphinx-book-theme>=1.0.0; extra == "docs"
|
|
36
|
+
Requires-Dist: sphinx-copybutton>=0.5.0; extra == "docs"
|
|
37
|
+
Requires-Dist: sphinx-design>=0.3.0; extra == "docs"
|
|
38
|
+
Requires-Dist: ipywidgets<8.0.0,>=7.8.0; extra == "docs"
|
|
39
|
+
Requires-Dist: plotly<6.0.0,>=5.24.0; extra == "docs"
|
|
40
|
+
Requires-Dist: sphinx-argparse>=0.4.0; extra == "docs"
|
|
41
|
+
Requires-Dist: sphinx-math-dollar>=1.2.1; extra == "docs"
|
|
42
|
+
Requires-Dist: myst-parser==0.18.1; extra == "docs"
|
|
43
|
+
Requires-Dist: myst-nb==0.17.2; extra == "docs"
|
|
44
|
+
Requires-Dist: pyyaml; extra == "docs"
|
|
45
|
+
Requires-Dist: furo==2022.12.7; extra == "docs"
|
|
46
|
+
Requires-Dist: h5py<4.0.0,>=3.1.0; extra == "docs"
|
|
47
|
+
|
|
48
|
+
# MicroImpute
|
|
49
|
+
|
|
50
|
+
MicroImpute enables variable imputation through different statistical methods. It facilitates comparison and benchmarking across methods through quantile loss calculations.
|
|
51
|
+
|
|
52
|
+
To install, run pip install microimpute.
|
|
53
|
+
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""MicroImpute Package
|
|
2
|
+
|
|
3
|
+
A package for benchmarking different imputation methods using microdata.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
__version__ = "0.1.0"
|
|
7
|
+
|
|
8
|
+
# Import data handling functions
|
|
9
|
+
from microimpute.comparisons.data import prepare_scf_data, preprocess_data
|
|
10
|
+
from microimpute.comparisons.imputations import get_imputations
|
|
11
|
+
|
|
12
|
+
# Import comparison utilities
|
|
13
|
+
from microimpute.comparisons.quantile_loss import (
|
|
14
|
+
compare_quantile_loss,
|
|
15
|
+
compute_quantile_loss,
|
|
16
|
+
quantile_loss,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Main configuration
|
|
20
|
+
from microimpute.config import (
|
|
21
|
+
PLOT_CONFIG,
|
|
22
|
+
QUANTILES,
|
|
23
|
+
RANDOM_STATE,
|
|
24
|
+
VALIDATE_CONFIG,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Import evaluation modules
|
|
28
|
+
from microimpute.evaluations.cross_validation import cross_validate_model
|
|
29
|
+
|
|
30
|
+
# Import main models and utilities
|
|
31
|
+
from microimpute.models import (
|
|
32
|
+
OLS,
|
|
33
|
+
QRF,
|
|
34
|
+
Imputer,
|
|
35
|
+
ImputerResults,
|
|
36
|
+
QuantReg,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
from microimpute.models.matching import Matching
|
|
41
|
+
except ImportError:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
# Import visualization modules
|
|
45
|
+
from microimpute.visualizations.plotting import (
|
|
46
|
+
method_comparison_results,
|
|
47
|
+
model_performance_results,
|
|
48
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Data Comparison Utilities
|
|
2
|
+
|
|
3
|
+
This module contains utilities for comparing different imputation methods.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
# Import automated imputation utilities
|
|
7
|
+
from .autoimpute import autoimpute
|
|
8
|
+
|
|
9
|
+
# Import data handling functions
|
|
10
|
+
from .data import prepare_scf_data, preprocess_data, scf_url
|
|
11
|
+
|
|
12
|
+
# Import imputation utilities
|
|
13
|
+
from .imputations import get_imputations
|
|
14
|
+
|
|
15
|
+
# Import loss functions
|
|
16
|
+
from .quantile_loss import (
|
|
17
|
+
compare_quantile_loss,
|
|
18
|
+
compute_quantile_loss,
|
|
19
|
+
quantile_loss,
|
|
20
|
+
)
|
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipeline for autoimputation of missing values in a dataset.
|
|
3
|
+
This module integrates all steps necessary for method selection and imputation of missing values.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import warnings
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import Any, Dict, List, Optional, Type
|
|
10
|
+
|
|
11
|
+
import joblib
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from pydantic import validate_call
|
|
14
|
+
from rpy2.robjects import pandas2ri
|
|
15
|
+
from tqdm.auto import tqdm
|
|
16
|
+
|
|
17
|
+
from microimpute.comparisons import *
|
|
18
|
+
from microimpute.comparisons.data import preprocess_data
|
|
19
|
+
from microimpute.config import (
|
|
20
|
+
QUANTILES,
|
|
21
|
+
RANDOM_STATE,
|
|
22
|
+
TRAIN_SIZE,
|
|
23
|
+
VALIDATE_CONFIG,
|
|
24
|
+
)
|
|
25
|
+
from microimpute.evaluations import cross_validate_model
|
|
26
|
+
from microimpute.models import *
|
|
27
|
+
|
|
28
|
+
log = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@validate_call(config=VALIDATE_CONFIG)
|
|
32
|
+
def autoimpute(
|
|
33
|
+
donor_data: pd.DataFrame,
|
|
34
|
+
receiver_data: pd.DataFrame,
|
|
35
|
+
predictors: List[str],
|
|
36
|
+
imputed_variables: List[str],
|
|
37
|
+
models: Optional[List[Type]] = None,
|
|
38
|
+
quantiles: Optional[List[float]] = QUANTILES,
|
|
39
|
+
hyperparameters: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
40
|
+
tune_hyperparameters: Optional[bool] = False,
|
|
41
|
+
random_state: Optional[int] = RANDOM_STATE,
|
|
42
|
+
train_size: Optional[float] = TRAIN_SIZE,
|
|
43
|
+
k_folds: Optional[int] = 5,
|
|
44
|
+
verbose: Optional[bool] = False,
|
|
45
|
+
) -> tuple[dict[float, pd.DataFrame], "Imputer", pd.DataFrame]:
|
|
46
|
+
"""Automatically select and apply the best imputation model.
|
|
47
|
+
|
|
48
|
+
This function evaluates multiple imputation methods using cross-validation
|
|
49
|
+
to determine which performs best on the provided donor data, then applies
|
|
50
|
+
the winning method to impute values in the receiver data.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
donor_data : Dataframe containing both predictor and target variables
|
|
54
|
+
used to train models
|
|
55
|
+
receiver_data : Dataframe containing predictor variables where imputed
|
|
56
|
+
values will be generated
|
|
57
|
+
predictors : List of column names of predictor variables used to
|
|
58
|
+
predict imputed variables
|
|
59
|
+
imputed_variables : List of column names of variables to be imputed in
|
|
60
|
+
the receiver data
|
|
61
|
+
models : List of imputer model classes to compare.
|
|
62
|
+
If None, uses [QRF, OLS, QuantReg, Matching]
|
|
63
|
+
quantiles : List of quantiles to predict for each imputed variable.
|
|
64
|
+
Uses default QUANTILES if not passed.
|
|
65
|
+
hyperparameters : Dictionary of hyperparameters for specific models,
|
|
66
|
+
with model names as keys. Defaults to None and uses default model hyperparameters then.
|
|
67
|
+
tune_hyperparameters : Whether to tune hyperparameters for the models.
|
|
68
|
+
Defaults to False.
|
|
69
|
+
random_state : Random seed for reproducibility
|
|
70
|
+
train_size : Proportion of data to use for training in preprocessing
|
|
71
|
+
k_folds : Number of folds for cross-validation. Defaults to 5.
|
|
72
|
+
verbose : Whether to print detailed logs. Defaults to False.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
A tuple containing:
|
|
76
|
+
- Dictionary mapping quantiles to DataFrames of imputed values
|
|
77
|
+
- The fitted imputation model (best performing)
|
|
78
|
+
- DataFrame with cross-validation performance metrics for all evaluated models
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: If inputs are invalid (e.g., invalid quantiles, missing columns)
|
|
82
|
+
RuntimeError: For unexpected errors during imputation
|
|
83
|
+
"""
|
|
84
|
+
# Set up logging level based on verbose parameter
|
|
85
|
+
log_level = logging.INFO if verbose else logging.WARNING
|
|
86
|
+
log.setLevel(log_level)
|
|
87
|
+
warnings.filterwarnings("ignore")
|
|
88
|
+
|
|
89
|
+
# Set up parallel processing
|
|
90
|
+
n_jobs: Optional[int] = -1
|
|
91
|
+
|
|
92
|
+
# Create a progress tracking system
|
|
93
|
+
if verbose:
|
|
94
|
+
main_progress = tqdm(total=4, desc="AutoImputation progress")
|
|
95
|
+
main_progress.set_description("Input validation")
|
|
96
|
+
|
|
97
|
+
# Step 0: Input validation
|
|
98
|
+
try:
|
|
99
|
+
# Validate quantiles if provided
|
|
100
|
+
if quantiles:
|
|
101
|
+
invalid_quantiles = [q for q in quantiles if not 0 <= q <= 1]
|
|
102
|
+
if invalid_quantiles:
|
|
103
|
+
error_msg = f"Invalid quantiles (must be between 0 and 1): {invalid_quantiles}"
|
|
104
|
+
log.error(error_msg)
|
|
105
|
+
raise ValueError(error_msg)
|
|
106
|
+
|
|
107
|
+
# Validate that predictor and imputed variable columns exist in donor data
|
|
108
|
+
missing_predictors_donor = [
|
|
109
|
+
col for col in predictors if col not in donor_data.columns
|
|
110
|
+
]
|
|
111
|
+
if missing_predictors_donor:
|
|
112
|
+
error_msg = f"Missing predictor columns in donor data: {missing_predictors_donor}"
|
|
113
|
+
log.error(error_msg)
|
|
114
|
+
raise ValueError(error_msg)
|
|
115
|
+
|
|
116
|
+
missing_predictors_receiver = [
|
|
117
|
+
col for col in predictors if col not in receiver_data.columns
|
|
118
|
+
]
|
|
119
|
+
if missing_predictors_receiver:
|
|
120
|
+
error_msg = f"Missing predictor columns in reciver data: {missing_predictors_receiver}"
|
|
121
|
+
log.error(error_msg)
|
|
122
|
+
raise ValueError(error_msg)
|
|
123
|
+
|
|
124
|
+
missing_imputed_donor = [
|
|
125
|
+
col for col in imputed_variables if col not in donor_data.columns
|
|
126
|
+
]
|
|
127
|
+
if missing_imputed_donor:
|
|
128
|
+
error_msg = f"Missing imputed variable columns in donor data: {missing_imputed_donor}"
|
|
129
|
+
log.error(error_msg)
|
|
130
|
+
raise ValueError(error_msg)
|
|
131
|
+
|
|
132
|
+
# Validate that predictor columns exist in receiver data (imputed variables may not be present in receiver data)
|
|
133
|
+
missing_predictors_receiver = [
|
|
134
|
+
col for col in predictors if col not in receiver_data.columns
|
|
135
|
+
]
|
|
136
|
+
if missing_predictors_receiver:
|
|
137
|
+
error_msg = f"Missing predictor columns in test data: {missing_predictors_receiver}"
|
|
138
|
+
log.error(error_msg)
|
|
139
|
+
raise ValueError(error_msg)
|
|
140
|
+
|
|
141
|
+
log.info(
|
|
142
|
+
f"Generating imputations to impute from {len(donor_data)} donor data to {len(receiver_data)} receiver data for variables {imputed_variables} with predictors {predictors}. "
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if (hyperparameters is not None) and (tune_hyperparameters == True):
|
|
146
|
+
error_msg = "Cannot specify both model_hyperparams and request to automatically tune hyperparameters, please select one or the other."
|
|
147
|
+
log.error(error_msg)
|
|
148
|
+
raise ValueError(error_msg)
|
|
149
|
+
|
|
150
|
+
# Step 1: Data preparation
|
|
151
|
+
if verbose:
|
|
152
|
+
log.info("Preprocessing data...")
|
|
153
|
+
main_progress.update(1)
|
|
154
|
+
main_progress.set_description("Data preparation")
|
|
155
|
+
|
|
156
|
+
# If imputed variables are in receiver data, remove them
|
|
157
|
+
receiver_data = receiver_data.drop(
|
|
158
|
+
columns=imputed_variables, errors="ignore"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
training_data = donor_data.copy()
|
|
162
|
+
imputing_data = receiver_data.copy()
|
|
163
|
+
|
|
164
|
+
training_data[predictors], dummy_info = preprocess_data(
|
|
165
|
+
training_data[predictors],
|
|
166
|
+
full_data=True,
|
|
167
|
+
train_size=train_size,
|
|
168
|
+
test_size=(1 - train_size),
|
|
169
|
+
)
|
|
170
|
+
training_data[imputed_variables], dummy_info, normalizing_params = (
|
|
171
|
+
preprocess_data(
|
|
172
|
+
training_data[imputed_variables],
|
|
173
|
+
full_data=True,
|
|
174
|
+
train_size=train_size,
|
|
175
|
+
test_size=(1 - train_size),
|
|
176
|
+
normalizing_features=True,
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
imputing_data, dummy_info = preprocess_data(
|
|
180
|
+
imputing_data[predictors],
|
|
181
|
+
full_data=True,
|
|
182
|
+
train_size=train_size,
|
|
183
|
+
test_size=(1 - train_size),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
if dummy_info:
|
|
187
|
+
# Retrieve new predictors and imputed variables after processed data
|
|
188
|
+
for orig_col, dummy_cols in dummy_info.items():
|
|
189
|
+
if orig_col in predictors:
|
|
190
|
+
predictors.remove(orig_col)
|
|
191
|
+
predictors.extend(dummy_cols)
|
|
192
|
+
elif orig_col in imputed_variables:
|
|
193
|
+
imputed_variables.remove(orig_col)
|
|
194
|
+
imputed_variables.extend(dummy_cols)
|
|
195
|
+
|
|
196
|
+
# Step 2: Imputation with each method
|
|
197
|
+
if verbose:
|
|
198
|
+
main_progress.update(1)
|
|
199
|
+
main_progress.set_description("Model evaluation")
|
|
200
|
+
|
|
201
|
+
if not models:
|
|
202
|
+
# If no models are provided, use default models
|
|
203
|
+
model_classes: List[Type[Imputer]] = [QRF, OLS, QuantReg, Matching]
|
|
204
|
+
else:
|
|
205
|
+
model_classes = models
|
|
206
|
+
|
|
207
|
+
if hyperparameters:
|
|
208
|
+
model_names = [
|
|
209
|
+
model_class.__name__ for model_class in model_classes
|
|
210
|
+
]
|
|
211
|
+
for model_name, model_params in hyperparameters.items():
|
|
212
|
+
if model_name in model_names:
|
|
213
|
+
# Update the model class with the provided hyperparameters
|
|
214
|
+
if model_name == "QRF":
|
|
215
|
+
log.info(
|
|
216
|
+
f"Using hyperparameters for QRF: {model_params}"
|
|
217
|
+
)
|
|
218
|
+
elif model_name == "Matching":
|
|
219
|
+
log.info(
|
|
220
|
+
f"Using hyperparameters for Matching: {model_params}"
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
log.info(
|
|
224
|
+
f"None of the hyperparameters provided are relevant for the supported models: {model_names}. Using default hyperparameters."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
method_test_losses = {}
|
|
228
|
+
log.info(
|
|
229
|
+
"Hyperparameter tuning and cross-validation for model comparisson in progress... "
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def evaluate_model(
|
|
233
|
+
model: Type[Imputer],
|
|
234
|
+
data: pd.DataFrame,
|
|
235
|
+
predictors: List[str],
|
|
236
|
+
imputed_variables: List[str],
|
|
237
|
+
quantiles: List[float],
|
|
238
|
+
k_folds: Optional[int] = 5,
|
|
239
|
+
random_state: Optional[bool] = RANDOM_STATE,
|
|
240
|
+
tune_hyperparams: Optional[bool] = True,
|
|
241
|
+
hyperparameters: Optional[Dict[str, Any]] = None,
|
|
242
|
+
) -> tuple[str, pd.DataFrame]:
|
|
243
|
+
"""Evaluate a single imputation model with cross-validation.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
model: The imputation model class to evaluate
|
|
247
|
+
data: The dataset to use for evaluation
|
|
248
|
+
predictors: List of predictor column names
|
|
249
|
+
imputed_variables: List of columns to impute
|
|
250
|
+
quantiles: List of quantiles to evaluate
|
|
251
|
+
k_folds: Number of cross-validation folds
|
|
252
|
+
random_state: Random seed for reproducibility
|
|
253
|
+
tune_hyperparams: Whether to tune hyperparameters
|
|
254
|
+
hyperparameters: Optional model-specific hyperparameters
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Tuple containing model name and cross-validation results DataFrame
|
|
258
|
+
"""
|
|
259
|
+
model_name = model.__name__
|
|
260
|
+
log.info(f"Evaluating {model_name}...")
|
|
261
|
+
|
|
262
|
+
# For Matching model using R, we need to activate converters in this thread
|
|
263
|
+
if model_name == "Matching":
|
|
264
|
+
# Explicitly activate pandas-to-R conversion for this thread
|
|
265
|
+
from rpy2.robjects import numpy2ri, pandas2ri
|
|
266
|
+
|
|
267
|
+
pandas2ri.activate()
|
|
268
|
+
numpy2ri.activate()
|
|
269
|
+
|
|
270
|
+
return model_name, cross_validate_model(
|
|
271
|
+
model_class=model,
|
|
272
|
+
data=data,
|
|
273
|
+
predictors=predictors,
|
|
274
|
+
imputed_variables=imputed_variables,
|
|
275
|
+
quantiles=quantiles,
|
|
276
|
+
n_splits=k_folds,
|
|
277
|
+
random_state=random_state,
|
|
278
|
+
model_hyperparams=hyperparameters,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Special handling for models that use rpy2
|
|
282
|
+
# Use sequential processing for Matching model to avoid thread context issues
|
|
283
|
+
has_matching = any(
|
|
284
|
+
model.__name__ == "Matching" for model in model_classes
|
|
285
|
+
)
|
|
286
|
+
if has_matching and n_jobs != 1:
|
|
287
|
+
log.info(
|
|
288
|
+
"Using sequential processing (n_jobs=1) because Matching model is present"
|
|
289
|
+
)
|
|
290
|
+
n_jobs = 1
|
|
291
|
+
|
|
292
|
+
parallel_tasks = []
|
|
293
|
+
for model in model_classes:
|
|
294
|
+
parallel_tasks.append(
|
|
295
|
+
(
|
|
296
|
+
model,
|
|
297
|
+
training_data,
|
|
298
|
+
predictors,
|
|
299
|
+
imputed_variables,
|
|
300
|
+
quantiles,
|
|
301
|
+
k_folds,
|
|
302
|
+
RANDOM_STATE,
|
|
303
|
+
tune_hyperparameters,
|
|
304
|
+
hyperparameters,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Execute in parallel
|
|
309
|
+
results = joblib.Parallel(n_jobs=n_jobs)(
|
|
310
|
+
joblib.delayed(lambda args: evaluate_model(*args))(task)
|
|
311
|
+
for task in tqdm(parallel_tasks, desc="Evaluating models")
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Process results
|
|
315
|
+
for model_name, cv_result in results:
|
|
316
|
+
method_test_losses[model_name] = cv_result.loc["test"]
|
|
317
|
+
|
|
318
|
+
method_results_df = pd.DataFrame.from_dict(
|
|
319
|
+
method_test_losses, orient="index"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Step 3: Compare imputation methods
|
|
323
|
+
log.info(f"Comparing across {model_classes} methods. ")
|
|
324
|
+
|
|
325
|
+
if verbose:
|
|
326
|
+
main_progress.update(1)
|
|
327
|
+
main_progress.set_description("Model selection")
|
|
328
|
+
|
|
329
|
+
# add a column called "mean_loss" with the average loss across quantiles
|
|
330
|
+
method_results_df["mean_loss"] = method_results_df.mean(axis=1)
|
|
331
|
+
|
|
332
|
+
# Step 4: Select best method
|
|
333
|
+
best_method = method_results_df["mean_loss"].idxmin()
|
|
334
|
+
best_row = method_results_df.loc[best_method]
|
|
335
|
+
|
|
336
|
+
log.info(
|
|
337
|
+
f"The method with the lowest average loss is {best_method}, with an average loss across variables and quantiles of {best_row['mean_loss']}. "
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Step 5: Generate imputations with the best method on the receiver data
|
|
341
|
+
log.info(
|
|
342
|
+
f"Generating imputations using the best method: {best_method} on the receiver data. "
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
if verbose:
|
|
346
|
+
main_progress.update(1)
|
|
347
|
+
main_progress.set_description("Imputation")
|
|
348
|
+
|
|
349
|
+
models_dict = {model.__name__: model for model in model_classes}
|
|
350
|
+
chosen_model = models_dict[best_method]
|
|
351
|
+
|
|
352
|
+
# Initialize the model
|
|
353
|
+
model = chosen_model()
|
|
354
|
+
imputation_q = 0.5 # this can be an input parameter, or if unspecified will default to a random quantile
|
|
355
|
+
# Fit the model
|
|
356
|
+
if best_method == "QuantReg":
|
|
357
|
+
# For QuantReg, we need to explicitly fit the quantile
|
|
358
|
+
fitted_model = model.fit(
|
|
359
|
+
training_data,
|
|
360
|
+
predictors,
|
|
361
|
+
imputed_variables,
|
|
362
|
+
quantiles=[imputation_q],
|
|
363
|
+
)
|
|
364
|
+
else:
|
|
365
|
+
fitted_model = model.fit(
|
|
366
|
+
training_data, predictors, imputed_variables
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Predict with explicit quantiles
|
|
370
|
+
imputations = fitted_model.predict(
|
|
371
|
+
imputing_data, quantiles=[imputation_q]
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Unnormalize the imputations
|
|
375
|
+
mean = pd.Series(
|
|
376
|
+
{col: p["mean"] for col, p in normalizing_params.items()}
|
|
377
|
+
)
|
|
378
|
+
std = pd.Series(
|
|
379
|
+
{col: p["std"] for col, p in normalizing_params.items()}
|
|
380
|
+
)
|
|
381
|
+
unnormalized_imputations = {}
|
|
382
|
+
for q, df in imputations.items():
|
|
383
|
+
cols = df.columns # the imputed variables
|
|
384
|
+
df_unnorm = df.mul(std[cols], axis=1) # × std
|
|
385
|
+
df_unnorm = df_unnorm.add(mean[cols], axis=1) # + mean
|
|
386
|
+
unnormalized_imputations[q] = df_unnorm
|
|
387
|
+
|
|
388
|
+
log.info(
|
|
389
|
+
f"Imputation generation completed for {len(receiver_data)} samples using the best method: {best_method} and the median quantile. "
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if verbose:
|
|
393
|
+
main_progress.set_description("Complete")
|
|
394
|
+
main_progress.close()
|
|
395
|
+
|
|
396
|
+
median_imputations = unnormalized_imputations[
|
|
397
|
+
0.5
|
|
398
|
+
] # this may not work if we change the value of imputation_q
|
|
399
|
+
# Add the imputed variables to the receiver data
|
|
400
|
+
try:
|
|
401
|
+
missing_imputed_vars = []
|
|
402
|
+
for var in imputed_variables:
|
|
403
|
+
if var in median_imputations.columns:
|
|
404
|
+
receiver_data[var] = median_imputations[var]
|
|
405
|
+
else:
|
|
406
|
+
missing_imputed_vars.append(var)
|
|
407
|
+
log.warning(
|
|
408
|
+
f"Imputed variable {var} not found in the imputations. "
|
|
409
|
+
)
|
|
410
|
+
except KeyError as e:
|
|
411
|
+
error_msg = f"Missing imputed variable in the imputations: {e}"
|
|
412
|
+
log.error(error_msg)
|
|
413
|
+
raise ValueError(error_msg)
|
|
414
|
+
|
|
415
|
+
return (
|
|
416
|
+
unnormalized_imputations,
|
|
417
|
+
receiver_data,
|
|
418
|
+
fitted_model,
|
|
419
|
+
method_results_df,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
except ValueError as e:
|
|
423
|
+
# Re-raise validation errors directly
|
|
424
|
+
raise e
|
|
425
|
+
except Exception as e:
|
|
426
|
+
log.error(f"Unexpected error during autoimputation: {str(e)}")
|
|
427
|
+
raise RuntimeError(f"Failed to generate imputations: {str(e)}") from e
|