dragon-ml-toolbox 3.12.6__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,230 @@
1
+ Metadata-Version: 2.4
2
+ Name: dragon-ml-toolbox
3
+ Version: 4.0.0
4
+ Summary: A collection of tools for data science and machine learning projects.
5
+ Author-email: Karl Loza <luigiloza@gmail.com>
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/DrAg0n-BoRn/ML_tools
8
+ Project-URL: Changelog, https://github.com/DrAg0n-BoRn/ML_tools/blob/master/CHANGELOG.md
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.10
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ License-File: LICENSE-THIRD-PARTY.md
15
+ Provides-Extra: ml
16
+ Requires-Dist: numpy; extra == "ml"
17
+ Requires-Dist: pandas; extra == "ml"
18
+ Requires-Dist: polars; extra == "ml"
19
+ Requires-Dist: joblib; extra == "ml"
20
+ Requires-Dist: scikit-learn; extra == "ml"
21
+ Requires-Dist: matplotlib; extra == "ml"
22
+ Requires-Dist: seaborn; extra == "ml"
23
+ Requires-Dist: imbalanced-learn; extra == "ml"
24
+ Requires-Dist: ipython; extra == "ml"
25
+ Requires-Dist: ipykernel; extra == "ml"
26
+ Requires-Dist: notebook; extra == "ml"
27
+ Requires-Dist: jupyterlab; extra == "ml"
28
+ Requires-Dist: ipywidgets; extra == "ml"
29
+ Requires-Dist: xgboost; extra == "ml"
30
+ Requires-Dist: lightgbm; extra == "ml"
31
+ Requires-Dist: shap; extra == "ml"
32
+ Requires-Dist: tqdm; extra == "ml"
33
+ Requires-Dist: Pillow; extra == "ml"
34
+ Provides-Extra: mice
35
+ Requires-Dist: numpy<2.0; extra == "mice"
36
+ Requires-Dist: pandas; extra == "mice"
37
+ Requires-Dist: polars; extra == "mice"
38
+ Requires-Dist: joblib; extra == "mice"
39
+ Requires-Dist: miceforest>=6.0.0; extra == "mice"
40
+ Requires-Dist: plotnine>=0.12; extra == "mice"
41
+ Requires-Dist: matplotlib; extra == "mice"
42
+ Requires-Dist: statsmodels; extra == "mice"
43
+ Requires-Dist: lightgbm<=4.5.0; extra == "mice"
44
+ Requires-Dist: shap; extra == "mice"
45
+ Provides-Extra: pytorch
46
+ Requires-Dist: torch; extra == "pytorch"
47
+ Requires-Dist: torchvision; extra == "pytorch"
48
+ Provides-Extra: excel
49
+ Requires-Dist: pandas; extra == "excel"
50
+ Requires-Dist: openpyxl; extra == "excel"
51
+ Requires-Dist: ipython; extra == "excel"
52
+ Requires-Dist: ipykernel; extra == "excel"
53
+ Requires-Dist: notebook; extra == "excel"
54
+ Requires-Dist: jupyterlab; extra == "excel"
55
+ Requires-Dist: ipywidgets; extra == "excel"
56
+ Provides-Extra: gui-boost
57
+ Requires-Dist: numpy; extra == "gui-boost"
58
+ Requires-Dist: joblib; extra == "gui-boost"
59
+ Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-boost"
60
+ Requires-Dist: pyinstaller; extra == "gui-boost"
61
+ Requires-Dist: xgboost; extra == "gui-boost"
62
+ Requires-Dist: lightgbm; extra == "gui-boost"
63
+ Provides-Extra: gui-torch
64
+ Requires-Dist: numpy; extra == "gui-torch"
65
+ Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-torch"
66
+ Requires-Dist: pyinstaller; extra == "gui-torch"
67
+ Provides-Extra: plot
68
+ Requires-Dist: matplotlib; extra == "plot"
69
+ Requires-Dist: seaborn; extra == "plot"
70
+ Dynamic: license-file
71
+
72
+ # dragon-ml-toolbox
73
+
74
+ A collection of Python utilities for data science and machine learning, structured as a modular package for easy reuse and installation. This package has no base dependencies, allowing for lightweight and customized virtual environments.
75
+
76
+ ### Features:
77
+
78
+ - Modular scripts for data exploration, logging, machine learning, and more.
79
+ - Designed for seamless integration as a Git submodule or installable Python package.
80
+
81
+ ## Installation
82
+
83
+ **Python 3.10+**
84
+
85
+ ### Via PyPI
86
+
87
+ Install the latest stable release from PyPI:
88
+
89
+ ```bash
90
+ pip install dragon-ml-toolbox
91
+ ```
92
+
93
+ ### Via GitHub (Editable)
94
+
95
+ Clone the repository and install in editable mode with optional dependencies:
96
+
97
+ ```bash
98
+ git clone https://github.com/DrAg0n-BoRn/ML_tools.git
99
+ cd ML_tools
100
+ pip install -e .
101
+ ```
102
+
103
+ ### Via conda-forge
104
+
105
+ Install from the conda-forge channel:
106
+
107
+ ```bash
108
+ conda install -c conda-forge dragon-ml-toolbox
109
+ ```
110
+ **Note:** This version is outdated or broken due to dependency incompatibilities. Use PyPi instead.
111
+
112
+ ## Modular Installation
113
+
114
+ ### 📦 Core Machine Learning Toolbox [ML]
115
+
116
+ Installs a comprehensive set of tools for typical data science workflows, including data manipulation, modeling, and evaluation. PyTorch is required.
117
+
118
+ ```Bash
119
+ pip install "dragon-ml-toolbox[ML]"
120
+ ```
121
+
122
+ To install the standard CPU-only versions of Torch and Torchvision:
123
+
124
+ ```Bash
125
+ pip install "dragon-ml-toolbox[pytorch]"
126
+ ```
127
+
128
+ ⚠️ To make use of GPU acceleration (highly recommended), follow the official instructions: [PyTorch website](https://pytorch.org/get-started/locally/)
129
+
130
+ #### Modules:
131
+
132
+ ```bash
133
+ custom_logger
134
+ data_exploration
135
+ datasetmaster
136
+ ensemble_learning
137
+ ensemble_inference
138
+ ETL_engineering
139
+ ML_callbacks
140
+ ML_evaluation
141
+ ML_trainer
142
+ ML_inference
143
+ path_manager
144
+ PSO_optimization
145
+ RNN_forecast
146
+ utilities
147
+ ```
148
+
149
+ ### 🔬 MICE Imputation and Variance Inflation Factor [mice]
150
+
151
+ ⚠️ Important: This group has strict version requirements. It is highly recommended to install this group in a separate virtual environment.
152
+
153
+ ```Bash
154
+ pip install "dragon-ml-toolbox[mice]"
155
+ ```
156
+
157
+ #### Modules:
158
+
159
+ ```bash
160
+ custom_logger
161
+ MICE_imputation
162
+ VIF_factor
163
+ path_manager
164
+ utilities
165
+ ```
166
+
167
+ ### 📋 Excel File Handling [excel]
168
+
169
+ Installs dependencies required to process and handle .xlsx or .xls files.
170
+
171
+ ```Bash
172
+ pip install "dragon-ml-toolbox[excel]"
173
+ ```
174
+
175
+ #### Modules:
176
+
177
+ ```bash
178
+ custom_logger
179
+ handle_excel
180
+ path_manager
181
+ ```
182
+
183
+ ### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
184
+
185
+ For GUIs that include plotting functionality, you must also install the [plot] extra.
186
+
187
+ ```Bash
188
+ pip install "dragon-ml-toolbox[gui-boost]"
189
+ ```
190
+
191
+ ```Bash
192
+ pip install "dragon-ml-toolbox[gui-boost,plot]"
193
+ ```
194
+
195
+ #### Modules:
196
+
197
+ ```bash
198
+ GUI_tools
199
+ ensemble_inference
200
+ path_manager
201
+ ```
202
+
203
+ ### 🤖 GUI for PyTorch Models [gui-torch]
204
+
205
+ For GUIs that include plotting functionality, you must also install the [plot] extra.
206
+
207
+ ```Bash
208
+ pip install "dragon-ml-toolbox[gui-torch]"
209
+ ```
210
+
211
+ ```Bash
212
+ pip install "dragon-ml-toolbox[gui-torch,plot]"
213
+ ```
214
+
215
+ #### Modules:
216
+
217
+ ```bash
218
+ GUI_tools
219
+ ML_inference
220
+ path_manager
221
+ ```
222
+
223
+ ## Usage
224
+
225
+ After installation, import modules like this:
226
+
227
+ ```python
228
+ from ml_tools.utilities import serialize_object, deserialize_object
229
+ from ml_tools.custom_logger import custom_logger
230
+ ```
@@ -0,0 +1,29 @@
1
+ dragon_ml_toolbox-4.0.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-4.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
3
+ ml_tools/ETL_engineering.py,sha256=m_IY-4hSp5X5TfJbWQ-MJNRxkxl4fcsxOnsivMs8tiM,39506
4
+ ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
5
+ ml_tools/MICE_imputation.py,sha256=b6ZTs8RedXFifOpuMCzr68xM16mCBVh1Ua6kcGfiVtg,11462
6
+ ml_tools/ML_callbacks.py,sha256=0a-Rbr0Xp_B1FNopOKBBmuJ4MqazS5JgDiT7wx1dHvE,13161
7
+ ml_tools/ML_evaluation.py,sha256=4dVqe6JF1Ukmk1sAcY8E5EG1oB1_oy2HXE5OT-pZwCs,10273
8
+ ml_tools/ML_inference.py,sha256=Fh-X2UQn3AznWBjf-7iPSxwE-EzkGQm1VEIRUAkURmE,5336
9
+ ml_tools/ML_trainer.py,sha256=dJjMfCEEM07Txy9KEH-2srZ3CZUa4lFWTJhpNWQ4Ndk,14974
10
+ ml_tools/PSO_optimization.py,sha256=z8zPyoMtE-Vl5LcB24ZtNNTcEw9kVHftkV1VkV5pLD8,24662
11
+ ml_tools/RNN_forecast.py,sha256=2CyjBLSYYc3xLHxwLXUmP5Qv8AmV1OB_EndETNX1IBk,1956
12
+ ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
13
+ ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
15
+ ml_tools/_pytorch_models.py,sha256=ewPPsTHgmRPzMMWwObZOdH1vxm2Ij2VWZP38NC6zSH4,10135
16
+ ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
17
+ ml_tools/custom_logger.py,sha256=a3ywSCQT7j5ypR-usnKh2l861d_aVJ93ZRVqxrHsBBw,4112
18
+ ml_tools/data_exploration.py,sha256=rJhvxUqVbEuB_7HG-PfLH3vaA7hrZEtbVHg9QO9VS4A,22837
19
+ ml_tools/datasetmaster.py,sha256=_tNC2v98eCQGr3nMW_EFs83TRgRme8Uc7ttg1vosmQU,30106
20
+ ml_tools/ensemble_inference.py,sha256=0SNX3YAz5bpvtwYmqEwqyWeIJP2Pb-v-bemENRSO7qg,9426
21
+ ml_tools/ensemble_learning.py,sha256=Zi1oy6G2FWnTI5hBwjlexwF3JKALFS2FN6F8HAlVi_s,35391
22
+ ml_tools/handle_excel.py,sha256=J9iwIqMZemoxK49J5osSwp9Ge0h9YTKyYGbOm53hcno,13007
23
+ ml_tools/keys.py,sha256=kK9UF-hek2VcPGFILCKl5geoN6flmMOu7IzhdEA6z5Y,1068
24
+ ml_tools/path_manager.py,sha256=ElDa25bntANujTjY7xN4ZfCDiZp-9Ud3x0aJSJptZBY,13419
25
+ ml_tools/utilities.py,sha256=mz-M351DzxWxnYVcLX-7ZQ6c-RGoCV9g4VTS9Qif2Es,18348
26
+ dragon_ml_toolbox-4.0.0.dist-info/METADATA,sha256=tGUq_S7xEHszYM0vUifYUzV4dmV_TQ9W8Ja4sJOEPTs,5994
27
+ dragon_ml_toolbox-4.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ dragon_ml_toolbox-4.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
29
+ dragon_ml_toolbox-4.0.0.dist-info/RECORD,,
@@ -1,8 +1,8 @@
1
1
  import polars as pl
2
2
  import re
3
3
  from typing import Literal, Union, Optional, Any, Callable, List, Dict, Tuple
4
- from .utilities import _script_info
5
- from .logger import _LOGGER
4
+ from ._script_info import _script_info
5
+ from ._logger import _LOGGER
6
6
 
7
7
 
8
8
  __all__ = [
ml_tools/GUI_tools.py CHANGED
@@ -4,9 +4,9 @@ import traceback
4
4
  import FreeSimpleGUI as sg
5
5
  from functools import wraps
6
6
  from typing import Any, Dict, Tuple, List, Literal, Union, Optional, Callable
7
- from .utilities import _script_info
7
+ from ._script_info import _script_info
8
8
  import numpy as np
9
- from .logger import _LOGGER
9
+ from ._logger import _LOGGER
10
10
  from .keys import _OneHotOtherPlaceholder
11
11
 
12
12
 
@@ -3,11 +3,12 @@ import miceforest as mf
3
3
  from pathlib import Path
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
- from .utilities import load_dataframe, list_csv_paths, sanitize_filename, _script_info, merge_dataframes, save_dataframe, threshold_binary_values, make_fullpath
6
+ from .utilities import load_dataframe, merge_dataframes, save_dataframe, threshold_binary_values
7
+ from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
7
8
  from plotnine import ggplot, labs, theme, element_blank # type: ignore
8
9
  from typing import Optional, Union
9
- from .logger import _LOGGER
10
-
10
+ from ._logger import _LOGGER
11
+ from ._script_info import _script_info
11
12
 
12
13
  __all__ = [
13
14
  "apply_mice",
ml_tools/ML_callbacks.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import numpy as np
2
2
  import torch
3
3
  from tqdm.auto import tqdm
4
- from .utilities import make_fullpath
4
+ from .path_manager import make_fullpath
5
5
  from .keys import LogKeys
6
- from .logger import _LOGGER
6
+ from ._logger import _LOGGER
7
7
  from typing import Optional
8
+ from ._script_info import _script_info
8
9
 
9
10
 
10
11
  __all__ = [
@@ -270,7 +271,7 @@ class ModelCheckpoint(Callback):
270
271
  self.last_best_filepath = new_filepath
271
272
 
272
273
  def _save_rolling_checkpoints(self, epoch, logs):
273
- """Saves the latest model and keeps only the last 5."""
274
+ """Saves the latest model and keeps only the most recent ones."""
274
275
  filename = f"epoch_{epoch}.pth"
275
276
  filepath = self.save_dir / filename
276
277
 
@@ -334,4 +335,7 @@ class LRScheduler(Callback):
334
335
  if current_lr != self.previous_lr:
335
336
  _LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
336
337
  self.previous_lr = current_lr
337
-
338
+
339
+
340
+ def info():
341
+ _script_info(__all__)
ml_tools/ML_evaluation.py CHANGED
@@ -14,9 +14,10 @@ from sklearn.metrics import (
14
14
  import torch
15
15
  import shap
16
16
  from pathlib import Path
17
- from .utilities import make_fullpath
18
- from .logger import _LOGGER
17
+ from .path_manager import make_fullpath
18
+ from ._logger import _LOGGER
19
19
  from typing import Union, Optional
20
+ from ._script_info import _script_info
20
21
 
21
22
 
22
23
  __all__ = [
@@ -62,7 +63,7 @@ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
62
63
  plt.tight_layout()
63
64
 
64
65
  if save_dir:
65
- save_dir_path = make_fullpath(save_dir, make=True)
66
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
66
67
  save_path = save_dir_path / "loss_plot.svg"
67
68
  plt.savefig(save_path)
68
69
  _LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
@@ -88,7 +89,7 @@ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optio
88
89
  print(report)
89
90
 
90
91
  if save_dir:
91
- save_dir_path = make_fullpath(save_dir, make=True)
92
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
92
93
  # Save text report
93
94
  report_path = save_dir_path / "classification_report.txt"
94
95
  report_path.write_text(report, encoding="utf-8")
@@ -158,7 +159,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
158
159
  print(report_string)
159
160
 
160
161
  if save_dir:
161
- save_dir_path = make_fullpath(save_dir, make=True)
162
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
162
163
  # Save text report
163
164
  report_path = save_dir_path / "regression_report.txt"
164
165
  report_path.write_text(report_string)
@@ -220,7 +221,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
220
221
  _LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
221
222
 
222
223
  if save_dir:
223
- save_dir_path = make_fullpath(save_dir, make=True)
224
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
224
225
  # Save Bar Plot
225
226
  bar_path = save_dir_path / "shap_bar_plot.svg"
226
227
  shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
@@ -253,3 +254,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
253
254
  else:
254
255
  _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
255
256
  shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
257
+
258
+
259
+ def info():
260
+ _script_info(__all__)
@@ -0,0 +1,131 @@
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from typing import Union, Literal, Dict, Any
6
+
7
+ from ._script_info import _script_info
8
+ from ._logger import _LOGGER
9
+ from .path_manager import make_fullpath
10
+ from .keys import PyTorchInferenceKeys
11
+
12
+ __all__ = [
13
+ "PyTorchInferenceHandler"
14
+ ]
15
+
16
+ class PyTorchInferenceHandler:
17
+ """
18
+ Handles loading a PyTorch model's state dictionary and performing inference
19
+ for either regression or classification tasks.
20
+ """
21
+ def __init__(self,
22
+ model: nn.Module,
23
+ state_dict: Union[str, Path],
24
+ task: Literal["classification", "regression"],
25
+ device: str = 'cpu'):
26
+ """
27
+ Initializes the handler by loading a model's state_dict.
28
+
29
+ Args:
30
+ model (nn.Module): An instantiated PyTorch model with the correct architecture.
31
+ state_dict (str | Path): The path to the saved .pth model state_dict file.
32
+ task (str): The type of task, 'regression' or 'classification'.
33
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
34
+ """
35
+ self.model = model
36
+ self.task = task
37
+ self.device = self._validate_device(device)
38
+
39
+ model_p = make_fullpath(state_dict, enforce="file")
40
+
41
+ try:
42
+ # Load the state dictionary and apply it to the model structure
43
+ self.model.load_state_dict(torch.load(model_p, map_location=self.device))
44
+ self.model.to(self.device)
45
+ self.model.eval() # Set the model to evaluation mode
46
+ _LOGGER.info(f"✅ Model state loaded from '{model_p.name}' and set to evaluation mode.")
47
+ except Exception as e:
48
+ _LOGGER.error(f"❌ Failed to load model state from '{model_p}': {e}")
49
+ raise
50
+
51
+ def _validate_device(self, device: str) -> torch.device:
52
+ """Validates the selected device and returns a torch.device object."""
53
+ device_lower = device.lower()
54
+ if "cuda" in device_lower and not torch.cuda.is_available():
55
+ _LOGGER.warning("⚠️ CUDA not available, switching to CPU.")
56
+ device_lower = "cpu"
57
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
58
+ _LOGGER.warning("⚠️ Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
59
+ device_lower = "cpu"
60
+ return torch.device(device_lower)
61
+
62
+ def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
63
+ """Converts input to a torch.Tensor and moves it to the correct device."""
64
+ if isinstance(features, np.ndarray):
65
+ features = torch.from_numpy(features).float()
66
+
67
+ # Ensure tensor is on the correct device
68
+ return features.to(self.device)
69
+
70
+ def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
71
+ """
72
+ Predicts on a single feature vector.
73
+
74
+ Args:
75
+ features (np.ndarray | torch.Tensor): A 1D or 2D array/tensor for a single sample.
76
+
77
+ Returns:
78
+ Dict[str, Any]: A dictionary containing the prediction.
79
+ - For regression: {'predictions': float}
80
+ - For classification: {'labels': int, 'probabilities': np.ndarray}
81
+ """
82
+ if features.ndim == 1:
83
+ features = features.reshape(1, -1)
84
+
85
+ if features.shape[0] != 1:
86
+ raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
87
+
88
+ results_batch = self.predict_batch(features)
89
+
90
+ # Extract the single result from the batch
91
+ if self.task == "regression":
92
+ return {PyTorchInferenceKeys.PREDICTIONS: results_batch[PyTorchInferenceKeys.PREDICTIONS].item()}
93
+ else: # classification
94
+ return {
95
+ PyTorchInferenceKeys.LABELS: results_batch[PyTorchInferenceKeys.LABELS].item(),
96
+ PyTorchInferenceKeys.PROBABILITIES: results_batch[PyTorchInferenceKeys.PROBABILITIES][0]
97
+ }
98
+
99
+ def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
100
+ """
101
+ Predicts on a batch of feature vectors.
102
+
103
+ Args:
104
+ features (np.ndarray | torch.Tensor): A 2D array/tensor where each row is a sample.
105
+
106
+ Returns:
107
+ Dict[str, Any]: A dictionary containing the predictions.
108
+ - For regression: {'predictions': np.ndarray}
109
+ - For classification: {'labels': np.ndarray, 'probabilities': np.ndarray}
110
+ """
111
+ if features.ndim != 2:
112
+ raise ValueError("Input for batch prediction must be a 2D array or tensor.")
113
+
114
+ input_tensor = self._preprocess_input(features)
115
+
116
+ with torch.no_grad():
117
+ output = self.model(input_tensor).cpu()
118
+
119
+ if self.task == "classification":
120
+ probs = nn.functional.softmax(output, dim=1)
121
+ labels = torch.argmax(probs, dim=1)
122
+ return {
123
+ PyTorchInferenceKeys.LABELS: labels.numpy(),
124
+ PyTorchInferenceKeys.PROBABILITIES: probs.numpy()
125
+ }
126
+ else: # regression
127
+ return {PyTorchInferenceKeys.PREDICTIONS: output.numpy()}
128
+
129
+
130
+ def info():
131
+ _script_info(__all__)
ml_tools/ML_trainer.py CHANGED
@@ -7,9 +7,9 @@ import numpy as np
7
7
 
8
8
  from .ML_callbacks import Callback, History, TqdmProgressBar
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
10
- from .utilities import _script_info
10
+ from ._script_info import _script_info
11
11
  from .keys import LogKeys
12
- from .logger import _LOGGER
12
+ from ._logger import _LOGGER
13
13
 
14
14
 
15
15
  __all__ = [
@@ -105,7 +105,7 @@ class MyTrainer:
105
105
  pin_memory=(self.device.type == "cuda")
106
106
  )
107
107
 
108
- def fit(self, epochs: int = 10, batch_size: int = 32, shuffle: bool = True):
108
+ def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
109
109
  """
110
110
  Starts the training-validation process of the model.
111
111
 
@@ -113,6 +113,13 @@ class MyTrainer:
113
113
  epochs (int): The total number of epochs to train for.
114
114
  batch_size (int): The number of samples per batch.
115
115
  shuffle (bool): Whether to shuffle the training data at each epoch.
116
+
117
+ Note:
118
+ For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
119
+ automatically aligns the model's output tensor with the target tensor's
120
+ shape using `output.view_as(target)`. This handles the common case
121
+ where a model outputs a shape of `[batch_size, 1]` and the target has a
122
+ shape of `[batch_size]`.
116
123
  """
117
124
  self.epochs = epochs
118
125
  self._create_dataloaders(batch_size, shuffle)
@@ -189,9 +196,10 @@ class MyTrainer:
189
196
  logs = {LogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
190
197
  return logs
191
198
 
192
- def predict(self, dataloader: DataLoader):
199
+ def _predict_for_eval(self, dataloader: DataLoader):
193
200
  """
194
- Yields model predictions batch by batch, avoids loading all predictions into memory at once.
201
+ Private method to yield model predictions batch by batch for evaluation.
202
+ This is used internally by the `evaluate` method.
195
203
 
196
204
  Args:
197
205
  dataloader (DataLoader): The dataloader to predict on.
@@ -213,13 +221,14 @@ class MyTrainer:
213
221
  preds = torch.argmax(probs, dim=1)
214
222
  y_pred_batch = preds.numpy()
215
223
  y_prob_batch = probs.numpy()
224
+ # regression
216
225
  else:
217
226
  y_pred_batch = output.numpy()
218
227
  y_prob_batch = None
219
228
 
220
229
  yield y_pred_batch, y_prob_batch, y_true_batch
221
230
 
222
- def evaluate(self, data: Optional[Union[DataLoader, Dataset]] = None, save_dir: Optional[Union[str,Path]] = None):
231
+ def evaluate(self, save_dir: Optional[Union[str,Path]], data: Optional[Union[DataLoader, Dataset]] = None):
223
232
  """
224
233
  Evaluates the model on the given data.
225
234
 
@@ -251,7 +260,7 @@ class MyTrainer:
251
260
 
252
261
  # Collect results from the predict generator
253
262
  all_preds, all_probs, all_true = [], [], []
254
- for y_pred_b, y_prob_b, y_true_b in self.predict(eval_loader):
263
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
255
264
  all_preds.append(y_pred_b)
256
265
  if y_prob_b is not None:
257
266
  all_probs.append(y_prob_b)
@@ -270,7 +279,7 @@ class MyTrainer:
270
279
  plot_losses(self.history, save_dir=save_dir)
271
280
 
272
281
  def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 100,
273
- feature_names: Optional[List[str]] = None, save_dir: Optional[str] = None):
282
+ feature_names: Optional[List[str]] = None, save_dir: Optional[Union[str,Path]] = None):
274
283
  """
275
284
  Explains model predictions using SHAP and saves all artifacts.
276
285
 
@@ -2,28 +2,23 @@ import numpy as np
2
2
  from pathlib import Path
3
3
  import xgboost as xgb
4
4
  import lightgbm as lgb
5
- from sklearn.ensemble import HistGradientBoostingRegressor
6
- from sklearn.base import ClassifierMixin
7
5
  from typing import Literal, Union, Tuple, Dict, Optional
8
6
  import pandas as pd
9
7
  from copy import deepcopy
10
8
  from .utilities import (
11
- _script_info,
12
- list_csv_paths,
13
9
  threshold_binary_values,
14
10
  threshold_binary_values_batch,
15
- deserialize_object,
16
- list_files_by_extension,
17
- save_dataframe,
18
- make_fullpath,
19
- yield_dataframes_from_dir,
20
- sanitize_filename)
11
+ deserialize_object,
12
+ save_dataframe,
13
+ yield_dataframes_from_dir)
14
+ from .path_manager import sanitize_filename, make_fullpath, list_files_by_extension, list_csv_paths
21
15
  import torch
22
16
  from tqdm import trange
23
17
  import matplotlib.pyplot as plt
24
18
  import seaborn as sns
25
- from .logger import _LOGGER
19
+ from ._logger import _LOGGER
26
20
  from .keys import ModelSaveKeys
21
+ from ._script_info import _script_info
27
22
 
28
23
 
29
24
  __all__ = [
@@ -125,7 +120,7 @@ class ObjectiveFunction():
125
120
  return features_array * noise
126
121
 
127
122
  def check_model(self):
128
- if isinstance(self.model, ClassifierMixin) or isinstance(self.model, xgb.XGBClassifier) or isinstance(self.model, lgb.LGBMClassifier):
123
+ if isinstance(self.model, xgb.XGBClassifier) or isinstance(self.model, lgb.LGBMClassifier):
129
124
  raise ValueError(f"[Model Check Failed] ❌\nThe loaded model ({type(self.model).__name__}) is a Classifier.\nOptimization is not suitable for standard classification tasks.")
130
125
  if self.model is None:
131
126
  raise ValueError("Loaded model is None")
ml_tools/RNN_forecast.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  from torch import nn
3
3
  import numpy as np
4
+ from ._script_info import _script_info
4
5
 
5
6
  __all__ = [
6
7
  "rnn_forecast"
@@ -47,3 +48,7 @@ def rnn_forecast(model: nn.Module, start_sequence: torch.Tensor, steps: int, dev
47
48
 
48
49
  # Concatenate all predictions and flatten the array for easy use
49
50
  return np.concatenate(predictions).flatten()
51
+
52
+
53
+ def info():
54
+ _script_info
ml_tools/VIF_factor.py CHANGED
@@ -7,9 +7,10 @@ from statsmodels.stats.outliers_influence import variance_inflation_factor
7
7
  from statsmodels.tools.tools import add_constant
8
8
  import warnings
9
9
  from pathlib import Path
10
- from .utilities import sanitize_filename, yield_dataframes_from_dir, save_dataframe, _script_info, make_fullpath
11
- from .logger import _LOGGER
12
-
10
+ from .utilities import yield_dataframes_from_dir, save_dataframe
11
+ from .path_manager import sanitize_filename, make_fullpath
12
+ from ._logger import _LOGGER
13
+ from ._script_info import _script_info
13
14
 
14
15
  __all__ = [
15
16
  "compute_vif",