dragon-ml-toolbox 3.12.6__py3-none-any.whl → 4.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,253 @@
1
+ Metadata-Version: 2.4
2
+ Name: dragon-ml-toolbox
3
+ Version: 4.1.0
4
+ Summary: A collection of tools for data science and machine learning projects.
5
+ Author-email: Karl Loza <luigiloza@gmail.com>
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/DrAg0n-BoRn/ML_tools
8
+ Project-URL: Changelog, https://github.com/DrAg0n-BoRn/ML_tools/blob/master/CHANGELOG.md
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.10
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ License-File: LICENSE-THIRD-PARTY.md
15
+ Provides-Extra: base
16
+ Requires-Dist: pandas; extra == "base"
17
+ Requires-Dist: numpy; extra == "base"
18
+ Requires-Dist: polars; extra == "base"
19
+ Requires-Dist: joblib; extra == "base"
20
+ Provides-Extra: ml
21
+ Requires-Dist: numpy; extra == "ml"
22
+ Requires-Dist: pandas; extra == "ml"
23
+ Requires-Dist: polars; extra == "ml"
24
+ Requires-Dist: joblib; extra == "ml"
25
+ Requires-Dist: scikit-learn; extra == "ml"
26
+ Requires-Dist: matplotlib; extra == "ml"
27
+ Requires-Dist: seaborn; extra == "ml"
28
+ Requires-Dist: imbalanced-learn; extra == "ml"
29
+ Requires-Dist: ipython; extra == "ml"
30
+ Requires-Dist: ipykernel; extra == "ml"
31
+ Requires-Dist: notebook; extra == "ml"
32
+ Requires-Dist: jupyterlab; extra == "ml"
33
+ Requires-Dist: ipywidgets; extra == "ml"
34
+ Requires-Dist: xgboost; extra == "ml"
35
+ Requires-Dist: lightgbm; extra == "ml"
36
+ Requires-Dist: shap; extra == "ml"
37
+ Requires-Dist: tqdm; extra == "ml"
38
+ Requires-Dist: Pillow; extra == "ml"
39
+ Provides-Extra: mice
40
+ Requires-Dist: numpy<2.0; extra == "mice"
41
+ Requires-Dist: pandas; extra == "mice"
42
+ Requires-Dist: polars; extra == "mice"
43
+ Requires-Dist: joblib; extra == "mice"
44
+ Requires-Dist: miceforest>=6.0.0; extra == "mice"
45
+ Requires-Dist: plotnine>=0.12; extra == "mice"
46
+ Requires-Dist: matplotlib; extra == "mice"
47
+ Requires-Dist: statsmodels; extra == "mice"
48
+ Requires-Dist: lightgbm<=4.5.0; extra == "mice"
49
+ Requires-Dist: shap; extra == "mice"
50
+ Provides-Extra: pytorch
51
+ Requires-Dist: torch; extra == "pytorch"
52
+ Requires-Dist: torchvision; extra == "pytorch"
53
+ Provides-Extra: excel
54
+ Requires-Dist: pandas; extra == "excel"
55
+ Requires-Dist: openpyxl; extra == "excel"
56
+ Requires-Dist: ipython; extra == "excel"
57
+ Requires-Dist: ipykernel; extra == "excel"
58
+ Requires-Dist: notebook; extra == "excel"
59
+ Requires-Dist: jupyterlab; extra == "excel"
60
+ Requires-Dist: ipywidgets; extra == "excel"
61
+ Provides-Extra: gui-boost
62
+ Requires-Dist: numpy; extra == "gui-boost"
63
+ Requires-Dist: joblib; extra == "gui-boost"
64
+ Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-boost"
65
+ Requires-Dist: pyinstaller; extra == "gui-boost"
66
+ Requires-Dist: xgboost; extra == "gui-boost"
67
+ Requires-Dist: lightgbm; extra == "gui-boost"
68
+ Provides-Extra: gui-torch
69
+ Requires-Dist: numpy; extra == "gui-torch"
70
+ Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-torch"
71
+ Requires-Dist: pyinstaller; extra == "gui-torch"
72
+ Provides-Extra: plot
73
+ Requires-Dist: matplotlib; extra == "plot"
74
+ Requires-Dist: seaborn; extra == "plot"
75
+ Dynamic: license-file
76
+
77
+ # dragon-ml-toolbox
78
+
79
+ A collection of Python utilities for data science and machine learning, structured as a modular package for easy reuse and installation. This package has no base dependencies, allowing for lightweight and customized virtual environments.
80
+
81
+ ### Features:
82
+
83
+ - Modular scripts for data exploration, logging, machine learning, and more.
84
+ - Designed for seamless integration as a Git submodule or installable Python package.
85
+
86
+ ## Installation
87
+
88
+ **Python 3.10+**
89
+
90
+ ### Via PyPI
91
+
92
+ Install the latest stable release from PyPI:
93
+
94
+ ```bash
95
+ pip install dragon-ml-toolbox
96
+ ```
97
+
98
+ ### Via GitHub (Editable)
99
+
100
+ Clone the repository and install in editable mode with optional dependencies:
101
+
102
+ ```bash
103
+ git clone https://github.com/DrAg0n-BoRn/ML_tools.git
104
+ cd ML_tools
105
+ pip install -e .
106
+ ```
107
+
108
+ ### Via conda-forge
109
+
110
+ Install from the conda-forge channel:
111
+
112
+ ```bash
113
+ conda install -c conda-forge dragon-ml-toolbox
114
+ ```
115
+
116
+ ## Modular Installation
117
+
118
+ ### 📦 Core Machine Learning Toolbox [ML]
119
+
120
+ Installs a comprehensive set of tools for typical data science workflows, including data manipulation, modeling, and evaluation. PyTorch is required.
121
+
122
+ ```Bash
123
+ pip install "dragon-ml-toolbox[ML]"
124
+ ```
125
+
126
+ To install the standard CPU-only versions of Torch and Torchvision:
127
+
128
+ ```Bash
129
+ pip install "dragon-ml-toolbox[pytorch]"
130
+ ```
131
+
132
+ ⚠️ To make use of GPU acceleration (highly recommended), follow the official instructions: [PyTorch website](https://pytorch.org/get-started/locally/)
133
+
134
+ #### Modules:
135
+
136
+ ```bash
137
+ custom_logger
138
+ data_exploration
139
+ datasetmaster
140
+ ensemble_learning
141
+ ensemble_inference
142
+ ETL_engineering
143
+ ML_callbacks
144
+ ML_evaluation
145
+ ML_trainer
146
+ ML_inference
147
+ path_manager
148
+ PSO_optimization
149
+ SQL
150
+ RNN_forecast
151
+ utilities
152
+ ```
153
+
154
+ ### 🔬 MICE Imputation and Variance Inflation Factor [mice]
155
+
156
+ ⚠️ Important: This group has strict version requirements. It is highly recommended to install this group in a separate virtual environment.
157
+
158
+ ```Bash
159
+ pip install "dragon-ml-toolbox[mice]"
160
+ ```
161
+
162
+ #### Modules:
163
+
164
+ ```bash
165
+ custom_logger
166
+ MICE_imputation
167
+ VIF_factor
168
+ path_manager
169
+ utilities
170
+ ```
171
+
172
+ ### 📋 Excel File Handling [excel]
173
+
174
+ Installs dependencies required to process and handle .xlsx or .xls files.
175
+
176
+ ```Bash
177
+ pip install "dragon-ml-toolbox[excel]"
178
+ ```
179
+
180
+ #### Modules:
181
+
182
+ ```bash
183
+ custom_logger
184
+ handle_excel
185
+ path_manager
186
+ ```
187
+
188
+ ### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
189
+
190
+ For GUIs that include plotting functionality, you must also install the [plot] extra.
191
+
192
+ ```Bash
193
+ pip install "dragon-ml-toolbox[gui-boost]"
194
+ ```
195
+
196
+ ```Bash
197
+ pip install "dragon-ml-toolbox[gui-boost,plot]"
198
+ ```
199
+
200
+ #### Modules:
201
+
202
+ ```bash
203
+ GUI_tools
204
+ ensemble_inference
205
+ path_manager
206
+ ```
207
+
208
+ ### 🤖 GUI for PyTorch Models [gui-torch]
209
+
210
+ For GUIs that include plotting functionality, you must also install the [plot] extra.
211
+
212
+ ```Bash
213
+ pip install "dragon-ml-toolbox[gui-torch]"
214
+ ```
215
+
216
+ ```Bash
217
+ pip install "dragon-ml-toolbox[gui-torch,plot]"
218
+ ```
219
+
220
+ #### Modules:
221
+
222
+ ```bash
223
+ GUI_tools
224
+ ML_inference
225
+ path_manager
226
+ ```
227
+
228
+ ### 🎫 Base Tools [base]
229
+
230
+ General purpose functions and classes.
231
+
232
+ ```Bash
233
+ pip install "dragon-ml-toolbox[base]"
234
+ ```
235
+
236
+ #### Modules:
237
+
238
+ ```bash
239
+ ETL_Engineering
240
+ custom_logger
241
+ SQL
242
+ utilities
243
+ path_manager
244
+ ```
245
+
246
+ ## Usage
247
+
248
+ After installation, import modules like this:
249
+
250
+ ```python
251
+ from ml_tools.utilities import serialize_object, deserialize_object
252
+ from ml_tools.custom_logger import custom_logger
253
+ ```
@@ -0,0 +1,30 @@
1
+ dragon_ml_toolbox-4.1.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-4.1.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
3
+ ml_tools/ETL_engineering.py,sha256=m_IY-4hSp5X5TfJbWQ-MJNRxkxl4fcsxOnsivMs8tiM,39506
4
+ ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
5
+ ml_tools/MICE_imputation.py,sha256=b6ZTs8RedXFifOpuMCzr68xM16mCBVh1Ua6kcGfiVtg,11462
6
+ ml_tools/ML_callbacks.py,sha256=0a-Rbr0Xp_B1FNopOKBBmuJ4MqazS5JgDiT7wx1dHvE,13161
7
+ ml_tools/ML_evaluation.py,sha256=4dVqe6JF1Ukmk1sAcY8E5EG1oB1_oy2HXE5OT-pZwCs,10273
8
+ ml_tools/ML_inference.py,sha256=Fh-X2UQn3AznWBjf-7iPSxwE-EzkGQm1VEIRUAkURmE,5336
9
+ ml_tools/ML_trainer.py,sha256=dJjMfCEEM07Txy9KEH-2srZ3CZUa4lFWTJhpNWQ4Ndk,14974
10
+ ml_tools/PSO_optimization.py,sha256=xtnPute5pkS_w-VvqOBgRLgke09mjfacGC2m9DiipHE,27626
11
+ ml_tools/RNN_forecast.py,sha256=2CyjBLSYYc3xLHxwLXUmP5Qv8AmV1OB_EndETNX1IBk,1956
12
+ ml_tools/SQL.py,sha256=9zzS6AFEJM9aj6nE31hDe8S9TqLonk-J1amwZoiHNbk,10468
13
+ ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
14
+ ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
16
+ ml_tools/_pytorch_models.py,sha256=ewPPsTHgmRPzMMWwObZOdH1vxm2Ij2VWZP38NC6zSH4,10135
17
+ ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
18
+ ml_tools/custom_logger.py,sha256=a3ywSCQT7j5ypR-usnKh2l861d_aVJ93ZRVqxrHsBBw,4112
19
+ ml_tools/data_exploration.py,sha256=rJhvxUqVbEuB_7HG-PfLH3vaA7hrZEtbVHg9QO9VS4A,22837
20
+ ml_tools/datasetmaster.py,sha256=_tNC2v98eCQGr3nMW_EFs83TRgRme8Uc7ttg1vosmQU,30106
21
+ ml_tools/ensemble_inference.py,sha256=0SNX3YAz5bpvtwYmqEwqyWeIJP2Pb-v-bemENRSO7qg,9426
22
+ ml_tools/ensemble_learning.py,sha256=Zi1oy6G2FWnTI5hBwjlexwF3JKALFS2FN6F8HAlVi_s,35391
23
+ ml_tools/handle_excel.py,sha256=J9iwIqMZemoxK49J5osSwp9Ge0h9YTKyYGbOm53hcno,13007
24
+ ml_tools/keys.py,sha256=kK9UF-hek2VcPGFILCKl5geoN6flmMOu7IzhdEA6z5Y,1068
25
+ ml_tools/path_manager.py,sha256=ElDa25bntANujTjY7xN4ZfCDiZp-9Ud3x0aJSJptZBY,13419
26
+ ml_tools/utilities.py,sha256=mz-M351DzxWxnYVcLX-7ZQ6c-RGoCV9g4VTS9Qif2Es,18348
27
+ dragon_ml_toolbox-4.1.0.dist-info/METADATA,sha256=eJQwYS8B7RMy4H8DveKsDVmj4ikBSJb_hkuTSzmObz4,6278
28
+ dragon_ml_toolbox-4.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
29
+ dragon_ml_toolbox-4.1.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
30
+ dragon_ml_toolbox-4.1.0.dist-info/RECORD,,
@@ -1,8 +1,8 @@
1
1
  import polars as pl
2
2
  import re
3
3
  from typing import Literal, Union, Optional, Any, Callable, List, Dict, Tuple
4
- from .utilities import _script_info
5
- from .logger import _LOGGER
4
+ from ._script_info import _script_info
5
+ from ._logger import _LOGGER
6
6
 
7
7
 
8
8
  __all__ = [
ml_tools/GUI_tools.py CHANGED
@@ -4,9 +4,9 @@ import traceback
4
4
  import FreeSimpleGUI as sg
5
5
  from functools import wraps
6
6
  from typing import Any, Dict, Tuple, List, Literal, Union, Optional, Callable
7
- from .utilities import _script_info
7
+ from ._script_info import _script_info
8
8
  import numpy as np
9
- from .logger import _LOGGER
9
+ from ._logger import _LOGGER
10
10
  from .keys import _OneHotOtherPlaceholder
11
11
 
12
12
 
@@ -3,11 +3,12 @@ import miceforest as mf
3
3
  from pathlib import Path
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
- from .utilities import load_dataframe, list_csv_paths, sanitize_filename, _script_info, merge_dataframes, save_dataframe, threshold_binary_values, make_fullpath
6
+ from .utilities import load_dataframe, merge_dataframes, save_dataframe, threshold_binary_values
7
+ from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
7
8
  from plotnine import ggplot, labs, theme, element_blank # type: ignore
8
9
  from typing import Optional, Union
9
- from .logger import _LOGGER
10
-
10
+ from ._logger import _LOGGER
11
+ from ._script_info import _script_info
11
12
 
12
13
  __all__ = [
13
14
  "apply_mice",
ml_tools/ML_callbacks.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import numpy as np
2
2
  import torch
3
3
  from tqdm.auto import tqdm
4
- from .utilities import make_fullpath
4
+ from .path_manager import make_fullpath
5
5
  from .keys import LogKeys
6
- from .logger import _LOGGER
6
+ from ._logger import _LOGGER
7
7
  from typing import Optional
8
+ from ._script_info import _script_info
8
9
 
9
10
 
10
11
  __all__ = [
@@ -270,7 +271,7 @@ class ModelCheckpoint(Callback):
270
271
  self.last_best_filepath = new_filepath
271
272
 
272
273
  def _save_rolling_checkpoints(self, epoch, logs):
273
- """Saves the latest model and keeps only the last 5."""
274
+ """Saves the latest model and keeps only the most recent ones."""
274
275
  filename = f"epoch_{epoch}.pth"
275
276
  filepath = self.save_dir / filename
276
277
 
@@ -334,4 +335,7 @@ class LRScheduler(Callback):
334
335
  if current_lr != self.previous_lr:
335
336
  _LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
336
337
  self.previous_lr = current_lr
337
-
338
+
339
+
340
+ def info():
341
+ _script_info(__all__)
ml_tools/ML_evaluation.py CHANGED
@@ -14,9 +14,10 @@ from sklearn.metrics import (
14
14
  import torch
15
15
  import shap
16
16
  from pathlib import Path
17
- from .utilities import make_fullpath
18
- from .logger import _LOGGER
17
+ from .path_manager import make_fullpath
18
+ from ._logger import _LOGGER
19
19
  from typing import Union, Optional
20
+ from ._script_info import _script_info
20
21
 
21
22
 
22
23
  __all__ = [
@@ -62,7 +63,7 @@ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
62
63
  plt.tight_layout()
63
64
 
64
65
  if save_dir:
65
- save_dir_path = make_fullpath(save_dir, make=True)
66
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
66
67
  save_path = save_dir_path / "loss_plot.svg"
67
68
  plt.savefig(save_path)
68
69
  _LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
@@ -88,7 +89,7 @@ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optio
88
89
  print(report)
89
90
 
90
91
  if save_dir:
91
- save_dir_path = make_fullpath(save_dir, make=True)
92
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
92
93
  # Save text report
93
94
  report_path = save_dir_path / "classification_report.txt"
94
95
  report_path.write_text(report, encoding="utf-8")
@@ -158,7 +159,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
158
159
  print(report_string)
159
160
 
160
161
  if save_dir:
161
- save_dir_path = make_fullpath(save_dir, make=True)
162
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
162
163
  # Save text report
163
164
  report_path = save_dir_path / "regression_report.txt"
164
165
  report_path.write_text(report_string)
@@ -220,7 +221,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
220
221
  _LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
221
222
 
222
223
  if save_dir:
223
- save_dir_path = make_fullpath(save_dir, make=True)
224
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
224
225
  # Save Bar Plot
225
226
  bar_path = save_dir_path / "shap_bar_plot.svg"
226
227
  shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
@@ -253,3 +254,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
253
254
  else:
254
255
  _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
255
256
  shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
257
+
258
+
259
+ def info():
260
+ _script_info(__all__)
@@ -0,0 +1,131 @@
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from typing import Union, Literal, Dict, Any
6
+
7
+ from ._script_info import _script_info
8
+ from ._logger import _LOGGER
9
+ from .path_manager import make_fullpath
10
+ from .keys import PyTorchInferenceKeys
11
+
12
+ __all__ = [
13
+ "PyTorchInferenceHandler"
14
+ ]
15
+
16
+ class PyTorchInferenceHandler:
17
+ """
18
+ Handles loading a PyTorch model's state dictionary and performing inference
19
+ for either regression or classification tasks.
20
+ """
21
+ def __init__(self,
22
+ model: nn.Module,
23
+ state_dict: Union[str, Path],
24
+ task: Literal["classification", "regression"],
25
+ device: str = 'cpu'):
26
+ """
27
+ Initializes the handler by loading a model's state_dict.
28
+
29
+ Args:
30
+ model (nn.Module): An instantiated PyTorch model with the correct architecture.
31
+ state_dict (str | Path): The path to the saved .pth model state_dict file.
32
+ task (str): The type of task, 'regression' or 'classification'.
33
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
34
+ """
35
+ self.model = model
36
+ self.task = task
37
+ self.device = self._validate_device(device)
38
+
39
+ model_p = make_fullpath(state_dict, enforce="file")
40
+
41
+ try:
42
+ # Load the state dictionary and apply it to the model structure
43
+ self.model.load_state_dict(torch.load(model_p, map_location=self.device))
44
+ self.model.to(self.device)
45
+ self.model.eval() # Set the model to evaluation mode
46
+ _LOGGER.info(f"✅ Model state loaded from '{model_p.name}' and set to evaluation mode.")
47
+ except Exception as e:
48
+ _LOGGER.error(f"❌ Failed to load model state from '{model_p}': {e}")
49
+ raise
50
+
51
+ def _validate_device(self, device: str) -> torch.device:
52
+ """Validates the selected device and returns a torch.device object."""
53
+ device_lower = device.lower()
54
+ if "cuda" in device_lower and not torch.cuda.is_available():
55
+ _LOGGER.warning("⚠️ CUDA not available, switching to CPU.")
56
+ device_lower = "cpu"
57
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
58
+ _LOGGER.warning("⚠️ Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
59
+ device_lower = "cpu"
60
+ return torch.device(device_lower)
61
+
62
+ def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
63
+ """Converts input to a torch.Tensor and moves it to the correct device."""
64
+ if isinstance(features, np.ndarray):
65
+ features = torch.from_numpy(features).float()
66
+
67
+ # Ensure tensor is on the correct device
68
+ return features.to(self.device)
69
+
70
+ def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
71
+ """
72
+ Predicts on a single feature vector.
73
+
74
+ Args:
75
+ features (np.ndarray | torch.Tensor): A 1D or 2D array/tensor for a single sample.
76
+
77
+ Returns:
78
+ Dict[str, Any]: A dictionary containing the prediction.
79
+ - For regression: {'predictions': float}
80
+ - For classification: {'labels': int, 'probabilities': np.ndarray}
81
+ """
82
+ if features.ndim == 1:
83
+ features = features.reshape(1, -1)
84
+
85
+ if features.shape[0] != 1:
86
+ raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
87
+
88
+ results_batch = self.predict_batch(features)
89
+
90
+ # Extract the single result from the batch
91
+ if self.task == "regression":
92
+ return {PyTorchInferenceKeys.PREDICTIONS: results_batch[PyTorchInferenceKeys.PREDICTIONS].item()}
93
+ else: # classification
94
+ return {
95
+ PyTorchInferenceKeys.LABELS: results_batch[PyTorchInferenceKeys.LABELS].item(),
96
+ PyTorchInferenceKeys.PROBABILITIES: results_batch[PyTorchInferenceKeys.PROBABILITIES][0]
97
+ }
98
+
99
+ def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
100
+ """
101
+ Predicts on a batch of feature vectors.
102
+
103
+ Args:
104
+ features (np.ndarray | torch.Tensor): A 2D array/tensor where each row is a sample.
105
+
106
+ Returns:
107
+ Dict[str, Any]: A dictionary containing the predictions.
108
+ - For regression: {'predictions': np.ndarray}
109
+ - For classification: {'labels': np.ndarray, 'probabilities': np.ndarray}
110
+ """
111
+ if features.ndim != 2:
112
+ raise ValueError("Input for batch prediction must be a 2D array or tensor.")
113
+
114
+ input_tensor = self._preprocess_input(features)
115
+
116
+ with torch.no_grad():
117
+ output = self.model(input_tensor).cpu()
118
+
119
+ if self.task == "classification":
120
+ probs = nn.functional.softmax(output, dim=1)
121
+ labels = torch.argmax(probs, dim=1)
122
+ return {
123
+ PyTorchInferenceKeys.LABELS: labels.numpy(),
124
+ PyTorchInferenceKeys.PROBABILITIES: probs.numpy()
125
+ }
126
+ else: # regression
127
+ return {PyTorchInferenceKeys.PREDICTIONS: output.numpy()}
128
+
129
+
130
+ def info():
131
+ _script_info(__all__)
ml_tools/ML_trainer.py CHANGED
@@ -7,9 +7,9 @@ import numpy as np
7
7
 
8
8
  from .ML_callbacks import Callback, History, TqdmProgressBar
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
10
- from .utilities import _script_info
10
+ from ._script_info import _script_info
11
11
  from .keys import LogKeys
12
- from .logger import _LOGGER
12
+ from ._logger import _LOGGER
13
13
 
14
14
 
15
15
  __all__ = [
@@ -105,7 +105,7 @@ class MyTrainer:
105
105
  pin_memory=(self.device.type == "cuda")
106
106
  )
107
107
 
108
- def fit(self, epochs: int = 10, batch_size: int = 32, shuffle: bool = True):
108
+ def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
109
109
  """
110
110
  Starts the training-validation process of the model.
111
111
 
@@ -113,6 +113,13 @@ class MyTrainer:
113
113
  epochs (int): The total number of epochs to train for.
114
114
  batch_size (int): The number of samples per batch.
115
115
  shuffle (bool): Whether to shuffle the training data at each epoch.
116
+
117
+ Note:
118
+ For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
119
+ automatically aligns the model's output tensor with the target tensor's
120
+ shape using `output.view_as(target)`. This handles the common case
121
+ where a model outputs a shape of `[batch_size, 1]` and the target has a
122
+ shape of `[batch_size]`.
116
123
  """
117
124
  self.epochs = epochs
118
125
  self._create_dataloaders(batch_size, shuffle)
@@ -189,9 +196,10 @@ class MyTrainer:
189
196
  logs = {LogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
190
197
  return logs
191
198
 
192
- def predict(self, dataloader: DataLoader):
199
+ def _predict_for_eval(self, dataloader: DataLoader):
193
200
  """
194
- Yields model predictions batch by batch, avoids loading all predictions into memory at once.
201
+ Private method to yield model predictions batch by batch for evaluation.
202
+ This is used internally by the `evaluate` method.
195
203
 
196
204
  Args:
197
205
  dataloader (DataLoader): The dataloader to predict on.
@@ -213,13 +221,14 @@ class MyTrainer:
213
221
  preds = torch.argmax(probs, dim=1)
214
222
  y_pred_batch = preds.numpy()
215
223
  y_prob_batch = probs.numpy()
224
+ # regression
216
225
  else:
217
226
  y_pred_batch = output.numpy()
218
227
  y_prob_batch = None
219
228
 
220
229
  yield y_pred_batch, y_prob_batch, y_true_batch
221
230
 
222
- def evaluate(self, data: Optional[Union[DataLoader, Dataset]] = None, save_dir: Optional[Union[str,Path]] = None):
231
+ def evaluate(self, save_dir: Optional[Union[str,Path]], data: Optional[Union[DataLoader, Dataset]] = None):
223
232
  """
224
233
  Evaluates the model on the given data.
225
234
 
@@ -251,7 +260,7 @@ class MyTrainer:
251
260
 
252
261
  # Collect results from the predict generator
253
262
  all_preds, all_probs, all_true = [], [], []
254
- for y_pred_b, y_prob_b, y_true_b in self.predict(eval_loader):
263
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
255
264
  all_preds.append(y_pred_b)
256
265
  if y_prob_b is not None:
257
266
  all_probs.append(y_prob_b)
@@ -270,7 +279,7 @@ class MyTrainer:
270
279
  plot_losses(self.history, save_dir=save_dir)
271
280
 
272
281
  def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 100,
273
- feature_names: Optional[List[str]] = None, save_dir: Optional[str] = None):
282
+ feature_names: Optional[List[str]] = None, save_dir: Optional[Union[str,Path]] = None):
274
283
  """
275
284
  Explains model predictions using SHAP and saves all artifacts.
276
285