dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from typing import Union, Optional
6
+ from abc import ABC, abstractmethod
7
+
8
+ from ..ML_finalize_handler import FinalizedFileHandler
9
+ from ..ML_scaler import DragonScaler
10
+
11
+ from .._core import get_logger
12
+ from ..path_manager import make_fullpath
13
+ from ..keys._keys import PyTorchCheckpointKeys, ScalerKeys, MagicWords
14
+
15
+
16
+ _LOGGER = get_logger("Inference Handler")
17
+
18
+
19
+ __all__ = [
20
+ "_BaseInferenceHandler",
21
+ ]
22
+
23
+
24
+ class _BaseInferenceHandler(ABC):
25
+ """
26
+ Abstract base class for PyTorch inference handlers.
27
+
28
+ Manages common tasks like loading a model's state dictionary via FinalizedFileHandler,
29
+ validating the target device, and preprocessing input features.
30
+ """
31
+ def __init__(self,
32
+ model: nn.Module,
33
+ state_dict: Union[str, Path],
34
+ device: str = 'cpu',
35
+ scaler: Optional[Union[str, Path]] = None,
36
+ task: Optional[str] = None):
37
+ """
38
+ Initializes the handler.
39
+
40
+ Args:
41
+ model (nn.Module): An instantiated PyTorch model.
42
+ state_dict (str | Path): Path to the saved .pth model state_dict file.
43
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
44
+ scaler (str | Path | None): An optional scaler or path to a saved scaler state.
45
+ task (str | None): The specific machine learning task. If None, it attempts to read it from the finalized-file.
46
+ """
47
+ self.model = model
48
+ self.device = self._validate_device(device)
49
+ self._classification_threshold = 0.5
50
+ self._loaded_threshold: bool = False
51
+ self._loaded_class_map: bool = False
52
+ self._class_map: Optional[dict[str,int]] = None
53
+ self._idx_to_class: Optional[dict[int, str]] = None
54
+
55
+ # --- 1. Load File Handler ---
56
+ # This loads the content on CPU and validates structure
57
+ self._file_handler = FinalizedFileHandler(state_dict)
58
+
59
+ # Silence warnings of the filehandler internally
60
+ self._file_handler._verbose = False
61
+
62
+ # --- 2. Task Resolution ---
63
+ file_task = self._file_handler.task
64
+
65
+ if task is None:
66
+ # User didn't provide task, must be in file
67
+ if file_task == MagicWords.UNKNOWN:
68
+ _LOGGER.error(f"Task not specified in arguments and not found in file '{make_fullpath(state_dict).name}'.")
69
+ raise ValueError()
70
+ self.task = file_task
71
+ _LOGGER.info(f"Task '{self.task}' detected from file.")
72
+ else:
73
+ # User provided task
74
+ if file_task != MagicWords.UNKNOWN and file_task != task:
75
+ _LOGGER.warning(f"Provided task '{task}' differs from file metadata task '{file_task}'. Using provided task '{task}'.")
76
+ self.task = task
77
+
78
+ # --- 3. Load Model Weights ---
79
+ # Weights are already loaded in file_handler (on CPU)
80
+ try:
81
+ self.model.load_state_dict(self._file_handler.model_state_dict)
82
+ except RuntimeError as e:
83
+ _LOGGER.error(f"State dict mismatch: {e}")
84
+ raise
85
+
86
+ # --- 4. Load Metadata (Thresholds, Class Maps) ---
87
+ if self._file_handler.classification_threshold is not None:
88
+ self._classification_threshold = self._file_handler.classification_threshold
89
+ self._loaded_threshold = True
90
+
91
+ if self._file_handler.class_map is not None:
92
+ self.set_class_map(self._file_handler.class_map)
93
+ # set_class_map sets _loaded_class_map to True
94
+
95
+ # --- 5. Move to Device ---
96
+ self.model.to(self.device)
97
+ self.model.eval()
98
+ _LOGGER.info(f"Model loaded and moved to {self.device} in evaluation mode.")
99
+
100
+ # --- 6. Load Scalers ---
101
+ self.feature_scaler: Optional[DragonScaler] = None
102
+ self.target_scaler: Optional[DragonScaler] = None
103
+
104
+ if scaler is not None:
105
+ if isinstance(scaler, (str, Path)):
106
+ path_obj = make_fullpath(scaler, enforce="file")
107
+ loaded_scaler_data = torch.load(path_obj)
108
+
109
+ if isinstance(loaded_scaler_data, dict) and (ScalerKeys.FEATURE_SCALER in loaded_scaler_data or ScalerKeys.TARGET_SCALER in loaded_scaler_data):
110
+ if ScalerKeys.FEATURE_SCALER in loaded_scaler_data:
111
+ self.feature_scaler = DragonScaler.load(loaded_scaler_data[ScalerKeys.FEATURE_SCALER], verbose=False)
112
+ _LOGGER.info("Loaded DragonScaler state for feature scaling.")
113
+ if ScalerKeys.TARGET_SCALER in loaded_scaler_data:
114
+ self.target_scaler = DragonScaler.load(loaded_scaler_data[ScalerKeys.TARGET_SCALER], verbose=False)
115
+ _LOGGER.info("Loaded DragonScaler state for target scaling.")
116
+ else:
117
+ _LOGGER.warning("Loaded scaler file does not contain separate feature/target scalers. Assuming it is a feature scaler (legacy format).")
118
+ self.feature_scaler = DragonScaler.load(loaded_scaler_data)
119
+ else:
120
+ _LOGGER.error("Scaler must be a file path (str or Path) to a saved DragonScaler state file.")
121
+ raise ValueError()
122
+
123
+ def _validate_device(self, device: str) -> torch.device:
124
+ """Validates the selected device and returns a torch.device object."""
125
+ device_lower = device.lower()
126
+ if "cuda" in device_lower and not torch.cuda.is_available():
127
+ _LOGGER.warning("CUDA not available, switching to CPU.")
128
+ device_lower = "cpu"
129
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
130
+ _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
131
+ device_lower = "cpu"
132
+ return torch.device(device_lower)
133
+
134
+ def set_class_map(self, class_map: dict[str, int], force_overwrite: bool = False):
135
+ """
136
+ Sets the class name mapping to translate predicted integer labels back into string names.
137
+
138
+ Args:
139
+ class_map (Dict[str, int]): The class_to_idx dictionary.
140
+ force_overwrite (bool): If True, allows overwriting a map that was loaded from a configuration file.
141
+ """
142
+ if self._loaded_class_map:
143
+ warning_message = f"A '{PyTorchCheckpointKeys.CLASS_MAP}' was loaded from the model configuration file."
144
+ if not force_overwrite:
145
+ warning_message += " Use 'force_overwrite=True' if you are sure you want to modify it. This will not affect the value from the file."
146
+ _LOGGER.warning(warning_message)
147
+ return
148
+ else:
149
+ warning_message += " Overwriting it for this inference instance."
150
+ _LOGGER.warning(warning_message)
151
+
152
+ self._class_map = class_map
153
+ self._idx_to_class = {v: k for k, v in class_map.items()}
154
+ self._loaded_class_map = True
155
+ _LOGGER.info("InferenceHandler: Class map set for label-to-name translation.")
156
+
157
+ @abstractmethod
158
+ def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, torch.Tensor]:
159
+ """Core batch prediction method. Must be implemented by subclasses."""
160
+ pass
161
+
162
+ @abstractmethod
163
+ def predict(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, torch.Tensor]:
164
+ """Core single-sample prediction method. Must be implemented by subclasses."""
165
+ pass
166
+
@@ -1,14 +1,14 @@
1
1
  import torch
