dragon-ml-toolbox 10.7.0__py3-none-any.whl → 10.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.7.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-10.7.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/RECORD +15 -15
- ml_tools/ML_datasetmaster.py +27 -20
- ml_tools/ML_evaluation.py +6 -4
- ml_tools/ML_models.py +11 -7
- ml_tools/ML_scaler.py +6 -4
- ml_tools/SQL.py +4 -2
- ml_tools/ensemble_evaluation.py +48 -1
- ml_tools/keys.py +26 -3
- ml_tools/path_manager.py +33 -0
- ml_tools/utilities.py +242 -20
- {dragon_ml_toolbox-10.7.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-10.7.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-10.7.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-10.7.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,36 +1,36 @@
|
|
|
1
|
-
dragon_ml_toolbox-10.
|
|
2
|
-
dragon_ml_toolbox-10.
|
|
1
|
+
dragon_ml_toolbox-10.9.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-10.9.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=lSP5q6-ukGhJBPV8dlsqJvPXAzj4du_0J-SbtEd0Pjg,19292
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=a6KCWH6kRatZtjaFEF_o917ApPMK5_vRD-BjfCDAl-E,49400
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=kEQWg-bog3pB5tI22gMGKWaCGHnz9TB2Lvvfhf5F2CI,45412
|
|
6
6
|
ml_tools/MICE_imputation.py,sha256=kVSythWfxJFR4-2mtcYCWQaQ1Oz5yyx_SJu5gjnS7H8,11670
|
|
7
7
|
ml_tools/ML_callbacks.py,sha256=JPvEw_cW5tYNJ2rMSgnNrKLuni_UrmuhDFaOw-u2SvA,13926
|
|
8
|
-
ml_tools/ML_datasetmaster.py,sha256=
|
|
9
|
-
ml_tools/ML_evaluation.py,sha256=
|
|
8
|
+
ml_tools/ML_datasetmaster.py,sha256=BMmdCVAZ-HSnnSPLzKla2TdZKvHkHj4t9A0V1Ba3i-I,30821
|
|
9
|
+
ml_tools/ML_evaluation.py,sha256=q4_RsBjmidc_yDX-DQvpJW8RCHrOCJbgXKBORQdt-TM,16111
|
|
10
10
|
ml_tools/ML_evaluation_multi.py,sha256=2jTSNFCu8cz5C05EusnrDyffs59M2Fq3UXSHxo2TR1A,12515
|
|
11
11
|
ml_tools/ML_inference.py,sha256=SGDPiPxs_OYDKKRZziFMyaWcC8A37c70W9t-dMP5niI,23066
|
|
12
|
-
ml_tools/ML_models.py,sha256=
|
|
12
|
+
ml_tools/ML_models.py,sha256=FliuqGhxP7AWHCweTLlfssXFOjwvFhIYJsgj_w_-EI4,27901
|
|
13
13
|
ml_tools/ML_optimization.py,sha256=a2Uxe1g-y4I-gFa8ENIM8QDS-Pz3hoPRRaVXAWMbyQA,13491
|
|
14
|
-
ml_tools/ML_scaler.py,sha256=
|
|
14
|
+
ml_tools/ML_scaler.py,sha256=IrZsAr1xjvuLi8s5IKR-qbk2mS_awl3mn_xoXg5TJyA,7535
|
|
15
15
|
ml_tools/ML_trainer.py,sha256=xw1zMgYpdqwsTt604xe3GTQNvpg6z6Ze-avmitGBFeU,23539
|
|
16
16
|
ml_tools/PSO_optimization.py,sha256=q0VYpssQGbPum7xdnkDXlJQKhZMYZo8acHpKhajPK3c,22954
|
|
17
17
|
ml_tools/RNN_forecast.py,sha256=8rNZr-eWOBXMiDQV22e_tQTPM5LM2IFggEAa1FaoXaI,1965
|
|
18
|
-
ml_tools/SQL.py,sha256=
|
|
18
|
+
ml_tools/SQL.py,sha256=givoz6CGWRUdqnBem3VGZxzGdo3ZbX00kyHNjzI8kWE,10803
|
|
19
19
|
ml_tools/VIF_factor.py,sha256=MkMh_RIdsN2XUPzKNGRiEcmB17R_MmvGV4ezpL5zD2E,10403
|
|
20
20
|
ml_tools/__init__.py,sha256=q0y9faQ6e17XCQ7eUiCZ1FJ4Bg5EQqLjZ9f_l5REUUY,41
|
|
21
21
|
ml_tools/_logger.py,sha256=wcImAiXEZKPNcwM30qBh3t7HvoPURonJY0nrgMGF0sM,4719
|
|
22
22
|
ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
23
23
|
ml_tools/custom_logger.py,sha256=ry43hk54K6xKo8jRAgq1sFxUpOA9T0LIJ7sw0so2BW0,5880
|
|
24
24
|
ml_tools/data_exploration.py,sha256=4McT2BR9muK4JVVTKUfvRyThe0m_o2vpy9RJ1f_1FeY,28692
|
|
25
|
-
ml_tools/ensemble_evaluation.py,sha256=
|
|
25
|
+
ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNVGg,26853
|
|
26
26
|
ml_tools/ensemble_inference.py,sha256=EFHnbjbu31fcVp88NBx8lWAVdu2Gpg9MY9huVZJHFfM,9350
|
|
27
27
|
ml_tools/ensemble_learning.py,sha256=3s0kH4i_naj0IVl_T4knst-Hwg4TScWjEdsXX5KAi7I,21929
|
|
28
28
|
ml_tools/handle_excel.py,sha256=He4UT15sCGhaG-JKfs7uYVAubxWjrqgJ6U7OhMR2fuE,14005
|
|
29
|
-
ml_tools/keys.py,sha256=
|
|
29
|
+
ml_tools/keys.py,sha256=FDpbS3Jb0pjrVvvp2_8nZi919mbob_-xwuy5OOtKM_A,1848
|
|
30
30
|
ml_tools/optimization_tools.py,sha256=P3I6lIpvZ8Xf2kX5FvvBKBmrK2pB6idBpkTzfUJxTeE,5073
|
|
31
|
-
ml_tools/path_manager.py,sha256=
|
|
32
|
-
ml_tools/utilities.py,sha256=
|
|
33
|
-
dragon_ml_toolbox-10.
|
|
34
|
-
dragon_ml_toolbox-10.
|
|
35
|
-
dragon_ml_toolbox-10.
|
|
36
|
-
dragon_ml_toolbox-10.
|
|
31
|
+
ml_tools/path_manager.py,sha256=wLJlz3Y9_1-LB9em4B2VYDCVuTOX2eOc7D6hbbebjgM,14990
|
|
32
|
+
ml_tools/utilities.py,sha256=30z0x1aDLyBGzF98_tgSaxwFafYwQS-GTFzXHopBSGc,29105
|
|
33
|
+
dragon_ml_toolbox-10.9.0.dist-info/METADATA,sha256=NK8z4StYOVR0ByF_l-vNjyrFgbb2qddBa6lOzlQsZrg,6968
|
|
34
|
+
dragon_ml_toolbox-10.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
35
|
+
dragon_ml_toolbox-10.9.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
36
|
+
dragon_ml_toolbox-10.9.0.dist-info/RECORD,,
|
ml_tools/ML_datasetmaster.py
CHANGED
|
@@ -15,6 +15,8 @@ from ._logger import _LOGGER
|
|
|
15
15
|
from ._script_info import _script_info
|
|
16
16
|
from .custom_logger import save_list_strings
|
|
17
17
|
from .ML_scaler import PytorchScaler
|
|
18
|
+
from .keys import DatasetKeys
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
__all__ = [
|
|
20
22
|
"DatasetMaker",
|
|
@@ -91,6 +93,7 @@ class _BaseDatasetMaker(ABC):
|
|
|
91
93
|
self.scaler: Optional[PytorchScaler] = None
|
|
92
94
|
self._id: Optional[str] = None
|
|
93
95
|
self._feature_names: List[str] = []
|
|
96
|
+
self._target_names: List[str] = []
|
|
94
97
|
self._X_train_shape = (0,0)
|
|
95
98
|
self._X_test_shape = (0,0)
|
|
96
99
|
self._y_train_shape = (0,)
|
|
@@ -142,6 +145,10 @@ class _BaseDatasetMaker(ABC):
|
|
|
142
145
|
@property
|
|
143
146
|
def feature_names(self) -> list[str]:
|
|
144
147
|
return self._feature_names
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def target_names(self) -> list[str]:
|
|
151
|
+
return self._target_names
|
|
145
152
|
|
|
146
153
|
@property
|
|
147
154
|
def id(self) -> Optional[str]:
|
|
@@ -162,10 +169,17 @@ class _BaseDatasetMaker(ABC):
|
|
|
162
169
|
"""Saves a list of feature names as a text file"""
|
|
163
170
|
save_list_strings(list_strings=self._feature_names,
|
|
164
171
|
directory=directory,
|
|
165
|
-
filename=
|
|
166
|
-
verbose=verbose)
|
|
172
|
+
filename=DatasetKeys.FEATURE_NAMES,
|
|
173
|
+
verbose=verbose)
|
|
174
|
+
|
|
175
|
+
def save_target_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
|
|
176
|
+
"""Saves a list of target names as a text file"""
|
|
177
|
+
save_list_strings(list_strings=self._target_names,
|
|
178
|
+
directory=directory,
|
|
179
|
+
filename=DatasetKeys.TARGET_NAMES,
|
|
180
|
+
verbose=verbose)
|
|
167
181
|
|
|
168
|
-
def save_scaler(self, save_dir: Union[str, Path]):
|
|
182
|
+
def save_scaler(self, save_dir: Union[str, Path], verbose: bool=True) -> None:
|
|
169
183
|
"""
|
|
170
184
|
Saves the fitted PytorchScaler's state to a .pth file.
|
|
171
185
|
|
|
@@ -178,14 +192,15 @@ class _BaseDatasetMaker(ABC):
|
|
|
178
192
|
_LOGGER.error("No scaler was fitted or provided.")
|
|
179
193
|
raise RuntimeError()
|
|
180
194
|
if not self.id:
|
|
181
|
-
_LOGGER.error("Must set the `id` before saving scaler.")
|
|
195
|
+
_LOGGER.error("Must set the dataset `id` before saving scaler.")
|
|
182
196
|
raise ValueError()
|
|
183
197
|
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
184
198
|
sanitized_id = sanitize_filename(self.id)
|
|
185
|
-
filename = f"
|
|
199
|
+
filename = f"{DatasetKeys.SCALER_PREFIX}{sanitized_id}.pth"
|
|
186
200
|
filepath = save_path / filename
|
|
187
|
-
self.scaler.save(filepath)
|
|
188
|
-
|
|
201
|
+
self.scaler.save(filepath, verbose=False)
|
|
202
|
+
if verbose:
|
|
203
|
+
_LOGGER.info(f"Scaler for dataset '{self.id}' saved to '{filepath.name}'.")
|
|
189
204
|
|
|
190
205
|
|
|
191
206
|
# Single target dataset
|
|
@@ -203,7 +218,7 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
203
218
|
`train_dataset` -> PyTorch Dataset
|
|
204
219
|
`test_dataset` -> PyTorch Dataset
|
|
205
220
|
`feature_names` -> list[str]
|
|
206
|
-
`
|
|
221
|
+
`target_names` -> list[str]
|
|
207
222
|
`id` -> str
|
|
208
223
|
|
|
209
224
|
The ID can be manually set to any string if needed, it is the target name by default.
|
|
@@ -231,8 +246,8 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
231
246
|
features = pandas_df.iloc[:, :-1]
|
|
232
247
|
target = pandas_df.iloc[:, -1]
|
|
233
248
|
self._feature_names = features.columns.tolist()
|
|
234
|
-
self.
|
|
235
|
-
self._id = self.
|
|
249
|
+
self._target_names = [str(target.name)]
|
|
250
|
+
self._id = self._target_names[0]
|
|
236
251
|
|
|
237
252
|
# --- 2. Split ---
|
|
238
253
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
@@ -249,12 +264,8 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
249
264
|
)
|
|
250
265
|
|
|
251
266
|
# --- 4. Create Datasets ---
|
|
252
|
-
self._train_ds = _PytorchDataset(X_train_final, y_train.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=
|
|
253
|
-
self._test_ds = _PytorchDataset(X_test_final, y_test.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=
|
|
254
|
-
|
|
255
|
-
@property
|
|
256
|
-
def target_name(self) -> str:
|
|
257
|
-
return self._target_name
|
|
267
|
+
self._train_ds = _PytorchDataset(X_train_final, y_train.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
268
|
+
self._test_ds = _PytorchDataset(X_test_final, y_test.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
258
269
|
|
|
259
270
|
|
|
260
271
|
# --- New Multi-Target Class ---
|
|
@@ -303,10 +314,6 @@ class DatasetMakerMulti(_BaseDatasetMaker):
|
|
|
303
314
|
self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
304
315
|
self._test_ds = _PytorchDataset(X_test_final, y_test, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
305
316
|
|
|
306
|
-
@property
|
|
307
|
-
def target_names(self) -> list[str]:
|
|
308
|
-
return self._target_names
|
|
309
|
-
|
|
310
317
|
|
|
311
318
|
# --- Private Base Class ---
|
|
312
319
|
class _BaseMaker(ABC):
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -22,6 +22,7 @@ from .path_manager import make_fullpath
|
|
|
22
22
|
from ._logger import _LOGGER
|
|
23
23
|
from typing import Union, Optional, List
|
|
24
24
|
from ._script_info import _script_info
|
|
25
|
+
from .keys import SHAPKeys
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
__all__ = [
|
|
@@ -333,7 +334,8 @@ def shap_summary_plot(model,
|
|
|
333
334
|
plt.close()
|
|
334
335
|
|
|
335
336
|
# Save Summary Data to CSV
|
|
336
|
-
|
|
337
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
338
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
337
339
|
# Ensure the array is 1D before creating the DataFrame
|
|
338
340
|
mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
|
|
339
341
|
|
|
@@ -341,9 +343,9 @@ def shap_summary_plot(model,
|
|
|
341
343
|
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
342
344
|
|
|
343
345
|
summary_df = pd.DataFrame({
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
}).sort_values(
|
|
346
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
347
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
348
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
347
349
|
|
|
348
350
|
summary_df.to_csv(summary_path, index=False)
|
|
349
351
|
|
ml_tools/ML_models.py
CHANGED
|
@@ -6,7 +6,7 @@ import json
|
|
|
6
6
|
from ._logger import _LOGGER
|
|
7
7
|
from .path_manager import make_fullpath
|
|
8
8
|
from ._script_info import _script_info
|
|
9
|
-
from .keys import
|
|
9
|
+
from .keys import PytorchModelArchitectureKeys
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
__all__ = [
|
|
@@ -29,11 +29,14 @@ class _ArchitectureHandlerMixin:
|
|
|
29
29
|
raise AttributeError()
|
|
30
30
|
|
|
31
31
|
path_dir = make_fullpath(directory, make=True, enforce="directory")
|
|
32
|
-
|
|
32
|
+
|
|
33
|
+
json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
|
|
34
|
+
|
|
35
|
+
full_path = path_dir / json_filename
|
|
33
36
|
|
|
34
37
|
config = {
|
|
35
|
-
|
|
36
|
-
|
|
38
|
+
PytorchModelArchitectureKeys.MODEL: self.__class__.__name__,
|
|
39
|
+
PytorchModelArchitectureKeys.CONFIG: self.get_architecture_config() # type: ignore
|
|
37
40
|
}
|
|
38
41
|
|
|
39
42
|
with open(full_path, 'w') as f:
|
|
@@ -48,7 +51,8 @@ class _ArchitectureHandlerMixin:
|
|
|
48
51
|
user_path = make_fullpath(file_or_dir)
|
|
49
52
|
|
|
50
53
|
if user_path.is_dir():
|
|
51
|
-
|
|
54
|
+
json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
|
|
55
|
+
target_path = make_fullpath(user_path / json_filename, enforce="file")
|
|
52
56
|
elif user_path.is_file():
|
|
53
57
|
target_path = user_path
|
|
54
58
|
else:
|
|
@@ -58,8 +62,8 @@ class _ArchitectureHandlerMixin:
|
|
|
58
62
|
with open(target_path, 'r') as f:
|
|
59
63
|
saved_data = json.load(f)
|
|
60
64
|
|
|
61
|
-
saved_class_name = saved_data[
|
|
62
|
-
config = saved_data[
|
|
65
|
+
saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
|
|
66
|
+
config = saved_data[PytorchModelArchitectureKeys.CONFIG]
|
|
63
67
|
|
|
64
68
|
if saved_class_name != cls.__name__:
|
|
65
69
|
_LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
|
ml_tools/ML_scaler.py
CHANGED
|
@@ -149,7 +149,7 @@ class PytorchScaler:
|
|
|
149
149
|
|
|
150
150
|
return data_clone
|
|
151
151
|
|
|
152
|
-
def save(self, filepath: Union[str, Path]):
|
|
152
|
+
def save(self, filepath: Union[str, Path], verbose: bool=True):
|
|
153
153
|
"""
|
|
154
154
|
Saves the scaler's state (mean, std, indices) to a .pth file.
|
|
155
155
|
|
|
@@ -163,10 +163,11 @@ class PytorchScaler:
|
|
|
163
163
|
'continuous_feature_indices': self.continuous_feature_indices
|
|
164
164
|
}
|
|
165
165
|
torch.save(state, path_obj)
|
|
166
|
-
|
|
166
|
+
if verbose:
|
|
167
|
+
_LOGGER.info(f"PytorchScaler state saved to '{path_obj.name}'.")
|
|
167
168
|
|
|
168
169
|
@staticmethod
|
|
169
|
-
def load(filepath: Union[str, Path]) -> 'PytorchScaler':
|
|
170
|
+
def load(filepath: Union[str, Path], verbose: bool=True) -> 'PytorchScaler':
|
|
170
171
|
"""
|
|
171
172
|
Loads a scaler's state from a .pth file.
|
|
172
173
|
|
|
@@ -178,7 +179,8 @@ class PytorchScaler:
|
|
|
178
179
|
"""
|
|
179
180
|
path_obj = make_fullpath(filepath, enforce="file")
|
|
180
181
|
state = torch.load(path_obj)
|
|
181
|
-
|
|
182
|
+
if verbose:
|
|
183
|
+
_LOGGER.info(f"PytorchScaler state loaded from '{path_obj.name}'.")
|
|
182
184
|
return PytorchScaler(
|
|
183
185
|
mean=state['mean'],
|
|
184
186
|
std=state['std'],
|
ml_tools/SQL.py
CHANGED
|
@@ -4,7 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
from typing import Union, Dict, Any, Optional, List, Literal
|
|
5
5
|
from ._logger import _LOGGER
|
|
6
6
|
from ._script_info import _script_info
|
|
7
|
-
from .path_manager import make_fullpath
|
|
7
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
@@ -94,11 +94,13 @@ class DatabaseManager:
|
|
|
94
94
|
if not self.cursor:
|
|
95
95
|
_LOGGER.error("Database connection is not open.")
|
|
96
96
|
raise sqlite3.Error()
|
|
97
|
+
|
|
98
|
+
sanitized_table_name = sanitize_filename(table_name)
|
|
97
99
|
|
|
98
100
|
columns_def = ", ".join([f'"{col_name}" {col_type}' for col_name, col_type in schema.items()])
|
|
99
101
|
exists_clause = "IF NOT EXISTS" if if_not_exists else ""
|
|
100
102
|
|
|
101
|
-
query = f"CREATE TABLE {exists_clause} {
|
|
103
|
+
query = f"CREATE TABLE {exists_clause} {sanitized_table_name} ({columns_def})"
|
|
102
104
|
|
|
103
105
|
_LOGGER.info(f"➡️ Executing: {query}")
|
|
104
106
|
self.cursor.execute(query)
|
ml_tools/ensemble_evaluation.py
CHANGED
|
@@ -25,6 +25,7 @@ from typing import Union, Optional, Literal
|
|
|
25
25
|
from .path_manager import sanitize_filename, make_fullpath
|
|
26
26
|
from ._script_info import _script_info
|
|
27
27
|
from ._logger import _LOGGER
|
|
28
|
+
from .keys import SHAPKeys
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
__all__ = [
|
|
@@ -472,7 +473,7 @@ def get_shap_values(
|
|
|
472
473
|
save_dir: Directory to save visualizations.
|
|
473
474
|
"""
|
|
474
475
|
sanitized_target_name = sanitize_filename(target_name)
|
|
475
|
-
global_save_path = make_fullpath(save_dir, make=True)
|
|
476
|
+
global_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
476
477
|
|
|
477
478
|
def _apply_plot_style():
|
|
478
479
|
styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
@@ -539,6 +540,15 @@ def get_shap_values(
|
|
|
539
540
|
plot_type=plot_type,
|
|
540
541
|
title=f"{model_name} - {target_name} (Class {class_name})"
|
|
541
542
|
)
|
|
543
|
+
|
|
544
|
+
# Save the summary data for the current class
|
|
545
|
+
summary_save_path = global_save_path / f"SHAP_{sanitized_target_name}_{class_name}.csv"
|
|
546
|
+
_save_summary_csv(
|
|
547
|
+
shap_values_for_summary=class_shap,
|
|
548
|
+
feature_names=feature_names,
|
|
549
|
+
save_path=summary_save_path
|
|
550
|
+
)
|
|
551
|
+
|
|
542
552
|
else:
|
|
543
553
|
values = shap_values[1] if isinstance(shap_values, list) else shap_values
|
|
544
554
|
for plot_type in ["bar", "dot"]:
|
|
@@ -549,6 +559,15 @@ def get_shap_values(
|
|
|
549
559
|
plot_type=plot_type,
|
|
550
560
|
title=f"{model_name} - {target_name}"
|
|
551
561
|
)
|
|
562
|
+
|
|
563
|
+
# Save the summary data for the positive class
|
|
564
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
565
|
+
summary_save_path = global_save_path / shap_summary_filename
|
|
566
|
+
_save_summary_csv(
|
|
567
|
+
shap_values_for_summary=values,
|
|
568
|
+
feature_names=feature_names,
|
|
569
|
+
save_path=summary_save_path
|
|
570
|
+
)
|
|
552
571
|
|
|
553
572
|
def _plot_for_regression(shap_values):
|
|
554
573
|
for plot_type in ["bar", "dot"]:
|
|
@@ -559,6 +578,34 @@ def get_shap_values(
|
|
|
559
578
|
plot_type=plot_type,
|
|
560
579
|
title=f"{model_name} - {target_name}"
|
|
561
580
|
)
|
|
581
|
+
|
|
582
|
+
# Save the summary data to a CSV file
|
|
583
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
584
|
+
summary_save_path = global_save_path / shap_summary_filename
|
|
585
|
+
_save_summary_csv(
|
|
586
|
+
shap_values_for_summary=shap_values,
|
|
587
|
+
feature_names=feature_names,
|
|
588
|
+
save_path=summary_save_path
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def _save_summary_csv(shap_values_for_summary: np.ndarray, feature_names: list[str], save_path: Path):
|
|
592
|
+
"""Calculates and saves the SHAP summary data to a CSV file."""
|
|
593
|
+
mean_abs_shap = np.abs(shap_values_for_summary).mean(axis=0)
|
|
594
|
+
|
|
595
|
+
# Create default feature names if none are provided
|
|
596
|
+
current_feature_names = feature_names
|
|
597
|
+
if current_feature_names is None:
|
|
598
|
+
current_feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
599
|
+
|
|
600
|
+
summary_df = pd.DataFrame({
|
|
601
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
602
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
603
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
604
|
+
|
|
605
|
+
summary_df.to_csv(save_path, index=False)
|
|
606
|
+
# print(f"📝 SHAP summary data saved as '{save_path.name}'")
|
|
607
|
+
|
|
608
|
+
|
|
562
609
|
#START_O
|
|
563
610
|
|
|
564
611
|
explainer = shap.TreeExplainer(model)
|
ml_tools/keys.py
CHANGED
|
@@ -38,11 +38,34 @@ class PyTorchInferenceKeys:
|
|
|
38
38
|
PROBABILITIES = "probabilities"
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
class
|
|
42
|
-
"""Keys for saving and loading
|
|
41
|
+
class PytorchModelArchitectureKeys:
|
|
42
|
+
"""Keys for saving and loading model architecture."""
|
|
43
43
|
MODEL = 'model_class'
|
|
44
44
|
CONFIG = "config"
|
|
45
|
-
SAVENAME = "architecture
|
|
45
|
+
SAVENAME = "architecture"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class PytorchArtifactPathKeys:
|
|
49
|
+
"""Keys for model artifact paths."""
|
|
50
|
+
FEATURES_PATH = "feature_names_path"
|
|
51
|
+
TARGETS_PATH = "target_names_path"
|
|
52
|
+
ARCHITECTURE_PATH = "model_architecture_path"
|
|
53
|
+
WEIGHTS_PATH = "model_weights_path"
|
|
54
|
+
SCALER_PATH = "scaler_path"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DatasetKeys:
|
|
58
|
+
"""Keys for saving dataset artifacts"""
|
|
59
|
+
FEATURE_NAMES = "feature_names"
|
|
60
|
+
TARGET_NAMES = "target_names"
|
|
61
|
+
SCALER_PREFIX = "scaler_"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SHAPKeys:
|
|
65
|
+
"""Keys for SHAP functions"""
|
|
66
|
+
FEATURE_COLUMN = "feature"
|
|
67
|
+
SHAP_VALUE_COLUMN = "mean_abs_shap_value"
|
|
68
|
+
SAVENAME = "shap_summary"
|
|
46
69
|
|
|
47
70
|
|
|
48
71
|
class _OneHotOtherPlaceholder:
|
ml_tools/path_manager.py
CHANGED
|
@@ -13,6 +13,7 @@ __all__ = [
|
|
|
13
13
|
"sanitize_filename",
|
|
14
14
|
"list_csv_paths",
|
|
15
15
|
"list_files_by_extension",
|
|
16
|
+
"list_subdirectories"
|
|
16
17
|
]
|
|
17
18
|
|
|
18
19
|
|
|
@@ -385,5 +386,37 @@ def list_files_by_extension(directory: Union[str,Path], extension: str, verbose:
|
|
|
385
386
|
return name_path_dict
|
|
386
387
|
|
|
387
388
|
|
|
389
|
+
def list_subdirectories(root_dir: Union[str,Path], verbose: bool=True) -> dict[str, Path]:
|
|
390
|
+
"""
|
|
391
|
+
Scans a directory and returns a dictionary of its immediate subdirectories.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
root_dir (str | Path): The path to the directory to scan.
|
|
395
|
+
verbose (bool): If True, prints the number of directories found.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
dict[str, Path]: A dictionary mapping subdirectory names (str) to their full Path objects.
|
|
399
|
+
"""
|
|
400
|
+
root_path = make_fullpath(root_dir, enforce="directory")
|
|
401
|
+
|
|
402
|
+
directories = [p.resolve() for p in root_path.iterdir() if p.is_dir()]
|
|
403
|
+
|
|
404
|
+
if len(directories) < 1:
|
|
405
|
+
_LOGGER.error(f"No subdirectories found inside '{root_path}'")
|
|
406
|
+
raise IOError()
|
|
407
|
+
|
|
408
|
+
if verbose:
|
|
409
|
+
count = len(directories)
|
|
410
|
+
# Use pluralization for better readability
|
|
411
|
+
plural = 'ies' if count != 1 else 'y'
|
|
412
|
+
print(f"Found {count} subdirector{plural} in '{root_path.name}'.")
|
|
413
|
+
|
|
414
|
+
# Create a dictionary where the key is the directory's name (a string)
|
|
415
|
+
# and the value is the full Path object.
|
|
416
|
+
dir_map = {p.name: p for p in directories}
|
|
417
|
+
|
|
418
|
+
return dir_map
|
|
419
|
+
|
|
420
|
+
|
|
388
421
|
def info():
|
|
389
422
|
_script_info(__all__)
|
ml_tools/utilities.py
CHANGED
|
@@ -6,9 +6,10 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple, overload
|
|
7
7
|
import joblib
|
|
8
8
|
from joblib.externals.loky.process_executor import TerminatedWorkerError
|
|
9
|
-
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
9
|
+
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths, list_files_by_extension, list_subdirectories
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
from ._logger import _LOGGER
|
|
12
|
+
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
# Keep track of available tools
|
|
@@ -24,7 +25,9 @@ __all__ = [
|
|
|
24
25
|
"deserialize_object",
|
|
25
26
|
"distribute_dataset_by_target",
|
|
26
27
|
"train_dataset_orchestrator",
|
|
27
|
-
"train_dataset_yielder"
|
|
28
|
+
"train_dataset_yielder",
|
|
29
|
+
"find_model_artifacts",
|
|
30
|
+
"select_features_by_shap"
|
|
28
31
|
]
|
|
29
32
|
|
|
30
33
|
|
|
@@ -32,6 +35,7 @@ __all__ = [
|
|
|
32
35
|
@overload
|
|
33
36
|
def load_dataframe(
|
|
34
37
|
df_path: Union[str, Path],
|
|
38
|
+
use_columns: Optional[list[str]] = None,
|
|
35
39
|
kind: Literal["pandas"] = "pandas",
|
|
36
40
|
all_strings: bool = False,
|
|
37
41
|
verbose: bool = True
|
|
@@ -42,7 +46,8 @@ def load_dataframe(
|
|
|
42
46
|
@overload
|
|
43
47
|
def load_dataframe(
|
|
44
48
|
df_path: Union[str, Path],
|
|
45
|
-
|
|
49
|
+
use_columns: Optional[list[str]] = None,
|
|
50
|
+
kind: Literal["polars"] = "polars",
|
|
46
51
|
all_strings: bool = False,
|
|
47
52
|
verbose: bool = True
|
|
48
53
|
) -> Tuple[pl.DataFrame, str]:
|
|
@@ -50,6 +55,7 @@ def load_dataframe(
|
|
|
50
55
|
|
|
51
56
|
def load_dataframe(
|
|
52
57
|
df_path: Union[str, Path],
|
|
58
|
+
use_columns: Optional[list[str]] = None,
|
|
53
59
|
kind: Literal["pandas", "polars"] = "pandas",
|
|
54
60
|
all_strings: bool = False,
|
|
55
61
|
verbose: bool = True
|
|
@@ -58,11 +64,13 @@ def load_dataframe(
|
|
|
58
64
|
Load a CSV file into a DataFrame and extract its base name.
|
|
59
65
|
|
|
60
66
|
Can load data as either a pandas or a polars DataFrame. Allows for loading all
|
|
61
|
-
columns as string types to prevent type inference errors.
|
|
67
|
+
columns or a subset of columns as string types to prevent type inference errors.
|
|
62
68
|
|
|
63
69
|
Args:
|
|
64
70
|
df_path (str, Path):
|
|
65
71
|
The path to the CSV file.
|
|
72
|
+
use_columns (list[str] | None):
|
|
73
|
+
If provided, only these columns will be loaded from the CSV.
|
|
66
74
|
kind ("pandas", "polars"):
|
|
67
75
|
The type of DataFrame to load. Defaults to "pandas".
|
|
68
76
|
all_strings (bool):
|
|
@@ -76,28 +84,43 @@ def load_dataframe(
|
|
|
76
84
|
|
|
77
85
|
Raises:
|
|
78
86
|
FileNotFoundError: If the file does not exist at the given path.
|
|
79
|
-
ValueError: If the DataFrame is empty
|
|
87
|
+
ValueError: If the DataFrame is empty, an invalid 'kind' is provided, or a column in 'use_columns' is not found in the file.
|
|
80
88
|
"""
|
|
81
89
|
path = make_fullpath(df_path)
|
|
82
90
|
|
|
83
91
|
df_name = path.stem
|
|
84
92
|
|
|
85
|
-
|
|
86
|
-
if
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
93
|
+
try:
|
|
94
|
+
if kind == "pandas":
|
|
95
|
+
pd_kwargs: dict[str,Any]
|
|
96
|
+
pd_kwargs = {'encoding': 'utf-8'}
|
|
97
|
+
if use_columns:
|
|
98
|
+
pd_kwargs['usecols'] = use_columns
|
|
99
|
+
if all_strings:
|
|
100
|
+
pd_kwargs['dtype'] = str
|
|
101
|
+
|
|
102
|
+
df = pd.read_csv(path, **pd_kwargs)
|
|
103
|
+
|
|
104
|
+
elif kind == "polars":
|
|
105
|
+
pl_kwargs: dict[str,Any]
|
|
106
|
+
pl_kwargs = {}
|
|
107
|
+
if use_columns:
|
|
108
|
+
pl_kwargs['columns'] = use_columns
|
|
109
|
+
|
|
110
|
+
if all_strings:
|
|
111
|
+
pl_kwargs['infer_schema'] = False
|
|
112
|
+
else:
|
|
113
|
+
pl_kwargs['infer_schema_length'] = 1000
|
|
114
|
+
|
|
115
|
+
df = pl.read_csv(path, **pl_kwargs)
|
|
116
|
+
|
|
94
117
|
else:
|
|
95
|
-
|
|
96
|
-
|
|
118
|
+
_LOGGER.error(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
|
|
119
|
+
raise ValueError()
|
|
97
120
|
|
|
98
|
-
|
|
99
|
-
_LOGGER.error(f"
|
|
100
|
-
raise
|
|
121
|
+
except (ValueError, pl.exceptions.ColumnNotFoundError) as e:
|
|
122
|
+
_LOGGER.error(f"Failed to load '{df_name}'. A specified column may not exist in the file.")
|
|
123
|
+
raise e
|
|
101
124
|
|
|
102
125
|
# This check works for both pandas and polars DataFrames
|
|
103
126
|
if df.shape[0] == 0:
|
|
@@ -109,7 +132,6 @@ def load_dataframe(
|
|
|
109
132
|
|
|
110
133
|
return df, df_name # type: ignore
|
|
111
134
|
|
|
112
|
-
|
|
113
135
|
def yield_dataframes_from_dir(datasets_dir: Union[str,Path], verbose: bool=True):
|
|
114
136
|
"""
|
|
115
137
|
Iterates over all CSV files in a given directory, loading each into a Pandas DataFrame.
|
|
@@ -560,5 +582,205 @@ def train_dataset_yielder(
|
|
|
560
582
|
yield (df_features, df_target, feature_names, target_col)
|
|
561
583
|
|
|
562
584
|
|
|
585
|
+
def find_model_artifacts(target_directory: Union[str,Path], load_scaler: bool, verbose: bool=False) -> list[dict[str,Any]]:
|
|
586
|
+
"""
|
|
587
|
+
Scans subdirectories to find paths to model weights, target names, feature names, and model architecture. Optionally an scaler path if `load_scaler` is True.
|
|
588
|
+
|
|
589
|
+
This function operates on a specific directory structure. It expects the
|
|
590
|
+
`target_directory` to contain one or more subdirectories, where each
|
|
591
|
+
subdirectory represents a single trained model result.
|
|
592
|
+
|
|
593
|
+
The expected directory structure for each model is as follows:
|
|
594
|
+
```
|
|
595
|
+
target_directory
|
|
596
|
+
├── model_1
|
|
597
|
+
│ ├── *.pth
|
|
598
|
+
│ ├── scaler_*.pth (Required if `load_scaler` is True)
|
|
599
|
+
│ ├── feature_names.txt
|
|
600
|
+
│ ├── target_names.txt
|
|
601
|
+
│ └── architecture.json
|
|
602
|
+
└── model_2/
|
|
603
|
+
└── ...
|
|
604
|
+
```
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
target_directory (str | Path): The path to the root directory that contains model subdirectories.
|
|
608
|
+
load_scaler (bool): If True, the function requires and searches for a scaler file (`.pth`) in each model subdirectory.
|
|
609
|
+
verbose (bool): If True, enables detailed logging during the file paths search process.
|
|
610
|
+
|
|
611
|
+
Returns:
|
|
612
|
+
(list[dict[str, Path]]): A list of dictionaries, where each dictionary
|
|
613
|
+
corresponds to a model found in a subdirectory. The dictionary
|
|
614
|
+
maps standardized keys to the absolute paths of the model's
|
|
615
|
+
artifacts (weights, architecture, features, targets, and scaler).
|
|
616
|
+
The scaler path will be `None` if `load_scaler` is False.
|
|
617
|
+
"""
|
|
618
|
+
# validate directory
|
|
619
|
+
root_path = make_fullpath(target_directory, enforce="directory")
|
|
620
|
+
|
|
621
|
+
# store results
|
|
622
|
+
all_artifacts: list[dict] = list()
|
|
623
|
+
|
|
624
|
+
# find model directories
|
|
625
|
+
result_dirs_dict = list_subdirectories(root_dir=root_path, verbose=verbose)
|
|
626
|
+
for dir_name, dir_path in result_dirs_dict.items():
|
|
627
|
+
# find files
|
|
628
|
+
model_pth_dict = list_files_by_extension(directory=dir_path, extension="pth", verbose=verbose)
|
|
629
|
+
|
|
630
|
+
# restriction
|
|
631
|
+
if load_scaler:
|
|
632
|
+
if len(model_pth_dict) != 2:
|
|
633
|
+
_LOGGER.error(f"Directory {dir_path} should contain exactly 2 '.pth' files: scaler and weights.")
|
|
634
|
+
raise IOError()
|
|
635
|
+
else:
|
|
636
|
+
if len(model_pth_dict) != 1:
|
|
637
|
+
_LOGGER.error(f"Directory {dir_path} should contain exactly 1 '.pth' file: weights.")
|
|
638
|
+
raise IOError()
|
|
639
|
+
|
|
640
|
+
##### Scaler and Weights #####
|
|
641
|
+
scaler_path = None
|
|
642
|
+
weights_path = None
|
|
643
|
+
|
|
644
|
+
# load weights and scaler if present
|
|
645
|
+
for pth_filename, pth_path in model_pth_dict.items():
|
|
646
|
+
if load_scaler and pth_filename.lower().startswith(DatasetKeys.SCALER_PREFIX):
|
|
647
|
+
scaler_path = pth_path
|
|
648
|
+
else:
|
|
649
|
+
weights_path = pth_path
|
|
650
|
+
|
|
651
|
+
# validation
|
|
652
|
+
if not weights_path:
|
|
653
|
+
_LOGGER.error(f"Error parsing the model weights path from '{dir_name}'")
|
|
654
|
+
raise IOError()
|
|
655
|
+
|
|
656
|
+
if load_scaler and not scaler_path:
|
|
657
|
+
_LOGGER.error(f"Error parsing the scaler path from '{dir_name}'")
|
|
658
|
+
raise IOError()
|
|
659
|
+
|
|
660
|
+
##### Target and Feature names #####
|
|
661
|
+
target_names_path = None
|
|
662
|
+
feature_names_path = None
|
|
663
|
+
|
|
664
|
+
# load feature and target names
|
|
665
|
+
model_txt_dict = list_files_by_extension(directory=dir_path, extension="txt", verbose=verbose)
|
|
666
|
+
|
|
667
|
+
for txt_filename, txt_path in model_txt_dict.items():
|
|
668
|
+
if txt_filename == DatasetKeys.FEATURE_NAMES:
|
|
669
|
+
feature_names_path = txt_path
|
|
670
|
+
elif txt_filename == DatasetKeys.TARGET_NAMES:
|
|
671
|
+
target_names_path = txt_path
|
|
672
|
+
|
|
673
|
+
# validation
|
|
674
|
+
if not target_names_path or not feature_names_path:
|
|
675
|
+
_LOGGER.error(f"Error parsing features path or targets path from '{dir_name}'")
|
|
676
|
+
raise IOError()
|
|
677
|
+
|
|
678
|
+
##### load model architecture path #####
|
|
679
|
+
architecture_path = None
|
|
680
|
+
|
|
681
|
+
model_json_dict = list_files_by_extension(directory=dir_path, extension="json", verbose=verbose)
|
|
682
|
+
|
|
683
|
+
for json_filename, json_path in model_json_dict.items():
|
|
684
|
+
if json_filename == PytorchModelArchitectureKeys.SAVENAME:
|
|
685
|
+
architecture_path = json_path
|
|
686
|
+
|
|
687
|
+
# validation
|
|
688
|
+
if not architecture_path:
|
|
689
|
+
_LOGGER.error(f"Error parsing the model architecture path from '{dir_name}'")
|
|
690
|
+
raise IOError()
|
|
691
|
+
|
|
692
|
+
##### Paths dictionary #####
|
|
693
|
+
parsing_dict = {
|
|
694
|
+
PytorchArtifactPathKeys.WEIGHTS_PATH: weights_path,
|
|
695
|
+
PytorchArtifactPathKeys.ARCHITECTURE_PATH: architecture_path,
|
|
696
|
+
PytorchArtifactPathKeys.FEATURES_PATH: feature_names_path,
|
|
697
|
+
PytorchArtifactPathKeys.TARGETS_PATH: target_names_path,
|
|
698
|
+
PytorchArtifactPathKeys.SCALER_PATH: scaler_path
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
all_artifacts.append(parsing_dict)
|
|
702
|
+
|
|
703
|
+
return all_artifacts
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def select_features_by_shap(
|
|
707
|
+
root_directory: Union[str, Path],
|
|
708
|
+
shap_threshold: float = 1.0,
|
|
709
|
+
verbose: bool = True) -> list[str]:
|
|
710
|
+
"""
|
|
711
|
+
Scans subdirectories to find SHAP summary CSVs, then extracts feature
|
|
712
|
+
names whose mean absolute SHAP value meets a specified threshold.
|
|
713
|
+
|
|
714
|
+
This function is useful for automated feature selection based on feature
|
|
715
|
+
importance scores aggregated from multiple models.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
root_directory (Union[str, Path]):
|
|
719
|
+
The path to the root directory that contains model subdirectories.
|
|
720
|
+
shap_threshold (float):
|
|
721
|
+
The minimum mean absolute SHAP value for a feature to be included
|
|
722
|
+
in the final list.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
list[str]:
|
|
726
|
+
A single, sorted list of unique feature names that meet the
|
|
727
|
+
threshold criteria across all found files.
|
|
728
|
+
"""
|
|
729
|
+
if verbose:
|
|
730
|
+
_LOGGER.info(f"Starting feature selection with SHAP threshold >= {shap_threshold}")
|
|
731
|
+
root_path = make_fullpath(root_directory, enforce="directory")
|
|
732
|
+
|
|
733
|
+
# --- Step 2: Directory and File Discovery ---
|
|
734
|
+
subdirectories = list_subdirectories(root_dir=root_path, verbose=False)
|
|
735
|
+
|
|
736
|
+
shap_filename = SHAPKeys.SAVENAME + ".csv"
|
|
737
|
+
|
|
738
|
+
valid_csv_paths = []
|
|
739
|
+
for dir_name, dir_path in subdirectories.items():
|
|
740
|
+
expected_path = dir_path / shap_filename
|
|
741
|
+
if expected_path.is_file():
|
|
742
|
+
valid_csv_paths.append(expected_path)
|
|
743
|
+
else:
|
|
744
|
+
_LOGGER.warning(f"No '{shap_filename}' found in subdirectory '{dir_name}'.")
|
|
745
|
+
|
|
746
|
+
if not valid_csv_paths:
|
|
747
|
+
_LOGGER.error(f"Process halted: No '{shap_filename}' files were found in any subdirectory.")
|
|
748
|
+
return []
|
|
749
|
+
|
|
750
|
+
if verbose:
|
|
751
|
+
_LOGGER.info(f"Found {len(valid_csv_paths)} SHAP summary files to process.")
|
|
752
|
+
|
|
753
|
+
# --- Step 3: Data Processing and Feature Extraction ---
|
|
754
|
+
master_feature_set = set()
|
|
755
|
+
for csv_path in valid_csv_paths:
|
|
756
|
+
try:
|
|
757
|
+
df, _ = load_dataframe(csv_path, kind="pandas", verbose=False)
|
|
758
|
+
|
|
759
|
+
# Validate required columns
|
|
760
|
+
required_cols = {SHAPKeys.FEATURE_COLUMN, SHAPKeys.SHAP_VALUE_COLUMN}
|
|
761
|
+
if not required_cols.issubset(df.columns):
|
|
762
|
+
_LOGGER.warning(f"Skipping '{csv_path}': missing required columns.")
|
|
763
|
+
continue
|
|
764
|
+
|
|
765
|
+
# Filter by threshold and extract features
|
|
766
|
+
filtered_df = df[df[SHAPKeys.SHAP_VALUE_COLUMN] >= shap_threshold]
|
|
767
|
+
features = filtered_df[SHAPKeys.FEATURE_COLUMN].tolist()
|
|
768
|
+
master_feature_set.update(features)
|
|
769
|
+
|
|
770
|
+
except (ValueError, pd.errors.EmptyDataError):
|
|
771
|
+
_LOGGER.warning(f"Skipping '{csv_path}' because it is empty or malformed.")
|
|
772
|
+
continue
|
|
773
|
+
except Exception as e:
|
|
774
|
+
_LOGGER.error(f"An unexpected error occurred while processing '{csv_path}': {e}")
|
|
775
|
+
continue
|
|
776
|
+
|
|
777
|
+
# --- Step 4: Finalize and Return ---
|
|
778
|
+
final_features = sorted(list(master_feature_set))
|
|
779
|
+
if verbose:
|
|
780
|
+
_LOGGER.info(f"Selected {len(final_features)} unique features across all files.")
|
|
781
|
+
|
|
782
|
+
return final_features
|
|
783
|
+
|
|
784
|
+
|
|
563
785
|
def info():
|
|
564
786
|
_script_info(__all__)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|