dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
ml_tools/_core/_ML_inference.py
DELETED
|
@@ -1,646 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import nn
|
|
3
|
-
import numpy as np
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Union, Literal, Dict, Any, Optional
|
|
6
|
-
from abc import ABC, abstractmethod
|
|
7
|
-
|
|
8
|
-
from ._ML_finalize_handler import FinalizedFileHandler
|
|
9
|
-
from ._ML_scaler import DragonScaler
|
|
10
|
-
from ._script_info import _script_info
|
|
11
|
-
from ._logger import get_logger
|
|
12
|
-
from ._path_manager import make_fullpath
|
|
13
|
-
from ._keys import PyTorchInferenceKeys, PyTorchCheckpointKeys, MLTaskKeys, ScalerKeys, MagicWords
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
_LOGGER = get_logger("Inference Handler")
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
__all__ = [
|
|
20
|
-
"DragonInferenceHandler",
|
|
21
|
-
"multi_inference_regression",
|
|
22
|
-
"multi_inference_classification"
|
|
23
|
-
]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class _BaseInferenceHandler(ABC):
|
|
27
|
-
"""
|
|
28
|
-
Abstract base class for PyTorch inference handlers.
|
|
29
|
-
|
|
30
|
-
Manages common tasks like loading a model's state dictionary via FinalizedFileHandler,
|
|
31
|
-
validating the target device, and preprocessing input features.
|
|
32
|
-
"""
|
|
33
|
-
def __init__(self,
|
|
34
|
-
model: nn.Module,
|
|
35
|
-
state_dict: Union[str, Path],
|
|
36
|
-
device: str = 'cpu',
|
|
37
|
-
scaler: Optional[Union[str, Path]] = None,
|
|
38
|
-
task: Optional[str] = None):
|
|
39
|
-
"""
|
|
40
|
-
Initializes the handler.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
model (nn.Module): An instantiated PyTorch model.
|
|
44
|
-
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
45
|
-
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
46
|
-
scaler (str | Path | None): An optional scaler or path to a saved scaler state.
|
|
47
|
-
task (str | None): The specific machine learning task. If None, it attempts to read it from the finalized-file.
|
|
48
|
-
"""
|
|
49
|
-
self.model = model
|
|
50
|
-
self.device = self._validate_device(device)
|
|
51
|
-
self._classification_threshold = 0.5
|
|
52
|
-
self._loaded_threshold: bool = False
|
|
53
|
-
self._loaded_class_map: bool = False
|
|
54
|
-
self._class_map: Optional[dict[str,int]] = None
|
|
55
|
-
self._idx_to_class: Optional[Dict[int, str]] = None
|
|
56
|
-
|
|
57
|
-
# --- 1. Load File Handler ---
|
|
58
|
-
# This loads the content on CPU and validates structure
|
|
59
|
-
self._file_handler = FinalizedFileHandler(state_dict)
|
|
60
|
-
|
|
61
|
-
# Silence warnings of the filehandler internally
|
|
62
|
-
self._file_handler._verbose = False
|
|
63
|
-
|
|
64
|
-
# --- 2. Task Resolution ---
|
|
65
|
-
file_task = self._file_handler.task
|
|
66
|
-
|
|
67
|
-
if task is None:
|
|
68
|
-
# User didn't provide task, must be in file
|
|
69
|
-
if file_task == MagicWords.UNKNOWN:
|
|
70
|
-
_LOGGER.error(f"Task not specified in arguments and not found in file '{make_fullpath(state_dict).name}'.")
|
|
71
|
-
raise ValueError()
|
|
72
|
-
self.task = file_task
|
|
73
|
-
_LOGGER.info(f"Task '{self.task}' detected from file.")
|
|
74
|
-
else:
|
|
75
|
-
# User provided task
|
|
76
|
-
if file_task != MagicWords.UNKNOWN and file_task != task:
|
|
77
|
-
_LOGGER.warning(f"Provided task '{task}' differs from file metadata task '{file_task}'. Using provided task '{task}'.")
|
|
78
|
-
self.task = task
|
|
79
|
-
|
|
80
|
-
# --- 3. Load Model Weights ---
|
|
81
|
-
# Weights are already loaded in file_handler (on CPU)
|
|
82
|
-
try:
|
|
83
|
-
self.model.load_state_dict(self._file_handler.model_state_dict)
|
|
84
|
-
except RuntimeError as e:
|
|
85
|
-
_LOGGER.error(f"State dict mismatch: {e}")
|
|
86
|
-
raise
|
|
87
|
-
|
|
88
|
-
# --- 4. Load Metadata (Thresholds, Class Maps) ---
|
|
89
|
-
if self._file_handler.classification_threshold is not None:
|
|
90
|
-
self._classification_threshold = self._file_handler.classification_threshold
|
|
91
|
-
self._loaded_threshold = True
|
|
92
|
-
|
|
93
|
-
if self._file_handler.class_map is not None:
|
|
94
|
-
self.set_class_map(self._file_handler.class_map)
|
|
95
|
-
# set_class_map sets _loaded_class_map to True
|
|
96
|
-
|
|
97
|
-
# --- 5. Move to Device ---
|
|
98
|
-
self.model.to(self.device)
|
|
99
|
-
self.model.eval()
|
|
100
|
-
_LOGGER.info(f"Model loaded and moved to {self.device} in evaluation mode.")
|
|
101
|
-
|
|
102
|
-
# --- 6. Load Scalers ---
|
|
103
|
-
self.feature_scaler: Optional[DragonScaler] = None
|
|
104
|
-
self.target_scaler: Optional[DragonScaler] = None
|
|
105
|
-
|
|
106
|
-
if scaler is not None:
|
|
107
|
-
if isinstance(scaler, (str, Path)):
|
|
108
|
-
path_obj = make_fullpath(scaler, enforce="file")
|
|
109
|
-
loaded_scaler_data = torch.load(path_obj)
|
|
110
|
-
|
|
111
|
-
if isinstance(loaded_scaler_data, dict) and (ScalerKeys.FEATURE_SCALER in loaded_scaler_data or ScalerKeys.TARGET_SCALER in loaded_scaler_data):
|
|
112
|
-
if ScalerKeys.FEATURE_SCALER in loaded_scaler_data:
|
|
113
|
-
self.feature_scaler = DragonScaler.load(loaded_scaler_data[ScalerKeys.FEATURE_SCALER], verbose=False)
|
|
114
|
-
_LOGGER.info("Loaded DragonScaler state for feature scaling.")
|
|
115
|
-
if ScalerKeys.TARGET_SCALER in loaded_scaler_data:
|
|
116
|
-
self.target_scaler = DragonScaler.load(loaded_scaler_data[ScalerKeys.TARGET_SCALER], verbose=False)
|
|
117
|
-
_LOGGER.info("Loaded DragonScaler state for target scaling.")
|
|
118
|
-
else:
|
|
119
|
-
_LOGGER.warning("Loaded scaler file does not contain separate feature/target scalers. Assuming it is a feature scaler (legacy format).")
|
|
120
|
-
self.feature_scaler = DragonScaler.load(loaded_scaler_data)
|
|
121
|
-
else:
|
|
122
|
-
_LOGGER.error("Scaler must be a file path (str or Path) to a saved DragonScaler state file.")
|
|
123
|
-
raise ValueError()
|
|
124
|
-
|
|
125
|
-
def _validate_device(self, device: str) -> torch.device:
|
|
126
|
-
"""Validates the selected device and returns a torch.device object."""
|
|
127
|
-
device_lower = device.lower()
|
|
128
|
-
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
129
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
130
|
-
device_lower = "cpu"
|
|
131
|
-
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
132
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
133
|
-
device_lower = "cpu"
|
|
134
|
-
return torch.device(device_lower)
|
|
135
|
-
|
|
136
|
-
def set_class_map(self, class_map: Dict[str, int], force_overwrite: bool = False):
|
|
137
|
-
"""
|
|
138
|
-
Sets the class name mapping to translate predicted integer labels back into string names.
|
|
139
|
-
|
|
140
|
-
Args:
|
|
141
|
-
class_map (Dict[str, int]): The class_to_idx dictionary.
|
|
142
|
-
force_overwrite (bool): If True, allows overwriting a map that was loaded from a configuration file.
|
|
143
|
-
"""
|
|
144
|
-
if self._loaded_class_map:
|
|
145
|
-
warning_message = f"A '{PyTorchCheckpointKeys.CLASS_MAP}' was loaded from the model configuration file."
|
|
146
|
-
if not force_overwrite:
|
|
147
|
-
warning_message += " Use 'force_overwrite=True' if you are sure you want to modify it. This will not affect the value from the file."
|
|
148
|
-
_LOGGER.warning(warning_message)
|
|
149
|
-
return
|
|
150
|
-
else:
|
|
151
|
-
warning_message += " Overwriting it for this inference instance."
|
|
152
|
-
_LOGGER.warning(warning_message)
|
|
153
|
-
|
|
154
|
-
self._class_map = class_map
|
|
155
|
-
self._idx_to_class = {v: k for k, v in class_map.items()}
|
|
156
|
-
self._loaded_class_map = True
|
|
157
|
-
_LOGGER.info("InferenceHandler: Class map set for label-to-name translation.")
|
|
158
|
-
|
|
159
|
-
@abstractmethod
|
|
160
|
-
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
161
|
-
"""Core batch prediction method. Must be implemented by subclasses."""
|
|
162
|
-
pass
|
|
163
|
-
|
|
164
|
-
@abstractmethod
|
|
165
|
-
def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
166
|
-
"""Core single-sample prediction method. Must be implemented by subclasses."""
|
|
167
|
-
pass
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
class DragonInferenceHandler(_BaseInferenceHandler):
|
|
171
|
-
"""
|
|
172
|
-
Handles loading a PyTorch model's state dictionary and performing inference for tabular data.
|
|
173
|
-
"""
|
|
174
|
-
def __init__(self,
|
|
175
|
-
model: nn.Module,
|
|
176
|
-
state_dict: Union[str, Path],
|
|
177
|
-
task: Optional[Literal["regression",
|
|
178
|
-
"binary classification",
|
|
179
|
-
"multiclass classification",
|
|
180
|
-
"multitarget regression",
|
|
181
|
-
"multilabel binary classification"]] = None,
|
|
182
|
-
device: str = 'cpu',
|
|
183
|
-
scaler: Optional[Union[str, Path]] = None):
|
|
184
|
-
"""
|
|
185
|
-
Initializes the handler for single-target tasks.
|
|
186
|
-
|
|
187
|
-
Args:
|
|
188
|
-
model (nn.Module): An instantiated PyTorch model architecture.
|
|
189
|
-
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
190
|
-
task (str, optional): The type of task. If None, it will be detected from file.
|
|
191
|
-
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
192
|
-
scaler (str | Path | None): A path to a saved DragonScaler state.
|
|
193
|
-
|
|
194
|
-
Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
|
|
195
|
-
"""
|
|
196
|
-
# Call the parent constructor to handle model loading, device, and scaler
|
|
197
|
-
# The parent constructor resolves 'task'
|
|
198
|
-
super().__init__(model=model,
|
|
199
|
-
state_dict=state_dict,
|
|
200
|
-
device=device,
|
|
201
|
-
scaler=scaler,
|
|
202
|
-
task=task)
|
|
203
|
-
|
|
204
|
-
# --- Validation of resolved task ---
|
|
205
|
-
valid_tasks = [
|
|
206
|
-
MLTaskKeys.REGRESSION,
|
|
207
|
-
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
208
|
-
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
209
|
-
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
210
|
-
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION
|
|
211
|
-
]
|
|
212
|
-
|
|
213
|
-
if self.task not in valid_tasks:
|
|
214
|
-
_LOGGER.error(f"'task' recognized as '{self.task}', but this inference handler only supports: {valid_tasks}.")
|
|
215
|
-
raise ValueError()
|
|
216
|
-
|
|
217
|
-
self.target_ids: Optional[list[str]] = None
|
|
218
|
-
self._target_ids_set: bool = False
|
|
219
|
-
|
|
220
|
-
# --- Attempt to load target names from FinalizedFileHandler ---
|
|
221
|
-
if self._file_handler.target_names is not None:
|
|
222
|
-
self.set_target_ids(self._file_handler.target_names)
|
|
223
|
-
elif self._file_handler.target_name is not None:
|
|
224
|
-
self.set_target_ids([self._file_handler.target_name])
|
|
225
|
-
else:
|
|
226
|
-
_LOGGER.warning("No target names found in file metadata.")
|
|
227
|
-
|
|
228
|
-
def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
229
|
-
"""
|
|
230
|
-
Converts input to a torch.Tensor, applies FEATURE scaling if a scaler is
|
|
231
|
-
present, and moves it to the correct device.
|
|
232
|
-
"""
|
|
233
|
-
if isinstance(features, np.ndarray):
|
|
234
|
-
features_tensor = torch.from_numpy(features).float()
|
|
235
|
-
else:
|
|
236
|
-
features_tensor = features.float()
|
|
237
|
-
|
|
238
|
-
if self.feature_scaler:
|
|
239
|
-
features_tensor = self.feature_scaler.transform(features_tensor)
|
|
240
|
-
|
|
241
|
-
return features_tensor.to(self.device)
|
|
242
|
-
|
|
243
|
-
def set_target_ids(self, target_names: list[str], force_overwrite: bool=False):
|
|
244
|
-
"""
|
|
245
|
-
Assigns the provided list of strings as the target variable names.
|
|
246
|
-
|
|
247
|
-
If target IDs have already been set, this method will log a warning.
|
|
248
|
-
|
|
249
|
-
Args:
|
|
250
|
-
target_names (list[str]): A list of target names.
|
|
251
|
-
force_overwrite (bool): If True, allows the method to overwrite previously set target IDs.
|
|
252
|
-
"""
|
|
253
|
-
if self._target_ids_set:
|
|
254
|
-
warning_message = "Target IDs was previously set."
|
|
255
|
-
if not force_overwrite:
|
|
256
|
-
warning_message += " Use `force_overwrite=True` to overwrite."
|
|
257
|
-
_LOGGER.warning(warning_message)
|
|
258
|
-
return
|
|
259
|
-
else:
|
|
260
|
-
warning_message += " Overwriting..."
|
|
261
|
-
_LOGGER.warning(warning_message)
|
|
262
|
-
|
|
263
|
-
self.target_ids = target_names
|
|
264
|
-
self._target_ids_set = True
|
|
265
|
-
_LOGGER.info("Target IDs set.")
|
|
266
|
-
|
|
267
|
-
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
268
|
-
"""
|
|
269
|
-
Core batch prediction method.
|
|
270
|
-
|
|
271
|
-
Args:
|
|
272
|
-
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
273
|
-
|
|
274
|
-
Returns:
|
|
275
|
-
Dict: A dictionary containing the raw output tensors from the model.
|
|
276
|
-
"""
|
|
277
|
-
if features.ndim != 2:
|
|
278
|
-
_LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
|
|
279
|
-
raise ValueError()
|
|
280
|
-
|
|
281
|
-
input_tensor = self._preprocess_input(features)
|
|
282
|
-
|
|
283
|
-
with torch.no_grad():
|
|
284
|
-
output = self.model(input_tensor)
|
|
285
|
-
|
|
286
|
-
# --- Target Scaling Logic (Inverse Transform) ---
|
|
287
|
-
# Only for regression tasks and if a target scaler exists
|
|
288
|
-
if self.target_scaler:
|
|
289
|
-
if self.task not in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
290
|
-
# raise error
|
|
291
|
-
_LOGGER.error("Target scaler is only applicable for regression tasks. A target scaler was provided for a non-regression task.")
|
|
292
|
-
raise ValueError()
|
|
293
|
-
|
|
294
|
-
# Ensure output is 2D (N, Targets) for the scaler
|
|
295
|
-
original_shape = output.shape
|
|
296
|
-
if output.ndim == 1:
|
|
297
|
-
output = output.reshape(-1, 1)
|
|
298
|
-
|
|
299
|
-
# Apply inverse transform (de-scale)
|
|
300
|
-
output = self.target_scaler.inverse_transform(output)
|
|
301
|
-
|
|
302
|
-
# Restore original shape if necessary (though usually we want 2D or 1D flat)
|
|
303
|
-
if len(original_shape) == 1:
|
|
304
|
-
output = output.flatten()
|
|
305
|
-
|
|
306
|
-
# --- Task Specific Formatting ---
|
|
307
|
-
if self.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
|
|
308
|
-
probs = torch.softmax(output, dim=1)
|
|
309
|
-
labels = torch.argmax(probs, dim=1)
|
|
310
|
-
return {
|
|
311
|
-
PyTorchInferenceKeys.LABELS: labels,
|
|
312
|
-
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
313
|
-
}
|
|
314
|
-
|
|
315
|
-
elif self.task == MLTaskKeys.BINARY_CLASSIFICATION:
|
|
316
|
-
if output.ndim == 2 and output.shape[1] == 1:
|
|
317
|
-
output = output.squeeze(1)
|
|
318
|
-
|
|
319
|
-
probs = torch.sigmoid(output)
|
|
320
|
-
labels = (probs >= self._classification_threshold).int()
|
|
321
|
-
return {
|
|
322
|
-
PyTorchInferenceKeys.LABELS: labels,
|
|
323
|
-
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
324
|
-
}
|
|
325
|
-
|
|
326
|
-
elif self.task == MLTaskKeys.REGRESSION:
|
|
327
|
-
# For single-target regression, ensure output is flattened
|
|
328
|
-
return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
|
|
329
|
-
|
|
330
|
-
elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
331
|
-
probs = torch.sigmoid(output)
|
|
332
|
-
labels = (probs >= self._classification_threshold).int()
|
|
333
|
-
return {
|
|
334
|
-
PyTorchInferenceKeys.LABELS: labels,
|
|
335
|
-
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
336
|
-
}
|
|
337
|
-
|
|
338
|
-
elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
339
|
-
return {PyTorchInferenceKeys.PREDICTIONS: output}
|
|
340
|
-
|
|
341
|
-
else:
|
|
342
|
-
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
343
|
-
raise ValueError()
|
|
344
|
-
|
|
345
|
-
def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
346
|
-
"""
|
|
347
|
-
Core single-sample prediction method for single-target models.
|
|
348
|
-
|
|
349
|
-
Args:
|
|
350
|
-
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
351
|
-
|
|
352
|
-
Returns:
|
|
353
|
-
Dict: A dictionary containing the raw output tensors for a single sample.
|
|
354
|
-
"""
|
|
355
|
-
if features.ndim == 1:
|
|
356
|
-
features = features.reshape(1, -1) # Reshape to a batch of one
|
|
357
|
-
|
|
358
|
-
if features.shape[0] != 1:
|
|
359
|
-
_LOGGER.error("The 'predict()' method is for a single sample. Use 'predict_batch()' for multiple samples.")
|
|
360
|
-
raise ValueError()
|
|
361
|
-
|
|
362
|
-
batch_results = self.predict_batch(features)
|
|
363
|
-
|
|
364
|
-
# Extract the first (and only) result from the batch output
|
|
365
|
-
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
366
|
-
return single_results
|
|
367
|
-
|
|
368
|
-
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
369
|
-
|
|
370
|
-
def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, np.ndarray]:
|
|
371
|
-
"""
|
|
372
|
-
Convenience wrapper for predict_batch that returns NumPy arrays
|
|
373
|
-
and adds string labels for classification tasks if a class_map is set.
|
|
374
|
-
"""
|
|
375
|
-
tensor_results = self.predict_batch(features)
|
|
376
|
-
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
377
|
-
|
|
378
|
-
# Add string names for classification if map exists
|
|
379
|
-
is_classification = self.task in [
|
|
380
|
-
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
381
|
-
MLTaskKeys.MULTICLASS_CLASSIFICATION
|
|
382
|
-
]
|
|
383
|
-
|
|
384
|
-
if is_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
385
|
-
int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
|
|
386
|
-
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [ # type: ignore
|
|
387
|
-
self._idx_to_class.get(label_id, "Unknown")
|
|
388
|
-
for label_id in int_labels
|
|
389
|
-
]
|
|
390
|
-
|
|
391
|
-
return numpy_results
|
|
392
|
-
|
|
393
|
-
def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
394
|
-
"""
|
|
395
|
-
Convenience wrapper for predict that returns NumPy arrays or scalars
|
|
396
|
-
and adds string labels for classification tasks if a class_map is set.
|
|
397
|
-
"""
|
|
398
|
-
tensor_results = self.predict(features)
|
|
399
|
-
|
|
400
|
-
if self.task == MLTaskKeys.REGRESSION:
|
|
401
|
-
# .item() implicitly moves to CPU and returns a Python scalar
|
|
402
|
-
return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
|
|
403
|
-
|
|
404
|
-
elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
405
|
-
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
406
|
-
label_name = "Unknown"
|
|
407
|
-
if self._idx_to_class:
|
|
408
|
-
label_name = self._idx_to_class.get(int_label, "Unknown") # type: ignore
|
|
409
|
-
|
|
410
|
-
return {
|
|
411
|
-
PyTorchInferenceKeys.LABELS: int_label,
|
|
412
|
-
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
413
|
-
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
414
|
-
}
|
|
415
|
-
|
|
416
|
-
elif self.task in [MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
417
|
-
# For multi-target models, the output is always an array.
|
|
418
|
-
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
419
|
-
return numpy_results
|
|
420
|
-
else:
|
|
421
|
-
# should never happen
|
|
422
|
-
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
423
|
-
raise ValueError()
|
|
424
|
-
|
|
425
|
-
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
426
|
-
"""
|
|
427
|
-
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
428
|
-
|
|
429
|
-
`target_ids` must be implemented.
|
|
430
|
-
"""
|
|
431
|
-
if self.target_ids is None:
|
|
432
|
-
_LOGGER.error(f"'target_ids' has not been implemented.")
|
|
433
|
-
raise AttributeError()
|
|
434
|
-
|
|
435
|
-
if self.task == MLTaskKeys.REGRESSION:
|
|
436
|
-
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
|
|
437
|
-
return {self.target_ids[0]: result}
|
|
438
|
-
|
|
439
|
-
elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
440
|
-
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
|
|
441
|
-
return {self.target_ids[0]: result}
|
|
442
|
-
|
|
443
|
-
elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
444
|
-
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
|
|
445
|
-
return {key: value for key, value in zip(self.target_ids, result)}
|
|
446
|
-
|
|
447
|
-
elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
448
|
-
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
|
|
449
|
-
return {key: value for key, value in zip(self.target_ids, result)}
|
|
450
|
-
|
|
451
|
-
else:
|
|
452
|
-
# should never happen
|
|
453
|
-
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
454
|
-
raise ValueError()
|
|
455
|
-
|
|
456
|
-
def set_classification_threshold(self, threshold: float, force_overwrite: bool=False):
|
|
457
|
-
"""
|
|
458
|
-
Sets the classification threshold for the current inference instance.
|
|
459
|
-
|
|
460
|
-
If a threshold was previously loaded from a model configuration, this
|
|
461
|
-
method will log a warning and refuse to update the value. This
|
|
462
|
-
prevents accidentally overriding a setting from a loaded checkpoint.
|
|
463
|
-
|
|
464
|
-
To bypass this safety check set `force_overwrite` to `True`.
|
|
465
|
-
|
|
466
|
-
Args:
|
|
467
|
-
threshold (float): The new classification threshold value to set.
|
|
468
|
-
force_overwrite (bool): If True, allows overwriting a threshold that was loaded from a configuration file.
|
|
469
|
-
"""
|
|
470
|
-
if self._loaded_threshold:
|
|
471
|
-
warning_message = f"The current '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}={self._classification_threshold}' was loaded and set from a model configuration file."
|
|
472
|
-
if not force_overwrite:
|
|
473
|
-
warning_message += " Use 'force_overwrite' if you are sure you want to modify it. This will not affect the value from the file."
|
|
474
|
-
_LOGGER.warning(warning_message)
|
|
475
|
-
return
|
|
476
|
-
else:
|
|
477
|
-
warning_message += f" Overwriting it to {threshold}."
|
|
478
|
-
_LOGGER.warning(warning_message)
|
|
479
|
-
|
|
480
|
-
self._classification_threshold = threshold
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
def multi_inference_regression(handlers: list[DragonInferenceHandler],
|
|
484
|
-
feature_vector: Union[np.ndarray, torch.Tensor],
|
|
485
|
-
output: Literal["numpy","torch"]="numpy") -> dict[str,Any]:
|
|
486
|
-
"""
|
|
487
|
-
Performs regression inference using multiple models on a single feature vector.
|
|
488
|
-
|
|
489
|
-
This function iterates through a list of DragonInferenceHandler objects,
|
|
490
|
-
each configured for a different regression target. It runs a prediction for
|
|
491
|
-
each handler using the same input feature vector and returns the results
|
|
492
|
-
in a dictionary.
|
|
493
|
-
|
|
494
|
-
The function adapts its behavior based on the input dimensions:
|
|
495
|
-
- 1D input: Returns a dictionary mapping target ID to a single value.
|
|
496
|
-
- 2D input: Returns a dictionary mapping target ID to a list of values.
|
|
497
|
-
|
|
498
|
-
Args:
|
|
499
|
-
handlers (list[DragonInferenceHandler]): A list of initialized inference
|
|
500
|
-
handlers. Each handler must have a unique `target_id` and be configured with `task="regression"`.
|
|
501
|
-
feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D) or a batch of samples (2D) to be fed into each regression model.
|
|
502
|
-
output (Literal["numpy", "torch"], optional): The desired format for the output predictions.
|
|
503
|
-
- "numpy": Returns predictions as Python scalars or NumPy arrays.
|
|
504
|
-
- "torch": Returns predictions as PyTorch tensors.
|
|
505
|
-
|
|
506
|
-
Returns:
|
|
507
|
-
(dict[str, Any]): A dictionary mapping each handler's `target_id` to its
|
|
508
|
-
predicted regression values.
|
|
509
|
-
|
|
510
|
-
Raises:
|
|
511
|
-
AttributeError: If any handler in the list is missing a `target_id`.
|
|
512
|
-
ValueError: If any handler's `task` is not 'regression' or if the input `feature_vector` is not 1D or 2D.
|
|
513
|
-
"""
|
|
514
|
-
# check batch dimension
|
|
515
|
-
is_single_sample = feature_vector.ndim == 1
|
|
516
|
-
|
|
517
|
-
# Reshape a 1D vector to a 2D batch of one for uniform processing.
|
|
518
|
-
if is_single_sample:
|
|
519
|
-
feature_vector = feature_vector.reshape(1, -1)
|
|
520
|
-
|
|
521
|
-
# Validate that the input is a 2D tensor.
|
|
522
|
-
if feature_vector.ndim != 2:
|
|
523
|
-
_LOGGER.error("Input feature_vector must be a 1D or 2D array/tensor.")
|
|
524
|
-
raise ValueError()
|
|
525
|
-
|
|
526
|
-
results: dict[str,Any] = dict()
|
|
527
|
-
for handler in handlers:
|
|
528
|
-
# validation
|
|
529
|
-
if handler.target_ids is None:
|
|
530
|
-
_LOGGER.error("All inference handlers must have a 'target_ids' attribute.")
|
|
531
|
-
raise AttributeError()
|
|
532
|
-
if handler.task != MLTaskKeys.REGRESSION:
|
|
533
|
-
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', only single target regression tasks are supported.")
|
|
534
|
-
raise ValueError()
|
|
535
|
-
|
|
536
|
-
# inference
|
|
537
|
-
if output == "numpy":
|
|
538
|
-
# This path returns NumPy arrays or standard Python scalars
|
|
539
|
-
numpy_result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
|
|
540
|
-
if is_single_sample:
|
|
541
|
-
# For a single sample, convert the 1-element array to a Python scalar
|
|
542
|
-
results[handler.target_ids[0]] = numpy_result.item()
|
|
543
|
-
else:
|
|
544
|
-
# For a batch, return the full NumPy array of predictions
|
|
545
|
-
results[handler.target_ids[0]] = numpy_result
|
|
546
|
-
|
|
547
|
-
else: # output == "torch"
|
|
548
|
-
# This path returns PyTorch tensors on the model's device
|
|
549
|
-
torch_result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
|
|
550
|
-
if is_single_sample:
|
|
551
|
-
# For a single sample, return the 0-dim tensor
|
|
552
|
-
results[handler.target_ids[0]] = torch_result[0]
|
|
553
|
-
else:
|
|
554
|
-
# For a batch, return the full tensor of predictions
|
|
555
|
-
results[handler.target_ids[0]] = torch_result
|
|
556
|
-
|
|
557
|
-
return results
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
def multi_inference_classification(
|
|
561
|
-
handlers: list[DragonInferenceHandler],
|
|
562
|
-
feature_vector: Union[np.ndarray, torch.Tensor],
|
|
563
|
-
output: Literal["numpy","torch"]="numpy"
|
|
564
|
-
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
565
|
-
"""
|
|
566
|
-
Performs classification inference on a single sample or a batch.
|
|
567
|
-
|
|
568
|
-
This function iterates through a list of DragonInferenceHandler objects,
|
|
569
|
-
each configured for a different classification target. It returns two
|
|
570
|
-
dictionaries: one for the predicted labels and one for the probabilities.
|
|
571
|
-
|
|
572
|
-
The function adapts its behavior based on the input dimensions:
|
|
573
|
-
- 1D input: The dictionaries map target ID to a single label and a single probability array.
|
|
574
|
-
- 2D input: The dictionaries map target ID to an array of labels and an array of probability arrays.
|
|
575
|
-
|
|
576
|
-
Args:
|
|
577
|
-
handlers (list[DragonInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
|
|
578
|
-
with `task="classification"`.
|
|
579
|
-
feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D)
|
|
580
|
-
or a batch of samples (2D) for prediction.
|
|
581
|
-
output (Literal["numpy", "torch"], optional): The desired format for the
|
|
582
|
-
output predictions.
|
|
583
|
-
|
|
584
|
-
Returns:
|
|
585
|
-
(tuple[dict[str, Any], dict[str, Any]]): A tuple containing two dictionaries:
|
|
586
|
-
1. A dictionary mapping `target_id` to the predicted label(s).
|
|
587
|
-
2. A dictionary mapping `target_id` to the prediction probabilities.
|
|
588
|
-
|
|
589
|
-
Raises:
|
|
590
|
-
AttributeError: If any handler in the list is missing a `target_id`.
|
|
591
|
-
ValueError: If any handler's `task` is not 'classification' or if the input `feature_vector` is not 1D or 2D.
|
|
592
|
-
"""
|
|
593
|
-
# Store if the original input was a single sample
|
|
594
|
-
is_single_sample = feature_vector.ndim == 1
|
|
595
|
-
|
|
596
|
-
# Reshape a 1D vector to a 2D batch of one for uniform processing
|
|
597
|
-
if is_single_sample:
|
|
598
|
-
feature_vector = feature_vector.reshape(1, -1)
|
|
599
|
-
|
|
600
|
-
if feature_vector.ndim != 2:
|
|
601
|
-
_LOGGER.error("Input feature_vector must be a 1D or 2D array/tensor.")
|
|
602
|
-
raise ValueError()
|
|
603
|
-
|
|
604
|
-
# Initialize two dictionaries for results
|
|
605
|
-
labels_results: dict[str, Any] = dict()
|
|
606
|
-
probs_results: dict[str, Any] = dict()
|
|
607
|
-
|
|
608
|
-
for handler in handlers:
|
|
609
|
-
# Validation
|
|
610
|
-
if handler.target_ids is None:
|
|
611
|
-
_LOGGER.error("All inference handlers must have a 'target_id' attribute.")
|
|
612
|
-
raise AttributeError()
|
|
613
|
-
if handler.task not in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
614
|
-
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', but this function only supports binary and multiclass classification.")
|
|
615
|
-
raise ValueError()
|
|
616
|
-
|
|
617
|
-
# Inference
|
|
618
|
-
if output == "numpy":
|
|
619
|
-
# predict_batch_numpy returns a dict of NumPy arrays
|
|
620
|
-
result = handler.predict_batch_numpy(feature_vector)
|
|
621
|
-
else: # torch
|
|
622
|
-
# predict_batch returns a dict of Torch tensors
|
|
623
|
-
result = handler.predict_batch(feature_vector)
|
|
624
|
-
|
|
625
|
-
labels = result[PyTorchInferenceKeys.LABELS]
|
|
626
|
-
probabilities = result[PyTorchInferenceKeys.PROBABILITIES]
|
|
627
|
-
|
|
628
|
-
if is_single_sample:
|
|
629
|
-
# For "numpy", convert the single label to a Python int scalar.
|
|
630
|
-
# For "torch", get the 0-dim tensor label.
|
|
631
|
-
if output == "numpy":
|
|
632
|
-
labels_results[handler.target_ids[0]] = labels.item()
|
|
633
|
-
else: # torch
|
|
634
|
-
labels_results[handler.target_ids[0]] = labels[0]
|
|
635
|
-
|
|
636
|
-
# The probabilities are an array/tensor of values
|
|
637
|
-
probs_results[handler.target_ids[0]] = probabilities[0]
|
|
638
|
-
else:
|
|
639
|
-
labels_results[handler.target_ids[0]] = labels
|
|
640
|
-
probs_results[handler.target_ids[0]] = probabilities
|
|
641
|
-
|
|
642
|
-
return labels_results, probs_results
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
def info():
|
|
646
|
-
_script_info(__all__)
|