2
2
  import numpy as np
3
- from typing import List, Dict, Union, Any
3
+ from typing import Union, Any
4
4
 
5
- from ._ML_inference import DragonInferenceHandler
6
- from ._keys import MLTaskKeys, PyTorchInferenceKeys
7
- from ._logger import get_logger
8
- from ._script_info import _script_info
5
+ from ..keys._keys import MLTaskKeys, PyTorchInferenceKeys
6
+ from .._core import get_logger
9
7
 
8
+ from ._dragon_inference import DragonInferenceHandler
10
9
 
11
- _LOGGER = get_logger("Chain Inference")
10
+
11
+ _LOGGER = get_logger("DragonChainInference")
12
12
 
13
13
 
14
14
  __all__ = [
@@ -25,7 +25,7 @@ class DragonChainInference:
25
25
  'Classifier Chains' where subsequent models depend on the predictions
26
26
  of previous models.
27
27
  """
28
- def __init__(self, handlers: List[DragonInferenceHandler]):
28
+ def __init__(self, handlers: list[DragonInferenceHandler]):
29
29
  """
30
30
  Args:
31
31
  handlers (List[DragonInferenceHandler]): An ordered list of inference handlers.
@@ -59,11 +59,11 @@ class DragonChainInference:
59
59
  seen_targets.add(tid)
60
60
 
61
61
  @property
62
- def target_ids(self) -> List[str]:
62
+ def target_ids(self) -> list[str]:
63
63
  """Returns a unified list of all target_ids in the chain order."""
64
64
  return self._all_target_ids
65
65
 
66
- def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
66
+ def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, torch.Tensor]:
67
67
  """
68
68
  Runs the inference chain on a batch of features.
69
69
 
@@ -71,7 +71,7 @@ class DragonChainInference:
71
71
  features (np.ndarray | torch.Tensor): The initial input features (2D).
72
72
 
73
73
  Returns:
74
- Dict[str, torch.Tensor]: A dictionary mapping every target_id in the chain
74
+ dict[str, torch.Tensor]: A dictionary mapping every target_id in the chain
75
75
  to its predicted tensor.
76
76
  """
