dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,325 @@
1
+ import pandas as pd
2
+ from pathlib import Path
3
+ from typing import Union, Optional
4
+ import torch
5
+ from torch import nn
6
+
7
+ from ..utilities import load_dataframe
8
+ from ..IO_tools import save_list_strings, custom_logger
9
+
10
+ from ..path_manager import make_fullpath, list_subdirectories
11
+ from .._core import get_logger
12
+ from ..keys._keys import DatasetKeys, SHAPKeys, UtilityKeys, PyTorchCheckpointKeys
13
+
14
+
15
+ _LOGGER = get_logger("ML Inspection")
16
+
17
+
18
+ __all__ = [
19
+ "get_model_parameters",
20
+ "inspect_model_architecture",
21
+ "inspect_pth_file",
22
+ "select_features_by_shap"
23
+ ]
24
+
25
+
26
+ def get_model_parameters(model: nn.Module, save_dir: Optional[Union[str,Path]]=None) -> dict[str, int]:
27
+ """
28
+ Calculates the total and trainable parameters of a PyTorch model.
29
+
30
+ Args:
31
+ model (nn.Module): The PyTorch model to inspect.
32
+ save_dir: Optional directory to save the output as a JSON file.
33
+
34
+ Returns:
35
+ Dict[str, int]: A dictionary containing:
36
+ - "total_params": The total number of parameters.
37
+ - "trainable_params": The number of trainable parameters (where requires_grad=True).
38
+ """
39
+ total_params = sum(p.numel() for p in model.parameters())
40
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
41
+
42
+ report = {
43
+ UtilityKeys.TOTAL_PARAMS: total_params,
44
+ UtilityKeys.TRAINABLE_PARAMS: trainable_params
45
+ }
46
+
47
+ if save_dir is not None:
48
+ output_dir = make_fullpath(save_dir, make=True, enforce="directory")
49
+ custom_logger(data=report,
50
+ save_directory=output_dir,
51
+ log_name=UtilityKeys.MODEL_PARAMS_FILE,
52
+ add_timestamp=False,
53
+ dict_as="json")
54
+
55
+ return report
56
+
57
+
58
+ def inspect_model_architecture(
59
+ model: nn.Module,
60
+ save_dir: Union[str, Path]
61
+ ) -> None:
62
+ """
63
+ Saves a human-readable text summary of a model's instantiated
64
+ architecture, including parameter counts.
65
+
66
+ Args:
67
+ model (nn.Module): The PyTorch model to inspect.
68
+ save_dir (str | Path): Directory to save the text file.
69
+ """
70
+ # --- 1. Validate path ---
71
+ output_dir = make_fullpath(save_dir, make=True, enforce="directory")
72
+ architecture_filename = UtilityKeys.MODEL_ARCHITECTURE_FILE + ".txt"
73
+ filepath = output_dir / architecture_filename
74
+
75
+ # --- 2. Get parameter counts from existing function ---
76
+ try:
77
+ params_report = get_model_parameters(model) # Get dict, don't save
78
+ total = params_report.get(UtilityKeys.TOTAL_PARAMS, 'N/A')
79
+ trainable = params_report.get(UtilityKeys.TRAINABLE_PARAMS, 'N/A')
80
+ header = (
81
+ f"Model: {model.__class__.__name__}\n"
82
+ f"Total Parameters: {total:,}\n"
83
+ f"Trainable Parameters: {trainable:,}\n"
84
+ f"{'='*80}\n\n"
85
+ )
86
+ except Exception as e:
87
+ _LOGGER.warning(f"Could not get model parameters: {e}")
88
+ header = f"Model: {model.__class__.__name__}\n{'='*80}\n\n"
89
+
90
+ # --- 3. Get architecture string ---
91
+ arch_string = str(model)
92
+
93
+ # --- 4. Write to file ---
94
+ try:
95
+ with open(filepath, 'w', encoding='utf-8') as f:
96
+ f.write(header)
97
+ f.write(arch_string)
98
+ _LOGGER.info(f"Model architecture summary saved to '{filepath.name}'")
99
+ except Exception as e:
100
+ _LOGGER.error(f"Failed to write model architecture file: {e}")
101
+ raise
102
+
103
+
104
+ def inspect_pth_file(
105
+ pth_path: Union[str, Path],
106
+ save_dir: Union[str, Path],
107
+ ) -> None:
108
+ """
109
+ Inspects a .pth file (e.g., checkpoint) and saves a human-readable
110
+ JSON summary of its contents.
111
+
112
+ Args:
113
+ pth_path (str | Path): The path to the .pth file to inspect.
114
+ save_dir (str | Path): The directory to save the JSON report.
115
+
116
+ Returns:
117
+ Dict (str, Any): A dictionary containing the inspection report.
118
+
119
+ Raises:
120
+ ValueError: If the .pth file is empty or in an unrecognized format.
121
+ """
122
+ # --- 1. Validate paths ---
123
+ pth_file = make_fullpath(pth_path, enforce="file")
124
+ output_dir = make_fullpath(save_dir, make=True, enforce="directory")
125
+ pth_name = pth_file.stem
126
+
127
+ # --- 2. Load data ---
128
+ try:
129
+ # Load onto CPU to avoid GPU memory issues
130
+ loaded_data = torch.load(pth_file, map_location=torch.device('cpu'))
131
+ except Exception as e:
132
+ _LOGGER.error(f"Failed to load .pth file '{pth_file}': {e}")
133
+ raise
134
+
135
+ # --- 3. Initialize Report ---
136
+ report = {
137
+ "top_level_type": str(type(loaded_data)),
138
+ "top_level_summary": {},
139
+ "model_state_analysis": None,
140
+ "notes": []
141
+ }
142
+
143
+ # --- 4. Parse loaded data ---
144
+ if isinstance(loaded_data, dict):
145
+ # --- Case 1: Loaded data is a dictionary (most common case) ---
146
+ # "main loop" that iterates over *everything* first.
147
+ for key, value in loaded_data.items():
148
+ key_summary = {}
149
+ val_type = str(type(value))
150
+ key_summary["type"] = val_type
151
+
152
+ if isinstance(value, torch.Tensor):
153
+ key_summary["shape"] = list(value.shape)
154
+ key_summary["dtype"] = str(value.dtype)
155
+ elif isinstance(value, dict):
156
+ key_summary["key_count"] = len(value)
157
+ key_summary["key_preview"] = list(value.keys())[:5]
158
+ elif isinstance(value, (int, float, str, bool)):
159
+ key_summary["value_preview"] = str(value)
160
+ elif isinstance(value, (list, tuple)):
161
+ key_summary["value_preview"] = str(value)[:100]
162
+
163
+ report["top_level_summary"][key] = key_summary
164
+
165
+ # Now, try to find the model state_dict within the dict
166
+ if PyTorchCheckpointKeys.MODEL_STATE in loaded_data and isinstance(loaded_data[PyTorchCheckpointKeys.MODEL_STATE], dict):
167
+ report["notes"].append(f"Found standard checkpoint key: '{PyTorchCheckpointKeys.MODEL_STATE}'. Analyzing as model state_dict.")
168
+ state_dict = loaded_data[PyTorchCheckpointKeys.MODEL_STATE]
169
+ report["model_state_analysis"] = _generate_weight_report(state_dict)
170
+
171
+ elif all(isinstance(v, torch.Tensor) for v in loaded_data.values()):
172
+ report["notes"].append("File dictionary contains only tensors. Analyzing entire dictionary as model state_dict.")
173
+ state_dict = loaded_data
174
+ report["model_state_analysis"] = _generate_weight_report(state_dict)
175
+
176
+ else:
177
+ report["notes"].append("Could not identify a single model state_dict. See top_level_summary for all contents. No detailed weight analysis will be performed.")
178
+
179
+ elif isinstance(loaded_data, nn.Module):
180
+ # --- Case 2: Loaded data is a full pickled model ---
181
+ # _LOGGER.warning("Loading a full, pickled nn.Module is not recommended. Inspecting its state_dict().")
182
+ report["notes"].append("File is a full, pickled nn.Module. This is not recommended. Extracting state_dict() for analysis.")
183
+ state_dict = loaded_data.state_dict()
184
+ report["model_state_analysis"] = _generate_weight_report(state_dict)
185
+
186
+ else:
187
+ # --- Case 3: Unrecognized format (e.g., single tensor, list) ---
188
+ _LOGGER.error(f"Could not parse .pth file. Loaded data is of type {type(loaded_data)}, not a dict or nn.Module.")
189
+ raise ValueError()
190
+
191
+ # --- 5. Save Report ---
192
+ custom_logger(data=report,
193
+ save_directory=output_dir,
194
+ log_name=UtilityKeys.PTH_FILE + pth_name,
195
+ add_timestamp=False,
196
+ dict_as="json")
197
+
198
+
199
+ def _generate_weight_report(state_dict: dict) -> dict:
200
+ """
201
+ Internal helper to analyze a state_dict and return a structured report.
202
+
203
+ Args:
204
+ state_dict (dict): The model state_dict to analyze.
205
+
206
+ Returns:
207
+ dict: A report containing total parameters and a per-parameter breakdown.
208
+ """
209
+ weight_report = {}
210
+ total_params = 0
211
+ if not isinstance(state_dict, dict):
212
+ _LOGGER.warning(f"Attempted to generate weight report on non-dict type: {type(state_dict)}")
213
+ return {"error": "Input was not a dictionary."}
214
+
215
+ for key, tensor in state_dict.items():
216
+ if not isinstance(tensor, torch.Tensor):
217
+ _LOGGER.warning(f"Skipping key '{key}' in state_dict: value is not a tensor (type: {type(tensor)}).")
218
+ weight_report[key] = {
219
+ "type": str(type(tensor)),
220
+ "value_preview": str(tensor)[:50] # Show a preview
221
+ }
222
+ continue
223
+ weight_report[key] = {
224
+ "shape": list(tensor.shape),
225
+ "dtype": str(tensor.dtype),
226
+ "requires_grad": tensor.requires_grad,
227
+ "num_elements": tensor.numel()
228
+ }
229
+ total_params += tensor.numel()
230
+
231
+ return {
232
+ "total_parameters": total_params,
233
+ "parameter_key_count": len(weight_report),
234
+ "parameters": weight_report
235
+ }
236
+
237
+
238
+ def select_features_by_shap(
239
+ root_directory: Union[str, Path],
240
+ shap_threshold: float,
241
+ log_feature_names_directory: Optional[Union[str, Path]],
242
+ verbose: bool = True) -> list[str]:
243
+ """
244
+ Scans subdirectories to find SHAP summary CSVs, then extracts feature
245
+ names whose mean absolute SHAP value meets a specified threshold.
246
+
247
+ This function is useful for automated feature selection based on feature
248
+ importance scores aggregated from multiple models.
249
+
250
+ Args:
251
+ root_directory (str | Path):
252
+ The path to the root directory that contains model subdirectories.
253
+ shap_threshold (float):
254
+ The minimum mean absolute SHAP value for a feature to be included
255
+ in the final list.
256
+ log_feature_names_directory (str | Path | None):
257
+ If given, saves the chosen feature names as a .txt file in this directory.
258
+
259
+ Returns:
260
+ list[str]:
261
+ A single, sorted list of unique feature names that meet the
262
+ threshold criteria across all found files.
263
+ """
264
+ if verbose:
265
+ _LOGGER.info(f"Starting feature selection with SHAP threshold >= {shap_threshold}")
266
+ root_path = make_fullpath(root_directory, enforce="directory")
267
+
268
+ # --- Step 2: Directory and File Discovery ---
269
+ subdirectories = list_subdirectories(root_dir=root_path, verbose=False, raise_on_empty=True)
270
+
271
+ shap_filename = SHAPKeys.SAVENAME + ".csv"
272
+
273
+ valid_csv_paths = []
274
+ for dir_name, dir_path in subdirectories.items():
275
+ expected_path = dir_path / shap_filename
276
+ if expected_path.is_file():
277
+ valid_csv_paths.append(expected_path)
278
+ else:
279
+ _LOGGER.warning(f"No '{shap_filename}' found in subdirectory '{dir_name}'.")
280
+
281
+ if not valid_csv_paths:
282
+ _LOGGER.error(f"Process halted: No '{shap_filename}' files were found in any subdirectory.")
283
+ return []
284
+
285
+ if verbose:
286
+ _LOGGER.info(f"Found {len(valid_csv_paths)} SHAP summary files to process.")
287
+
288
+ # --- Step 3: Data Processing and Feature Extraction ---
289
+ master_feature_set = set()
290
+ for csv_path in valid_csv_paths:
291
+ try:
292
+ df, _ = load_dataframe(csv_path, kind="pandas", verbose=False)
293
+
294
+ # Validate required columns
295
+ required_cols = {SHAPKeys.FEATURE_COLUMN, SHAPKeys.SHAP_VALUE_COLUMN}
296
+ if not required_cols.issubset(df.columns):
297
+ _LOGGER.warning(f"Skipping '{csv_path}': missing required columns.")
298
+ continue
299
+
300
+ # Filter by threshold and extract features
301
+ filtered_df = df[df[SHAPKeys.SHAP_VALUE_COLUMN] >= shap_threshold]
302
+ features = filtered_df[SHAPKeys.FEATURE_COLUMN].tolist()
303
+ master_feature_set.update(features)
304
+
305
+ except (ValueError, pd.errors.EmptyDataError):
306
+ _LOGGER.warning(f"Skipping '{csv_path}' because it is empty or malformed.")
307
+ continue
308
+ except Exception as e:
309
+ _LOGGER.error(f"An unexpected error occurred while processing '{csv_path}': {e}")
310
+ continue
311
+
312
+ # --- Step 4: Finalize and Return ---
313
+ final_features = sorted(list(master_feature_set))
314
+ if verbose:
315
+ _LOGGER.info(f"Selected {len(final_features)} unique features across all files.")
316
+
317
+ if log_feature_names_directory is not None:
318
+ save_names_path = make_fullpath(log_feature_names_directory, make=True, enforce="directory")
319
+ save_list_strings(list_strings=final_features,
320
+ directory=save_names_path,
321
+ filename=DatasetKeys.FEATURE_NAMES,
322
+ verbose=verbose)
323
+
324
+ return final_features
325
+
@@ -0,0 +1,205 @@
1
+ from pathlib import Path
2
+ from typing import Union, Any, Iterable
3
+ from torch import nn
4
+
5
+ from ..serde import serialize_object_filename
6
+
7
+ from .._core import get_logger
8
+
9
+
10
+ _LOGGER = get_logger("Torch Utilities")
11
+
12
+
13
+ __all__ = [
14
+ "build_optimizer_params",
15
+ "set_parameter_requires_grad",
16
+ "save_pretrained_transforms",
17
+ ]
18
+
19
+
20
+ def build_optimizer_params(model: nn.Module, weight_decay: float = 0.01) -> list[dict[str, Any]]:
21
+ """
22
+ Groups model parameters to apply weight decay only to weights (matrices/embeddings),
23
+ while excluding biases and normalization parameters (scales/shifts).
24
+
25
+ This function uses a robust hybrid strategy:
26
+ 1. It excludes parameters matching standard names (e.g., "bias", "norm").
27
+ 2. It excludes any parameter with < 2 dimensions (vector parameters), which
28
+ automatically catches unnamed BatchNorm/LayerNorm weights in Sequential containers.
29
+
30
+ Args:
31
+ model (nn.Module):
32
+ The PyTorch model.
33
+ weight_decay (float):
34
+ The L2 regularization coefficient for the weights.
35
+ (Default: 0.01)
36
+
37
+ Returns:
38
+ List[Dict[str, Any]]: A list of parameter groups formatted for PyTorch optimizers.
39
+ - Group 0: 'params' = Weights (decay applied)
40
+ - Group 1: 'params' = Biases/Norms (decay = 0.0)
41
+ """
42
+ # 1. Hard-coded strings for explicit safety
43
+ no_decay_strings = {"bias", "LayerNorm", "BatchNorm", "GroupNorm", "norm.weight"}
44
+
45
+ decay_params = []
46
+ no_decay_params = []
47
+
48
+ # 2. Iterate only over trainable parameters
49
+ for name, param in model.named_parameters():
50
+ if not param.requires_grad:
51
+ continue
52
+
53
+ # Check 1: Name match
54
+ is_blacklisted_name = any(nd in name for nd in no_decay_strings)
55
+
56
+ # Check 2: Dimensionality (Robust fallback)
57
+ # Weights/Embeddings are 2D+, Biases/Norm Scales are 1D
58
+ is_1d = param.ndim < 2
59
+
60
+ if is_blacklisted_name or is_1d:
61
+ no_decay_params.append(param)
62
+ else:
63
+ decay_params.append(param)
64
+
65
+ _LOGGER.info(f"Weight decay configured:\n Decaying parameters: {len(decay_params)}\n Non-decaying parameters: {len(no_decay_params)}")
66
+
67
+ return [
68
+ {
69
+ 'params': decay_params,
70
+ 'weight_decay': weight_decay,
71
+ },
72
+ {
73
+ 'params': no_decay_params,
74
+ 'weight_decay': 0.0,
75
+ }
76
+ ]
77
+
78
+
79
+ def set_parameter_requires_grad(
80
+ model: nn.Module,
81
+ unfreeze_last_n_params: int,
82
+ ) -> int:
83
+ """
84
+ Freezes or unfreezes parameters in a model based on unfreeze_last_n_params.
85
+
86
+ - N = 0: Freezes ALL parameters.
87
+ - N > 0 and N < total: Freezes ALL parameters, then unfreezes the last N.
88
+ - N >= total: Unfreezes ALL parameters.
89
+
90
+ Note: 'N' refers to individual parameter tensors (e.g., `layer.weight`
91
+ or `layer.bias`), not modules or layers. For example, to unfreeze
92
+ the final nn.Linear layer, you would use N=2 (for its weight and bias).
93
+
94
+ Args:
95
+ model (nn.Module): The model to modify.
96
+ unfreeze_last_n_params (int):
97
+ The number of parameter tensors to unfreeze, starting from
98
+ the end of the model.
99
+
100
+ Returns:
101
+ int: The total number of individual parameters (elements) that were set to `requires_grad=True`.
102
+ """
103
+ if unfreeze_last_n_params < 0:
104
+ _LOGGER.error(f"unfreeze_last_n_params must be >= 0, but got {unfreeze_last_n_params}")
105
+ raise ValueError()
106
+
107
+ # --- Step 1: Get all parameter tensors ---
108
+ all_params = list(model.parameters())
109
+ total_param_tensors = len(all_params)
110
+
111
+ # --- Case 1: N = 0 (Freeze ALL parameters) ---
112
+ # early exit for the "freeze all" case.
113
+ if unfreeze_last_n_params == 0:
114
+ params_frozen = _set_params_grad(all_params, requires_grad=False)
115
+ _LOGGER.warning(f"Froze all {total_param_tensors} parameter tensors ({params_frozen} total elements).")
116
+ return 0 # 0 parameters unfrozen
117
+
118
+ # --- Case 2: N >= total (Unfreeze ALL parameters) ---
119
+ if unfreeze_last_n_params >= total_param_tensors:
120
+ if unfreeze_last_n_params > total_param_tensors:
121
+ _LOGGER.warning(f"Requested to unfreeze {unfreeze_last_n_params} params, but model only has {total_param_tensors}. Unfreezing all.")
122
+
123
+ params_unfrozen = _set_params_grad(all_params, requires_grad=True)
124
+ _LOGGER.info(f"Unfroze all {total_param_tensors} parameter tensors ({params_unfrozen} total elements) for training.")
125
+ return params_unfrozen
126
+
127
+ # --- Case 3: 0 < N < total (Standard: Freeze all, unfreeze last N) ---
128
+ # Freeze ALL
129
+ params_frozen = _set_params_grad(all_params, requires_grad=False)
130
+ _LOGGER.info(f"Froze {params_frozen} parameters.")
131
+
132
+ # Unfreeze the last N
133
+ params_to_unfreeze = all_params[-unfreeze_last_n_params:]
134
+
135
+ # these are all False, so the helper will set them to True
136
+ params_unfrozen = _set_params_grad(params_to_unfreeze, requires_grad=True)
137
+
138
+ _LOGGER.info(f"Unfroze the last {unfreeze_last_n_params} parameter tensors ({params_unfrozen} total elements) for training.")
139
+
140
+ return params_unfrozen
141
+
142
+
143
+ def _set_params_grad(
144
+ params: Iterable[nn.Parameter],
145
+ requires_grad: bool
146
+ ) -> int:
147
+ """
148
+ A helper function to set the `requires_grad` attribute for an iterable
149
+ of parameters and return the total number of elements changed.
150
+ """
151
+ params_changed = 0
152
+ for param in params:
153
+ if param.requires_grad != requires_grad:
154
+ param.requires_grad = requires_grad
155
+ params_changed += param.numel()
156
+ return params_changed
157
+
158
+
159
+ def save_pretrained_transforms(model: nn.Module, output_dir: Union[str, Path]):
160
+ """
161
+ Checks a model for the 'self._pretrained_default_transforms' attribute, if found,
162
+ serializes the returned transform object as a .joblib file.
163
+
164
+ Used for wrapper vision models when initialized with pre-trained weights.
165
+
166
+ This saves the callable transform object itself for
167
+ later use, such as passing it directly to the 'transform_source'
168
+ argument of the PyTorchVisionInferenceHandler.
169
+
170
+ Args:
171
+ model (nn.Module): The model instance to check.
172
+ output_dir (str | Path): The directory where the transform file will be saved.
173
+ """
174
+ output_filename = "pretrained_model_transformations"
175
+
176
+ # 1. Check for the "secret attribute"
177
+ if not hasattr(model, '_pretrained_default_transforms'):
178
+ _LOGGER.warning(f"Model of type {type(model).__name__} does not have the required attribute. No transformations saved.")
179
+ return
180
+
181
+ # 2. Get the transform object
182
+ try:
183
+ transform_obj = model._pretrained_default_transforms
184
+ except Exception as e:
185
+ _LOGGER.error(f"Error calling the required attribute on model: {e}")
186
+ return
187
+
188
+ # 3. Check if the object is actually there
189
+ if transform_obj is None:
190
+ _LOGGER.warning(f"Model {type(model).__name__} has the required attribute but returned None. No transforms saved.")
191
+ return
192
+
193
+ # 4. Serialize and save using serde
194
+ try:
195
+ serialize_object_filename(
196
+ obj=transform_obj,
197
+ save_dir=output_dir,
198
+ filename=output_filename,
199
+ verbose=True,
200
+ raise_on_error=True
201
+ )
202
+ # _LOGGER.info(f"Successfully saved pretrained transforms to '{output_dir}'.")
203
+ except Exception as e:
204
+ _LOGGER.error(f"Failed to serialize transformations: {e}")
205
+ raise
@@ -1,18 +1,21 @@
1
- from ._core._ML_vision_transformers import (
2
- TRANSFORM_REGISTRY,
1
+ from ._core_transforms import (
3
2
  ResizeAspectFill,
4
3
  LetterboxResize,
5
4
  HistogramEqualization,
6
5
  RandomHistogramEqualization,
7
- create_offline_augmentations,
8
- info
9
6
  )
10
7
 
8
+ from ._offline_augmentation import create_offline_augmentations
9
+
10
+ from ._imprimir import info
11
+
12
+
11
13
  __all__ = [
12
- "TRANSFORM_REGISTRY",
14
+ # Custom Transforms
13
15
  "ResizeAspectFill",
14
16
  "LetterboxResize",
15
17
  "HistogramEqualization",
16
18
  "RandomHistogramEqualization",
17
- "create_offline_augmentations"
19
+ # Offline Augmentation
20
+ "create_offline_augmentations",
18
21
  ]