dragon-ml-toolbox 11.1.0__py3-none-any.whl → 12.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,14 +1,14 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 11.1.0
3
+ Version: 12.0.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
- Author-email: Karl Loza <luigiloza@gmail.com>
5
+ Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
7
7
  Project-URL: Homepage, https://github.com/DrAg0n-BoRn/ML_tools
8
8
  Project-URL: Changelog, https://github.com/DrAg0n-BoRn/ML_tools/blob/master/CHANGELOG.md
9
9
  Classifier: Programming Language :: Python :: 3
10
10
  Classifier: Operating System :: OS Independent
11
- Requires-Python: >=3.10
11
+ Requires-Python: ==3.12
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  License-File: LICENSE-THIRD-PARTY.md
@@ -47,9 +47,6 @@ Requires-Dist: lightgbm<=4.5.0; extra == "mice"
47
47
  Requires-Dist: shap; extra == "mice"
48
48
  Requires-Dist: colorlog; extra == "mice"
49
49
  Requires-Dist: pyarrow; extra == "mice"
50
- Provides-Extra: pytorch
51
- Requires-Dist: torch; extra == "pytorch"
52
- Requires-Dist: torchvision; extra == "pytorch"
53
50
  Provides-Extra: excel
54
51
  Requires-Dist: pandas; extra == "excel"
55
52
  Requires-Dist: openpyxl; extra == "excel"
@@ -68,9 +65,6 @@ Requires-Dist: lightgbm; extra == "gui-boost"
68
65
  Provides-Extra: gui-torch
69
66
  Requires-Dist: numpy; extra == "gui-torch"
70
67
  Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-torch"
71
- Provides-Extra: plot
72
- Requires-Dist: matplotlib; extra == "plot"
73
- Requires-Dist: seaborn; extra == "plot"
74
68
  Provides-Extra: pyinstaller
75
69
  Requires-Dist: pyinstaller; extra == "pyinstaller"
76
70
  Provides-Extra: nuitka
@@ -90,7 +84,7 @@ A collection of Python utilities for data science and machine learning, structur
90
84
 
91
85
  ## Installation
92
86
 
93
- **Python 3.10+**
87
+ **Python 3.12**
94
88
 
95
89
  ### Via PyPI
96
90
 
@@ -100,22 +94,22 @@ Install the latest stable release from PyPI:
100
94
  pip install dragon-ml-toolbox
101
95
  ```
102
96
 
103
- ### Via GitHub (Editable)
97
+ ### Via conda-forge
104
98
 
105
- Clone the repository and install in editable mode with optional dependencies:
99
+ Install from the conda-forge channel:
106
100
 
107
101
  ```bash
108
- git clone https://github.com/DrAg0n-BoRn/ML_tools.git
109
- cd ML_tools
110
- pip install -e .
102
+ conda install -c conda-forge dragon-ml-toolbox
111
103
  ```
112
104
 
113
- ### Via conda-forge
105
+ ### Via GitHub (Editable)
114
106
 
115
- Install from the conda-forge channel:
107
+ Clone the repository and install in editable mode:
116
108
 
117
109
  ```bash
118
- conda install -c conda-forge dragon-ml-toolbox
110
+ git clone https://github.com/DrAg0n-BoRn/ML_tools.git
111
+ cd ML_tools
112
+ pip install -e .
119
113
  ```
120
114
 
121
115
  ## Modular Installation
@@ -128,13 +122,7 @@ Installs a comprehensive set of tools for typical data science workflows, includ
128
122
  pip install "dragon-ml-toolbox[ML]"
129
123
  ```
130
124
 