77
77
  # We perform operations on CPU or let the handlers manage device transfer internally.
@@ -86,7 +86,7 @@ class DragonChainInference:
86
86
  features = features.unsqueeze(0)
87
87
 
88
88
  current_features = features
89
- results: Dict[str, torch.Tensor] = {}
89
+ results: dict[str, torch.Tensor] = {}
90
90
 
91
91
  for i, handler in enumerate(self.handlers):
92
92
  # 1. Predict
@@ -131,7 +131,7 @@ class DragonChainInference:
131
131
 
132
132
  return results
133
133
 
134
- def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
134
+ def predict(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, torch.Tensor]:
135
135
  """
136
136
  Runs the chain on a single sample, returning PyTorch Tensors.
137
137
 
@@ -156,7 +156,7 @@ class DragonChainInference:
156
156
  single_results = {k: v[0] for k, v in batch_results.items()}
157
157
  return single_results
158
158
 
159
- def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, np.ndarray]:
159
+ def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, np.ndarray]:
160
160
  """
161
161
  Convenience wrapper that returns NumPy arrays instead of Tensors.
162
162
  Useful for final consumption of the chain results.
@@ -164,7 +164,7 @@ class DragonChainInference:
164
164
  tensor_results = self.predict_batch(features)
165
165
  return {k: v.cpu().numpy() for k, v in tensor_results.items()}
166
166
 
