dragon-ml-toolbox 13.0.0__py3-none-any.whl → 14.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/METADATA +12 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +49 -0
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_configuration.py +108 -0
- ml_tools/ML_datasetmaster.py +241 -260
- ml_tools/ML_evaluation.py +229 -76
- ml_tools/ML_evaluation_multi.py +45 -16
- ml_tools/ML_inference.py +0 -1
- ml_tools/ML_models.py +135 -55
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +49 -36
- ml_tools/ML_trainer.py +498 -29
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1492 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +203 -0
- ml_tools/PSO_optimization.py +5 -1
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_schema.py +96 -0
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +576 -138
- ml_tools/ensemble_evaluation.py +53 -10
- ml_tools/keys.py +43 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +65 -86
- ml_tools/serde.py +78 -17
- ml_tools/utilities.py +192 -3
- dragon_ml_toolbox-13.0.0.dist-info/RECORD +0 -41
- ml_tools/ML_simple_optimization.py +0 -413
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version:
|
|
3
|
+
Version: 14.7.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -34,6 +34,10 @@ Requires-Dist: Pillow; extra == "ml"
|
|
|
34
34
|
Requires-Dist: evotorch; extra == "ml"
|
|
35
35
|
Requires-Dist: pyarrow; extra == "ml"
|
|
36
36
|
Requires-Dist: colorlog; extra == "ml"
|
|
37
|
+
Requires-Dist: torchmetrics; extra == "ml"
|
|
38
|
+
Provides-Extra: py-tab
|
|
39
|
+
Requires-Dist: pytorch_tabular; extra == "py-tab"
|
|
40
|
+
Requires-Dist: omegaconf; extra == "py-tab"
|
|
37
41
|
Provides-Extra: mice
|
|
38
42
|
Requires-Dist: numpy<2.0; extra == "mice"
|
|
39
43
|
Requires-Dist: pandas; extra == "mice"
|
|
@@ -137,15 +141,22 @@ ETL_cleaning
|
|
|
137
141
|
ETL_engineering
|
|
138
142
|
math_utilities
|
|
139
143
|
ML_callbacks
|
|
144
|
+
ML_configuration
|
|
140
145
|
ML_datasetmaster
|
|
141
146
|
ML_evaluation_multi
|
|
142
147
|
ML_evaluation
|
|
143
148
|
ML_inference
|
|
144
149
|
ML_models
|
|
150
|
+
ML_models_advanced # Requires the extra flag [py-tab]
|
|
145
151
|
ML_optimization
|
|
146
152
|
ML_scaler
|
|
147
153
|
ML_trainer
|
|
148
154
|
ML_utilities
|
|
155
|
+
ML_vision_datasetmaster
|
|
156
|
+
ML_vision_evaluation
|
|
157
|
+
ML_vision_inference
|
|
158
|
+
ML_vision_models
|
|
159
|
+
ML_vision_transformers
|
|
149
160
|
optimization_tools
|
|
150
161
|
path_manager
|
|
151
162
|
PSO_optimization
|
|
@@ -191,7 +202,6 @@ pip install "dragon-ml-toolbox[excel]"
|
|
|
191
202
|
#### Modules:
|
|
192
203
|
|
|
193
204
|
```Bash
|
|
194
|
-
constants
|
|
195
205
|
custom_logger
|
|
196
206
|
handle_excel
|
|
197
207
|
path_manager
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
dragon_ml_toolbox-14.7.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-14.7.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=gkOdNDbKYpIJezwSo2CEnISkLeYfYHv9t8b5K2-P69A,2687
|
|
3
|
+
ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
|
|
4
|
+
ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
|
|
5
|
+
ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
|
|
6
|
+
ml_tools/MICE_imputation.py,sha256=KLJXGQLKJ6AuWWttAG-LCCaxpS-ygM4dXPiguHDaL6Y,20815
|
|
7
|
+
ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
|
|
8
|
+
ml_tools/ML_configuration.py,sha256=DaYmm7Yklcu1emLyo-pRQG74SK4YEkCYFRT6_aV3rqA,4417
|
|
9
|
+
ml_tools/ML_datasetmaster.py,sha256=Zi5jBnBI_U6tD8mpCVL5bQcsqsGEMAzMsCVI_wFD2QU,30175
|
|
10
|
+
ml_tools/ML_evaluation.py,sha256=EvlgFeMQeZ1RSEMtNd-nv7W0d0SVcR4n6cwW5UG16DU,25358
|
|
11
|
+
ml_tools/ML_evaluation_multi.py,sha256=bQZ2gJY-dBzKQxvtd-B6wVaGBdFpQGVBr7tQZFokp5E,17166
|
|
12
|
+
ml_tools/ML_inference.py,sha256=YJ953bhNWsdlPRtJQh3h2ACfMIgp8dQ9KtL9Azar-5s,23489
|
|
13
|
+
ml_tools/ML_models.py,sha256=PqOcNlws7vCJMbiVCKqlPuktxvskZVUHG3VfU-Yshf8,31415
|
|
14
|
+
ml_tools/ML_models_advanced.py,sha256=vk3PZBSu3DVso2S1rKTxxdS43XG8Q5FnasIL3-rMajc,12410
|
|
15
|
+
ml_tools/ML_optimization.py,sha256=P0zkhKAwTpkorIBtR0AOIDcyexo5ngmvFUzo3DfNO-E,22692
|
|
16
|
+
ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
|
|
17
|
+
ml_tools/ML_trainer.py,sha256=salZxfv3RWRCiinp5S9xeUsHysMbMQ52EecR8GyEbaM,51461
|
|
18
|
+
ml_tools/ML_utilities.py,sha256=eYe2N-65FTzaOHF5gmiJl-HmicyzhqcdvlDiIivr5_g,22993
|
|
19
|
+
ml_tools/ML_vision_datasetmaster.py,sha256=bmHDC6SsBUxDSFjqQGuyzGfKuf1Imi1Ng6O2-dYF7I4,62607
|
|
20
|
+
ml_tools/ML_vision_evaluation.py,sha256=t12R7i1RkOCt9zu1_lxSBr8OH6A6Get0k8ftDLctn6I,10486
|
|
21
|
+
ml_tools/ML_vision_inference.py,sha256=He3KV3VJAm8PwO-fOq4b9VO8UXFr-GmpuCnoHXf4VZI,20588
|
|
22
|
+
ml_tools/ML_vision_models.py,sha256=WqiRN9JAjv--BcwkDrooXAs4Qo26JHPCHh3JSPm4kMI,26226
|
|
23
|
+
ml_tools/ML_vision_transformers.py,sha256=h332O9BjDMgxrBc0I-bJwJODWlcp7nJHbX1QS2etwBk,7738
|
|
24
|
+
ml_tools/PSO_optimization.py,sha256=T-HWHMRJUnPvPwixdU5jif3_rnnI36TzcL8u3oSCwuA,22960
|
|
25
|
+
ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
|
|
26
|
+
ml_tools/SQL.py,sha256=vXLPGfVVg8bfkbBE3HVfyEclVbdJy0TBhuQONtMwSCQ,11234
|
|
27
|
+
ml_tools/VIF_factor.py,sha256=at5IVqPvicja2-DNSTSIIy3SkzDWCmLzo3qTG_qr5n8,10422
|
|
28
|
+
ml_tools/_ML_vision_recipe.py,sha256=zrgxFUvTJqQVuwR7jWlbIC2FD29u6eNFPkTRoJ7yEZI,3178
|
|
29
|
+
ml_tools/__init__.py,sha256=kJiankjz9_qXu7gU92mYqYg_anLvt-B6RtW0mMH8uGo,76
|
|
30
|
+
ml_tools/_logger.py,sha256=dlp5cGbzooK9YSNSZYB4yjZrOaQUGW8PTrM411AOvL8,4717
|
|
31
|
+
ml_tools/_schema.py,sha256=yu6aWmn_2Z4_AxAtJGDDCIa96y6JcUp-vgnCS013Qmw,3908
|
|
32
|
+
ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
33
|
+
ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
|
|
34
|
+
ml_tools/custom_logger.py,sha256=TGc0Ww2Xlqj2XE3q4bP43hV7T3qnb5ci9f0pYHXF5TY,11226
|
|
35
|
+
ml_tools/data_exploration.py,sha256=bwHzFJ-IAo5GN3T53F-1J_pXUg8VHS91sG_90utAsfg,69911
|
|
36
|
+
ml_tools/ensemble_evaluation.py,sha256=2sJ3jD6yBNPRNwSokyaLKqKHi0QhF13ChoFe5yd4zwg,28368
|
|
37
|
+
ml_tools/ensemble_inference.py,sha256=0yLmLNj45RVVoSCLH1ZYJG9IoAhTkWUqEZmLOQTFGTY,9348
|
|
38
|
+
ml_tools/ensemble_learning.py,sha256=vsIED7nlheYI4w2SBzP6SC1AnNeMfn-2A1Gqw5EfxsM,21964
|
|
39
|
+
ml_tools/handle_excel.py,sha256=pfdAPb9ywegFkM9T54bRssDOsX-K7rSeV0RaMz7lEAo,14006
|
|
40
|
+
ml_tools/keys.py,sha256=-OiL9G0RIOKQk6BwETKIP3LWz2s5-x6lZW2YitJa4mY,3330
|
|
41
|
+
ml_tools/math_utilities.py,sha256=xeKq1quR_3DYLgowcp4Uam_4s3JltUyOnqMOGuAiYWU,8802
|
|
42
|
+
ml_tools/optimization_tools.py,sha256=TYFQ2nSnp7xxs-VyoZISWgnGJghFbsWasHjruegyJRs,12763
|
|
43
|
+
ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
|
|
44
|
+
ml_tools/serde.py,sha256=c8uDYjYry_VrLvoG4ixqDj5pij88lVn6Tu4NHcPkwDU,6943
|
|
45
|
+
ml_tools/utilities.py,sha256=aWqvYzmxlD74PD5Yqu1VuTekDJeYLQrmPIU_VeVyRp0,22526
|
|
46
|
+
dragon_ml_toolbox-14.7.0.dist-info/METADATA,sha256=NTifVXiC2zr5RhzCUTuUMEcU-wfswXxoYOO6N3UXFmM,6492
|
|
47
|
+
dragon_ml_toolbox-14.7.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
48
|
+
dragon_ml_toolbox-14.7.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
49
|
+
dragon_ml_toolbox-14.7.0.dist-info/RECORD,,
|
|
@@ -27,3 +27,13 @@ This project depends on the following third-party packages. Each is governed by
|
|
|
27
27
|
- [plotnine](https://github.com/has2k1/plotnine/blob/main/LICENSE)
|
|
28
28
|
- [tqdm](https://github.com/tqdm/tqdm/blob/master/LICENSE)
|
|
29
29
|
- [pyarrow](https://github.com/apache/arrow/blob/main/LICENSE.txt)
|
|
30
|
+
- [colorlog](https://github.com/borntyping/python-colorlog/blob/main/LICENSE)
|
|
31
|
+
- [evotorch](https://github.com/nnaisense/evotorch/blob/master/LICENSE)
|
|
32
|
+
- [FreeSimpleGUI](https://github.com/spyoungtech/FreeSimpleGUI/blob/main/license.txt)
|
|
33
|
+
- [nuitka](https://github.com/Nuitka/Nuitka/blob/main/LICENSE.txt)
|
|
34
|
+
- [omegaconf](https://github.com/omry/omegaconf/blob/master/LICENSE)
|
|
35
|
+
- [ordered-set](https://github.com/rspeer/ordered-set/blob/master/MIT-LICENSE)
|
|
36
|
+
- [pyinstaller](https://github.com/pyinstaller/pyinstaller/blob/develop/COPYING.txt)
|
|
37
|
+
- [pytorch_tabular](https://github.com/manujosephv/pytorch_tabular/blob/main/LICENSE)
|
|
38
|
+
- [torchmetrics](https://github.com/Lightning-AI/torchmetrics/blob/master/LICENSE)
|
|
39
|
+
- [zstandard](https://github.com/indygreg/python-zstandard/blob/main/LICENSE)
|
ml_tools/MICE_imputation.py
CHANGED
|
@@ -7,19 +7,20 @@ from plotnine import ggplot, labs, theme, element_blank # type: ignore
|
|
|
7
7
|
from typing import Optional, Union
|
|
8
8
|
|
|
9
9
|
from .utilities import load_dataframe, merge_dataframes, save_dataframe_filename
|
|
10
|
-
from .math_utilities import threshold_binary_values
|
|
10
|
+
from .math_utilities import threshold_binary_values, discretize_categorical_values
|
|
11
11
|
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
12
12
|
from ._logger import _LOGGER
|
|
13
13
|
from ._script_info import _script_info
|
|
14
|
+
from ._schema import FeatureSchema
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
18
|
+
"MiceImputer",
|
|
17
19
|
"apply_mice",
|
|
18
20
|
"save_imputed_datasets",
|
|
19
|
-
"get_na_column_names",
|
|
20
21
|
"get_convergence_diagnostic",
|
|
21
22
|
"get_imputed_distributions",
|
|
22
|
-
"run_mice_pipeline"
|
|
23
|
+
"run_mice_pipeline",
|
|
23
24
|
]
|
|
24
25
|
|
|
25
26
|
|
|
@@ -79,7 +80,7 @@ def save_imputed_datasets(save_dir: Union[str, Path], imputed_datasets: list, df
|
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
#Get names of features that had missing values before imputation
|
|
82
|
-
def
|
|
83
|
+
def _get_na_column_names(df: pd.DataFrame):
|
|
83
84
|
return [col for col in df.columns if df[col].isna().any()]
|
|
84
85
|
|
|
85
86
|
|
|
@@ -264,7 +265,7 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
|
|
|
264
265
|
|
|
265
266
|
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
266
267
|
|
|
267
|
-
imputed_column_names =
|
|
268
|
+
imputed_column_names = _get_na_column_names(df=df)
|
|
268
269
|
|
|
269
270
|
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
270
271
|
|
|
@@ -278,5 +279,206 @@ def _skip_targets(df: pd.DataFrame, target_cols: list[str]):
|
|
|
278
279
|
return df_feats, df_targets
|
|
279
280
|
|
|
280
281
|
|
|
282
|
+
# modern implementation
|
|
283
|
+
class MiceImputer:
|
|
284
|
+
"""
|
|
285
|
+
A modern MICE imputation pipeline that uses a FeatureSchema
|
|
286
|
+
to correctly discretize categorical features after imputation.
|
|
287
|
+
"""
|
|
288
|
+
def __init__(self,
|
|
289
|
+
schema: FeatureSchema,
|
|
290
|
+
iterations: int=20,
|
|
291
|
+
resulting_datasets: int = 1,
|
|
292
|
+
random_state: int = 101):
|
|
293
|
+
|
|
294
|
+
self.schema = schema
|
|
295
|
+
self.random_state = random_state
|
|
296
|
+
self.iterations = iterations
|
|
297
|
+
self.resulting_datasets = resulting_datasets
|
|
298
|
+
|
|
299
|
+
# --- Store schema info ---
|
|
300
|
+
|
|
301
|
+
# 1. Categorical info
|
|
302
|
+
if not self.schema.categorical_index_map:
|
|
303
|
+
_LOGGER.warning("FeatureSchema has no 'categorical_index_map'. No discretization will be applied.")
|
|
304
|
+
self.cat_info = {}
|
|
305
|
+
else:
|
|
306
|
+
self.cat_info = self.schema.categorical_index_map
|
|
307
|
+
|
|
308
|
+
# 2. Ordered feature names (critical for index mapping)
|
|
309
|
+
self.ordered_features = list(self.schema.feature_names)
|
|
310
|
+
|
|
311
|
+
# 3. Names of categorical features
|
|
312
|
+
self.categorical_features = list(self.schema.categorical_feature_names)
|
|
313
|
+
|
|
314
|
+
_LOGGER.info(f"MiceImputer initialized. Found {len(self.cat_info)} categorical features to discretize.")
|
|
315
|
+
|
|
316
|
+
def _post_process(self, imputed_df: pd.DataFrame) -> pd.DataFrame:
|
|
317
|
+
"""
|
|
318
|
+
Applies schema-based discretization to a completed dataframe.
|
|
319
|
+
|
|
320
|
+
This method works around the behavior of `discretize_categorical_values`
|
|
321
|
+
(which returns a full int32 array) by:
|
|
322
|
+
1. Calling it on the full, ordered feature array.
|
|
323
|
+
2. Extracting *only* the valid discretized categorical columns.
|
|
324
|
+
3. Updating the original float dataframe with these integer values.
|
|
325
|
+
"""
|
|
326
|
+
# If no categorical features are defined, return the df as-is.
|
|
327
|
+
if not self.cat_info:
|
|
328
|
+
return imputed_df
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
# 1. Ensure DataFrame columns match the schema order
|
|
332
|
+
# This is critical for the index-based categorical_info
|
|
333
|
+
df_ordered: pd.DataFrame = imputed_df[self.ordered_features] # type: ignore
|
|
334
|
+
|
|
335
|
+
# 2. Convert to NumPy array
|
|
336
|
+
array_ordered = df_ordered.to_numpy()
|
|
337
|
+
|
|
338
|
+
# 3. Apply discretization utility (which returns a full int32 array)
|
|
339
|
+
# This array has *correct* categorical values but *truncated* continuous values.
|
|
340
|
+
discretized_array_int32 = discretize_categorical_values(
|
|
341
|
+
array_ordered,
|
|
342
|
+
self.cat_info,
|
|
343
|
+
start_at_zero=True # Assuming 0-based indexing
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# 4. Create a new DF from the int32 array, keeping the categorical columns.
|
|
347
|
+
df_discretized_cats = pd.DataFrame(
|
|
348
|
+
discretized_array_int32,
|
|
349
|
+
columns=self.ordered_features,
|
|
350
|
+
index=df_ordered.index # <-- Critical: align index
|
|
351
|
+
)[self.categorical_features] # <-- Select only cat features
|
|
352
|
+
|
|
353
|
+
# 5. "Rejoin": Start with a fresh copy of the *original* imputed DF (which has correct continuous floats).
|
|
354
|
+
final_df = df_ordered.copy()
|
|
355
|
+
|
|
356
|
+
# 6. Use .update() to "paste" the integer categorical values
|
|
357
|
+
# over the old float categorical values. Continuous floats are unaffected.
|
|
358
|
+
final_df.update(df_discretized_cats)
|
|
359
|
+
|
|
360
|
+
return final_df
|
|
361
|
+
|
|
362
|
+
except Exception as e:
|
|
363
|
+
_LOGGER.error(f"Failed during post-processing discretization:\n\tInput DF shape: {imputed_df.shape}\n\tSchema features: {len(self.ordered_features)}\n\tCategorical info keys: {list(self.cat_info.keys())}\n{e}")
|
|
364
|
+
raise
|
|
365
|
+
|
|
366
|
+
def _run_mice(self,
|
|
367
|
+
df: pd.DataFrame,
|
|
368
|
+
df_name: str) -> tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
|
|
369
|
+
"""
|
|
370
|
+
Runs the MICE kernel and applies schema-based post-processing.
|
|
371
|
+
|
|
372
|
+
Parameters:
|
|
373
|
+
df (pd.DataFrame): The input dataframe *with NaNs*. Should only contain feature columns.
|
|
374
|
+
df_name (str): The base name for the dataset.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
|
|
378
|
+
- The trained MICE kernel
|
|
379
|
+
- A list of imputed and processed DataFrames
|
|
380
|
+
- A list of names for the new DataFrames
|
|
381
|
+
"""
|
|
382
|
+
# Ensure input df only contains features from the schema and is in the correct order.
|
|
383
|
+
try:
|
|
384
|
+
df_feats = df[self.ordered_features]
|
|
385
|
+
except KeyError as e:
|
|
386
|
+
_LOGGER.error(f"Input DataFrame is missing required schema columns: {e}")
|
|
387
|
+
raise
|
|
388
|
+
|
|
389
|
+
# 1. Initialize kernel
|
|
390
|
+
kernel = mf.ImputationKernel(
|
|
391
|
+
data=df_feats,
|
|
392
|
+
num_datasets=self.resulting_datasets,
|
|
393
|
+
random_state=self.random_state
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
_LOGGER.info("➡️ Schema-based MICE imputation running...")
|
|
397
|
+
|
|
398
|
+
# 2. Perform MICE
|
|
399
|
+
kernel.mice(self.iterations)
|
|
400
|
+
|
|
401
|
+
# 3. Retrieve, process, and collect datasets
|
|
402
|
+
imputed_datasets = []
|
|
403
|
+
for i in range(self.resulting_datasets):
|
|
404
|
+
# complete_data returns a pd.DataFrame
|
|
405
|
+
completed_df = kernel.complete_data(dataset=i)
|
|
406
|
+
|
|
407
|
+
# Apply our new discretization and ordering
|
|
408
|
+
processed_df = self._post_process(completed_df)
|
|
409
|
+
imputed_datasets.append(processed_df)
|
|
410
|
+
|
|
411
|
+
if not imputed_datasets:
|
|
412
|
+
_LOGGER.error("No imputed datasets were generated.")
|
|
413
|
+
raise ValueError()
|
|
414
|
+
|
|
415
|
+
# 4. Generate names
|
|
416
|
+
if self.resulting_datasets == 1:
|
|
417
|
+
imputed_dataset_names = [f"{df_name}_MICE"]
|
|
418
|
+
else:
|
|
419
|
+
imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(self.resulting_datasets)]
|
|
420
|
+
|
|
421
|
+
# 5. Validate indexes
|
|
422
|
+
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
423
|
+
assert imputed_df.shape[0] == df.shape[0], f"❌ Row count mismatch in dataset {subname}"
|
|
424
|
+
assert all(imputed_df.index == df.index), f"❌ Index mismatch in dataset {subname}"
|
|
425
|
+
|
|
426
|
+
_LOGGER.info("Schema-based MICE imputation complete.")
|
|
427
|
+
|
|
428
|
+
return kernel, imputed_datasets, imputed_dataset_names
|
|
429
|
+
|
|
430
|
+
def run_pipeline(self,
|
|
431
|
+
df_path_or_dir: Union[str,Path],
|
|
432
|
+
save_datasets_dir: Union[str,Path],
|
|
433
|
+
save_metrics_dir: Union[str,Path],
|
|
434
|
+
):
|
|
435
|
+
"""
|
|
436
|
+
Runs the complete MICE imputation pipeline.
|
|
437
|
+
|
|
438
|
+
This method automates the entire workflow:
|
|
439
|
+
1. Loads data from a CSV file path or a directory with CSV files.
|
|
440
|
+
2. Separates features and targets based on the `FeatureSchema`.
|
|
441
|
+
3. Runs the MICE algorithm on the feature set.
|
|
442
|
+
4. Applies schema-based post-processing to discretize categorical features.
|
|
443
|
+
5. Saves the final, processed, and imputed dataset(s) (re-joined with targets) to `save_datasets_dir`.
|
|
444
|
+
6. Generates and saves convergence and distribution plots for all imputed columns to `save_metrics_dir`.
|
|
445
|
+
|
|
446
|
+
Parameters
|
|
447
|
+
----------
|
|
448
|
+
df_path_or_dir :[str,Path]
|
|
449
|
+
Path to a single CSV file or a directory containing multiple CSV files to impute.
|
|
450
|
+
save_datasets_dir : [str,Path]
|
|
451
|
+
Directory where the final imputed and processed dataset(s) will be saved as CSVs.
|
|
452
|
+
save_metrics_dir : [str,Path]
|
|
453
|
+
Directory where convergence and distribution plots will be saved.
|
|
454
|
+
"""
|
|
455
|
+
# Check paths
|
|
456
|
+
save_datasets_path = make_fullpath(save_datasets_dir, make=True)
|
|
457
|
+
save_metrics_path = make_fullpath(save_metrics_dir, make=True)
|
|
458
|
+
|
|
459
|
+
input_path = make_fullpath(df_path_or_dir)
|
|
460
|
+
if input_path.is_file():
|
|
461
|
+
all_file_paths = [input_path]
|
|
462
|
+
else:
|
|
463
|
+
all_file_paths = list(list_csv_paths(input_path).values())
|
|
464
|
+
|
|
465
|
+
for df_path in all_file_paths:
|
|
466
|
+
|
|
467
|
+
df, df_name = load_dataframe(df_path=df_path, kind="pandas")
|
|
468
|
+
|
|
469
|
+
df_features: pd.DataFrame = df[self.schema.feature_names] # type: ignore
|
|
470
|
+
df_targets = df.drop(columns=self.schema.feature_names)
|
|
471
|
+
|
|
472
|
+
imputed_column_names = _get_na_column_names(df=df_features)
|
|
473
|
+
|
|
474
|
+
kernel, imputed_datasets, imputed_dataset_names = self._run_mice(df=df_features, df_name=df_name)
|
|
475
|
+
|
|
476
|
+
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
477
|
+
|
|
478
|
+
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
479
|
+
|
|
480
|
+
get_imputed_distributions(kernel=kernel, df_name=df_name, root_dir=save_metrics_path, column_names=imputed_column_names)
|
|
481
|
+
|
|
482
|
+
|
|
281
483
|
def info():
|
|
282
484
|
_script_info(__all__)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from ._script_info import _script_info
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"ClassificationMetricsFormat",
|
|
7
|
+
"MultiClassificationMetricsFormat"
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ClassificationMetricsFormat:
|
|
12
|
+
"""
|
|
13
|
+
Optional configuration for classification tasks, use in the '.evaluate()' method of the MLTrainer.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self,
|
|
16
|
+
cmap: str="Blues",
|
|
17
|
+
class_map: Optional[dict[str,int]]=None,
|
|
18
|
+
ROC_PR_line: str='darkorange',
|
|
19
|
+
calibration_bins: int=15,
|
|
20
|
+
font_size: int=16) -> None:
|
|
21
|
+
"""
|
|
22
|
+
Initializes the formatting configuration for single-label classification metrics.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
cmap (str): The matplotlib colormap name for the confusion matrix
|
|
26
|
+
and report heatmap. Defaults to "Blues".
|
|
27
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
28
|
+
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
29
|
+
|
|
30
|
+
class_map (dict[str,int] | None): A dictionary mapping
|
|
31
|
+
class string names to their integer indices (e.g., {'cat': 0, 'dog': 1}).
|
|
32
|
+
This is used to label the axes of the confusion matrix and classification
|
|
33
|
+
report correctly. Defaults to None.
|
|
34
|
+
|
|
35
|
+
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
36
|
+
on the ROC and Precision-Recall curves. Defaults to 'darkorange'.
|
|
37
|
+
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
38
|
+
- Hex codes: '#FF6347', '#4682B4'
|
|
39
|
+
|
|
40
|
+
calibration_bins (int): The number of bins to use when
|
|
41
|
+
creating the calibration (reliability) plot. Defaults to 15.
|
|
42
|
+
|
|
43
|
+
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
44
|
+
"""
|
|
45
|
+
self.cmap = cmap
|
|
46
|
+
self.class_map = class_map
|
|
47
|
+
self.ROC_PR_line = ROC_PR_line
|
|
48
|
+
self.calibration_bins = calibration_bins
|
|
49
|
+
self.font_size = font_size
|
|
50
|
+
|
|
51
|
+
def __repr__(self) -> str:
|
|
52
|
+
parts = [
|
|
53
|
+
f"cmap='{self.cmap}'",
|
|
54
|
+
f"class_map={self.class_map}",
|
|
55
|
+
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
56
|
+
f"calibration_bins={self.calibration_bins}",
|
|
57
|
+
f"font_size={self.font_size}"
|
|
58
|
+
]
|
|
59
|
+
return f"ClassificationMetricsFormat({', '.join(parts)})"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MultiClassificationMetricsFormat:
|
|
63
|
+
"""
|
|
64
|
+
Optional configuration for multi-label classification tasks, use in the '.evaluate()' method of the MLTrainer.
|
|
65
|
+
"""
|
|
66
|
+
def __init__(self,
|
|
67
|
+
threshold: float=0.5,
|
|
68
|
+
ROC_PR_line: str='darkorange',
|
|
69
|
+
cmap: str = "Blues",
|
|
70
|
+
font_size: int = 16) -> None:
|
|
71
|
+
"""
|
|
72
|
+
Initializes the formatting configuration for multi-label classification metrics.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
threshold (float): The probability threshold (0.0 to 1.0) used
|
|
76
|
+
to convert sigmoid outputs into binary (0 or 1) predictions for
|
|
77
|
+
calculating the confusion matrix and overall metrics. Defaults to 0.5.
|
|
78
|
+
|
|
79
|
+
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
80
|
+
on the ROC and Precision-Recall curves (one for each label).
|
|
81
|
+
Defaults to 'darkorange'.
|
|
82
|
+
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
83
|
+
- Hex codes: '#FF6347', '#4682B4'
|
|
84
|
+
|
|
85
|
+
cmap (str): The matplotlib colormap name for the per-label
|
|
86
|
+
confusion matrices. Defaults to "Blues".
|
|
87
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
88
|
+
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
89
|
+
|
|
90
|
+
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
91
|
+
"""
|
|
92
|
+
self.threshold = threshold
|
|
93
|
+
self.cmap = cmap
|
|
94
|
+
self.ROC_PR_line = ROC_PR_line
|
|
95
|
+
self.font_size = font_size
|
|
96
|
+
|
|
97
|
+
def __repr__(self) -> str:
|
|
98
|
+
parts = [
|
|
99
|
+
f"threshold={self.threshold}",
|
|
100
|
+
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
101
|
+
f"cmap='{self.cmap}'",
|
|
102
|
+
f"font_size={self.font_size}"
|
|
103
|
+
]
|
|
104
|
+
return f"MultiClassificationMetricsFormat({', '.join(parts)})"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def info():
|
|
108
|
+
_script_info(__all__)
|