dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import csv
|
|
3
|
+
import traceback
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union, Any, Literal
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
|
|
8
|
+
from ..path_manager import sanitize_filename, make_fullpath
|
|
9
|
+
from .._core import get_logger
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_LOGGER = get_logger("IO logger")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"custom_logger",
|
|
17
|
+
"train_logger"
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def custom_logger(
|
|
22
|
+
data: Union[
|
|
23
|
+
list[Any],
|
|
24
|
+
dict[Any, Any],
|
|
25
|
+
str,
|
|
26
|
+
BaseException
|
|
27
|
+
],
|
|
28
|
+
save_directory: Union[str, Path],
|
|
29
|
+
log_name: str,
|
|
30
|
+
add_timestamp: bool=True,
|
|
31
|
+
dict_as: Literal['auto', 'json', 'csv'] = 'auto',
|
|
32
|
+
) -> None:
|
|
33
|
+
"""
|
|
34
|
+
Logs various data types to corresponding output formats:
|
|
35
|
+
|
|
36
|
+
- list[Any] → .txt
|
|
37
|
+
Each element is written on a new line.
|
|
38
|
+
|
|
39
|
+
- dict[str, list[Any]] → .csv (if dict_as='auto' or 'csv')
|
|
40
|
+
Dictionary is treated as tabular data; keys become columns, values become rows.
|
|
41
|
+
|
|
42
|
+
- dict[str, scalar] → .json (if dict_as='auto' or 'json')
|
|
43
|
+
Dictionary is treated as structured data and serialized as JSON.
|
|
44
|
+
|
|
45
|
+
- str → .log
|
|
46
|
+
Plain text string is written to a .log file.
|
|
47
|
+
|
|
48
|
+
- BaseException → .log
|
|
49
|
+
Full traceback is logged for debugging purposes.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
data (Any): The data to be logged. Must be one of the supported types.
|
|
53
|
+
save_directory (str | Path): Directory where the log will be saved. Created if it does not exist.
|
|
54
|
+
log_name (str): Base name for the log file.
|
|
55
|
+
add_timestamp (bool): Whether to add a timestamp to the filename.
|
|
56
|
+
dict_as ('auto'|'json'|'csv'):
|
|
57
|
+
- 'auto': Guesses format (JSON or CSV) based on dictionary content.
|
|
58
|
+
- 'json': Forces .json format for any dictionary.
|
|
59
|
+
- 'csv': Forces .csv format. Will fail if dict values are not all lists.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If the data type is unsupported.
|
|
63
|
+
"""
|
|
64
|
+
try:
|
|
65
|
+
if not isinstance(data, BaseException) and not data:
|
|
66
|
+
_LOGGER.warning("Empty data received. No log file will be saved.")
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
save_path = make_fullpath(save_directory, make=True)
|
|
70
|
+
|
|
71
|
+
sanitized_log_name = sanitize_filename(log_name)
|
|
72
|
+
|
|
73
|
+
if add_timestamp:
|
|
74
|
+
timestamp = datetime.now().strftime(r"%Y%m%d_%H%M%S")
|
|
75
|
+
base_path = save_path / f"{sanitized_log_name}_{timestamp}"
|
|
76
|
+
else:
|
|
77
|
+
base_path = save_path / sanitized_log_name
|
|
78
|
+
|
|
79
|
+
# Router
|
|
80
|
+
if isinstance(data, list):
|
|
81
|
+
_log_list_to_txt(data, base_path.with_suffix(".txt"))
|
|
82
|
+
|
|
83
|
+
elif isinstance(data, dict):
|
|
84
|
+
if dict_as == 'json':
|
|
85
|
+
_log_dict_to_json(data, base_path.with_suffix(".json"))
|
|
86
|
+
|
|
87
|
+
elif dict_as == 'csv':
|
|
88
|
+
# This will raise a ValueError if data is not all lists
|
|
89
|
+
_log_dict_to_csv(data, base_path.with_suffix(".csv"))
|
|
90
|
+
|
|
91
|
+
else: # 'auto' mode
|
|
92
|
+
if all(isinstance(v, list) for v in data.values()):
|
|
93
|
+
_log_dict_to_csv(data, base_path.with_suffix(".csv"))
|
|
94
|
+
else:
|
|
95
|
+
_log_dict_to_json(data, base_path.with_suffix(".json"))
|
|
96
|
+
|
|
97
|
+
elif isinstance(data, str):
|
|
98
|
+
_log_string_to_log(data, base_path.with_suffix(".log"))
|
|
99
|
+
|
|
100
|
+
elif isinstance(data, BaseException):
|
|
101
|
+
_log_exception_to_log(data, base_path.with_suffix(".log"))
|
|
102
|
+
|
|
103
|
+
else:
|
|
104
|
+
_LOGGER.error("Unsupported data type. Must be list, dict, str, or BaseException.")
|
|
105
|
+
raise ValueError()
|
|
106
|
+
|
|
107
|
+
_LOGGER.info(f"Log saved as: '{base_path.name}'")
|
|
108
|
+
|
|
109
|
+
except Exception:
|
|
110
|
+
_LOGGER.exception(f"Log not saved.")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _log_list_to_txt(data: list[Any], path: Path) -> None:
|
|
114
|
+
log_lines = []
|
|
115
|
+
for item in data:
|
|
116
|
+
try:
|
|
117
|
+
log_lines.append(str(item).strip())
|
|
118
|
+
except Exception:
|
|
119
|
+
log_lines.append(f"(unrepresentable item of type {type(item)})")
|
|
120
|
+
|
|
121
|
+
with open(path, 'w', encoding='utf-8') as f:
|
|
122
|
+
f.write('\n'.join(log_lines))
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _log_dict_to_csv(data: dict[Any, list[Any]], path: Path) -> None:
|
|
126
|
+
sanitized_dict = {}
|
|
127
|
+
max_length = max(len(v) for v in data.values()) if data else 0
|
|
128
|
+
|
|
129
|
+
for key, value in data.items():
|
|
130
|
+
if not isinstance(value, list):
|
|
131
|
+
_LOGGER.error(f"Dictionary value for key '{key}' must be a list.")
|
|
132
|
+
raise ValueError()
|
|
133
|
+
|
|
134
|
+
sanitized_key = str(key).strip().replace('\n', '_').replace('\r', '_')
|
|
135
|
+
padded_value = value + [None] * (max_length - len(value))
|
|
136
|
+
sanitized_dict[sanitized_key] = padded_value
|
|
137
|
+
|
|
138
|
+
# The `newline=''` argument is important to prevent extra blank rows
|
|
139
|
+
with open(path, 'w', newline='', encoding='utf-8') as csv_file:
|
|
140
|
+
writer = csv.writer(csv_file)
|
|
141
|
+
|
|
142
|
+
# 1. Write the header row from the sanitized dictionary keys
|
|
143
|
+
header = list(sanitized_dict.keys())
|
|
144
|
+
writer.writerow(header)
|
|
145
|
+
|
|
146
|
+
# 2. Transpose columns to rows and write them
|
|
147
|
+
# zip(*sanitized_dict.values()) elegantly converts the column data
|
|
148
|
+
# (lists in the dict) into row-by-row tuples.
|
|
149
|
+
rows_to_write = zip(*sanitized_dict.values())
|
|
150
|
+
writer.writerows(rows_to_write)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _log_string_to_log(data: str, path: Path) -> None:
|
|
154
|
+
with open(path, 'w', encoding='utf-8') as f:
|
|
155
|
+
f.write(data.strip() + '\n')
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _log_exception_to_log(exc: BaseException, path: Path) -> None:
|
|
159
|
+
with open(path, 'w', encoding='utf-8') as f:
|
|
160
|
+
f.write("Exception occurred:\n")
|
|
161
|
+
traceback.print_exception(type(exc), exc, exc.__traceback__, file=f)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _log_dict_to_json(data: dict[Any, Any], path: Path) -> None:
|
|
165
|
+
with open(path, 'w', encoding='utf-8') as f:
|
|
166
|
+
json.dump(data, f, indent=4, ensure_ascii=False)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def train_logger(train_config: Union[dict, Any],
|
|
170
|
+
model_parameters: Union[dict, Any],
|
|
171
|
+
train_history: Union[dict, None],
|
|
172
|
+
save_directory: Union[str, Path]):
|
|
173
|
+
"""
|
|
174
|
+
Logs training data to JSON, adding a timestamp to the filename.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
train_config (dict | Any): Training configuration parameters. If object, must have a `.to_log()` method returning a dict.
|
|
178
|
+
model_parameters (dict | Any): Model parameters. If object, must have a `.to_log()` method returning a dict.
|
|
179
|
+
train_history (dict | None): Training history log.
|
|
180
|
+
save_directory (str | Path): Directory to save the log file.
|
|
181
|
+
"""
|
|
182
|
+
# train_config should be a dict or a custom object with the ".to_log()" method
|
|
183
|
+
if not isinstance(train_config, dict):
|
|
184
|
+
if hasattr(train_config, "to_log") and callable(getattr(train_config, "to_log")):
|
|
185
|
+
train_config_dict: dict = train_config.to_log()
|
|
186
|
+
if not isinstance(train_config_dict, dict):
|
|
187
|
+
_LOGGER.error("'train_config.to_log()' did not return a dictionary.")
|
|
188
|
+
raise ValueError()
|
|
189
|
+
else:
|
|
190
|
+
_LOGGER.error("'train_config' must be a dict or an object with a 'to_log()' method.")
|
|
191
|
+
raise ValueError()
|
|
192
|
+
else:
|
|
193
|
+
# check for empty dict
|
|
194
|
+
if not train_config:
|
|
195
|
+
_LOGGER.error("'train_config' dictionary is empty.")
|
|
196
|
+
raise ValueError()
|
|
197
|
+
|
|
198
|
+
train_config_dict = train_config
|
|
199
|
+
|
|
200
|
+
# model_parameters should be a dict or a custom object with the ".to_log()" method
|
|
201
|
+
if not isinstance(model_parameters, dict):
|
|
202
|
+
if hasattr(model_parameters, "to_log") and callable(getattr(model_parameters, "to_log")):
|
|
203
|
+
model_parameters_dict: dict = model_parameters.to_log()
|
|
204
|
+
if not isinstance(model_parameters_dict, dict):
|
|
205
|
+
_LOGGER.error("'model_parameters.to_log()' did not return a dictionary.")
|
|
206
|
+
raise ValueError()
|
|
207
|
+
else:
|
|
208
|
+
_LOGGER.error("'model_parameters' must be a dict or an object with a 'to_log()' method.")
|
|
209
|
+
raise ValueError()
|
|
210
|
+
else:
|
|
211
|
+
# check for empty dict
|
|
212
|
+
if not model_parameters:
|
|
213
|
+
_LOGGER.error("'model_parameters' dictionary is empty.")
|
|
214
|
+
raise ValueError()
|
|
215
|
+
|
|
216
|
+
model_parameters_dict = model_parameters
|
|
217
|
+
|
|
218
|
+
# make base dictionary
|
|
219
|
+
data: dict = train_config_dict | model_parameters_dict
|
|
220
|
+
|
|
221
|
+
# add training history if provided and is not empty
|
|
222
|
+
if train_history is not None:
|
|
223
|
+
if not train_history:
|
|
224
|
+
_LOGGER.error("'train_history' dictionary was provided but is empty.")
|
|
225
|
+
raise ValueError()
|
|
226
|
+
data.update(train_history)
|
|
227
|
+
|
|
228
|
+
custom_logger(
|
|
229
|
+
data=data,
|
|
230
|
+
save_directory=save_directory,
|
|
231
|
+
log_name="Training_Log",
|
|
232
|
+
add_timestamp=True,
|
|
233
|
+
dict_as='json'
|
|
234
|
+
)
|
|
235
|
+
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from typing import Any, Union, Literal, overload
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from ..path_manager import sanitize_filename, make_fullpath
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
|
|
8
|
+
from ._IO_utils import _RobustEncoder
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_LOGGER = get_logger("IO Save/Load")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"save_json",
|
|
16
|
+
"load_json",
|
|
17
|
+
"save_list_strings",
|
|
18
|
+
"load_list_strings"
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def save_json(
|
|
23
|
+
data: Union[dict[Any, Any], list[Any]],
|
|
24
|
+
directory: Union[str, Path],
|
|
25
|
+
filename: str,
|
|
26
|
+
verbose: bool = True
|
|
27
|
+
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Saves a dictionary or list as a JSON file.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
data (dict | list): The data to save.
|
|
33
|
+
directory (str | Path): The directory to save the file in.
|
|
34
|
+
filename (str): The name of the file (extension .json will be added if missing).
|
|
35
|
+
verbose (bool): Whether to log success messages.
|
|
36
|
+
"""
|
|
37
|
+
target_dir = make_fullpath(directory, make=True, enforce="directory")
|
|
38
|
+
sanitized_name = sanitize_filename(filename)
|
|
39
|
+
|
|
40
|
+
if not sanitized_name.endswith(".json"):
|
|
41
|
+
sanitized_name += ".json"
|
|
42
|
+
|
|
43
|
+
full_path = target_dir / sanitized_name
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
with open(full_path, 'w', encoding='utf-8') as f:
|
|
47
|
+
# Using _RobustEncoder ensures compatibility with non-standard types (like 'type' objects)
|
|
48
|
+
json.dump(data, f, indent=4, ensure_ascii=False, cls=_RobustEncoder)
|
|
49
|
+
|
|
50
|
+
if verbose:
|
|
51
|
+
_LOGGER.info(f"JSON file saved as '{full_path.name}'.")
|
|
52
|
+
|
|
53
|
+
except Exception as e:
|
|
54
|
+
_LOGGER.error(f"Failed to save JSON to '{full_path}': {e}")
|
|
55
|
+
raise
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# 1. Define Overloads (for the type checker)
|
|
59
|
+
@overload
|
|
60
|
+
def load_json(
|
|
61
|
+
file_path: Union[str, Path],
|
|
62
|
+
expected_type: Literal["dict"] = "dict",
|
|
63
|
+
verbose: bool = True
|
|
64
|
+
) -> dict[Any, Any]: ...
|
|
65
|
+
|
|
66
|
+
@overload
|
|
67
|
+
def load_json(
|
|
68
|
+
file_path: Union[str, Path],
|
|
69
|
+
expected_type: Literal["list"],
|
|
70
|
+
verbose: bool = True
|
|
71
|
+
) -> list[Any]: ...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def load_json(
|
|
75
|
+
file_path: Union[str, Path],
|
|
76
|
+
expected_type: Literal["dict", "list"] = "dict",
|
|
77
|
+
verbose: bool = True
|
|
78
|
+
) -> Union[dict[Any, Any], list[Any]]:
|
|
79
|
+
"""
|
|
80
|
+
Loads a JSON file.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
file_path (str | Path): The path to the JSON file.
|
|
84
|
+
expected_type ('dict' | 'list'): strict check for the root type of the JSON.
|
|
85
|
+
verbose (bool): Whether to log success/failure messages.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
dict | list: The loaded JSON data.
|
|
89
|
+
"""
|
|
90
|
+
target_path = make_fullpath(file_path, enforce="file")
|
|
91
|
+
|
|
92
|
+
# Map string literals to actual python types
|
|
93
|
+
type_map = {"dict": dict, "list": list}
|
|
94
|
+
target_type = type_map.get(expected_type, dict)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
with open(target_path, 'r', encoding='utf-8') as f:
|
|
98
|
+
data = json.load(f)
|
|
99
|
+
|
|
100
|
+
if not isinstance(data, target_type):
|
|
101
|
+
_LOGGER.error(f"JSON root is type {type(data)}, expected {expected_type}.")
|
|
102
|
+
raise ValueError()
|
|
103
|
+
|
|
104
|
+
if verbose:
|
|
105
|
+
_LOGGER.info(f"Loaded JSON data from '{target_path.name}'.")
|
|
106
|
+
|
|
107
|
+
return data
|
|
108
|
+
|
|
109
|
+
except json.JSONDecodeError as e:
|
|
110
|
+
_LOGGER.error(f"Failed to decode JSON from '{target_path}': {e.msg}")
|
|
111
|
+
raise ValueError()
|
|
112
|
+
|
|
113
|
+
except Exception as e:
|
|
114
|
+
_LOGGER.error(f"Error loading JSON from '{target_path}': {e}")
|
|
115
|
+
raise
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def save_list_strings(list_strings: list[str], directory: Union[str,Path], filename: str, verbose: bool=True):
|
|
119
|
+
"""Saves a list of strings as a text file."""
|
|
120
|
+
target_dir = make_fullpath(directory, make=True, enforce="directory")
|
|
121
|
+
sanitized_name = sanitize_filename(filename)
|
|
122
|
+
|
|
123
|
+
if not sanitized_name.endswith(".txt"):
|
|
124
|
+
sanitized_name = sanitized_name + ".txt"
|
|
125
|
+
|
|
126
|
+
full_path = target_dir / sanitized_name
|
|
127
|
+
with open(full_path, 'w') as f:
|
|
128
|
+
for string_data in list_strings:
|
|
129
|
+
f.write(f"{string_data}\n")
|
|
130
|
+
|
|
131
|
+
if verbose:
|
|
132
|
+
_LOGGER.info(f"Text file saved as '{full_path.name}'.")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def load_list_strings(text_file: Union[str,Path], verbose: bool=True) -> list[str]:
|
|
136
|
+
"""Loads a text file as a list of strings."""
|
|
137
|
+
target_path = make_fullpath(text_file, enforce="file")
|
|
138
|
+
loaded_strings = []
|
|
139
|
+
|
|
140
|
+
with open(target_path, 'r') as f:
|
|
141
|
+
loaded_strings = [line.strip() for line in f]
|
|
142
|
+
|
|
143
|
+
if len(loaded_strings) == 0:
|
|
144
|
+
_LOGGER.error("The text file is empty.")
|
|
145
|
+
raise ValueError()
|
|
146
|
+
|
|
147
|
+
if verbose:
|
|
148
|
+
_LOGGER.info(f"Loaded '{target_path.name}' as list of strings.")
|
|
149
|
+
|
|
150
|
+
return loaded_strings
|
|
151
|
+
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from itertools import zip_longest
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
from ..path_manager import make_fullpath
|
|
8
|
+
from .._core import get_logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_LOGGER = get_logger("IO tools")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"compare_lists",
|
|
16
|
+
"_RobustEncoder"
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _RobustEncoder(json.JSONEncoder):
|
|
21
|
+
"""
|
|
22
|
+
Custom JSON encoder to handle non-serializable objects.
|
|
23
|
+
|
|
24
|
+
This handles:
|
|
25
|
+
1. `type` objects (e.g., <class 'int'>) which result from
|
|
26
|
+
`check_type_only=True`.
|
|
27
|
+
2. Any other custom class or object by falling back to its
|
|
28
|
+
string representation.
|
|
29
|
+
"""
|
|
30
|
+
def default(self, o):
|
|
31
|
+
if isinstance(o, type):
|
|
32
|
+
return str(o)
|
|
33
|
+
try:
|
|
34
|
+
return super().default(o)
|
|
35
|
+
except TypeError:
|
|
36
|
+
return str(o)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def compare_lists(
|
|
40
|
+
list_A: list,
|
|
41
|
+
list_B: list,
|
|
42
|
+
save_dir: Union[str, Path],
|
|
43
|
+
strict: bool = False,
|
|
44
|
+
check_type_only: bool = False
|
|
45
|
+
) -> dict:
|
|
46
|
+
"""
|
|
47
|
+
Compares two lists and saves a JSON report of the differences.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
list_A (list): The first list to compare.
|
|
51
|
+
list_B (list): The second list to compare.
|
|
52
|
+
save_dir (str | Path): The directory where the resulting report will be saved.
|
|
53
|
+
strict (bool):
|
|
54
|
+
- If False: Performs a "bag" comparison. Order does not matter, but duplicates do.
|
|
55
|
+
- If True: Performs a strict, positional comparison.
|
|
56
|
+
|
|
57
|
+
check_type_only (bool):
|
|
58
|
+
- If False: Compares items using `==` (`__eq__` operator).
|
|
59
|
+
- If True: Compares only the `type()` of the items.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
dict: A dictionary detailing the differences. (saved to `save_dir`).
|
|
63
|
+
"""
|
|
64
|
+
MISSING_A_KEY = "missing_in_A"
|
|
65
|
+
MISSING_B_KEY = "missing_in_B"
|
|
66
|
+
MISMATCH_KEY = "mismatch"
|
|
67
|
+
|
|
68
|
+
results: dict[str, list] = {MISSING_A_KEY: [], MISSING_B_KEY: []}
|
|
69
|
+
|
|
70
|
+
# make directory
|
|
71
|
+
save_path = make_fullpath(input_path=save_dir, make=True, enforce="directory")
|
|
72
|
+
|
|
73
|
+
if strict:
|
|
74
|
+
# --- STRICT (Positional) Mode ---
|
|
75
|
+
results[MISMATCH_KEY] = []
|
|
76
|
+
sentinel = object()
|
|
77
|
+
|
|
78
|
+
if check_type_only:
|
|
79
|
+
compare_func = lambda a, b: type(a) == type(b)
|
|
80
|
+
else:
|
|
81
|
+
compare_func = lambda a, b: a == b
|
|
82
|
+
|
|
83
|
+
for index, (item_a, item_b) in enumerate(
|
|
84
|
+
zip_longest(list_A, list_B, fillvalue=sentinel)
|
|
85
|
+
):
|
|
86
|
+
if item_a is sentinel:
|
|
87
|
+
results[MISSING_A_KEY].append({"index": index, "item": item_b})
|
|
88
|
+
elif item_b is sentinel:
|
|
89
|
+
results[MISSING_B_KEY].append({"index": index, "item": item_a})
|
|
90
|
+
elif not compare_func(item_a, item_b):
|
|
91
|
+
results[MISMATCH_KEY].append(
|
|
92
|
+
{
|
|
93
|
+
"index": index,
|
|
94
|
+
"list_A_item": item_a,
|
|
95
|
+
"list_B_item": item_b,
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
else:
|
|
100
|
+
# --- NON-STRICT (Bag) Mode ---
|
|
101
|
+
if check_type_only:
|
|
102
|
+
# Types are hashable, we can use Counter (O(N))
|
|
103
|
+
types_A_counts = Counter(type(item) for item in list_A)
|
|
104
|
+
types_B_counts = Counter(type(item) for item in list_B)
|
|
105
|
+
|
|
106
|
+
diff_A_B = types_A_counts - types_B_counts
|
|
107
|
+
for item_type, count in diff_A_B.items():
|
|
108
|
+
results[MISSING_B_KEY].extend([item_type] * count)
|
|
109
|
+
|
|
110
|
+
diff_B_A = types_B_counts - types_A_counts
|
|
111
|
+
for item_type, count in diff_B_A.items():
|
|
112
|
+
results[MISSING_A_KEY].extend([item_type] * count)
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
# Items may be unhashable. Use O(N*M) .remove() method
|
|
116
|
+
temp_B = list(list_B)
|
|
117
|
+
missing_in_B = []
|
|
118
|
+
|
|
119
|
+
for item_a in list_A:
|
|
120
|
+
try:
|
|
121
|
+
temp_B.remove(item_a)
|
|
122
|
+
except ValueError:
|
|
123
|
+
missing_in_B.append(item_a)
|
|
124
|
+
|
|
125
|
+
results[MISSING_A_KEY] = temp_B
|
|
126
|
+
results[MISSING_B_KEY] = missing_in_B
|
|
127
|
+
|
|
128
|
+
# --- Save the Report ---
|
|
129
|
+
try:
|
|
130
|
+
full_path = save_path / "list_comparison.json"
|
|
131
|
+
|
|
132
|
+
# Write the report dictionary to the JSON file
|
|
133
|
+
with open(full_path, 'w', encoding='utf-8') as f:
|
|
134
|
+
json.dump(results, f, indent=4, cls=_RobustEncoder)
|
|
135
|
+
|
|
136
|
+
except Exception as e:
|
|
137
|
+
_LOGGER.error(f"Failed to save comparison report to {save_path}: \n{e}")
|
|
138
|
+
|
|
139
|
+
return results
|
|
140
|
+
|
|
@@ -1,14 +1,22 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from ._IO_utils import (
|
|
2
|
+
compare_lists,
|
|
3
|
+
)
|
|
4
|
+
|
|
5
|
+
from ._IO_loggers import (
|
|
2
6
|
custom_logger,
|
|
3
7
|
train_logger,
|
|
4
|
-
|
|
5
|
-
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from ._IO_save_load import (
|
|
6
11
|
save_json,
|
|
7
12
|
load_json,
|
|
8
|
-
|
|
9
|
-
|
|
13
|
+
save_list_strings,
|
|
14
|
+
load_list_strings,
|
|
10
15
|
)
|
|
11
16
|
|
|
17
|
+
from ._imprimir import info
|
|
18
|
+
|
|
19
|
+
|
|
12
20
|
__all__ = [
|
|
13
21
|
"custom_logger",
|
|
14
22
|
"train_logger",
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import miceforest as mf
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
from ..utilities import load_dataframe
|
|
7
|
+
from ..math_utilities import threshold_binary_values
|
|
8
|
+
|
|
9
|
+
from ..path_manager import make_fullpath, list_csv_paths
|
|
10
|
+
from .._core import get_logger
|
|
11
|
+
|
|
12
|
+
from ._dragon_mice import (
|
|
13
|
+
_save_imputed_datasets,
|
|
14
|
+
get_convergence_diagnostic,
|
|
15
|
+
get_imputed_distributions,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
_LOGGER = get_logger("MICE")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"apply_mice",
|
|
23
|
+
"run_mice_pipeline",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def apply_mice(df: pd.DataFrame, df_name: str, binary_columns: Optional[list[str]]=None, resulting_datasets: int=1, iterations: int=20, random_state: int=101):
|
|
28
|
+
|
|
29
|
+
# Initialize kernel with number of imputed datasets to generate
|
|
30
|
+
kernel = mf.ImputationKernel(
|
|
31
|
+
data=df,
|
|
32
|
+
num_datasets=resulting_datasets,
|
|
33
|
+
random_state=random_state
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
_LOGGER.info("➡️ MICE imputation running...")
|
|
37
|
+
|
|
38
|
+
# Perform MICE with n iterations per dataset
|
|
39
|
+
kernel.mice(iterations)
|
|
40
|
+
|
|
41
|
+
# Retrieve the imputed datasets
|
|
42
|
+
imputed_datasets = [kernel.complete_data(dataset=i) for i in range(resulting_datasets)]
|
|
43
|
+
|
|
44
|
+
if imputed_datasets is None or len(imputed_datasets) == 0:
|
|
45
|
+
_LOGGER.error("No imputed datasets were generated. Check the MICE process.")
|
|
46
|
+
raise ValueError()
|
|
47
|
+
|
|
48
|
+
# threshold binary columns
|
|
49
|
+
if binary_columns is not None:
|
|
50
|
+
invalid_binary_columns = set(binary_columns) - set(df.columns)
|
|
51
|
+
if invalid_binary_columns:
|
|
52
|
+
_LOGGER.warning(f"These 'binary columns' are not in the dataset:")
|
|
53
|
+
for invalid_binary_col in invalid_binary_columns:
|
|
54
|
+
print(f" - {invalid_binary_col}")
|
|
55
|
+
valid_binary_columns = [col for col in binary_columns if col not in invalid_binary_columns]
|
|
56
|
+
for imputed_df in imputed_datasets:
|
|
57
|
+
for binary_column_name in valid_binary_columns:
|
|
58
|
+
imputed_df[binary_column_name] = threshold_binary_values(imputed_df[binary_column_name]) # type: ignore
|
|
59
|
+
|
|
60
|
+
if resulting_datasets == 1:
|
|
61
|
+
imputed_dataset_names = [f"{df_name}_MICE"]
|
|
62
|
+
else:
|
|
63
|
+
imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(resulting_datasets)]
|
|
64
|
+
|
|
65
|
+
# Ensure indexes match
|
|
66
|
+
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
67
|
+
assert imputed_df.shape[0] == df.shape[0], f"❌ Row count mismatch in dataset {subname}" # type: ignore
|
|
68
|
+
assert all(imputed_df.index == df.index), f"❌ Index mismatch in dataset {subname}" # type: ignore
|
|
69
|
+
# print("✅ All imputed datasets match the original DataFrame indexes.")
|
|
70
|
+
|
|
71
|
+
_LOGGER.info("MICE imputation complete.")
|
|
72
|
+
|
|
73
|
+
return kernel, imputed_datasets, imputed_dataset_names
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
#Get names of features that had missing values before imputation
|
|
77
|
+
def _get_na_column_names(df: pd.DataFrame):
|
|
78
|
+
return [col for col in df.columns if df[col].isna().any()]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str],
|
|
82
|
+
save_datasets_dir: Union[str,Path], save_metrics_dir: Union[str,Path],
|
|
83
|
+
binary_columns: Optional[list[str]]=None,
|
|
84
|
+
resulting_datasets: int=1,
|
|
85
|
+
iterations: int=20,
|
|
86
|
+
random_state: int=101):
|
|
87
|
+
"""
|
|
88
|
+
DEPRECATED: Use DragonMICE class instead.
|
|
89
|
+
|
|
90
|
+
Call functions in sequence for each dataset in the provided path or directory:
|
|
91
|
+
1. Load dataframe
|
|
92
|
+
2. Apply MICE
|
|
93
|
+
3. Save imputed dataset(s)
|
|
94
|
+
4. Save convergence metrics
|
|
95
|
+
5. Save distribution metrics
|
|
96
|
+
|
|
97
|
+
Target columns must be skipped from the imputation. Binary columns will be thresholded after imputation.
|
|
98
|
+
"""
|
|
99
|
+
# Check paths
|
|
100
|
+
save_datasets_path = make_fullpath(save_datasets_dir, make=True)
|
|
101
|
+
save_metrics_path = make_fullpath(save_metrics_dir, make=True)
|
|
102
|
+
|
|
103
|
+
input_path = make_fullpath(df_path_or_dir)
|
|
104
|
+
if input_path.is_file():
|
|
105
|
+
all_file_paths = [input_path]
|
|
106
|
+
else:
|
|
107
|
+
all_file_paths = list(list_csv_paths(input_path, raise_on_empty=True).values())
|
|
108
|
+
|
|
109
|
+
for df_path in all_file_paths:
|
|
110
|
+
df: pd.DataFrame
|
|
111
|
+
df, df_name = load_dataframe(df_path=df_path, kind="pandas") # type: ignore
|
|
112
|
+
|
|
113
|
+
df, df_targets = _skip_targets(df, target_columns)
|
|
114
|
+
|
|
115
|
+
kernel, imputed_datasets, imputed_dataset_names = apply_mice(df=df, df_name=df_name, binary_columns=binary_columns, resulting_datasets=resulting_datasets, iterations=iterations, random_state=random_state)
|
|
116
|
+
|
|
117
|
+
_save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
118
|
+
|
|
119
|
+
imputed_column_names = _get_na_column_names(df=df)
|
|
120
|
+
|
|
121
|
+
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
122
|
+
|
|
123
|
+
get_imputed_distributions(kernel=kernel, df_name=df_name, root_dir=save_metrics_path, column_names=imputed_column_names)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _skip_targets(df: pd.DataFrame, target_cols: list[str]):
|
|
127
|
+
valid_targets = [col for col in target_cols if col in df.columns]
|
|
128
|
+
df_targets = df[valid_targets]
|
|
129
|
+
df_feats = df.drop(columns=valid_targets)
|
|
130
|
+
return df_feats, df_targets
|
|
131
|
+
|
|
132
|
+
|