167
- def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
167
+ def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, Any]:
168
168
  """
169
169
  Runs the chain on a single sample, returning Python scalars or NumPy arrays.
170
170
 
@@ -187,6 +187,3 @@ class DragonChainInference:
187
187
 
188
188
  return numpy_results
189
189
 
190
-
191
- def info():
192
- _script_info(__all__)
@@ -0,0 +1,332 @@
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from typing import Union, Literal, Any, Optional
6
+
7
+ from .._core import get_logger
8
+ from ..keys._keys import PyTorchInferenceKeys, PyTorchCheckpointKeys, MLTaskKeys
9
+
10
+ from ._base_inference import _BaseInferenceHandler
11
+
12
+
13
+ _LOGGER = get_logger("DragonInference")
14
+
15
+
16
+ __all__ = [
17
+ "DragonInferenceHandler",
18
+ ]
19
+
20
+
21
+ class DragonInferenceHandler(_BaseInferenceHandler):
22
+ """
23
+ Handles loading a PyTorch model's state dictionary and performing inference for tabular data.
24
+ """
25
+ def __init__(self,
26
+ model: nn.Module,
27
+ state_dict: Union[str, Path],
28
+ task: Optional[Literal["regression",
29
+ "binary classification",
30
+ "multiclass classification",
31
+ "multitarget regression",
32
+ "multilabel binary classification"]] = None,
33
+ device: str = 'cpu',
34
+ scaler: Optional[Union[str, Path]] = None):
35
+ """
36
+ Initializes the handler for single-target tasks.
37
+
38
+ Args:
39
+ model (nn.Module): An instantiated PyTorch model architecture.
40
+ state_dict (str | Path): Path to the saved .pth model state_dict file.
41
+ task (str, optional): The type of task. If None, it will be detected from file.
42
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
43
+ scaler (str | Path | None): A path to a saved DragonScaler state.
44
+
45
+ Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
46
+ """
47
+ # Call the parent constructor to handle model loading, device, and scaler
48
+ # The parent constructor resolves 'task'
49
+ super().__init__(model=model,
50
+ state_dict=state_dict,
51
+ device=device,
52
+ scaler=scaler,
53
+ task=task)
54
+
55
+ # --- Validation of resolved task ---
56
+ valid_tasks = [
57
+ MLTaskKeys.REGRESSION,
58
+ MLTaskKeys.BINARY_CLASSIFICATION,
59
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
60
+ MLTaskKeys.MULTITARGET_REGRESSION,
61
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION
62
+ ]
63
+
64
+ if self.task not in valid_tasks:
65
+ _LOGGER.error(f"'task' recognized as '{self.task}', but this inference handler only supports: {valid_tasks}.")
66
+ raise ValueError()
67
+
68
+ self.target_ids: Optional[list[str]] = None
69
+ self._target_ids_set: bool = False
70
+
71
+ # --- Attempt to load target names from FinalizedFileHandler ---
72
+ if self._file_handler.target_names is not None:
73
+ self.set_target_ids(self._file_handler.target_names)
74
+ elif self._file_handler.target_name is not None:
75
+ self.set_target_ids([self._file_handler.target_name])
76
+ else:
77
+ _LOGGER.warning("No target names found in file metadata.")
78
+
79
+ def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
80
+ """
81
+ Converts input to a torch.Tensor, applies FEATURE scaling if a scaler is
82
+ present, and moves it to the correct device.
83
+ """
84
+ if isinstance(features, np.ndarray):
85
+ features_tensor = torch.from_numpy(features).float()
86
+ else:
87
+ features_tensor = features.float()
88
+
89
+ if self.feature_scaler:
90
+ features_tensor = self.feature_scaler.transform(features_tensor)
91
+
92
+ return features_tensor.to(self.device)
93
+
94
+ def set_target_ids(self, target_names: list[str], force_overwrite: bool=False):
95
+ """
96
+ Assigns the provided list of strings as the target variable names.
97
+
98
+ If target IDs have already been set, this method will log a warning.
99
+
100
+ Args:
101
+ target_names (list[str]): A list of target names.
102
+ force_overwrite (bool): If True, allows the method to overwrite previously set target IDs.
103
+ """
104
+ if self._target_ids_set:
105
+ warning_message = "Target IDs was previously set."
106
+ if not force_overwrite:
107
+ warning_message += " Use `force_overwrite=True` to overwrite."
108
+ _LOGGER.warning(warning_message)
109
+ return
110
+ else:
111
+ warning_message += " Overwriting..."
112
+ _LOGGER.warning(warning_message)
113
+
114
+ self.target_ids = target_names
115
+ self._target_ids_set = True
116
+ _LOGGER.info("Target IDs set.")
117
+
118
+ def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, torch.Tensor]:
119
+ """
120
+ Core batch prediction method.
121
+
122
+ Args:
123
+ features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
124
+
125
+ Returns:
126
+ Dict: A dictionary containing the raw output tensors from the model.
127
+ """
128
+ if features.ndim != 2:
129
+ _LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
130
+ raise ValueError()
131
+
132
+ input_tensor = self._preprocess_input(features)
133
+
134
+ with torch.no_grad():
135
+ output = self.model(input_tensor)
136
+
137
+ # --- Target Scaling Logic (Inverse Transform) ---
138
+ # Only for regression tasks and if a target scaler exists
139
+ if self.target_scaler:
140
+ if self.task not in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
141
+ # raise error
142
+ _LOGGER.error("Target scaler is only applicable for regression tasks. A target scaler was provided for a non-regression task.")
143
+ raise ValueError()
144
+
145
+ # Ensure output is 2D (N, Targets) for the scaler
146
+ original_shape = output.shape
147
+ if output.ndim == 1:
148
+ output = output.reshape(-1, 1)
149
+
150
+ # Apply inverse transform (de-scale)
151
+ output = self.target_scaler.inverse_transform(output)
152
+
153
+ # Restore original shape if necessary (though usually we want 2D or 1D flat)
154
+ if len(original_shape) == 1:
155
+ output = output.flatten()
156
+
157
+ # --- Task Specific Formatting ---
158
+ if self.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
159
+ probs = torch.softmax(output, dim=1)
160
+ labels = torch.argmax(probs, dim=1)
161
+ return {
162
+ PyTorchInferenceKeys.LABELS: labels,
163
+ PyTorchInferenceKeys.PROBABILITIES: probs
164
+ }
165
+
166
+ elif self.task == MLTaskKeys.BINARY_CLASSIFICATION:
167
+ if output.ndim == 2 and output.shape[1] == 1:
168
+ output = output.squeeze(1)
169
+
170
+ probs = torch.sigmoid(output)
171
+ labels = (probs >= self._classification_threshold).int()
172
+ return {
173
+ PyTorchInferenceKeys.LABELS: labels,
174
+ PyTorchInferenceKeys.PROBABILITIES: probs
175
+ }
176
+
177
+ elif self.task == MLTaskKeys.REGRESSION:
178
+ # For single-target regression, ensure output is flattened
179
+ return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
180
+
181
+ elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
182
+ probs = torch.sigmoid(output)
183
+ labels = (probs >= self._classification_threshold).int()
184
+ return {
185
+ PyTorchInferenceKeys.LABELS: labels,
186
+ PyTorchInferenceKeys.PROBABILITIES: probs
187
+ }
188
+
189
+ elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
190
+ return {PyTorchInferenceKeys.PREDICTIONS: output}
191
+
192
+ else:
193
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
194
+ raise ValueError()
195
+
196
+ def predict(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, torch.Tensor]:
197
+ """
198
+ Core single-sample prediction method for single-target models.
199
+
200
+ Args:
201
+ features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
202
+
203
+ Returns:
204
+ Dict: A dictionary containing the raw output tensors for a single sample.
205
+ """
206
+ if features.ndim == 1:
207
+ features = features.reshape(1, -1) # Reshape to a batch of one
208
+
209
+ if features.shape[0] != 1:
210
+ _LOGGER.error("The 'predict()' method is for a single sample. Use 'predict_batch()' for multiple samples.")
211
+ raise ValueError()
212
+
213
+ batch_results = self.predict_batch(features)
214
+
215
+ # Extract the first (and only) result from the batch output
216
+ single_results = {key: value[0] for key, value in batch_results.items()}
217
+ return single_results
218
+
219
+ # --- NumPy Convenience Wrappers (on CPU) ---
220
+
221
+ def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, np.ndarray]:
222
+ """
223
+ Convenience wrapper for predict_batch that returns NumPy arrays
224
+ and adds string labels for classification tasks if a class_map is set.
225
+ """
226
+ tensor_results = self.predict_batch(features)
227
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
228
+
229
+ # Add string names for classification if map exists
230
+ is_classification = self.task in [
231
+ MLTaskKeys.BINARY_CLASSIFICATION,
232
+ MLTaskKeys.MULTICLASS_CLASSIFICATION
233
+ ]
234
+
235
+ if is_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
236
+ int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
237
+ numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [ # type: ignore
238
+ self._idx_to_class.get(label_id, "Unknown")
239
+ for label_id in int_labels
240
+ ]
241
+
242
+ return numpy_results
243
+
244
+ def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, Any]:
245
+ """
246
+ Convenience wrapper for predict that returns NumPy arrays or scalars
247
+ and adds string labels for classification tasks if a class_map is set.
248
+ """
249
+ tensor_results = self.predict(features)
250
+
251
+ if self.task == MLTaskKeys.REGRESSION:
252
+ # .item() implicitly moves to CPU and returns a Python scalar
253
+ return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
254
+
255
+ elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
256
+ int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
257
+ label_name = "Unknown"
258
+ if self._idx_to_class:
259
+ label_name = self._idx_to_class.get(int_label, "Unknown") # type: ignore
260
+
261
+ return {
262
+ PyTorchInferenceKeys.LABELS: int_label,
263
+ PyTorchInferenceKeys.LABEL_NAMES: label_name,
264
+ PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
265
+ }
266
+
267
+ elif self.task in [MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION, MLTaskKeys.MULTITARGET_REGRESSION]:
268
+ # For multi-target models, the output is always an array.
269
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
270
+ return numpy_results
271
+ else:
272
+ # should never happen
273
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
274
+ raise ValueError()
275
+
276
+ def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> dict[str, Any]:
277
+ """
278
+ Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
279
+
280
+ `target_ids` must be implemented.
281
+ """
282
+ if self.target_ids is None:
283
+ _LOGGER.error(f"'target_ids' has not been implemented.")
284
+ raise AttributeError()
285
+
286
+ if self.task == MLTaskKeys.REGRESSION:
287
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
288
+ return {self.target_ids[0]: result}
289
+
290
+ elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
291
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
292
+ return {self.target_ids[0]: result}
293
+
294
+ elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
295
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
296
+ return {key: value for key, value in zip(self.target_ids, result)}
297
+
298
+ elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
299
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
300
+ return {key: value for key, value in zip(self.target_ids, result)}
301
+
302
+ else:
303
+ # should never happen
304
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
305
+ raise ValueError()
306
+
307
+ def set_classification_threshold(self, threshold: float, force_overwrite: bool=False):
308
+ """
309
+ Sets the classification threshold for the current inference instance.
310
+
311
+ If a threshold was previously loaded from a model configuration, this
312
+ method will log a warning and refuse to update the value. This
313
+ prevents accidentally overriding a setting from a loaded checkpoint.
314
+
315
+ To bypass this safety check set `force_overwrite` to `True`.
316
+
317
+ Args:
318
+ threshold (float): The new classification threshold value to set.
319
+ force_overwrite (bool): If True, allows overwriting a threshold that was loaded from a configuration file.
320
+ """
321
+ if self._loaded_threshold:
322
+ warning_message = f"The current '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}={self._classification_threshold}' was loaded and set from a model configuration file."
323
+ if not force_overwrite:
324
+ warning_message += " Use 'force_overwrite' if you are sure you want to modify it. This will not affect the value from the file."
325
+ _LOGGER.warning(warning_message)
326
+ return
327
+ else:
328
+ warning_message += f" Overwriting it to {threshold}."
329
+ _LOGGER.warning(warning_message)
330
+
331
+ self._classification_threshold = threshold
332
+
@@ -0,0 +1,11 @@
1
+ from .._core import _imprimir_disponibles
2
+
3
+ _GRUPOS = [
4
+ "DragonInferenceHandler",
5
+ "DragonChainInference",
6
+ "multi_inference_regression",
7
+ "multi_inference_classification"
8
+ ]
9
+
10
+ def info():
11
+ _imprimir_disponibles(_GRUPOS)