dragon-ml-toolbox 13.0.0__py3-none-any.whl → 14.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/METADATA +12 -2
  2. dragon_ml_toolbox-14.7.0.dist-info/RECORD +49 -0
  3. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/MICE_imputation.py +207 -5
  5. ml_tools/ML_configuration.py +108 -0
  6. ml_tools/ML_datasetmaster.py +241 -260
  7. ml_tools/ML_evaluation.py +229 -76
  8. ml_tools/ML_evaluation_multi.py +45 -16
  9. ml_tools/ML_inference.py +0 -1
  10. ml_tools/ML_models.py +135 -55
  11. ml_tools/ML_models_advanced.py +323 -0
  12. ml_tools/ML_optimization.py +49 -36
  13. ml_tools/ML_trainer.py +498 -29
  14. ml_tools/ML_utilities.py +351 -4
  15. ml_tools/ML_vision_datasetmaster.py +1492 -0
  16. ml_tools/ML_vision_evaluation.py +260 -0
  17. ml_tools/ML_vision_inference.py +428 -0
  18. ml_tools/ML_vision_models.py +641 -0
  19. ml_tools/ML_vision_transformers.py +203 -0
  20. ml_tools/PSO_optimization.py +5 -1
  21. ml_tools/_ML_vision_recipe.py +88 -0
  22. ml_tools/__init__.py +1 -0
  23. ml_tools/_schema.py +96 -0
  24. ml_tools/custom_logger.py +37 -14
  25. ml_tools/data_exploration.py +576 -138
  26. ml_tools/ensemble_evaluation.py +53 -10
  27. ml_tools/keys.py +43 -1
  28. ml_tools/math_utilities.py +1 -1
  29. ml_tools/optimization_tools.py +65 -86
  30. ml_tools/serde.py +78 -17
  31. ml_tools/utilities.py +192 -3
  32. dragon_ml_toolbox-13.0.0.dist-info/RECORD +0 -41
  33. ml_tools/ML_simple_optimization.py +0 -413
  34. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/WHEEL +0 -0
  35. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE +0 -0
  36. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 13.0.0
3
+ Version: 14.7.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -34,6 +34,10 @@ Requires-Dist: Pillow; extra == "ml"
34
34
  Requires-Dist: evotorch; extra == "ml"
35
35
  Requires-Dist: pyarrow; extra == "ml"
36
36
  Requires-Dist: colorlog; extra == "ml"
37
+ Requires-Dist: torchmetrics; extra == "ml"
38
+ Provides-Extra: py-tab
39
+ Requires-Dist: pytorch_tabular; extra == "py-tab"
40
+ Requires-Dist: omegaconf; extra == "py-tab"
37
41
  Provides-Extra: mice
38
42
  Requires-Dist: numpy<2.0; extra == "mice"
39
43
  Requires-Dist: pandas; extra == "mice"
@@ -137,15 +141,22 @@ ETL_cleaning
137
141
  ETL_engineering
138
142
  math_utilities
139
143
  ML_callbacks
144
+ ML_configuration
140
145
  ML_datasetmaster
141
146
  ML_evaluation_multi
142
147
  ML_evaluation
143
148
  ML_inference
144
149
  ML_models
150
+ ML_models_advanced # Requires the extra flag [py-tab]
145
151
  ML_optimization
146
152
  ML_scaler
147
153
  ML_trainer
148
154
  ML_utilities
155
+ ML_vision_datasetmaster
156
+ ML_vision_evaluation
157
+ ML_vision_inference
158
+ ML_vision_models
159
+ ML_vision_transformers
149
160
  optimization_tools
150
161
  path_manager
151
162
  PSO_optimization
@@ -191,7 +202,6 @@ pip install "dragon-ml-toolbox[excel]"
191
202
  #### Modules:
192
203
 
