dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,886 +0,0 @@
1
- import pandas as pd
2
- from pathlib import Path
3
- from typing import Union, Any, Optional, Dict, List, Iterable
4
- import torch
5
- from torch import nn
6
-
7
- from ._path_manager import make_fullpath, list_subdirectories, list_files_by_extension
8
- from ._script_info import _script_info
9
- from ._logger import get_logger
10
- from ._keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys, UtilityKeys, PyTorchCheckpointKeys
11
- from ._utilities import load_dataframe
12
- from ._IO_tools import save_list_strings, custom_logger, load_list_strings
13
- from ._serde import serialize_object_filename
14
- from ._schema import FeatureSchema
15
-
16
-
17
- _LOGGER = get_logger("Torch Utilities")
18
-
19
-
20
- __all__ = [
21
- "ArtifactFinder",
22
- "find_model_artifacts_multi",
23
- "build_optimizer_params",
24
- "get_model_parameters",
25
- "inspect_model_architecture",
26
- "inspect_pth_file",
27
- "set_parameter_requires_grad",
28
- "save_pretrained_transforms",
29
- "select_features_by_shap"
30
- ]
31
-
32
-
33
- class ArtifactFinder:
34
- """
35
- Finds, processes, and returns model training artifacts from a target directory.
36
-
37
- The expected directory structure is:
38
-
39
- ```
40
- directory
41
- ├── *.pth
42
- ├── scaler_*.pth (Required if `load_scaler` is True)
43
- ├── feature_names.txt
44
- ├── target_names.txt
45
- ├── architecture.json
46
- └── FeatureSchema.json (Required if `load_schema` is True)
47
- ```
48
- """
49
- def __init__(self,
50
- directory: Union[str, Path],
51
- load_scaler: bool,
52
- load_schema: bool,
53
- strict: bool=False,
54
- verbose: bool=True) -> None:
55
- """
56
- Args:
57
- directory (str | Path): The path to the directory that contains training artifacts.
58
- load_scaler (bool): If True, requires and searches for a scaler file `scaler_*.pth`.
59
- load_schema (bool): If True, requires and searches for a FeatureSchema file `FeatureSchema.json`.
60
- strict (bool): If True, raises an error if any artifact is missing. If False, returns None for missing artifacts silently.
61
- verbose (bool): Displays the missing artifacts in the directory or a success message.
62
- """
63
- # validate directory
64
- dir_path = make_fullpath(directory, enforce="directory")
65
-
66
- parsing_dict = _find_model_artifacts(target_directory=dir_path, load_scaler=load_scaler, verbose=False, strict=strict)
67
-
68
- self._weights_path = parsing_dict[PytorchArtifactPathKeys.WEIGHTS_PATH]
69
- self._feature_names_path = parsing_dict[PytorchArtifactPathKeys.FEATURES_PATH]
70
- self._target_names_path = parsing_dict[PytorchArtifactPathKeys.TARGETS_PATH]
71
- self._model_architecture_path = parsing_dict[PytorchArtifactPathKeys.ARCHITECTURE_PATH]
72
- self._scaler_path = None
73
- self._schema = None
74
- self._strict = strict
75
-
76
- if load_scaler:
77
- self._scaler_path = parsing_dict[PytorchArtifactPathKeys.SCALER_PATH]
78
-
79
- if load_schema:
80
- try:
81
- self._schema = FeatureSchema.from_json(directory=dir_path)
82
- except Exception:
83
- if strict:
84
- # FeatureSchema logs its own error details
85
- # _LOGGER.error(f"Failed to load FeatureSchema from '{dir_path.name}': {e}")
86
- raise FileNotFoundError()
87
- else:
88
- # _LOGGER.warning(f"Could not load FeatureSchema from '{dir_path.name}': {e}")
89
- self._schema = None
90
-
91
- # Process feature names
92
- if self._feature_names_path is not None:
93
- self._feature_names = self._process_text(self._feature_names_path)
94
- else:
95
- self._feature_names = None
96
- # Process target names
97
- if self._target_names_path is not None:
98
- self._target_names = self._process_text(self._target_names_path)
99
- else:
100
- self._target_names = None
101
-
102
- if verbose:
103
- # log missing artifacts
104
- missing_artifacts = []
105
- if self._feature_names is None:
106
- missing_artifacts.append("Feature Names")
107
- if self._target_names is None:
108
- missing_artifacts.append("Target Names")
109
- if self._weights_path is None:
110
- missing_artifacts.append("Weights File")
111
- if self._model_architecture_path is None:
112
- missing_artifacts.append("Model Architecture File")
113
- if load_scaler and self._scaler_path is None:
114
- missing_artifacts.append("Scaler File")
115
- if load_schema and self._schema is None:
116
- missing_artifacts.append("FeatureSchema File")
117
-
118
- if missing_artifacts:
119
- _LOGGER.warning(f"Missing artifacts in '{dir_path.name}': {', '.join(missing_artifacts)}.")
120
- else:
121
- _LOGGER.info(f"All artifacts successfully loaded from '{dir_path.name}'.")
122
-
123
- def _process_text(self, text_file_path: Path):
124
- list_strings = load_list_strings(text_file=text_file_path, verbose=False)
125
- return list_strings
126
-
127
- @property
128
- def feature_names(self) -> Union[list[str], None]:
129
- """Returns the feature names as a list of strings."""
130
- if self._strict and not self._feature_names:
131
- _LOGGER.error("No feature names loaded for Strict mode.")
132
- raise ValueError()
133
- return self._feature_names
134
-
135
- @property
136
- def target_names(self) -> Union[list[str], None]:
137
- """Returns the target names as a list of strings."""
138
- if self._strict and not self._target_names:
139
- _LOGGER.error("No target names loaded for Strict mode.")
140
- raise ValueError()
141
- return self._target_names
142
-
143
- @property
144
- def weights_path(self) -> Union[Path, None]:
145
- """Returns the path to the state dictionary pth file."""
146
- if self._strict and self._weights_path is None:
147
- _LOGGER.error("No weights file loaded for Strict mode.")
148
- raise ValueError()
149
- return self._weights_path
150
-
151
- @property
152
- def model_architecture_path(self) -> Union[Path, None]:
153
- """Returns the path to the model architecture json file."""
154
- if self._strict and self._model_architecture_path is None:
155
- _LOGGER.error("No model architecture file loaded for Strict mode.")
156
- raise ValueError()
157
- return self._model_architecture_path
158
-
159
- @property
160
- def scaler_path(self) -> Union[Path, None]:
161
- """Returns the path to the scaler file."""
162
- if self._strict and self._scaler_path is None:
163
- _LOGGER.error("No scaler file loaded for Strict mode.")
164
- raise ValueError()
165
- else:
166
- return self._scaler_path
167
-
168
- @property
169
- def feature_schema(self) -> Union[FeatureSchema, None]:
170
- """Returns the FeatureSchema object."""
171
- if self._strict and self._schema is None:
172
- _LOGGER.error("No FeatureSchema loaded for Strict mode.")
173
- raise ValueError()
174
- else:
175
- return self._schema
176
-
177
- def __repr__(self) -> str:
178
- dir_name = self._weights_path.parent.name if self._weights_path else "Unknown"
179
- n_features = len(self._feature_names) if self._feature_names else "None"
180
- n_targets = len(self._target_names) if self._target_names else "None"
181
- scaler_status = self._scaler_path.name if self._scaler_path else "None"
182
- schema_status = "Loaded" if self._schema else "None"
183
-
184
- return (
185
- f"{self.__class__.__name__}\n"
186
- f" directory='{dir_name}'\n"
187
- f" weights='{self._weights_path.name if self._weights_path else 'None'}'\n"
188
- f" architecture='{self._model_architecture_path.name if self._model_architecture_path else 'None'}'\n"
189
- f" scaler='{scaler_status}'\n"
190
- f" schema='{schema_status}'\n"
191
- f" features={n_features}\n"
192
- f" targets={n_targets}"
193
- )
194
-
195
-
196
- def _find_model_artifacts(target_directory: Union[str,Path], load_scaler: bool, verbose: bool=True, strict:bool=True) -> dict[str, Union[Path, None]]:
197
- """
198
- Scans a directory to find paths to model weights, target names, feature names, and model architecture. Optionally an scaler path if `load_scaler` is True.
199
-
200
- The expected directory structure is as follows:
201
-
202
- ```
203
- target_directory
204
- ├── *.pth
205
- ├── scaler_*.pth (Required if `load_scaler` is True)
206
- ├── feature_names.txt
207
- ├── target_names.txt
208
- └── architecture.json
209
- ```
210
-
211
- Args:
212
- target_directory (str | Path): The path to the directory that contains training artifacts.
213
- load_scaler (bool): If True, the function requires and searches for a scaler file `scaler_*.pth`.
214
- verbose (bool): If True, enables detailed logging during the search process.
215
- strict (bool): If True, raises errors on missing files. If False, returns None for missing files.
216
- """
217
- # validate directory
218
- dir_path = make_fullpath(target_directory, enforce="directory")
219
- dir_name = dir_path.name
220
-
221
- # find files
222
- model_pth_dict = list_files_by_extension(directory=dir_path, extension="pth", verbose=False, raise_on_empty=False)
223
-
224
- if not model_pth_dict:
225
- pth_msg=f"No '.pth' files found in directory: {dir_name}."
226
- if strict:
227
- _LOGGER.error(pth_msg)
228
- raise IOError()
229
- else:
230
- if verbose:
231
- _LOGGER.warning(pth_msg)
232
- model_pth_dict = None
233
-
234
- # restriction
235
- if model_pth_dict is not None:
236
- valid_count = False
237
- msg = ""
238
-
239
- if load_scaler:
240
- if len(model_pth_dict) == 2:
241
- valid_count = True
242
- else:
243
- msg = f"Directory '{dir_name}' should contain exactly 2 '.pth' files: scaler and weights. Found {len(model_pth_dict)}."
244
- else:
245
- if len(model_pth_dict) == 1:
246
- valid_count = True
247
- else:
248
- msg = f"Directory '{dir_name}' should contain exactly 1 '.pth' file for weights. Found {len(model_pth_dict)}."
249
-
250
- # Respect strict mode for count mismatch
251
- if not valid_count:
252
- if strict:
253
- _LOGGER.error(msg)
254
- raise IOError()
255
- else:
256
- if verbose:
257
- _LOGGER.warning(msg)
258
- # Invalidate dictionary
259
- model_pth_dict = None
260
-
261
- ##### Scaler and Weights #####
262
- scaler_path = None
263
- weights_path = None
264
-
265
- # load weights and scaler if present
266
- if model_pth_dict is not None:
267
- for pth_filename, pth_path in model_pth_dict.items():
268
- if load_scaler and pth_filename.lower().startswith(DatasetKeys.SCALER_PREFIX):
269
- scaler_path = pth_path
270
- else:
271
- weights_path = pth_path
272
-
273
- # validation
274
- if not weights_path and strict:
275
- _LOGGER.error(f"Error parsing the model weights path from '{dir_name}'")
276
- raise IOError()
277
-
278
- if strict and load_scaler and not scaler_path:
279
- _LOGGER.error(f"Error parsing the scaler path from '{dir_name}'")
280
- raise IOError()
281
-
282
- ##### Target and Feature names #####
283
- target_names_path = None
284
- feature_names_path = None
285
-
286
- # load feature and target names
287
- model_txt_dict = list_files_by_extension(directory=dir_path, extension="txt", verbose=False, raise_on_empty=False)
288
-
289
- # if the directory has no txt files, the loop is skipped
290
- for txt_filename, txt_path in model_txt_dict.items():
291
- if txt_filename == DatasetKeys.FEATURE_NAMES:
292
- feature_names_path = txt_path
293
- elif txt_filename == DatasetKeys.TARGET_NAMES:
294
- target_names_path = txt_path
295
-
296
- # validation per case
297
- if strict and not target_names_path:
298
- _LOGGER.error(f"Error parsing the target names path from '{dir_name}'")
299
- raise IOError()
300
- elif verbose and not target_names_path:
301
- _LOGGER.warning(f"Target names file not found in '{dir_name}'.")
302
-
303
- if strict and not feature_names_path:
304
- _LOGGER.error(f"Error parsing the feature names path from '{dir_name}'")
305
- raise IOError()
306
- elif verbose and not feature_names_path:
307
- _LOGGER.warning(f"Feature names file not found in '{dir_name}'.")
308
-
309
- ##### load model architecture path #####
310
- architecture_path = None
311
-
312
- model_json_dict = list_files_by_extension(directory=dir_path, extension="json", verbose=False, raise_on_empty=False)
313
-
314
- # if the directory has no json files, the loop is skipped
315
- for json_filename, json_path in model_json_dict.items():
316
- if json_filename == PytorchModelArchitectureKeys.SAVENAME:
317
- architecture_path = json_path
318
-
319
- # validation
320
- if strict and not architecture_path:
321
- _LOGGER.error(f"Error parsing the model architecture path from '{dir_name}'")
322
- raise IOError()
323
- elif verbose and not architecture_path:
324
- _LOGGER.warning(f"Model architecture file not found in '{dir_name}'.")
325
-
326
- ##### Paths dictionary #####
327
- parsing_dict = {
328
- PytorchArtifactPathKeys.WEIGHTS_PATH: weights_path,
329
- PytorchArtifactPathKeys.ARCHITECTURE_PATH: architecture_path,
330
- PytorchArtifactPathKeys.FEATURES_PATH: feature_names_path,
331
- PytorchArtifactPathKeys.TARGETS_PATH: target_names_path,
332
- }
333
-
334
- if load_scaler:
335
- parsing_dict[PytorchArtifactPathKeys.SCALER_PATH] = scaler_path
336
-
337
- return parsing_dict
338
-
339
-
340
- def find_model_artifacts_multi(target_directory: Union[str,Path], load_scaler: bool, verbose: bool=False) -> list[dict[str, Path]]:
341
- """
342
- Scans subdirectories to find paths to model weights, target names, feature names, and model architecture. Optionally an scaler path if `load_scaler` is True.
343
-
344
- This function operates on a specific directory structure. It expects the
345
- `target_directory` to contain one or more subdirectories, where each
346
- subdirectory represents a single trained model result.
347
-
348
- This function works using a strict mode, meaning that it will raise errors if
349
- any required artifacts are missing in a model's subdirectory.
350
-
351
- The expected directory structure for each model is as follows:
352
- ```
353
- target_directory
354
- ├── model_1
355
- │ ├── *.pth
356
- │ ├── scaler_*.pth (Required if `load_scaler` is True)
357
- │ ├── feature_names.txt
358
- │ ├── target_names.txt
359
- │ └── architecture.json
360
- └── model_2/
361
- └── ...
362
- ```
363
-
364
- Args:
365
- target_directory (str | Path): The path to the root directory that contains model subdirectories.
366
- load_scaler (bool): If True, the function requires and searches for a scaler file (`.pth`) in each model subdirectory.
367
- verbose (bool): If True, enables detailed logging during the file paths search process.
368
-
369
- Returns:
370
- (list[dict[str, Path]]): A list of dictionaries, where each dictionary
371
- corresponds to a model found in a subdirectory. The dictionary
372
- maps standardized keys to the absolute paths of the model's
373
- artifacts (weights, architecture, features, targets, and scaler).
374
- """
375
- # validate directory
376
- root_path = make_fullpath(target_directory, enforce="directory")
377
-
378
- # store results
379
- all_artifacts: list[dict[str, Path]] = list()
380
-
381
- # find model directories
382
- result_dirs_dict = list_subdirectories(root_dir=root_path, verbose=verbose, raise_on_empty=True)
383
- for _dir_name, dir_path in result_dirs_dict.items():
384
-
385
- parsing_dict = _find_model_artifacts(target_directory=dir_path,
386
- load_scaler=load_scaler,
387
- verbose=verbose,
388
- strict=True)
389
-
390
- # parsing_dict is guaranteed to have all required paths due to strict=True
391
- all_artifacts.append(parsing_dict) # type: ignore
392
-
393
- return all_artifacts
394
-
395
-
396
- def build_optimizer_params(model: nn.Module, weight_decay: float = 0.01) -> List[Dict[str, Any]]:
397
- """
398
- Groups model parameters to apply weight decay only to weights (matrices/embeddings),
399
- while excluding biases and normalization parameters (scales/shifts).
400
-
401
- This function uses a robust hybrid strategy:
402
- 1. It excludes parameters matching standard names (e.g., "bias", "norm").
403
- 2. It excludes any parameter with < 2 dimensions (vector parameters), which
404
- automatically catches unnamed BatchNorm/LayerNorm weights in Sequential containers.
405
-
406
- Args:
407
- model (nn.Module):
408
- The PyTorch model.
409
- weight_decay (float):
410
- The L2 regularization coefficient for the weights.
411
- (Default: 0.01)
412
-
413
- Returns:
414
- List[Dict[str, Any]]: A list of parameter groups formatted for PyTorch optimizers.
415
- - Group 0: 'params' = Weights (decay applied)
416
- - Group 1: 'params' = Biases/Norms (decay = 0.0)
417
- """
418
- # 1. Hard-coded strings for explicit safety
419
- no_decay_strings = {"bias", "LayerNorm", "BatchNorm", "GroupNorm", "norm.weight"}
420
-
421
- decay_params = []
422
- no_decay_params = []
423
-
424
- # 2. Iterate only over trainable parameters
425
- for name, param in model.named_parameters():
426
- if not param.requires_grad:
427
- continue
428
-
429
- # Check 1: Name match
430
- is_blacklisted_name = any(nd in name for nd in no_decay_strings)
431
-
432
- # Check 2: Dimensionality (Robust fallback)
433
- # Weights/Embeddings are 2D+, Biases/Norm Scales are 1D
434
- is_1d = param.ndim < 2
435
-
436
- if is_blacklisted_name or is_1d:
437
- no_decay_params.append(param)
438
- else:
439
- decay_params.append(param)
440
-
441
- _LOGGER.info(f"Weight decay configured:\n Decaying parameters: {len(decay_params)}\n Non-decaying parameters: {len(no_decay_params)}")
442
-
443
- return [
444
- {
445
- 'params': decay_params,
446
- 'weight_decay': weight_decay,
447
- },
448
- {
449
- 'params': no_decay_params,
450
- 'weight_decay': 0.0,
451
- }
452
- ]
453
-
454
-
455
- def get_model_parameters(model: nn.Module, save_dir: Optional[Union[str,Path]]=None) -> Dict[str, int]:
456
- """
457
- Calculates the total and trainable parameters of a PyTorch model.
458
-
459
- Args:
460
- model (nn.Module): The PyTorch model to inspect.
461
- save_dir: Optional directory to save the output as a JSON file.
462
-
463
- Returns:
464
- Dict[str, int]: A dictionary containing:
465
- - "total_params": The total number of parameters.
466
- - "trainable_params": The number of trainable parameters (where requires_grad=True).
467
- """
468
- total_params = sum(p.numel() for p in model.parameters())
469
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
470
-
471
- report = {
472
- UtilityKeys.TOTAL_PARAMS: total_params,
473
- UtilityKeys.TRAINABLE_PARAMS: trainable_params
474
- }
475
-
476
- if save_dir is not None:
477
- output_dir = make_fullpath(save_dir, make=True, enforce="directory")
478
- custom_logger(data=report,
479
- save_directory=output_dir,
480
- log_name=UtilityKeys.MODEL_PARAMS_FILE,
481
- add_timestamp=False,
482
- dict_as="json")
483
-
484
- return report
485
-
486
-
487
- def inspect_model_architecture(
488
- model: nn.Module,
489
- save_dir: Union[str, Path]
490
- ) -> None:
491
- """
492
- Saves a human-readable text summary of a model's instantiated
493
- architecture, including parameter counts.
494
-
495
- Args:
496
- model (nn.Module): The PyTorch model to inspect.
497
- save_dir (str | Path): Directory to save the text file.
498
- """
499
- # --- 1. Validate path ---
500
- output_dir = make_fullpath(save_dir, make=True, enforce="directory")
501
- architecture_filename = UtilityKeys.MODEL_ARCHITECTURE_FILE + ".txt"
502
- filepath = output_dir / architecture_filename
503
-
504
- # --- 2. Get parameter counts from existing function ---
505
- try:
506
- params_report = get_model_parameters(model) # Get dict, don't save
507
- total = params_report.get(UtilityKeys.TOTAL_PARAMS, 'N/A')
508
- trainable = params_report.get(UtilityKeys.TRAINABLE_PARAMS, 'N/A')
509
- header = (
510
- f"Model: {model.__class__.__name__}\n"
511
- f"Total Parameters: {total:,}\n"
512
- f"Trainable Parameters: {trainable:,}\n"
513
- f"{'='*80}\n\n"
514
- )
515
- except Exception as e:
516
- _LOGGER.warning(f"Could not get model parameters: {e}")
517
- header = f"Model: {model.__class__.__name__}\n{'='*80}\n\n"
518
-
519
- # --- 3. Get architecture string ---
520
- arch_string = str(model)
521
-
522
- # --- 4. Write to file ---
523
- try:
524
- with open(filepath, 'w', encoding='utf-8') as f:
525
- f.write(header)
526
- f.write(arch_string)
527
- _LOGGER.info(f"Model architecture summary saved to '{filepath.name}'")
528
- except Exception as e:
529
- _LOGGER.error(f"Failed to write model architecture file: {e}")
530
- raise
531
-
532
-
533
- def inspect_pth_file(
534
- pth_path: Union[str, Path],
535
- save_dir: Union[str, Path],
536
- ) -> None:
537
- """
538
- Inspects a .pth file (e.g., checkpoint) and saves a human-readable
539
- JSON summary of its contents.
540
-
541
- Args:
542
- pth_path (str | Path): The path to the .pth file to inspect.
543
- save_dir (str | Path): The directory to save the JSON report.
544
-
545
- Returns:
546
- Dict (str, Any): A dictionary containing the inspection report.
547
-
548
- Raises:
549
- ValueError: If the .pth file is empty or in an unrecognized format.
550
- """
551
- # --- 1. Validate paths ---
552
- pth_file = make_fullpath(pth_path, enforce="file")
553
- output_dir = make_fullpath(save_dir, make=True, enforce="directory")
554
- pth_name = pth_file.stem
555
-
556
- # --- 2. Load data ---
557
- try:
558
- # Load onto CPU to avoid GPU memory issues
559
- loaded_data = torch.load(pth_file, map_location=torch.device('cpu'))
560
- except Exception as e:
561
- _LOGGER.error(f"Failed to load .pth file '{pth_file}': {e}")
562
- raise
563
-
564
- # --- 3. Initialize Report ---
565
- report = {
566
- "top_level_type": str(type(loaded_data)),
567
- "top_level_summary": {},
568
- "model_state_analysis": None,
569
- "notes": []
570
- }
571
-
572
- # --- 4. Parse loaded data ---
573
- if isinstance(loaded_data, dict):
574
- # --- Case 1: Loaded data is a dictionary (most common case) ---
575
- # "main loop" that iterates over *everything* first.
576
- for key, value in loaded_data.items():
577
- key_summary = {}
578
- val_type = str(type(value))
579
- key_summary["type"] = val_type
580
-
581
- if isinstance(value, torch.Tensor):
582
- key_summary["shape"] = list(value.shape)
583
- key_summary["dtype"] = str(value.dtype)
584
- elif isinstance(value, dict):
585
- key_summary["key_count"] = len(value)
586
- key_summary["key_preview"] = list(value.keys())[:5]
587
- elif isinstance(value, (int, float, str, bool)):
588
- key_summary["value_preview"] = str(value)
589
- elif isinstance(value, (list, tuple)):
590
- key_summary["value_preview"] = str(value)[:100]
591
-
592
- report["top_level_summary"][key] = key_summary
593
-
594
- # Now, try to find the model state_dict within the dict
595
- if PyTorchCheckpointKeys.MODEL_STATE in loaded_data and isinstance(loaded_data[PyTorchCheckpointKeys.MODEL_STATE], dict):
596
- report["notes"].append(f"Found standard checkpoint key: '{PyTorchCheckpointKeys.MODEL_STATE}'. Analyzing as model state_dict.")
597
- state_dict = loaded_data[PyTorchCheckpointKeys.MODEL_STATE]
598
- report["model_state_analysis"] = _generate_weight_report(state_dict)
599
-
600
- elif all(isinstance(v, torch.Tensor) for v in loaded_data.values()):
601
- report["notes"].append("File dictionary contains only tensors. Analyzing entire dictionary as model state_dict.")
602
- state_dict = loaded_data
603
- report["model_state_analysis"] = _generate_weight_report(state_dict)
604
-
605
- else:
606
- report["notes"].append("Could not identify a single model state_dict. See top_level_summary for all contents. No detailed weight analysis will be performed.")
607
-
608
- elif isinstance(loaded_data, nn.Module):
609
- # --- Case 2: Loaded data is a full pickled model ---
610
- # _LOGGER.warning("Loading a full, pickled nn.Module is not recommended. Inspecting its state_dict().")
611
- report["notes"].append("File is a full, pickled nn.Module. This is not recommended. Extracting state_dict() for analysis.")
612
- state_dict = loaded_data.state_dict()
613
- report["model_state_analysis"] = _generate_weight_report(state_dict)
614
-
615
- else:
616
- # --- Case 3: Unrecognized format (e.g., single tensor, list) ---
617
- _LOGGER.error(f"Could not parse .pth file. Loaded data is of type {type(loaded_data)}, not a dict or nn.Module.")
618
- raise ValueError()
619
-
620
- # --- 5. Save Report ---
621
- custom_logger(data=report,
622
- save_directory=output_dir,
623
- log_name=UtilityKeys.PTH_FILE + pth_name,
624
- add_timestamp=False,
625
- dict_as="json")
626
-
627
-
628
- def _generate_weight_report(state_dict: dict) -> dict:
629
- """
630
- Internal helper to analyze a state_dict and return a structured report.
631
-
632
- Args:
633
- state_dict (dict): The model state_dict to analyze.
634
-
635
- Returns:
636
- dict: A report containing total parameters and a per-parameter breakdown.
637
- """
638
- weight_report = {}
639
- total_params = 0
640
- if not isinstance(state_dict, dict):
641
- _LOGGER.warning(f"Attempted to generate weight report on non-dict type: {type(state_dict)}")
642
- return {"error": "Input was not a dictionary."}
643
-
644
- for key, tensor in state_dict.items():
645
- if not isinstance(tensor, torch.Tensor):
646
- _LOGGER.warning(f"Skipping key '{key}' in state_dict: value is not a tensor (type: {type(tensor)}).")
647
- weight_report[key] = {
648
- "type": str(type(tensor)),
649
- "value_preview": str(tensor)[:50] # Show a preview
650
- }
651
- continue
652
- weight_report[key] = {
653
- "shape": list(tensor.shape),
654
- "dtype": str(tensor.dtype),
655
- "requires_grad": tensor.requires_grad,
656
- "num_elements": tensor.numel()
657
- }
658
- total_params += tensor.numel()
659
-
660
- return {
661
- "total_parameters": total_params,
662
- "parameter_key_count": len(weight_report),
663
- "parameters": weight_report
664
- }
665
-
666
-
667
- def set_parameter_requires_grad(
668
- model: nn.Module,
669
- unfreeze_last_n_params: int,
670
- ) -> int:
671
- """
672
- Freezes or unfreezes parameters in a model based on unfreeze_last_n_params.
673
-
674
- - N = 0: Freezes ALL parameters.
675
- - N > 0 and N < total: Freezes ALL parameters, then unfreezes the last N.
676
- - N >= total: Unfreezes ALL parameters.
677
-
678
- Note: 'N' refers to individual parameter tensors (e.g., `layer.weight`
679
- or `layer.bias`), not modules or layers. For example, to unfreeze
680
- the final nn.Linear layer, you would use N=2 (for its weight and bias).
681
-
682
- Args:
683
- model (nn.Module): The model to modify.
684
- unfreeze_last_n_params (int):
685
- The number of parameter tensors to unfreeze, starting from
686
- the end of the model.
687
-
688
- Returns:
689
- int: The total number of individual parameters (elements) that were set to `requires_grad=True`.
690
- """
691
- if unfreeze_last_n_params < 0:
692
- _LOGGER.error(f"unfreeze_last_n_params must be >= 0, but got {unfreeze_last_n_params}")
693
- raise ValueError()
694
-
695
- # --- Step 1: Get all parameter tensors ---
696
- all_params = list(model.parameters())
697
- total_param_tensors = len(all_params)
698
-
699
- # --- Case 1: N = 0 (Freeze ALL parameters) ---
700
- # early exit for the "freeze all" case.
701
- if unfreeze_last_n_params == 0:
702
- params_frozen = _set_params_grad(all_params, requires_grad=False)
703
- _LOGGER.warning(f"Froze all {total_param_tensors} parameter tensors ({params_frozen} total elements).")
704
- return 0 # 0 parameters unfrozen
705
-
706
- # --- Case 2: N >= total (Unfreeze ALL parameters) ---
707
- if unfreeze_last_n_params >= total_param_tensors:
708
- if unfreeze_last_n_params > total_param_tensors:
709
- _LOGGER.warning(f"Requested to unfreeze {unfreeze_last_n_params} params, but model only has {total_param_tensors}. Unfreezing all.")
710
-
711
- params_unfrozen = _set_params_grad(all_params, requires_grad=True)
712
- _LOGGER.info(f"Unfroze all {total_param_tensors} parameter tensors ({params_unfrozen} total elements) for training.")
713
- return params_unfrozen
714
-
715
- # --- Case 3: 0 < N < total (Standard: Freeze all, unfreeze last N) ---
716
- # Freeze ALL
717
- params_frozen = _set_params_grad(all_params, requires_grad=False)
718
- _LOGGER.info(f"Froze {params_frozen} parameters.")
719
-
720
- # Unfreeze the last N
721
- params_to_unfreeze = all_params[-unfreeze_last_n_params:]
722
-
723
- # these are all False, so the helper will set them to True
724
- params_unfrozen = _set_params_grad(params_to_unfreeze, requires_grad=True)
725
-
726
- _LOGGER.info(f"Unfroze the last {unfreeze_last_n_params} parameter tensors ({params_unfrozen} total elements) for training.")
727
-
728
- return params_unfrozen
729
-
730
-
731
- def _set_params_grad(
732
- params: Iterable[nn.Parameter],
733
- requires_grad: bool
734
- ) -> int:
735
- """
736
- A helper function to set the `requires_grad` attribute for an iterable
737
- of parameters and return the total number of elements changed.
738
- """
739
- params_changed = 0
740
- for param in params:
741
- if param.requires_grad != requires_grad:
742
- param.requires_grad = requires_grad
743
- params_changed += param.numel()
744
- return params_changed
745
-
746
-
747
- def save_pretrained_transforms(model: nn.Module, output_dir: Union[str, Path]):
748
- """
749
- Checks a model for the 'self._pretrained_default_transforms' attribute, if found,
750
- serializes the returned transform object as a .joblib file.
751
-
752
- Used for wrapper vision models when initialized with pre-trained weights.
753
-
754
- This saves the callable transform object itself for
755
- later use, such as passing it directly to the 'transform_source'
756
- argument of the PyTorchVisionInferenceHandler.
757
-
758
- Args:
759
- model (nn.Module): The model instance to check.
760
- output_dir (str | Path): The directory where the transform file will be saved.
761
- """
762
- output_filename = "pretrained_model_transformations"
763
-
764
- # 1. Check for the "secret attribute"
765
- if not hasattr(model, '_pretrained_default_transforms'):
766
- _LOGGER.warning(f"Model of type {type(model).__name__} does not have the required attribute. No transformations saved.")
767
- return
768
-
769
- # 2. Get the transform object
770
- try:
771
- transform_obj = model._pretrained_default_transforms
772
- except Exception as e:
773
- _LOGGER.error(f"Error calling the required attribute on model: {e}")
774
- return
775
-
776
- # 3. Check if the object is actually there
777
- if transform_obj is None:
778
- _LOGGER.warning(f"Model {type(model).__name__} has the required attribute but returned None. No transforms saved.")
779
- return
780
-
781
- # 4. Serialize and save using serde
782
- try:
783
- serialize_object_filename(
784
- obj=transform_obj,
785
- save_dir=output_dir,
786
- filename=output_filename,
787
- verbose=True,
788
- raise_on_error=True
789
- )
790
- # _LOGGER.info(f"Successfully saved pretrained transforms to '{output_dir}'.")
791
- except Exception as e:
792
- _LOGGER.error(f"Failed to serialize transformations: {e}")
793
- raise
794
-
795
-
796
- def select_features_by_shap(
797
- root_directory: Union[str, Path],
798
- shap_threshold: float,
799
- log_feature_names_directory: Optional[Union[str, Path]],
800
- verbose: bool = True) -> list[str]:
801
- """
802
- Scans subdirectories to find SHAP summary CSVs, then extracts feature
803
- names whose mean absolute SHAP value meets a specified threshold.
804
-
805
- This function is useful for automated feature selection based on feature
806
- importance scores aggregated from multiple models.
807
-
808
- Args:
809
- root_directory (str | Path):
810
- The path to the root directory that contains model subdirectories.
811
- shap_threshold (float):
812
- The minimum mean absolute SHAP value for a feature to be included
813
- in the final list.
814
- log_feature_names_directory (str | Path | None):
815
- If given, saves the chosen feature names as a .txt file in this directory.
816
-
817
- Returns:
818
- list[str]:
819
- A single, sorted list of unique feature names that meet the
820
- threshold criteria across all found files.
821
- """
822
- if verbose:
823
- _LOGGER.info(f"Starting feature selection with SHAP threshold >= {shap_threshold}")
824
- root_path = make_fullpath(root_directory, enforce="directory")
825
-
826
- # --- Step 2: Directory and File Discovery ---
827
- subdirectories = list_subdirectories(root_dir=root_path, verbose=False, raise_on_empty=True)
828
-
829
- shap_filename = SHAPKeys.SAVENAME + ".csv"
830
-
831
- valid_csv_paths = []
832
- for dir_name, dir_path in subdirectories.items():
833
- expected_path = dir_path / shap_filename
834
- if expected_path.is_file():
835
- valid_csv_paths.append(expected_path)
836
- else:
837
- _LOGGER.warning(f"No '{shap_filename}' found in subdirectory '{dir_name}'.")
838
-
839
- if not valid_csv_paths:
840
- _LOGGER.error(f"Process halted: No '{shap_filename}' files were found in any subdirectory.")
841
- return []
842
-
843
- if verbose:
844
- _LOGGER.info(f"Found {len(valid_csv_paths)} SHAP summary files to process.")
845
-
846
- # --- Step 3: Data Processing and Feature Extraction ---
847
- master_feature_set = set()
848
- for csv_path in valid_csv_paths:
849
- try:
850
- df, _ = load_dataframe(csv_path, kind="pandas", verbose=False)
851
-
852
- # Validate required columns
853
- required_cols = {SHAPKeys.FEATURE_COLUMN, SHAPKeys.SHAP_VALUE_COLUMN}
854
- if not required_cols.issubset(df.columns):
855
- _LOGGER.warning(f"Skipping '{csv_path}': missing required columns.")
856
- continue
857
-
858
- # Filter by threshold and extract features
859
- filtered_df = df[df[SHAPKeys.SHAP_VALUE_COLUMN] >= shap_threshold]
860
- features = filtered_df[SHAPKeys.FEATURE_COLUMN].tolist()
861
- master_feature_set.update(features)
862
-
863
- except (ValueError, pd.errors.EmptyDataError):
864
- _LOGGER.warning(f"Skipping '{csv_path}' because it is empty or malformed.")
865
- continue
866
- except Exception as e:
867
- _LOGGER.error(f"An unexpected error occurred while processing '{csv_path}': {e}")
868
- continue
869
-
870
- # --- Step 4: Finalize and Return ---
871
- final_features = sorted(list(master_feature_set))
872
- if verbose:
873
- _LOGGER.info(f"Selected {len(final_features)} unique features across all files.")
874
-
875
- if log_feature_names_directory is not None:
876
- save_names_path = make_fullpath(log_feature_names_directory, make=True, enforce="directory")
877
- save_list_strings(list_strings=final_features,
878
- directory=save_names_path,
879
- filename=DatasetKeys.FEATURE_NAMES,
880
- verbose=verbose)
881
-
882
- return final_features
883
-
884
-
885
- def info():
886
- _script_info(__all__)