dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,288 +1,29 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
|
-
import miceforest as mf
|
|
3
2
|
from pathlib import Path
|
|
3
|
+
from typing import Union
|
|
4
|
+
import miceforest as mf
|
|
4
5
|
import matplotlib.pyplot as plt
|
|
5
6
|
import numpy as np
|
|
6
7
|
from plotnine import ggplot, labs, theme, element_blank # type: ignore
|
|
7
|
-
from typing import Optional, Union
|
|
8
8
|
|
|
9
|
-
from
|
|
10
|
-
from
|
|
11
|
-
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
9
|
+
from ..utilities import load_dataframe, merge_dataframes, save_dataframe_filename
|
|
10
|
+
from ..schema import FeatureSchema
|
|
11
|
+
|
|
12
|
+
from ..math_utilities import discretize_categorical_values
|
|
13
|
+
from ..path_manager import make_fullpath, list_csv_paths, sanitize_filename
|
|
14
|
+
from .._core import get_logger
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
_LOGGER = get_logger("
|
|
17
|
+
_LOGGER = get_logger("DragonMICE")
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
21
|
"DragonMICE",
|
|
22
|
-
"apply_mice",
|
|
23
|
-
"save_imputed_datasets",
|
|
24
22
|
"get_convergence_diagnostic",
|
|
25
23
|
"get_imputed_distributions",
|
|
26
|
-
"run_mice_pipeline",
|
|
27
24
|
]
|
|
28
25
|
|
|
29
26
|
|
|
30
|
-
def apply_mice(df: pd.DataFrame, df_name: str, binary_columns: Optional[list[str]]=None, resulting_datasets: int=1, iterations: int=20, random_state: int=101):
|
|
31
|
-
|
|
32
|
-
# Initialize kernel with number of imputed datasets to generate
|
|
33
|
-
kernel = mf.ImputationKernel(
|
|
34
|
-
data=df,
|
|
35
|
-
num_datasets=resulting_datasets,
|
|
36
|
-
random_state=random_state
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
_LOGGER.info("➡️ MICE imputation running...")
|
|
40
|
-
|
|
41
|
-
# Perform MICE with n iterations per dataset
|
|
42
|
-
kernel.mice(iterations)
|
|
43
|
-
|
|
44
|
-
# Retrieve the imputed datasets
|
|
45
|
-
imputed_datasets = [kernel.complete_data(dataset=i) for i in range(resulting_datasets)]
|
|
46
|
-
|
|
47
|
-
if imputed_datasets is None or len(imputed_datasets) == 0:
|
|
48
|
-
_LOGGER.error("No imputed datasets were generated. Check the MICE process.")
|
|
49
|
-
raise ValueError()
|
|
50
|
-
|
|
51
|
-
# threshold binary columns
|
|
52
|
-
if binary_columns is not None:
|
|
53
|
-
invalid_binary_columns = set(binary_columns) - set(df.columns)
|
|
54
|
-
if invalid_binary_columns:
|
|
55
|
-
_LOGGER.warning(f"These 'binary columns' are not in the dataset:")
|
|
56
|
-
for invalid_binary_col in invalid_binary_columns:
|
|
57
|
-
print(f" - {invalid_binary_col}")
|
|
58
|
-
valid_binary_columns = [col for col in binary_columns if col not in invalid_binary_columns]
|
|
59
|
-
for imputed_df in imputed_datasets:
|
|
60
|
-
for binary_column_name in valid_binary_columns:
|
|
61
|
-
imputed_df[binary_column_name] = threshold_binary_values(imputed_df[binary_column_name]) # type: ignore
|
|
62
|
-
|
|
63
|
-
if resulting_datasets == 1:
|
|
64
|
-
imputed_dataset_names = [f"{df_name}_MICE"]
|
|
65
|
-
else:
|
|
66
|
-
imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(resulting_datasets)]
|
|
67
|
-
|
|
68
|
-
# Ensure indexes match
|
|
69
|
-
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
70
|
-
assert imputed_df.shape[0] == df.shape[0], f"❌ Row count mismatch in dataset {subname}" # type: ignore
|
|
71
|
-
assert all(imputed_df.index == df.index), f"❌ Index mismatch in dataset {subname}" # type: ignore
|
|
72
|
-
# print("✅ All imputed datasets match the original DataFrame indexes.")
|
|
73
|
-
|
|
74
|
-
_LOGGER.info("MICE imputation complete.")
|
|
75
|
-
|
|
76
|
-
return kernel, imputed_datasets, imputed_dataset_names
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def save_imputed_datasets(save_dir: Union[str, Path], imputed_datasets: list, df_targets: pd.DataFrame, imputed_dataset_names: list[str]):
|
|
80
|
-
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
81
|
-
merged_df = merge_dataframes(imputed_df, df_targets, direction="horizontal", verbose=False)
|
|
82
|
-
save_dataframe_filename(df=merged_df, save_dir=save_dir, filename=subname)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
#Get names of features that had missing values before imputation
|
|
86
|
-
def _get_na_column_names(df: pd.DataFrame):
|
|
87
|
-
return [col for col in df.columns if df[col].isna().any()]
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
#Convergence diagnostic
|
|
91
|
-
def get_convergence_diagnostic(kernel: mf.ImputationKernel, imputed_dataset_names: list[str], column_names: list[str], root_dir: Union[str,Path], fontsize: int=16):
|
|
92
|
-
"""
|
|
93
|
-
Generate and save convergence diagnostic plots for imputed variables.
|
|
94
|
-
|
|
95
|
-
Parameters:
|
|
96
|
-
- kernel: Trained miceforest.ImputationKernel.
|
|
97
|
-
- imputed_dataset_names: Names assigned to each imputed dataset.
|
|
98
|
-
- column_names: List of feature names to track over iterations.
|
|
99
|
-
- root_dir: Directory to save convergence plots.
|
|
100
|
-
"""
|
|
101
|
-
# get number of iterations used
|
|
102
|
-
iterations_cap = kernel.iteration_count()
|
|
103
|
-
dataset_count = kernel.num_datasets
|
|
104
|
-
|
|
105
|
-
if dataset_count != len(imputed_dataset_names):
|
|
106
|
-
_LOGGER.error(f"Expected {dataset_count} names in imputed_dataset_names, got {len(imputed_dataset_names)}")
|
|
107
|
-
raise ValueError()
|
|
108
|
-
|
|
109
|
-
# Check path
|
|
110
|
-
root_path = make_fullpath(root_dir, make=True)
|
|
111
|
-
|
|
112
|
-
# Styling parameters
|
|
113
|
-
label_font = {'size': fontsize, 'weight': 'bold'}
|
|
114
|
-
|
|
115
|
-
# iterate over each imputed dataset
|
|
116
|
-
for dataset_id, imputed_dataset_name in zip(range(dataset_count), imputed_dataset_names):
|
|
117
|
-
#Check directory for current dataset
|
|
118
|
-
dataset_file_dir = f"Convergence_Metrics_{imputed_dataset_name}"
|
|
119
|
-
local_save_dir = make_fullpath(input_path=root_path / dataset_file_dir, make=True)
|
|
120
|
-
|
|
121
|
-
for feature_name in column_names:
|
|
122
|
-
means_per_iteration = []
|
|
123
|
-
for iteration in range(iterations_cap):
|
|
124
|
-
current_imputed = kernel.complete_data(dataset=dataset_id, iteration=iteration)
|
|
125
|
-
means_per_iteration.append(np.mean(current_imputed[feature_name])) # type: ignore
|
|
126
|
-
|
|
127
|
-
plt.figure(figsize=(10, 8))
|
|
128
|
-
plt.plot(means_per_iteration, marker='o')
|
|
129
|
-
plt.xlabel("Iteration", **label_font)
|
|
130
|
-
plt.ylabel("Mean of Imputed Values", **label_font)
|
|
131
|
-
plt.title(f"Mean Convergence for '{feature_name}'", **label_font)
|
|
132
|
-
|
|
133
|
-
# Adjust plot display for the X axis
|
|
134
|
-
_ticks = np.arange(iterations_cap)
|
|
135
|
-
_labels = np.arange(1, iterations_cap + 1)
|
|
136
|
-
plt.xticks(ticks=_ticks, labels=_labels) # type: ignore
|
|
137
|
-
plt.grid(True)
|
|
138
|
-
|
|
139
|
-
feature_save_name = sanitize_filename(feature_name)
|
|
140
|
-
feature_save_name = feature_save_name + ".svg"
|
|
141
|
-
save_path = local_save_dir / feature_save_name
|
|
142
|
-
plt.savefig(save_path, bbox_inches='tight', format="svg")
|
|
143
|
-
plt.close()
|
|
144
|
-
|
|
145
|
-
_LOGGER.info(f"{dataset_file_dir} process completed.")
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
# Imputed distributions
|
|
149
|
-
def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_dir: Union[str, Path], column_names: list[str], one_plot: bool=False, fontsize: int=14):
|
|
150
|
-
'''
|
|
151
|
-
It works using miceforest's authors implementation of the method `.plot_imputed_distributions()`.
|
|
152
|
-
|
|
153
|
-
Set `one_plot=True` to save a single image including all feature distribution plots instead.
|
|
154
|
-
'''
|
|
155
|
-
# Check path
|
|
156
|
-
root_path = make_fullpath(root_dir, make=True)
|
|
157
|
-
|
|
158
|
-
local_dir_name = f"Distribution_Metrics_{df_name}_imputed"
|
|
159
|
-
local_save_dir = make_fullpath(root_path / local_dir_name, make=True)
|
|
160
|
-
|
|
161
|
-
# Styling parameters
|
|
162
|
-
legend_kwargs = {'frameon': True, 'facecolor': 'white', 'framealpha': 0.8}
|
|
163
|
-
label_font = {'size': fontsize, 'weight': 'bold'}
|
|
164
|
-
|
|
165
|
-
def _process_figure(fig, filename: str):
|
|
166
|
-
"""Helper function to add labels and legends to a figure"""
|
|
167
|
-
|
|
168
|
-
if not isinstance(fig, ggplot):
|
|
169
|
-
_LOGGER.error(f"Expected a plotnine.ggplot object, received {type(fig)}.")
|
|
170
|
-
raise TypeError()
|
|
171
|
-
|
|
172
|
-
# Edit labels and title
|
|
173
|
-
fig = fig + theme(
|
|
174
|
-
plot_title=element_blank(), # removes labs(title=...)
|
|
175
|
-
strip_text=element_blank() # removes facet_wrap labels
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
fig = fig + labs(y="", x="")
|
|
179
|
-
|
|
180
|
-
# Render to matplotlib figure
|
|
181
|
-
fig = fig.draw()
|
|
182
|
-
|
|
183
|
-
if not hasattr(fig, 'axes') or len(fig.axes) == 0:
|
|
184
|
-
_LOGGER.error("Rendered figure has no axes to modify.")
|
|
185
|
-
raise RuntimeError()
|
|
186
|
-
|
|
187
|
-
if filename == "Combined_Distributions":
|
|
188
|
-
custom_xlabel = "Feature Values"
|
|
189
|
-
else:
|
|
190
|
-
custom_xlabel = filename
|
|
191
|
-
|
|
192
|
-
for ax in fig.axes:
|
|
193
|
-
# Set axis labels
|
|
194
|
-
ax.set_xlabel(custom_xlabel, **label_font)
|
|
195
|
-
ax.set_ylabel('Distribution', **label_font)
|
|
196
|
-
|
|
197
|
-
# Add legend based on line colors
|
|
198
|
-
lines = ax.get_lines()
|
|
199
|
-
if len(lines) >= 1:
|
|
200
|
-
lines[0].set_label('Original Data')
|
|
201
|
-
if len(lines) > 1:
|
|
202
|
-
lines[1].set_label('Imputed Data')
|
|
203
|
-
ax.legend(**legend_kwargs)
|
|
204
|
-
|
|
205
|
-
# Adjust layout and save
|
|
206
|
-
# fig.tight_layout()
|
|
207
|
-
# fig.subplots_adjust(bottom=0.2, left=0.2) # Optional, depending on overflow
|
|
208
|
-
|
|
209
|
-
# sanitize savename
|
|
210
|
-
feature_save_name = sanitize_filename(filename)
|
|
211
|
-
feature_save_name = feature_save_name + ".svg"
|
|
212
|
-
new_save_path = local_save_dir / feature_save_name
|
|
213
|
-
|
|
214
|
-
fig.savefig(
|
|
215
|
-
new_save_path,
|
|
216
|
-
format='svg',
|
|
217
|
-
bbox_inches='tight',
|
|
218
|
-
pad_inches=0.1
|
|
219
|
-
)
|
|
220
|
-
plt.close(fig)
|
|
221
|
-
|
|
222
|
-
if one_plot:
|
|
223
|
-
# Generate combined plot
|
|
224
|
-
fig = kernel.plot_imputed_distributions(variables=column_names)
|
|
225
|
-
_process_figure(fig, "Combined_Distributions")
|
|
226
|
-
# Generate individual plots per feature
|
|
227
|
-
else:
|
|
228
|
-
for feature in column_names:
|
|
229
|
-
fig = kernel.plot_imputed_distributions(variables=[feature])
|
|
230
|
-
_process_figure(fig, feature)
|
|
231
|
-
|
|
232
|
-
_LOGGER.info(f"{local_dir_name} completed.")
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str],
|
|
236
|
-
save_datasets_dir: Union[str,Path], save_metrics_dir: Union[str,Path],
|
|
237
|
-
binary_columns: Optional[list[str]]=None,
|
|
238
|
-
resulting_datasets: int=1,
|
|
239
|
-
iterations: int=20,
|
|
240
|
-
random_state: int=101):
|
|
241
|
-
"""
|
|
242
|
-
Call functions in sequence for each dataset in the provided path or directory:
|
|
243
|
-
1. Load dataframe
|
|
244
|
-
2. Apply MICE
|
|
245
|
-
3. Save imputed dataset(s)
|
|
246
|
-
4. Save convergence metrics
|
|
247
|
-
5. Save distribution metrics
|
|
248
|
-
|
|
249
|
-
Target columns must be skipped from the imputation. Binary columns will be thresholded after imputation.
|
|
250
|
-
"""
|
|
251
|
-
# Check paths
|
|
252
|
-
save_datasets_path = make_fullpath(save_datasets_dir, make=True)
|
|
253
|
-
save_metrics_path = make_fullpath(save_metrics_dir, make=True)
|
|
254
|
-
|
|
255
|
-
input_path = make_fullpath(df_path_or_dir)
|
|
256
|
-
if input_path.is_file():
|
|
257
|
-
all_file_paths = [input_path]
|
|
258
|
-
else:
|
|
259
|
-
all_file_paths = list(list_csv_paths(input_path, raise_on_empty=True).values())
|
|
260
|
-
|
|
261
|
-
for df_path in all_file_paths:
|
|
262
|
-
df: pd.DataFrame
|
|
263
|
-
df, df_name = load_dataframe(df_path=df_path, kind="pandas") # type: ignore
|
|
264
|
-
|
|
265
|
-
df, df_targets = _skip_targets(df, target_columns)
|
|
266
|
-
|
|
267
|
-
kernel, imputed_datasets, imputed_dataset_names = apply_mice(df=df, df_name=df_name, binary_columns=binary_columns, resulting_datasets=resulting_datasets, iterations=iterations, random_state=random_state)
|
|
268
|
-
|
|
269
|
-
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
270
|
-
|
|
271
|
-
imputed_column_names = _get_na_column_names(df=df)
|
|
272
|
-
|
|
273
|
-
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
274
|
-
|
|
275
|
-
get_imputed_distributions(kernel=kernel, df_name=df_name, root_dir=save_metrics_path, column_names=imputed_column_names)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
def _skip_targets(df: pd.DataFrame, target_cols: list[str]):
|
|
279
|
-
valid_targets = [col for col in target_cols if col in df.columns]
|
|
280
|
-
df_targets = df[valid_targets]
|
|
281
|
-
df_feats = df.drop(columns=valid_targets)
|
|
282
|
-
return df_feats, df_targets
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
# modern implementation
|
|
286
27
|
class DragonMICE:
|
|
287
28
|
"""
|
|
288
29
|
A modern MICE imputation pipeline that uses a FeatureSchema
|
|
@@ -293,71 +34,80 @@ class DragonMICE:
|
|
|
293
34
|
def __init__(self,
|
|
294
35
|
schema: FeatureSchema,
|
|
295
36
|
impute_targets: bool = False,
|
|
296
|
-
iterations: int =
|
|
37
|
+
iterations: int = 30,
|
|
297
38
|
resulting_datasets: int = 1,
|
|
298
39
|
random_state: int = 101):
|
|
299
40
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
41
|
+
# Validation
|
|
42
|
+
if not isinstance(schema, FeatureSchema):
|
|
43
|
+
raise TypeError(f"schema must be a FeatureSchema, got {type(schema)}")
|
|
44
|
+
if iterations < 1:
|
|
45
|
+
raise ValueError("iterations must be >= 1")
|
|
46
|
+
if resulting_datasets < 1:
|
|
47
|
+
raise ValueError("resulting_datasets must be >= 1")
|
|
48
|
+
|
|
49
|
+
# Private Attributes
|
|
50
|
+
self._schema = schema
|
|
51
|
+
self._impute_targets = impute_targets
|
|
52
|
+
self._random_state = random_state
|
|
53
|
+
self._iterations = iterations
|
|
54
|
+
self._resulting_datasets = resulting_datasets
|
|
305
55
|
|
|
306
56
|
# --- Store schema info ---
|
|
307
57
|
|
|
308
58
|
# 1. Categorical info
|
|
309
|
-
if not self.
|
|
59
|
+
if not self._schema.categorical_index_map:
|
|
310
60
|
_LOGGER.warning("FeatureSchema has no 'categorical_index_map'. No discretization will be applied.")
|
|
311
|
-
self.
|
|
61
|
+
self._cat_info = {}
|
|
312
62
|
else:
|
|
313
|
-
self.
|
|
63
|
+
self._cat_info = self._schema.categorical_index_map
|
|
314
64
|
|
|
315
65
|
# 2. Ordered feature names (critical for index mapping)
|
|
316
|
-
|
|
66
|
+
# Convert to list immediately to avoid Pandas Tuple indexing errors
|
|
67
|
+
self._ordered_features = list(self._schema.feature_names)
|
|
317
68
|
|
|
318
69
|
# 3. Names of categorical features
|
|
319
|
-
self.
|
|
70
|
+
self._categorical_features = list(self._schema.categorical_feature_names)
|
|
71
|
+
|
|
72
|
+
_LOGGER.info(f"DragonMICE initialized. Impute Targets: {self._impute_targets}. Found {len(self._cat_info)} categorical features to discretize.")
|
|
320
73
|
|
|
321
|
-
|
|
74
|
+
@property
|
|
75
|
+
def schema(self) -> FeatureSchema:
|
|
76
|
+
"""Exposes the used FeatureSchema as read-only for inspection/logging purposes."""
|
|
77
|
+
return self._schema
|
|
322
78
|
|
|
323
79
|
def _post_process(self, imputed_df: pd.DataFrame) -> pd.DataFrame:
|
|
324
80
|
"""
|
|
325
81
|
Applies schema-based discretization to a completed dataframe.
|
|
326
|
-
|
|
327
|
-
This method works around the behavior of `discretize_categorical_values`
|
|
328
|
-
(which returns a full int32 array) by:
|
|
329
|
-
1. Extracting *only* the schema features.
|
|
330
|
-
2. Discretizing them.
|
|
331
|
-
3. Updating the original dataframe (which may contain targets) with these integers.
|
|
332
82
|
"""
|
|
333
83
|
# If no categorical features are defined, return the df as-is.
|
|
334
|
-
if not self.
|
|
84
|
+
if not self._cat_info:
|
|
335
85
|
return imputed_df
|
|
336
86
|
|
|
337
87
|
try:
|
|
338
88
|
# 1. Extract the features strictly defined in the schema
|
|
339
89
|
# We must respect the schema order for index-based discretization
|
|
340
|
-
df_schema_features = imputed_df[self.
|
|
90
|
+
df_schema_features = imputed_df[self._ordered_features]
|
|
341
91
|
|
|
342
92
|
# 2. Convert to NumPy array
|
|
343
93
|
array_ordered = df_schema_features.to_numpy()
|
|
344
94
|
|
|
345
|
-
# 3. Apply discretization utility (returns int32 array)
|
|
95
|
+
# 3. Apply discretization utility (returns int32 array usually, or floats)
|
|
346
96
|
discretized_array_int32 = discretize_categorical_values(
|
|
347
97
|
array_ordered,
|
|
348
|
-
self.
|
|
98
|
+
self._cat_info,
|
|
349
99
|
start_at_zero=True
|
|
350
100
|
)
|
|
351
101
|
|
|
352
102
|
# 4. Create a DataFrame for the discretized values
|
|
353
103
|
df_discretized_full = pd.DataFrame(
|
|
354
104
|
discretized_array_int32,
|
|
355
|
-
columns=self.
|
|
105
|
+
columns=self._ordered_features,
|
|
356
106
|
index=df_schema_features.index
|
|
357
107
|
)
|
|
358
108
|
|
|
359
109
|
# 5. Isolate only the categorical columns that changed
|
|
360
|
-
df_discretized_cats = df_discretized_full[self.
|
|
110
|
+
df_discretized_cats = df_discretized_full[self._categorical_features]
|
|
361
111
|
|
|
362
112
|
# 6. Update the original imputed DF
|
|
363
113
|
# This preserves Target columns if they exist in imputed_df
|
|
@@ -367,7 +117,7 @@ class DragonMICE:
|
|
|
367
117
|
return final_df
|
|
368
118
|
|
|
369
119
|
except Exception as e:
|
|
370
|
-
_LOGGER.error(f"Failed during post-processing discretization:\n\tSchema features: {len(self.
|
|
120
|
+
_LOGGER.error(f"Failed during post-processing discretization:\n\tSchema features: {len(self._ordered_features)}\n{e}")
|
|
371
121
|
raise
|
|
372
122
|
|
|
373
123
|
def _run_mice(self,
|
|
@@ -378,43 +128,45 @@ class DragonMICE:
|
|
|
378
128
|
|
|
379
129
|
Parameters:
|
|
380
130
|
df (pd.DataFrame): The input dataframe.
|
|
381
|
-
If impute_targets=False, this should only be features.
|
|
382
|
-
If impute_targets=True, this can be the full dataset.
|
|
383
|
-
df_name (str): The base name for the dataset.
|
|
384
131
|
"""
|
|
385
132
|
# Validation: Ensure Schema features exist in the input
|
|
386
|
-
|
|
133
|
+
# Note: self._ordered_features is already a list
|
|
134
|
+
missing_cols = [col for col in self._ordered_features if col not in df.columns]
|
|
387
135
|
if missing_cols:
|
|
388
136
|
_LOGGER.error(f"Input DataFrame is missing required schema columns: {missing_cols}")
|
|
389
|
-
raise ValueError()
|
|
137
|
+
raise ValueError(f"Missing columns: {missing_cols}")
|
|
390
138
|
|
|
391
139
|
# If NOT imputing targets, we strictly filter to features.
|
|
392
140
|
# If we ARE imputing targets, we use the whole DF provided (Features + Targets).
|
|
393
|
-
if not self.
|
|
394
|
-
data_for_mice = df[self.
|
|
141
|
+
if not self._impute_targets:
|
|
142
|
+
data_for_mice = df[self._ordered_features]
|
|
395
143
|
else:
|
|
396
144
|
data_for_mice = df
|
|
397
145
|
|
|
398
146
|
# 1. Initialize kernel
|
|
399
147
|
kernel = mf.ImputationKernel(
|
|
400
148
|
data=data_for_mice,
|
|
401
|
-
num_datasets=self.
|
|
402
|
-
random_state=self.
|
|
149
|
+
num_datasets=self._resulting_datasets,
|
|
150
|
+
random_state=self._random_state
|
|
403
151
|
)
|
|
404
152
|
|
|
405
153
|
# base message
|
|
406
154
|
message = "➡️ Schema-based MICE imputation running"
|
|
407
|
-
if self.
|
|
155
|
+
if self._impute_targets:
|
|
408
156
|
message += " (Targets included)"
|
|
409
157
|
|
|
410
158
|
_LOGGER.info(message)
|
|
411
159
|
|
|
412
160
|
# 2. Perform MICE
|
|
413
|
-
|
|
161
|
+
try:
|
|
162
|
+
kernel.mice(self._iterations)
|
|
163
|
+
except Exception as e:
|
|
164
|
+
_LOGGER.error(f"MICE imputation failed during execution: {e}")
|
|
165
|
+
raise
|
|
414
166
|
|
|
415
167
|
# 3. Retrieve, process, and collect datasets
|
|
416
168
|
imputed_datasets = []
|
|
417
|
-
for i in range(self.
|
|
169
|
+
for i in range(self._resulting_datasets):
|
|
418
170
|
# complete_data returns a pd.DataFrame
|
|
419
171
|
completed_df = kernel.complete_data(dataset=i)
|
|
420
172
|
|
|
@@ -431,15 +183,19 @@ class DragonMICE:
|
|
|
431
183
|
raise ValueError()
|
|
432
184
|
|
|
433
185
|
# 4. Generate names
|
|
434
|
-
if self.
|
|
186
|
+
if self._resulting_datasets == 1:
|
|
435
187
|
imputed_dataset_names = [f"{df_name}_MICE"]
|
|
436
188
|
else:
|
|
437
|
-
imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(self.
|
|
189
|
+
imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(self._resulting_datasets)]
|
|
438
190
|
|
|
439
191
|
# 5. Validate indexes and Row Counts
|
|
440
192
|
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
441
|
-
|
|
442
|
-
|
|
193
|
+
if imputed_df.shape[0] != df.shape[0]:
|
|
194
|
+
_LOGGER.error(f"Row count mismatch in dataset {subname}")
|
|
195
|
+
raise ValueError()
|
|
196
|
+
if not all(imputed_df.index == df.index):
|
|
197
|
+
_LOGGER.error(f"Index mismatch in dataset {subname}")
|
|
198
|
+
raise ValueError()
|
|
443
199
|
|
|
444
200
|
_LOGGER.info("Schema-based MICE imputation complete.")
|
|
445
201
|
|
|
@@ -452,34 +208,51 @@ class DragonMICE:
|
|
|
452
208
|
):
|
|
453
209
|
"""
|
|
454
210
|
Runs the complete MICE imputation pipeline.
|
|
211
|
+
|
|
212
|
+
Parameters:
|
|
213
|
+
df_path_or_dir (str | Path): Path to a CSV file or directory containing CSV files.
|
|
214
|
+
save_datasets_dir (str | Path): Directory to save imputed datasets.
|
|
215
|
+
save_metrics_dir (str | Path): Directory to save convergence and distribution metrics.
|
|
455
216
|
"""
|
|
456
217
|
# Check paths
|
|
457
|
-
save_datasets_path = make_fullpath(save_datasets_dir, make=True)
|
|
458
|
-
save_metrics_path = make_fullpath(save_metrics_dir, make=True)
|
|
218
|
+
save_datasets_path = make_fullpath(save_datasets_dir, make=True, enforce="directory")
|
|
219
|
+
save_metrics_path = make_fullpath(save_metrics_dir, make=True, enforce="directory")
|
|
459
220
|
|
|
460
221
|
input_path = make_fullpath(df_path_or_dir)
|
|
461
222
|
if input_path.is_file():
|
|
462
223
|
all_file_paths = [input_path]
|
|
463
|
-
|
|
224
|
+
elif input_path.is_dir():
|
|
464
225
|
all_file_paths = list(list_csv_paths(input_path, raise_on_empty=True).values())
|
|
226
|
+
else:
|
|
227
|
+
_LOGGER.error(f"Input path '{input_path}' is neither a file nor a directory.")
|
|
228
|
+
raise FileNotFoundError()
|
|
465
229
|
|
|
466
230
|
for df_path in all_file_paths:
|
|
467
231
|
|
|
468
|
-
df, df_name = load_dataframe(df_path=df_path, kind="pandas")
|
|
232
|
+
df, df_name = load_dataframe(df_path=df_path, kind="pandas") # type: ignore
|
|
469
233
|
|
|
470
234
|
# --- SPLIT LOGIC BASED ON CONFIGURATION ---
|
|
471
|
-
if self.
|
|
235
|
+
if self._impute_targets:
|
|
472
236
|
# If we impute targets, we pass the whole DF to MICE.
|
|
473
237
|
# We pass an empty DF as 'targets' to save_imputed_datasets to prevent duplication.
|
|
474
238
|
df_input = df
|
|
475
239
|
df_targets_to_save = pd.DataFrame(index=df.index)
|
|
476
240
|
|
|
477
|
-
#
|
|
478
|
-
imputed_column_names =
|
|
241
|
+
# Monitor all columns that had NaNs
|
|
242
|
+
imputed_column_names = [col for col in df.columns if df[col].isna().any()]
|
|
479
243
|
else:
|
|
480
|
-
#
|
|
481
|
-
|
|
482
|
-
|
|
244
|
+
# Explicitly cast tuple to list for Pandas indexing
|
|
245
|
+
feature_cols = list(self._schema.feature_names)
|
|
246
|
+
|
|
247
|
+
# Check for column existence before slicing
|
|
248
|
+
if not set(feature_cols).issubset(df.columns):
|
|
249
|
+
missing = set(feature_cols) - set(df.columns)
|
|
250
|
+
_LOGGER.error(f"Dataset '{df_name}' is missing schema features: {missing}")
|
|
251
|
+
raise KeyError(f"Missing features: {missing}")
|
|
252
|
+
|
|
253
|
+
df_input = df[feature_cols]
|
|
254
|
+
# Drop features to get targets (more robust than explicit selection if targets vary)
|
|
255
|
+
df_targets_to_save = df.drop(columns=feature_cols)
|
|
483
256
|
|
|
484
257
|
imputed_column_names = _get_na_column_names(df=df_input) # type: ignore
|
|
485
258
|
|
|
@@ -487,7 +260,7 @@ class DragonMICE:
|
|
|
487
260
|
kernel, imputed_datasets, imputed_dataset_names = self._run_mice(df=df_input, df_name=df_name) # type: ignore
|
|
488
261
|
|
|
489
262
|
# Save (merges imputed_datasets with df_targets_to_save)
|
|
490
|
-
|
|
263
|
+
_save_imputed_datasets(
|
|
491
264
|
save_dir=save_datasets_path,
|
|
492
265
|
imputed_datasets=imputed_datasets,
|
|
493
266
|
df_targets=df_targets_to_save,
|
|
@@ -510,5 +283,153 @@ class DragonMICE:
|
|
|
510
283
|
)
|
|
511
284
|
|
|
512
285
|
|
|
513
|
-
def
|
|
514
|
-
|
|
286
|
+
def _save_imputed_datasets(save_dir: Union[str, Path], imputed_datasets: list, df_targets: pd.DataFrame, imputed_dataset_names: list[str]):
|
|
287
|
+
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
288
|
+
merged_df = merge_dataframes(imputed_df, df_targets, direction="horizontal", verbose=False)
|
|
289
|
+
save_dataframe_filename(df=merged_df, save_dir=save_dir, filename=subname)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
#Convergence diagnostic
|
|
293
|
+
def get_convergence_diagnostic(kernel: mf.ImputationKernel, imputed_dataset_names: list[str], column_names: list[str], root_dir: Union[str,Path], fontsize: int=16):
|
|
294
|
+
"""
|
|
295
|
+
Generate and save convergence diagnostic plots for imputed variables.
|
|
296
|
+
|
|
297
|
+
Parameters:
|
|
298
|
+
- kernel: Trained miceforest.ImputationKernel.
|
|
299
|
+
- imputed_dataset_names: Names assigned to each imputed dataset.
|
|
300
|
+
- column_names: List of feature names to track over iterations.
|
|
301
|
+
- root_dir: Directory to save convergence plots.
|
|
302
|
+
"""
|
|
303
|
+
# get number of iterations used
|
|
304
|
+
iterations_cap = kernel.iteration_count()
|
|
305
|
+
dataset_count = kernel.num_datasets
|
|
306
|
+
|
|
307
|
+
if dataset_count != len(imputed_dataset_names):
|
|
308
|
+
_LOGGER.error(f"Expected {dataset_count} names in imputed_dataset_names, got {len(imputed_dataset_names)}")
|
|
309
|
+
raise ValueError()
|
|
310
|
+
|
|
311
|
+
# Check path
|
|
312
|
+
root_path = make_fullpath(root_dir, make=True)
|
|
313
|
+
|
|
314
|
+
# Styling parameters
|
|
315
|
+
label_font = {'size': fontsize, 'weight': 'bold'}
|
|
316
|
+
|
|
317
|
+
# iterate over each imputed dataset
|
|
318
|
+
for dataset_id, imputed_dataset_name in zip(range(dataset_count), imputed_dataset_names):
|
|
319
|
+
#Check directory for current dataset
|
|
320
|
+
dataset_file_dir = f"Convergence_Metrics_{imputed_dataset_name}"
|
|
321
|
+
local_save_dir = make_fullpath(input_path=root_path / dataset_file_dir, make=True)
|
|
322
|
+
|
|
323
|
+
for feature_name in column_names:
|
|
324
|
+
means_per_iteration = []
|
|
325
|
+
for iteration in range(iterations_cap):
|
|
326
|
+
current_imputed = kernel.complete_data(dataset=dataset_id, iteration=iteration)
|
|
327
|
+
means_per_iteration.append(np.mean(current_imputed[feature_name])) # type: ignore
|
|
328
|
+
|
|
329
|
+
plt.figure(figsize=(10, 8))
|
|
330
|
+
plt.plot(means_per_iteration, marker='o')
|
|
331
|
+
plt.xlabel("Iteration", **label_font)
|
|
332
|
+
plt.ylabel("Mean of Imputed Values", **label_font)
|
|
333
|
+
plt.title(f"Mean Convergence for '{feature_name}'", **label_font)
|
|
334
|
+
|
|
335
|
+
# Adjust plot display for the X axis
|
|
336
|
+
_ticks = np.arange(iterations_cap)
|
|
337
|
+
_labels = np.arange(1, iterations_cap + 1)
|
|
338
|
+
plt.xticks(ticks=_ticks, labels=_labels) # type: ignore
|
|
339
|
+
plt.grid(True)
|
|
340
|
+
|
|
341
|
+
feature_save_name = sanitize_filename(feature_name)
|
|
342
|
+
feature_save_name = feature_save_name + ".svg"
|
|
343
|
+
save_path = local_save_dir / feature_save_name
|
|
344
|
+
plt.savefig(save_path, bbox_inches='tight', format="svg")
|
|
345
|
+
plt.close()
|
|
346
|
+
|
|
347
|
+
_LOGGER.info(f"{dataset_file_dir} process completed.")
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
# Imputed distributions
|
|
351
|
+
def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_dir: Union[str, Path], column_names: list[str], one_plot: bool=False, fontsize: int=14):
|
|
352
|
+
'''
|
|
353
|
+
It works using miceforest's authors implementation of the method `.plot_imputed_distributions()`.
|
|
354
|
+
|
|
355
|
+
Set `one_plot=True` to save a single image including all feature distribution plots instead.
|
|
356
|
+
'''
|
|
357
|
+
# Check path
|
|
358
|
+
root_path = make_fullpath(root_dir, make=True)
|
|
359
|
+
|
|
360
|
+
local_dir_name = f"Distribution_Metrics_{df_name}_imputed"
|
|
361
|
+
local_save_dir = make_fullpath(root_path / local_dir_name, make=True)
|
|
362
|
+
|
|
363
|
+
# Styling parameters
|
|
364
|
+
legend_kwargs = {'frameon': True, 'facecolor': 'white', 'framealpha': 0.8}
|
|
365
|
+
label_font = {'size': fontsize, 'weight': 'bold'}
|
|
366
|
+
|
|
367
|
+
def _process_figure(fig, filename: str):
|
|
368
|
+
"""Helper function to add labels and legends to a figure"""
|
|
369
|
+
|
|
370
|
+
if not isinstance(fig, ggplot):
|
|
371
|
+
_LOGGER.error(f"Expected a plotnine.ggplot object, received {type(fig)}.")
|
|
372
|
+
raise TypeError()
|
|
373
|
+
|
|
374
|
+
# Edit labels and title
|
|
375
|
+
fig = fig + theme(
|
|
376
|
+
plot_title=element_blank(), # removes labs(title=...)
|
|
377
|
+
strip_text=element_blank() # removes facet_wrap labels
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
fig = fig + labs(y="", x="")
|
|
381
|
+
|
|
382
|
+
# Render to matplotlib figure
|
|
383
|
+
fig = fig.draw()
|
|
384
|
+
|
|
385
|
+
if not hasattr(fig, 'axes') or len(fig.axes) == 0:
|
|
386
|
+
_LOGGER.error("Rendered figure has no axes to modify.")
|
|
387
|
+
raise RuntimeError()
|
|
388
|
+
|
|
389
|
+
if filename == "Combined_Distributions":
|
|
390
|
+
custom_xlabel = "Feature Values"
|
|
391
|
+
else:
|
|
392
|
+
custom_xlabel = filename
|
|
393
|
+
|
|
394
|
+
for ax in fig.axes:
|
|
395
|
+
# Set axis labels
|
|
396
|
+
ax.set_xlabel(custom_xlabel, **label_font)
|
|
397
|
+
ax.set_ylabel('Distribution', **label_font)
|
|
398
|
+
|
|
399
|
+
# Add legend based on line colors
|
|
400
|
+
lines = ax.get_lines()
|
|
401
|
+
if len(lines) >= 1:
|
|
402
|
+
lines[0].set_label('Original Data')
|
|
403
|
+
if len(lines) > 1:
|
|
404
|
+
lines[1].set_label('Imputed Data')
|
|
405
|
+
ax.legend(**legend_kwargs)
|
|
406
|
+
|
|
407
|
+
# Adjust layout and save
|
|
408
|
+
# fig.tight_layout()
|
|
409
|
+
# fig.subplots_adjust(bottom=0.2, left=0.2) # Optional, depending on overflow
|
|
410
|
+
|
|
411
|
+
# sanitize savename
|
|
412
|
+
feature_save_name = sanitize_filename(filename)
|
|
413
|
+
feature_save_name = feature_save_name + ".svg"
|
|
414
|
+
new_save_path = local_save_dir / feature_save_name
|
|
415
|
+
|
|
416
|
+
fig.savefig(
|
|
417
|
+
new_save_path,
|
|
418
|
+
format='svg',
|
|
419
|
+
bbox_inches='tight',
|
|
420
|
+
pad_inches=0.1
|
|
421
|
+
)
|
|
422
|
+
plt.close(fig)
|
|
423
|
+
|
|
424
|
+
if one_plot:
|
|
425
|
+
# Generate combined plot
|
|
426
|
+
fig = kernel.plot_imputed_distributions(variables=column_names)
|
|
427
|
+
_process_figure(fig, "Combined_Distributions")
|
|
428
|
+
# Generate individual plots per feature
|
|
429
|
+
else:
|
|
430
|
+
for feature in column_names:
|
|
431
|
+
fig = kernel.plot_imputed_distributions(variables=[feature])
|
|
432
|
+
_process_figure(fig, feature)
|
|
433
|
+
|
|
434
|
+
_LOGGER.info(f"{local_dir_name} completed.")
|
|
435
|
+
|