193
204
  ```Bash
194
- constants
195
205
  custom_logger
196
206
  handle_excel
197
207
  path_manager
@@ -0,0 +1,49 @@
1
+ dragon_ml_toolbox-14.7.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-14.7.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=gkOdNDbKYpIJezwSo2CEnISkLeYfYHv9t8b5K2-P69A,2687
3
+ ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
4
+ ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
5
+ ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
6
+ ml_tools/MICE_imputation.py,sha256=KLJXGQLKJ6AuWWttAG-LCCaxpS-ygM4dXPiguHDaL6Y,20815
7
+ ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
8
+ ml_tools/ML_configuration.py,sha256=DaYmm7Yklcu1emLyo-pRQG74SK4YEkCYFRT6_aV3rqA,4417
9
+ ml_tools/ML_datasetmaster.py,sha256=Zi5jBnBI_U6tD8mpCVL5bQcsqsGEMAzMsCVI_wFD2QU,30175
10
+ ml_tools/ML_evaluation.py,sha256=EvlgFeMQeZ1RSEMtNd-nv7W0d0SVcR4n6cwW5UG16DU,25358
11
+ ml_tools/ML_evaluation_multi.py,sha256=bQZ2gJY-dBzKQxvtd-B6wVaGBdFpQGVBr7tQZFokp5E,17166
12
+ ml_tools/ML_inference.py,sha256=YJ953bhNWsdlPRtJQh3h2ACfMIgp8dQ9KtL9Azar-5s,23489
13
+ ml_tools/ML_models.py,sha256=PqOcNlws7vCJMbiVCKqlPuktxvskZVUHG3VfU-Yshf8,31415
14
+ ml_tools/ML_models_advanced.py,sha256=vk3PZBSu3DVso2S1rKTxxdS43XG8Q5FnasIL3-rMajc,12410
15
+ ml_tools/ML_optimization.py,sha256=P0zkhKAwTpkorIBtR0AOIDcyexo5ngmvFUzo3DfNO-E,22692
16
+ ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
17
+ ml_tools/ML_trainer.py,sha256=salZxfv3RWRCiinp5S9xeUsHysMbMQ52EecR8GyEbaM,51461
18
+ ml_tools/ML_utilities.py,sha256=eYe2N-65FTzaOHF5gmiJl-HmicyzhqcdvlDiIivr5_g,22993
19
+ ml_tools/ML_vision_datasetmaster.py,sha256=bmHDC6SsBUxDSFjqQGuyzGfKuf1Imi1Ng6O2-dYF7I4,62607
20
+ ml_tools/ML_vision_evaluation.py,sha256=t12R7i1RkOCt9zu1_lxSBr8OH6A6Get0k8ftDLctn6I,10486
21
+ ml_tools/ML_vision_inference.py,sha256=He3KV3VJAm8PwO-fOq4b9VO8UXFr-GmpuCnoHXf4VZI,20588
22
+ ml_tools/ML_vision_models.py,sha256=WqiRN9JAjv--BcwkDrooXAs4Qo26JHPCHh3JSPm4kMI,26226
23
+ ml_tools/ML_vision_transformers.py,sha256=h332O9BjDMgxrBc0I-bJwJODWlcp7nJHbX1QS2etwBk,7738
24
+ ml_tools/PSO_optimization.py,sha256=T-HWHMRJUnPvPwixdU5jif3_rnnI36TzcL8u3oSCwuA,22960
25
+ ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
26
+ ml_tools/SQL.py,sha256=vXLPGfVVg8bfkbBE3HVfyEclVbdJy0TBhuQONtMwSCQ,11234
27
+ ml_tools/VIF_factor.py,sha256=at5IVqPvicja2-DNSTSIIy3SkzDWCmLzo3qTG_qr5n8,10422
28
+ ml_tools/_ML_vision_recipe.py,sha256=zrgxFUvTJqQVuwR7jWlbIC2FD29u6eNFPkTRoJ7yEZI,3178
29
+ ml_tools/__init__.py,sha256=kJiankjz9_qXu7gU92mYqYg_anLvt-B6RtW0mMH8uGo,76
30
+ ml_tools/_logger.py,sha256=dlp5cGbzooK9YSNSZYB4yjZrOaQUGW8PTrM411AOvL8,4717
31
+ ml_tools/_schema.py,sha256=yu6aWmn_2Z4_AxAtJGDDCIa96y6JcUp-vgnCS013Qmw,3908
32
+ ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
33
+ ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
34
+ ml_tools/custom_logger.py,sha256=TGc0Ww2Xlqj2XE3q4bP43hV7T3qnb5ci9f0pYHXF5TY,11226
35
+ ml_tools/data_exploration.py,sha256=bwHzFJ-IAo5GN3T53F-1J_pXUg8VHS91sG_90utAsfg,69911
36
+ ml_tools/ensemble_evaluation.py,sha256=2sJ3jD6yBNPRNwSokyaLKqKHi0QhF13ChoFe5yd4zwg,28368
37
+ ml_tools/ensemble_inference.py,sha256=0yLmLNj45RVVoSCLH1ZYJG9IoAhTkWUqEZmLOQTFGTY,9348
38
+ ml_tools/ensemble_learning.py,sha256=vsIED7nlheYI4w2SBzP6SC1AnNeMfn-2A1Gqw5EfxsM,21964
39
+ ml_tools/handle_excel.py,sha256=pfdAPb9ywegFkM9T54bRssDOsX-K7rSeV0RaMz7lEAo,14006
40
+ ml_tools/keys.py,sha256=-OiL9G0RIOKQk6BwETKIP3LWz2s5-x6lZW2YitJa4mY,3330
41
+ ml_tools/math_utilities.py,sha256=xeKq1quR_3DYLgowcp4Uam_4s3JltUyOnqMOGuAiYWU,8802
42
+ ml_tools/optimization_tools.py,sha256=TYFQ2nSnp7xxs-VyoZISWgnGJghFbsWasHjruegyJRs,12763
43
+ ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
44
+ ml_tools/serde.py,sha256=c8uDYjYry_VrLvoG4ixqDj5pij88lVn6Tu4NHcPkwDU,6943
45
+ ml_tools/utilities.py,sha256=aWqvYzmxlD74PD5Yqu1VuTekDJeYLQrmPIU_VeVyRp0,22526
46
+ dragon_ml_toolbox-14.7.0.dist-info/METADATA,sha256=NTifVXiC2zr5RhzCUTuUMEcU-wfswXxoYOO6N3UXFmM,6492
47
+ dragon_ml_toolbox-14.7.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
+ dragon_ml_toolbox-14.7.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
49
+ dragon_ml_toolbox-14.7.0.dist-info/RECORD,,
@@ -27,3 +27,13 @@ This project depends on the following third-party packages. Each is governed by
27
27
  - [plotnine](https://github.com/has2k1/plotnine/blob/main/LICENSE)
28
28
  - [tqdm](https://github.com/tqdm/tqdm/blob/master/LICENSE)
29
29
  - [pyarrow](https://github.com/apache/arrow/blob/main/LICENSE.txt)
30
+ - [colorlog](https://github.com/borntyping/python-colorlog/blob/main/LICENSE)
31
+ - [evotorch](https://github.com/nnaisense/evotorch/blob/master/LICENSE)
32
+ - [FreeSimpleGUI](https://github.com/spyoungtech/FreeSimpleGUI/blob/main/license.txt)
33
+ - [nuitka](https://github.com/Nuitka/Nuitka/blob/main/LICENSE.txt)
34
+ - [omegaconf](https://github.com/omry/omegaconf/blob/master/LICENSE)
35
+ - [ordered-set](https://github.com/rspeer/ordered-set/blob/master/MIT-LICENSE)
36
+ - [pyinstaller](https://github.com/pyinstaller/pyinstaller/blob/develop/COPYING.txt)
37
+ - [pytorch_tabular](https://github.com/manujosephv/pytorch_tabular/blob/main/LICENSE)
38
+ - [torchmetrics](https://github.com/Lightning-AI/torchmetrics/blob/master/LICENSE)
39
+ - [zstandard](https://github.com/indygreg/python-zstandard/blob/main/LICENSE)
@@ -7,19 +7,20 @@ from plotnine import ggplot, labs, theme, element_blank # type: ignore
7
7
  from typing import Optional, Union
8
8
 
9
9
  from .utilities import load_dataframe, merge_dataframes, save_dataframe_filename
10
- from .math_utilities import threshold_binary_values
10
+ from .math_utilities import threshold_binary_values, discretize_categorical_values
11
11
  from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
12
12
  from ._logger import _LOGGER
13
13
  from ._script_info import _script_info
14
+ from ._schema import FeatureSchema
14
15
 
15
16
 
16
17
  __all__ = [
18
+ "MiceImputer",
17
19
  "apply_mice",
18
20
  "save_imputed_datasets",
19
- "get_na_column_names",
20
21
  "get_convergence_diagnostic",
21
22
  "get_imputed_distributions",
22
- "run_mice_pipeline"
23
+ "run_mice_pipeline",
23
24
  ]
24
25
 
25
26
 
@@ -79,7 +80,7 @@ def save_imputed_datasets(save_dir: Union[str, Path], imputed_datasets: list, df
79
80
 
80
81
 
81
82
  #Get names of features that had missing values before imputation
82
- def get_na_column_names(df: pd.DataFrame):
83
+ def _get_na_column_names(df: pd.DataFrame):
83
84
  return [col for col in df.columns if df[col].isna().any()]
84
85
 
85
86
 
@@ -264,7 +265,7 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
264
265
 
265
266
  save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
266
267
 
267
- imputed_column_names = get_na_column_names(df=df)
268
+ imputed_column_names = _get_na_column_names(df=df)
268
269
 
269
270
  get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
270
271
 
@@ -278,5 +279,206 @@ def _skip_targets(df: pd.DataFrame, target_cols: list[str]):
278
279
  return df_feats, df_targets
279
280
 
280
281
 
282
+ # modern implementation
283
+ class MiceImputer:
284
+ """
285
+ A modern MICE imputation pipeline that uses a FeatureSchema
286
+ to correctly discretize categorical features after imputation.
287
+ """
288
+ def __init__(self,
289
+ schema: FeatureSchema,
290
+ iterations: int=20,
291
+ resulting_datasets: int = 1,
292
+ random_state: int = 101):
293
+
294
+ self.schema = schema
295
+ self.random_state = random_state
296
+ self.iterations = iterations
297
+ self.resulting_datasets = resulting_datasets
298
+
299
+ # --- Store schema info ---
300
+
301
+ # 1. Categorical info
302
+ if not self.schema.categorical_index_map:
303
+ _LOGGER.warning("FeatureSchema has no 'categorical_index_map'. No discretization will be applied.")
304
+ self.cat_info = {}
305
+ else:
306
+ self.cat_info = self.schema.categorical_index_map
307
+
308
+ # 2. Ordered feature names (critical for index mapping)
309
+ self.ordered_features = list(self.schema.feature_names)
310
+
311
+ # 3. Names of categorical features
312
+ self.categorical_features = list(self.schema.categorical_feature_names)
313
+
314
+ _LOGGER.info(f"MiceImputer initialized. Found {len(self.cat_info)} categorical features to discretize.")
315
+
316
+ def _post_process(self, imputed_df: pd.DataFrame) -> pd.DataFrame:
317
+ """
318
+ Applies schema-based discretization to a completed dataframe.
319
+
320
+ This method works around the behavior of `discretize_categorical_values`
321
+ (which returns a full int32 array) by:
322
+ 1. Calling it on the full, ordered feature array.
323
+ 2. Extracting *only* the valid discretized categorical columns.
324
+ 3. Updating the original float dataframe with these integer values.
325
+ """
326
+ # If no categorical features are defined, return the df as-is.
327
+ if not self.cat_info:
328
+ return imputed_df
329
+
330
+ try:
331
+ # 1. Ensure DataFrame columns match the schema order
332
+ # This is critical for the index-based categorical_info
333
+ df_ordered: pd.DataFrame = imputed_df[self.ordered_features] # type: ignore
334
+
335
+ # 2. Convert to NumPy array
336
+ array_ordered = df_ordered.to_numpy()
337
+
338
+ # 3. Apply discretization utility (which returns a full int32 array)
339
+ # This array has *correct* categorical values but *truncated* continuous values.
340
+ discretized_array_int32 = discretize_categorical_values(
341
+ array_ordered,
342
+ self.cat_info,
343
+ start_at_zero=True # Assuming 0-based indexing
344
+ )
345
+
346
+ # 4. Create a new DF from the int32 array, keeping the categorical columns.
347
+ df_discretized_cats = pd.DataFrame(
348
+ discretized_array_int32,
349
+ columns=self.ordered_features,
350
+ index=df_ordered.index # <-- Critical: align index
351
+ )[self.categorical_features] # <-- Select only cat features
352
+
353
+ # 5. "Rejoin": Start with a fresh copy of the *original* imputed DF (which has correct continuous floats).
354
+ final_df = df_ordered.copy()
355
+
356
+ # 6. Use .update() to "paste" the integer categorical values
357
+ # over the old float categorical values. Continuous floats are unaffected.
358
+ final_df.update(df_discretized_cats)
359
+
360
+ return final_df
361
+
362
+ except Exception as e:
363
+ _LOGGER.error(f"Failed during post-processing discretization:\n\tInput DF shape: {imputed_df.shape}\n\tSchema features: {len(self.ordered_features)}\n\tCategorical info keys: {list(self.cat_info.keys())}\n{e}")
364
+ raise
365
+
366
+ def _run_mice(self,
367
+ df: pd.DataFrame,
368
+ df_name: str) -> tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
369
+ """
370
+ Runs the MICE kernel and applies schema-based post-processing.
371
+
372
+ Parameters:
373
+ df (pd.DataFrame): The input dataframe *with NaNs*. Should only contain feature columns.
374
+ df_name (str): The base name for the dataset.
375
+
376
+ Returns:
377
+ tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
378
+ - The trained MICE kernel
379
+ - A list of imputed and processed DataFrames
380
+ - A list of names for the new DataFrames
381
+ """
382
+ # Ensure input df only contains features from the schema and is in the correct order.
383
+ try:
384
+ df_feats = df[self.ordered_features]
385
+ except KeyError as e:
386
+ _LOGGER.error(f"Input DataFrame is missing required schema columns: {e}")
387
+ raise
388
+
389
+ # 1. Initialize kernel
390
+ kernel = mf.ImputationKernel(
391
+ data=df_feats,
392
+ num_datasets=self.resulting_datasets,
393
+ random_state=self.random_state
394
+ )
395
+
396
+ _LOGGER.info("➡️ Schema-based MICE imputation running...")
397
+
398
+ # 2. Perform MICE
399
+ kernel.mice(self.iterations)
400
+
401
+ # 3. Retrieve, process, and collect datasets
402
+ imputed_datasets = []
403
+ for i in range(self.resulting_datasets):
404
+ # complete_data returns a pd.DataFrame
405
+ completed_df = kernel.complete_data(dataset=i)
406
+
407
+ # Apply our new discretization and ordering
408
+ processed_df = self._post_process(completed_df)
409
+ imputed_datasets.append(processed_df)
410
+
411
+ if not imputed_datasets:
412
+ _LOGGER.error("No imputed datasets were generated.")
413
+ raise ValueError()
414
+
415
+ # 4. Generate names
416
+ if self.resulting_datasets == 1:
417
+ imputed_dataset_names = [f"{df_name}_MICE"]
418
+ else:
419
+ imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(self.resulting_datasets)]
420
+
421
+ # 5. Validate indexes
422
+ for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
423
+ assert imputed_df.shape[0] == df.shape[0], f"❌ Row count mismatch in dataset {subname}"
424
+ assert all(imputed_df.index == df.index), f"❌ Index mismatch in dataset {subname}"
425
+
426
+ _LOGGER.info("Schema-based MICE imputation complete.")
427
+
428
+ return kernel, imputed_datasets, imputed_dataset_names
429
+
430
+ def run_pipeline(self,
431
+ df_path_or_dir: Union[str,Path],
432
+ save_datasets_dir: Union[str,Path],
433
+ save_metrics_dir: Union[str,Path],
434
+ ):
435
+ """
436
+ Runs the complete MICE imputation pipeline.
437
+
438
+ This method automates the entire workflow:
439
+ 1. Loads data from a CSV file path or a directory with CSV files.
440
+ 2. Separates features and targets based on the `FeatureSchema`.
441
+ 3. Runs the MICE algorithm on the feature set.
442
+ 4. Applies schema-based post-processing to discretize categorical features.
443
+ 5. Saves the final, processed, and imputed dataset(s) (re-joined with targets) to `save_datasets_dir`.
444
+ 6. Generates and saves convergence and distribution plots for all imputed columns to `save_metrics_dir`.
445
+
446
+ Parameters
447
+ ----------
448
+ df_path_or_dir :[str,Path]
449
+ Path to a single CSV file or a directory containing multiple CSV files to impute.
450
+ save_datasets_dir : [str,Path]
451
+ Directory where the final imputed and processed dataset(s) will be saved as CSVs.
452
+ save_metrics_dir : [str,Path]
453
+ Directory where convergence and distribution plots will be saved.
454
+ """
455
+ # Check paths
456
+ save_datasets_path = make_fullpath(save_datasets_dir, make=True)
457
+ save_metrics_path = make_fullpath(save_metrics_dir, make=True)
458
+
459
+ input_path = make_fullpath(df_path_or_dir)
460
+ if input_path.is_file():
461
+ all_file_paths = [input_path]
462
+ else:
463
+ all_file_paths = list(list_csv_paths(input_path).values())
464
+
465
+ for df_path in all_file_paths:
466
+
467
+ df, df_name = load_dataframe(df_path=df_path, kind="pandas")
468
+
469
+ df_features: pd.DataFrame = df[self.schema.feature_names] # type: ignore
470
+ df_targets = df.drop(columns=self.schema.feature_names)
471
+
472
+ imputed_column_names = _get_na_column_names(df=df_features)
473
+
474
+ kernel, imputed_datasets, imputed_dataset_names = self._run_mice(df=df_features, df_name=df_name)
475
+
476
+ save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
477
+
478
+ get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
479
+
480
+ get_imputed_distributions(kernel=kernel, df_name=df_name, root_dir=save_metrics_path, column_names=imputed_column_names)
481
+
482
+
281
483
  def info():
282
484
  _script_info(__all__)
@@ -0,0 +1,108 @@
1
+ from typing import Optional
2
+ from ._script_info import _script_info
3
+
4
+
5
+ __all__ = [
6
+ "ClassificationMetricsFormat",
7
+ "MultiClassificationMetricsFormat"
8
+ ]
9
+
10
+
11
+ class ClassificationMetricsFormat:
12
+ """
13
+ Optional configuration for classification tasks, use in the '.evaluate()' method of the MLTrainer.
14
+ """
15
+ def __init__(self,
16
+ cmap: str="Blues",
17
+ class_map: Optional[dict[str,int]]=None,
18
+ ROC_PR_line: str='darkorange',
19
+ calibration_bins: int=15,
20
+ font_size: int=16) -> None:
21
+ """
22
+ Initializes the formatting configuration for single-label classification metrics.
23
+
24
+ Args:
25
+ cmap (str): The matplotlib colormap name for the confusion matrix
26
+ and report heatmap. Defaults to "Blues".
27
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
28
+ - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
29
+
30
+ class_map (dict[str,int] | None): A dictionary mapping
31
+ class string names to their integer indices (e.g., {'cat': 0, 'dog': 1}).
32
+ This is used to label the axes of the confusion matrix and classification
33
+ report correctly. Defaults to None.
34
+
35
+ ROC_PR_line (str): The color name or hex code for the line plotted
36
+ on the ROC and Precision-Recall curves. Defaults to 'darkorange'.
37
+ - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
38
+ - Hex codes: '#FF6347', '#4682B4'
39
+
40
+ calibration_bins (int): The number of bins to use when
41
+ creating the calibration (reliability) plot. Defaults to 15.
42
+
43
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
44
+ """
45
+ self.cmap = cmap
46
+ self.class_map = class_map
47
+ self.ROC_PR_line = ROC_PR_line
48
+ self.calibration_bins = calibration_bins
49
+ self.font_size = font_size
50
+
51
+ def __repr__(self) -> str:
52
+ parts = [
53
+ f"cmap='{self.cmap}'",
54
+ f"class_map={self.class_map}",
55
+ f"ROC_PR_line='{self.ROC_PR_line}'",
56
+ f"calibration_bins={self.calibration_bins}",
57
+ f"font_size={self.font_size}"
58
+ ]
59
+ return f"ClassificationMetricsFormat({', '.join(parts)})"
60
+
61
+
62
+ class MultiClassificationMetricsFormat:
63
+ """
64
+ Optional configuration for multi-label classification tasks, use in the '.evaluate()' method of the MLTrainer.
65
+ """
66
+ def __init__(self,
67
+ threshold: float=0.5,
68
+ ROC_PR_line: str='darkorange',
69
+ cmap: str = "Blues",
70
+ font_size: int = 16) -> None:
71
+ """
72
+ Initializes the formatting configuration for multi-label classification metrics.
73
+
74
+ Args:
75
+ threshold (float): The probability threshold (0.0 to 1.0) used
76
+ to convert sigmoid outputs into binary (0 or 1) predictions for
77
+ calculating the confusion matrix and overall metrics. Defaults to 0.5.
78
+
79
+ ROC_PR_line (str): The color name or hex code for the line plotted
80
+ on the ROC and Precision-Recall curves (one for each label).
81
+ Defaults to 'darkorange'.
82
+ - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
83
+ - Hex codes: '#FF6347', '#4682B4'
84
+
85
+ cmap (str): The matplotlib colormap name for the per-label
86
+ confusion matrices. Defaults to "Blues".
87
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
88
+ - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
89
+
90
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
91
+ """
92
+ self.threshold = threshold
93
+ self.cmap = cmap
94
+ self.ROC_PR_line = ROC_PR_line
95
+ self.font_size = font_size
96
+
97
+ def __repr__(self) -> str:
98
+ parts = [
99
+ f"threshold={self.threshold}",
100
+ f"ROC_PR_line='{self.ROC_PR_line}'",
101
+ f"cmap='{self.cmap}'",
102
+ f"font_size={self.font_size}"
103
+ ]
104
+ return f"MultiClassificationMetricsFormat({', '.join(parts)})"
105
+
106
+
107
+ def info():
108
+ _script_info(__all__)