131
- To install the standard CPU-only versions of Torch and Torchvision:
132
-
133
- ```Bash
134
- pip install "dragon-ml-toolbox[pytorch]"
135
- ```
136
-
137
- ⚠️ To make use of GPU acceleration (highly recommended), follow the official instructions: [PyTorch website](https://pytorch.org/get-started/locally/)
125
+ ⚠️ PyTorch required, follow the official instructions: [PyTorch website](https://pytorch.org/get-started/locally/)
138
126
 
139
127
  #### Modules:
140
128
 
@@ -147,6 +135,7 @@ ensemble_inference
147
135
  ensemble_learning
148
136
  ETL_cleaning
149
137
  ETL_engineering
138
+ math_utilities
150
139
  ML_callbacks
151
140
  ML_datasetmaster
152
141
  ML_evaluation_multi
@@ -156,10 +145,12 @@ ML_models
156
145
  ML_optimization
157
146
  ML_scaler
158
147
  ML_trainer
148
+ ML_utilities
159
149
  optimization_tools
160
150
  path_manager
161
151
  PSO_optimization
162
152
  RNN_forecast
153
+ serde
163
154
  SQL
164
155
  utilities
165
156
  ```
@@ -179,7 +170,9 @@ pip install "dragon-ml-toolbox[mice]"
179
170
  ```Bash
180
171
  constants
181
172
  custom_logger
173
+ math_utilities
182
174
  MICE_imputation
175
+ serde
183
176
  VIF_factor
184
177
  path_manager
185
178
  utilities
@@ -208,16 +201,12 @@ path_manager
208
201
 
209
202
  ### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
210
203
 
211
- For GUIs that include plotting functionality, you must also install the [plot] extra.
204
+ GUI tools compatible with XGBoost and LightGBM models used for inference.
212
205
 
213
206
  ```Bash
214
207
  pip install "dragon-ml-toolbox[gui-boost]"
215
208
  ```
216
209
 
217
- ```Bash
218
- pip install "dragon-ml-toolbox[gui-boost,plot]"
219
- ```
220
-
221
210
  #### Modules:
222
211
 
223
212
  ```Bash
@@ -226,22 +215,19 @@ custom_logger
226
215
  GUI_tools
227
216
  ensemble_inference
228
217
  path_manager
218
+ serde
229
219
  ```
230
220
 
231
221
  ---
232
222
 
233
223
  ### 🤖 GUI for PyTorch Models [gui-torch]
234
224
 
235
- For GUIs that include plotting functionality, you must also install the [plot] extra.
225
+ GUI tools compatible with PyTorch models used for inference.
236
226
 
237
227
  ```Bash
238
228
  pip install "dragon-ml-toolbox[gui-torch]"
239
229
  ```
240
230
 
241
- ```Bash
242
- pip install "dragon-ml-toolbox[gui-torch,plot]"
243
- ```
244
-
245
231
  #### Modules:
246
232
 
247
233
  ```Bash
@@ -273,6 +259,6 @@ pip install "dragon-ml-toolbox[nuitka]"
273
259
  After installation, import modules like this:
274
260
 
275
261
  ```python
276
- from ml_tools.utilities import serialize_object, deserialize_object
262
+ from ml_tools.serde import serialize_object, deserialize_object
277
263
  from ml_tools import custom_logger
278
264
  ```
@@ -0,0 +1,40 @@
1
+ dragon_ml_toolbox-12.0.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-12.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
3
+ ml_tools/ETL_cleaning.py,sha256=PLRSR-VYnt1nNT9XrcWq40SE0VzHCw7DQ8v9czfSQsU,20366
4
+ ml_tools/ETL_engineering.py,sha256=l0I6Og9o4s6EODdk0kZXjbbC-a3vVPYy1FopP2BkQSQ,54909
5
+ ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
6
+ ml_tools/MICE_imputation.py,sha256=eNN7JuT43bydAJ5E2k2A5sDjYDu3X8kCHtMdFBkzjR0,11699
7
+ ml_tools/ML_callbacks.py,sha256=-XRIZEy3CPJWTHcoReyIw53FZlTs3pWcTVVnncTQQSc,13909
8
+ ml_tools/ML_datasetmaster.py,sha256=t6q6mU9lz2rYKTVPKjA7yZ5ImV7_NykiciHaYnqIEpA,30822
9
+ ml_tools/ML_evaluation.py,sha256=tLswOPgH4G1KExSMn0876YtNkbxPh-W3J4MYOjomMWA,16208
10
+ ml_tools/ML_evaluation_multi.py,sha256=6OZyQ4SM9ALh38mOABmiHgIQDWcovsD_iOo7Bg9YZCE,12516
11
+ ml_tools/ML_inference.py,sha256=ymFvncFsU10PExq87xnEj541DKV5ck0nMuK8ToJHzVQ,23067
12
+ ml_tools/ML_models.py,sha256=pSCV6KbmVnPZr49Kbyg7g25CYaWBWJr6IinBHKgVKGw,28042
13
+ ml_tools/ML_optimization.py,sha256=r1lAQiztTtRuh13rWj1iqbXvWO0LCqbzlkRdy3gEWo4,18124
14
+ ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
15
+ ml_tools/ML_trainer.py,sha256=_g48w5Ak-wQr5fGHdJqlcpnzv3gWyL1ghkOhy9VOZbo,23930
16
+ ml_tools/ML_utilities.py,sha256=35DfZzAwfDwVwfRECD8X_2ynsU2NCpTdNJSmza6oAzQ,8712
17
+ ml_tools/PSO_optimization.py,sha256=fVHeemqilBS0zrGV25E5yKwDlGdd2ZKa18d8CZ6Q6Fk,22961
18
+ ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
19
+ ml_tools/SQL.py,sha256=vXLPGfVVg8bfkbBE3HVfyEclVbdJy0TBhuQONtMwSCQ,11234
20
+ ml_tools/VIF_factor.py,sha256=dizjK0zmgOMuLBnJ66y5Sll5do6wjGWhAPVzJF1uwhQ,10404
21
+ ml_tools/__init__.py,sha256=q0y9faQ6e17XCQ7eUiCZ1FJ4Bg5EQqLjZ9f_l5REUUY,41
22
+ ml_tools/_logger.py,sha256=dlp5cGbzooK9YSNSZYB4yjZrOaQUGW8PTrM411AOvL8,4717
23
+ ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
24
+ ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
25
+ ml_tools/custom_logger.py,sha256=OZqG7FR_UE6byzY3RDmlj08a336ZU-4DzNBMPLr_d5c,5881
26
+ ml_tools/data_exploration.py,sha256=qpRUCQEVUmkxjx7DAztT6yIdI___xNV5NVPMBqCp3Mk,38870
27
+ ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNVGg,26853
28
+ ml_tools/ensemble_inference.py,sha256=0yLmLNj45RVVoSCLH1ZYJG9IoAhTkWUqEZmLOQTFGTY,9348
29
+ ml_tools/ensemble_learning.py,sha256=aTPeKthO4zRWBEaQJOUj8jEqVHiHjjOMXuiEWjI9NxM,21946
30
+ ml_tools/handle_excel.py,sha256=pfdAPb9ywegFkM9T54bRssDOsX-K7rSeV0RaMz7lEAo,14006
31
+ ml_tools/keys.py,sha256=FDpbS3Jb0pjrVvvp2_8nZi919mbob_-xwuy5OOtKM_A,1848
32
+ ml_tools/math_utilities.py,sha256=CUkyBuExFOnEHp9J1Xsh6H4xILwYOBilwFccM9J_Dxo,7870
33
+ ml_tools/optimization_tools.py,sha256=P3I6lIpvZ8Xf2kX5FvvBKBmrK2pB6idBpkTzfUJxTeE,5073
34
+ ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
35
+ ml_tools/serde.py,sha256=k0qAwfMf13lVBQSgq5u9MSXEoo31iOA2-Ncm8XgMCMI,3974
36
+ ml_tools/utilities.py,sha256=gef62GLK7ev5BWkkQekeJoVZqwf2mIuOlOfyCw6WdtE,13882
37
+ dragon_ml_toolbox-12.0.0.dist-info/METADATA,sha256=piCOJTB5V7QKGXqbYiu3GjdNLeyrpzV-42tIxVxBRBU,6166
38
+ dragon_ml_toolbox-12.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ dragon_ml_toolbox-12.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
40
+ dragon_ml_toolbox-12.0.0.dist-info/RECORD,,
ml_tools/ETL_cleaning.py CHANGED
@@ -2,6 +2,7 @@ import polars as pl
2
2
  import pandas as pd
3
3
  from pathlib import Path
4
4
  from typing import Union, List, Dict
5
+
5
6
  from .path_manager import sanitize_filename, make_fullpath
6
7
  from .data_exploration import drop_macro
7
8
  from .utilities import save_dataframe, load_dataframe
@@ -2,6 +2,7 @@ import polars as pl
2
2
  import re
3
3
  from pathlib import Path
4
4
  from typing import Literal, Union, Optional, Any, Callable, List, Dict, Tuple
5
+
5
6
  from .utilities import load_dataframe, save_dataframe
6
7
  from .path_manager import make_fullpath
7
8
  from ._script_info import _script_info
@@ -370,8 +371,20 @@ class AutoDummifier:
370
371
  Column names are auto-generated by Polars as
371
372
  '{original_col_name}_{category_value}'.
372
373
  """
373
- # Ensure the column is treated as a string before creating dummies
374
- return column.cast(pl.Utf8).to_dummies(drop_first=self.drop_first)
374
+ # Store the original column name to construct the potential null column name
375
+ col_name = column.name
376
+
377
+ # Create the dummy variables from the series
378
+ dummies = column.cast(pl.Utf8).to_dummies(drop_first=self.drop_first)
379
+
380
+ # Define the name of the column that Polars creates for null values
381
+ null_col_name = f"{col_name}_null"
382
+
383
+ # Check if the null column exists and drop it if it does
384
+ if null_col_name in dummies.columns:
385
+ return dummies.drop(null_col_name)
386
+
387
+ return dummies
375
388
 
376
389
 
377
390
  class MultiBinaryDummifier:
@@ -388,7 +401,7 @@ class MultiBinaryDummifier:
388
401
  A list of strings, where each string is a keyword to search for. A separate
389
402
  binary column will be created for each keyword.
390
403
  case_insensitive (bool):
391
- If True, keyword matching ignores case. Defaults to True.
404
+ If True, keyword matching ignores case.
392
405
  """
393
406
  def __init__(self, keywords: List[str], case_insensitive: bool = True):
394
407
  if not isinstance(keywords, list) or not all(isinstance(k, str) for k in keywords):
@@ -531,7 +544,7 @@ class NumberExtractor:
531
544
  round_digits (int | None):
532
545
  If the dtype is 'float', you can specify the number of decimal
533
546
  places to round the result to. This parameter is ignored if
534
- dtype is 'int'. Defaults to None (no rounding).
547
+ dtype is 'int'.
535
548
  """
536
549
  def __init__(
537
550
  self,
@@ -657,7 +670,7 @@ class MultiNumberExtractor:
657
670
  # Define the core extraction logic for the i-th number
658
671
  extraction_expr = (
659
672
  column.str.extract_all(self.regex_pattern)
660
- .list.get(i)
673
+ .list.get(i, null_on_oob=True)
661
674
  .cast(self.polars_dtype, strict=False)
662
675
  )
663
676
 
@@ -944,8 +957,7 @@ class RatioCalculator:
944
957
 
945
958
  class TriRatioCalculator:
946
959
  """
947
- A transformer that handles three-part ("A:B:C") and two-part ("A:C")
948
- ratios, enforcing a strict output structure.
960
+ A transformer that handles three-part ("A:B:C") ratios, enforcing a strict output structure.
949
961
 
950
962
  - Three-part ratios produce A/B and A/C.
951
963
  - Two-part ratios are assumed to be A:C and produce None for A/B.
ml_tools/GUI_tools.py CHANGED
@@ -4,8 +4,9 @@ import traceback
4
4
  import FreeSimpleGUI as sg
5
5
  from functools import wraps
6
6
  from typing import Any, Dict, Tuple, List, Literal, Union, Optional, Callable
7
- from ._script_info import _script_info
8
7
  import numpy as np
8
+
9
+ from ._script_info import _script_info
9
10
  from ._logger import _LOGGER
10
11
  from .keys import _OneHotOtherPlaceholder
11
12
 
@@ -3,13 +3,16 @@ import miceforest as mf
3
3
  from pathlib import Path
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
- from .utilities import load_dataframe, merge_dataframes, save_dataframe, threshold_binary_values
7
- from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
8
6
  from plotnine import ggplot, labs, theme, element_blank # type: ignore
9
7
  from typing import Optional, Union
8
+
9
+ from .utilities import load_dataframe, merge_dataframes, save_dataframe
10
+ from .math_utilities import threshold_binary_values
11
+ from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
10
12
  from ._logger import _LOGGER
11
13
  from ._script_info import _script_info
12
14
 
15
+
13
16
  __all__ = [
14
17
  "apply_mice",
15
18
  "save_imputed_datasets",
ml_tools/ML_callbacks.py CHANGED
@@ -1,13 +1,13 @@
1
1
  import numpy as np
2
2
  import torch
3
3
  from tqdm.auto import tqdm
4
+ from typing import Union, Literal, Optional
5
+ from pathlib import Path
6
+
4
7
  from .path_manager import make_fullpath, sanitize_filename
5
8
  from .keys import PyTorchLogKeys
6
9
  from ._logger import _LOGGER
7
- from typing import Optional
8
10
  from ._script_info import _script_info
9
- from typing import Union, Literal
10
- from pathlib import Path
11
11
 
12
12
 
13
13
  __all__ = [
@@ -10,6 +10,7 @@ from torchvision.datasets import ImageFolder
10
10
  from torchvision import transforms
11
11
  import matplotlib.pyplot as plt
12
12
  from pathlib import Path
13
+
13
14
  from .path_manager import make_fullpath, sanitize_filename
14
15
  from ._logger import _LOGGER
15
16
  from ._script_info import _script_info
ml_tools/ML_evaluation.py CHANGED
@@ -18,9 +18,10 @@ from sklearn.metrics import (
18
18
  import torch
19
19
  import shap
20
20
  from pathlib import Path
21
+ from typing import Union, Optional, List
22
+
21
23
  from .path_manager import make_fullpath
22
24
  from ._logger import _LOGGER
23
- from typing import Union, Optional, List
24
25
  from ._script_info import _script_info
25
26
  from .keys import SHAPKeys
26
27
 
@@ -25,6 +25,7 @@ from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
27
 
28
+
28
29
  __all__ = [
29
30
  "multi_target_regression_metrics",
30
31
  "multi_label_classification_metrics",
ml_tools/ML_inference.py CHANGED
@@ -11,6 +11,7 @@ from ._logger import _LOGGER
11
11
  from .path_manager import make_fullpath
12
12
  from .keys import PyTorchInferenceKeys
13
13
 
14
+
14
15
  __all__ = [
15
16
  "PyTorchInferenceHandler",
16
17
  "PyTorchInferenceHandlerMulti",
ml_tools/ML_models.py CHANGED
@@ -3,6 +3,7 @@ from torch import nn
3
3
  from typing import List, Union, Tuple, Dict, Any
4
4
  from pathlib import Path
5
5
  import json
6
+
6
7
  from ._logger import _LOGGER
7
8
  from .path_manager import make_fullpath
8
9
  from ._script_info import _script_info
@@ -155,6 +156,7 @@ class _BaseAttention(_BaseMLP):
155
156
  def __init__(self, *args, **kwargs):
156
157
  super().__init__(*args, **kwargs)
157
158
  # By default, models inheriting this do not have the flag.
159
+ self.attention = None
158
160
  self.has_interpretable_attention = False
159
161
 
160
162
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -165,7 +167,7 @@ class _BaseAttention(_BaseMLP):
165
167
  def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
166
168
  """Returns logits and attention weights."""
167
169
  # This logic is now shared and defined in one place
168
- x, attention_weights = self.attention(x)
170
+ x, attention_weights = self.attention(x) # type: ignore
169
171
  x = self.mlp(x)
170
172
  logits = self.output_layer(x)
171
173
  return logits, attention_weights
@@ -18,7 +18,8 @@ from .ML_inference import PyTorchInferenceHandler
18
18
  from .keys import PyTorchInferenceKeys
19
19
  from .SQL import DatabaseManager
20
20
  from .optimization_tools import _save_result
21
- from .utilities import threshold_binary_values, save_dataframe
21
+ from .utilities import save_dataframe
22
+ from .math_utilities import threshold_binary_values
22
23
 
23
24
 
24
25
  __all__ = [
ml_tools/ML_scaler.py CHANGED
@@ -2,14 +2,17 @@ import torch
2
2
  from torch.utils.data import Dataset, DataLoader
3
3
  from pathlib import Path
4
4
  from typing import Union, List, Optional
5
+
5
6
  from ._logger import _LOGGER
6
7
  from ._script_info import _script_info
7
8
  from .path_manager import make_fullpath
8
9
 
10
+
9
11
  __all__ = [
10
12
  "PytorchScaler"
11
13
  ]
12
14
 
15
+
13
16
  class PytorchScaler:
14
17
  """
15
18
  Standardizes continuous features in a PyTorch dataset by subtracting the
@@ -0,0 +1,219 @@
1
+ import pandas as pd
2
+ from pathlib import Path
3
+ from typing import Union, Any
4
+
5
+ from .path_manager import make_fullpath, list_subdirectories, list_files_by_extension
6
+ from ._script_info import _script_info
7
+ from ._logger import _LOGGER
8
+ from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys
9
+ from .utilities import load_dataframe
10
+
11
+
12
+ __all__ = [
13
+ "find_model_artifacts",
14
+ "select_features_by_shap"
15
+ ]
16
+
17
+
18
+ def find_model_artifacts(target_directory: Union[str,Path], load_scaler: bool, verbose: bool=False) -> list[dict[str,Any]]:
19
+ """
20
+ Scans subdirectories to find paths to model weights, target names, feature names, and model architecture. Optionally an scaler path if `load_scaler` is True.
21
+
22
+ This function operates on a specific directory structure. It expects the
23
+ `target_directory` to contain one or more subdirectories, where each
24
+ subdirectory represents a single trained model result.
25
+
26
+ The expected directory structure for each model is as follows:
27
+ ```
28
+ target_directory
29
+ ├── model_1
30
+ │ ├── *.pth
31
+ │ ├── scaler_*.pth (Required if `load_scaler` is True)
32
+ │ ├── feature_names.txt
33
+ │ ├── target_names.txt
34
+ │ └── architecture.json
35
+ └── model_2/
36
+ └── ...
37
+ ```
38
+
39
+ Args:
40
+ target_directory (str | Path): The path to the root directory that contains model subdirectories.
41
+ load_scaler (bool): If True, the function requires and searches for a scaler file (`.pth`) in each model subdirectory.
42
+ verbose (bool): If True, enables detailed logging during the file paths search process.
43
+
44
+ Returns:
45
+ (list[dict[str, Path]]): A list of dictionaries, where each dictionary
46
+ corresponds to a model found in a subdirectory. The dictionary
47
+ maps standardized keys to the absolute paths of the model's
48
+ artifacts (weights, architecture, features, targets, and scaler).
49
+ The scaler path will be `None` if `load_scaler` is False.
50
+ """
51
+ # validate directory
52
+ root_path = make_fullpath(target_directory, enforce="directory")
53
+
54
+ # store results
55
+ all_artifacts: list[dict] = list()
56
+
57
+ # find model directories
58
+ result_dirs_dict = list_subdirectories(root_dir=root_path, verbose=verbose)
59
+ for dir_name, dir_path in result_dirs_dict.items():
60
+ # find files
61
+ model_pth_dict = list_files_by_extension(directory=dir_path, extension="pth", verbose=verbose)
62
+
63
+ # restriction
64
+ if load_scaler:
65
+ if len(model_pth_dict) != 2:
66
+ _LOGGER.error(f"Directory {dir_path} should contain exactly 2 '.pth' files: scaler and weights.")
67
+ raise IOError()
68
+ else:
69
+ if len(model_pth_dict) != 1:
70
+ _LOGGER.error(f"Directory {dir_path} should contain exactly 1 '.pth' file: weights.")
71
+ raise IOError()
72
+
73
+ ##### Scaler and Weights #####
74
+ scaler_path = None
75
+ weights_path = None
76
+
77
+ # load weights and scaler if present
78
+ for pth_filename, pth_path in model_pth_dict.items():
79
+ if load_scaler and pth_filename.lower().startswith(DatasetKeys.SCALER_PREFIX):
80
+ scaler_path = pth_path
81
+ else:
82
+ weights_path = pth_path
83
+
84
+ # validation
85
+ if not weights_path:
86
+ _LOGGER.error(f"Error parsing the model weights path from '{dir_name}'")
87
+ raise IOError()
88
+
89
+ if load_scaler and not scaler_path:
90
+ _LOGGER.error(f"Error parsing the scaler path from '{dir_name}'")
91
+ raise IOError()
92
+
93
+ ##### Target and Feature names #####
94
+ target_names_path = None
95
+ feature_names_path = None
96
+
97
+ # load feature and target names
98
+ model_txt_dict = list_files_by_extension(directory=dir_path, extension="txt", verbose=verbose)
99
+
100
+ for txt_filename, txt_path in model_txt_dict.items():
101
+ if txt_filename == DatasetKeys.FEATURE_NAMES:
102
+ feature_names_path = txt_path
103
+ elif txt_filename == DatasetKeys.TARGET_NAMES:
104
+ target_names_path = txt_path
105
+
106
+ # validation
107
+ if not target_names_path or not feature_names_path:
108
+ _LOGGER.error(f"Error parsing features path or targets path from '{dir_name}'")
109
+ raise IOError()
110
+
111
+ ##### load model architecture path #####
112
+ architecture_path = None
113
+
114
+ model_json_dict = list_files_by_extension(directory=dir_path, extension="json", verbose=verbose)
115
+
116
+ for json_filename, json_path in model_json_dict.items():
117
+ if json_filename == PytorchModelArchitectureKeys.SAVENAME:
118
+ architecture_path = json_path
119
+
120
+ # validation
121
+ if not architecture_path:
122
+ _LOGGER.error(f"Error parsing the model architecture path from '{dir_name}'")
123
+ raise IOError()
124
+
125
+ ##### Paths dictionary #####
126
+ parsing_dict = {
127
+ PytorchArtifactPathKeys.WEIGHTS_PATH: weights_path,
128
+ PytorchArtifactPathKeys.ARCHITECTURE_PATH: architecture_path,
129
+ PytorchArtifactPathKeys.FEATURES_PATH: feature_names_path,
130
+ PytorchArtifactPathKeys.TARGETS_PATH: target_names_path,
131
+ PytorchArtifactPathKeys.SCALER_PATH: scaler_path
132
+ }
133
+
134
+ all_artifacts.append(parsing_dict)
135
+
136
+ return all_artifacts
137
+
138
+
139
+ def select_features_by_shap(
140
+ root_directory: Union[str, Path],
141
+ shap_threshold: float,
142
+ verbose: bool = True) -> list[str]:
143
+ """
144
+ Scans subdirectories to find SHAP summary CSVs, then extracts feature
145
+ names whose mean absolute SHAP value meets a specified threshold.
146
+
147
+ This function is useful for automated feature selection based on feature
148
+ importance scores aggregated from multiple models.
149
+
150
+ Args:
151
+ root_directory (Union[str, Path]):
152
+ The path to the root directory that contains model subdirectories.
153
+ shap_threshold (float):
154
+ The minimum mean absolute SHAP value for a feature to be included
155
+ in the final list.
156
+
157
+ Returns:
158
+ list[str]:
159
+ A single, sorted list of unique feature names that meet the
160
+ threshold criteria across all found files.
161
+ """
162
+ if verbose:
163
+ _LOGGER.info(f"Starting feature selection with SHAP threshold >= {shap_threshold}")
164
+ root_path = make_fullpath(root_directory, enforce="directory")
165
+
166
+ # --- Step 2: Directory and File Discovery ---
167
+ subdirectories = list_subdirectories(root_dir=root_path, verbose=False)
168
+
169
+ shap_filename = SHAPKeys.SAVENAME + ".csv"
170
+
171
+ valid_csv_paths = []
172
+ for dir_name, dir_path in subdirectories.items():
173
+ expected_path = dir_path / shap_filename
174
+ if expected_path.is_file():
175
+ valid_csv_paths.append(expected_path)
176
+ else:
177
+ _LOGGER.warning(f"No '{shap_filename}' found in subdirectory '{dir_name}'.")
178
+
179
+ if not valid_csv_paths:
180
+ _LOGGER.error(f"Process halted: No '{shap_filename}' files were found in any subdirectory.")
181
+ return []
182
+
183
+ if verbose:
184
+ _LOGGER.info(f"Found {len(valid_csv_paths)} SHAP summary files to process.")
185
+
186
+ # --- Step 3: Data Processing and Feature Extraction ---
187
+ master_feature_set = set()
188
+ for csv_path in valid_csv_paths:
189
+ try:
190
+ df, _ = load_dataframe(csv_path, kind="pandas", verbose=False)
191
+
192
+ # Validate required columns
193
+ required_cols = {SHAPKeys.FEATURE_COLUMN, SHAPKeys.SHAP_VALUE_COLUMN}
194
+ if not required_cols.issubset(df.columns):
195
+ _LOGGER.warning(f"Skipping '{csv_path}': missing required columns.")
196
+ continue
197
+
198
+ # Filter by threshold and extract features
199
+ filtered_df = df[df[SHAPKeys.SHAP_VALUE_COLUMN] >= shap_threshold]
200
+ features = filtered_df[SHAPKeys.FEATURE_COLUMN].tolist()
201
+ master_feature_set.update(features)
202
+
203
+ except (ValueError, pd.errors.EmptyDataError):
204
+ _LOGGER.warning(f"Skipping '{csv_path}' because it is empty or malformed.")
205
+ continue
206
+ except Exception as e:
207
+ _LOGGER.error(f"An unexpected error occurred while processing '{csv_path}': {e}")
208
+ continue
209
+
210
+ # --- Step 4: Finalize and Return ---
211
+ final_features = sorted(list(master_feature_set))
212
+ if verbose:
213
+ _LOGGER.info(f"Selected {len(final_features)} unique features across all files.")
214
+
215
+ return final_features
216
+
217
+
218
+ def info():
219
+ _script_info(__all__)
@@ -4,18 +4,17 @@ import xgboost as xgb
4
4
  import lightgbm as lgb
5
5
  from typing import Literal, Union, Tuple, Dict, Optional
6
6
  from copy import deepcopy
7
- from .utilities import (
8
- threshold_binary_values,
9
- threshold_binary_values_batch,
10
- deserialize_object)
11
- from .path_manager import sanitize_filename, make_fullpath, list_files_by_extension
12
7
  import torch
13
8
  from tqdm import trange
9
+ from contextlib import nullcontext
10
+
11
+ from .serde import deserialize_object
12
+ from .math_utilities import threshold_binary_values, threshold_binary_values_batch
13
+ from .path_manager import sanitize_filename, make_fullpath, list_files_by_extension
14
14
  from ._logger import _LOGGER
15
15
  from .keys import EnsembleKeys
16
16
  from ._script_info import _script_info
17
17
  from .SQL import DatabaseManager
18
- from contextlib import nullcontext
19
18
  from .optimization_tools import _save_result
20
19
 
21
20