dragon-ml-toolbox 3.12.5__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,230 @@
1
+ Metadata-Version: 2.4
2
+ Name: dragon-ml-toolbox
3
+ Version: 4.0.0
4
+ Summary: A collection of tools for data science and machine learning projects.
5
+ Author-email: Karl Loza <luigiloza@gmail.com>
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/DrAg0n-BoRn/ML_tools
8
+ Project-URL: Changelog, https://github.com/DrAg0n-BoRn/ML_tools/blob/master/CHANGELOG.md
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.10
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ License-File: LICENSE-THIRD-PARTY.md
15
+ Provides-Extra: ml
16
+ Requires-Dist: numpy; extra == "ml"
17
+ Requires-Dist: pandas; extra == "ml"
18
+ Requires-Dist: polars; extra == "ml"
19
+ Requires-Dist: joblib; extra == "ml"
20
+ Requires-Dist: scikit-learn; extra == "ml"
21
+ Requires-Dist: matplotlib; extra == "ml"
22
+ Requires-Dist: seaborn; extra == "ml"
23
+ Requires-Dist: imbalanced-learn; extra == "ml"
24
+ Requires-Dist: ipython; extra == "ml"
25
+ Requires-Dist: ipykernel; extra == "ml"
26
+ Requires-Dist: notebook; extra == "ml"
27
+ Requires-Dist: jupyterlab; extra == "ml"
28
+ Requires-Dist: ipywidgets; extra == "ml"
29
+ Requires-Dist: xgboost; extra == "ml"
30
+ Requires-Dist: lightgbm; extra == "ml"
31
+ Requires-Dist: shap; extra == "ml"
32
+ Requires-Dist: tqdm; extra == "ml"
33
+ Requires-Dist: Pillow; extra == "ml"
34
+ Provides-Extra: mice
35
+ Requires-Dist: numpy<2.0; extra == "mice"
36
+ Requires-Dist: pandas; extra == "mice"
37
+ Requires-Dist: polars; extra == "mice"
38
+ Requires-Dist: joblib; extra == "mice"
39
+ Requires-Dist: miceforest>=6.0.0; extra == "mice"
40
+ Requires-Dist: plotnine>=0.12; extra == "mice"
41
+ Requires-Dist: matplotlib; extra == "mice"
42
+ Requires-Dist: statsmodels; extra == "mice"
43
+ Requires-Dist: lightgbm<=4.5.0; extra == "mice"
44
+ Requires-Dist: shap; extra == "mice"
45
+ Provides-Extra: pytorch
46
+ Requires-Dist: torch; extra == "pytorch"
47
+ Requires-Dist: torchvision; extra == "pytorch"
48
+ Provides-Extra: excel
49
+ Requires-Dist: pandas; extra == "excel"
50
+ Requires-Dist: openpyxl; extra == "excel"
51
+ Requires-Dist: ipython; extra == "excel"
52
+ Requires-Dist: ipykernel; extra == "excel"
53
+ Requires-Dist: notebook; extra == "excel"
54
+ Requires-Dist: jupyterlab; extra == "excel"
55
+ Requires-Dist: ipywidgets; extra == "excel"
56
+ Provides-Extra: gui-boost
57
+ Requires-Dist: numpy; extra == "gui-boost"
58
+ Requires-Dist: joblib; extra == "gui-boost"
59
+ Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-boost"
60
+ Requires-Dist: pyinstaller; extra == "gui-boost"
61
+ Requires-Dist: xgboost; extra == "gui-boost"
62
+ Requires-Dist: lightgbm; extra == "gui-boost"
63
+ Provides-Extra: gui-torch
64
+ Requires-Dist: numpy; extra == "gui-torch"
65
+ Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-torch"
66
+ Requires-Dist: pyinstaller; extra == "gui-torch"
67
+ Provides-Extra: plot
68
+ Requires-Dist: matplotlib; extra == "plot"
69
+ Requires-Dist: seaborn; extra == "plot"
70
+ Dynamic: license-file
71
+
72
+ # dragon-ml-toolbox
73
+
74
+ A collection of Python utilities for data science and machine learning, structured as a modular package for easy reuse and installation. This package has no base dependencies, allowing for lightweight and customized virtual environments.
75
+
76
+ ### Features:
77
+
78
+ - Modular scripts for data exploration, logging, machine learning, and more.
79
+ - Designed for seamless integration as a Git submodule or installable Python package.
80
+
81
+ ## Installation
82
+
83
+ **Python 3.10+**
84
+
85
+ ### Via PyPI
86
+
87
+ Install the latest stable release from PyPI:
88
+
89
+ ```bash
90
+ pip install dragon-ml-toolbox
91
+ ```
92
+
93
+ ### Via GitHub (Editable)
94
+
95
+ Clone the repository and install in editable mode with optional dependencies:
96
+
97
+ ```bash
98
+ git clone https://github.com/DrAg0n-BoRn/ML_tools.git
99
+ cd ML_tools
100
+ pip install -e .
101
+ ```
102
+
103
+ ### Via conda-forge
104
+
105
+ Install from the conda-forge channel:
106
+
107
+ ```bash
108
+ conda install -c conda-forge dragon-ml-toolbox
109
+ ```
110
+ **Note:** This version is outdated or broken due to dependency incompatibilities. Use PyPi instead.
111
+
112
+ ## Modular Installation
113
+
114
+ ### 📦 Core Machine Learning Toolbox [ML]
115
+
116
+ Installs a comprehensive set of tools for typical data science workflows, including data manipulation, modeling, and evaluation. PyTorch is required.
117
+
118
+ ```Bash
119
+ pip install "dragon-ml-toolbox[ML]"
120
+ ```
121
+
122
+ To install the standard CPU-only versions of Torch and Torchvision:
123
+
124
+ ```Bash
125
+ pip install "dragon-ml-toolbox[pytorch]"
126
+ ```
127
+
128
+ ⚠️ To make use of GPU acceleration (highly recommended), follow the official instructions: [PyTorch website](https://pytorch.org/get-started/locally/)
129
+
130
+ #### Modules:
131
+
132
+ ```bash
133
+ custom_logger
134
+ data_exploration
135
+ datasetmaster
136
+ ensemble_learning
137
+ ensemble_inference
138
+ ETL_engineering
139
+ ML_callbacks
140
+ ML_evaluation
141
+ ML_trainer
142
+ ML_inference
143
+ path_manager
144
+ PSO_optimization
145
+ RNN_forecast
146
+ utilities
147
+ ```
148
+
149
+ ### 🔬 MICE Imputation and Variance Inflation Factor [mice]
150
+
151
+ ⚠️ Important: This group has strict version requirements. It is highly recommended to install this group in a separate virtual environment.
152
+
153
+ ```Bash
154
+ pip install "dragon-ml-toolbox[mice]"
155
+ ```
156
+
157
+ #### Modules:
158
+
159
+ ```bash
160
+ custom_logger
161
+ MICE_imputation
162
+ VIF_factor
163
+ path_manager
164
+ utilities
165
+ ```
166
+
167
+ ### 📋 Excel File Handling [excel]
168
+
169
+ Installs dependencies required to process and handle .xlsx or .xls files.
170
+
171
+ ```Bash
172
+ pip install "dragon-ml-toolbox[excel]"
173
+ ```
174
+
175
+ #### Modules:
176
+
177
+ ```bash
178
+ custom_logger
179
+ handle_excel
180
+ path_manager
181
+ ```
182
+
183
+ ### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
184
+
185
+ For GUIs that include plotting functionality, you must also install the [plot] extra.
186
+
187
+ ```Bash
188
+ pip install "dragon-ml-toolbox[gui-boost]"
189
+ ```
190
+
191
+ ```Bash
192
+ pip install "dragon-ml-toolbox[gui-boost,plot]"
193
+ ```
194
+
195
+ #### Modules:
196
+
197
+ ```bash
198
+ GUI_tools
199
+ ensemble_inference
200
+ path_manager
201
+ ```
202
+
203
+ ### 🤖 GUI for PyTorch Models [gui-torch]
204
+
205
+ For GUIs that include plotting functionality, you must also install the [plot] extra.
206
+
207
+ ```Bash
208
+ pip install "dragon-ml-toolbox[gui-torch]"
209
+ ```
210
+
211
+ ```Bash
212
+ pip install "dragon-ml-toolbox[gui-torch,plot]"
213
+ ```
214
+
215
+ #### Modules:
216
+
217
+ ```bash
218
+ GUI_tools
219
+ ML_inference
220
+ path_manager
221
+ ```
222
+
223
+ ## Usage
224
+
225
+ After installation, import modules like this:
226
+
227
+ ```python
228
+ from ml_tools.utilities import serialize_object, deserialize_object
229
+ from ml_tools.custom_logger import custom_logger
230
+ ```
@@ -0,0 +1,29 @@
1
+ dragon_ml_toolbox-4.0.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-4.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
3
+ ml_tools/ETL_engineering.py,sha256=m_IY-4hSp5X5TfJbWQ-MJNRxkxl4fcsxOnsivMs8tiM,39506
4
+ ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
5
+ ml_tools/MICE_imputation.py,sha256=b6ZTs8RedXFifOpuMCzr68xM16mCBVh1Ua6kcGfiVtg,11462
6
+ ml_tools/ML_callbacks.py,sha256=0a-Rbr0Xp_B1FNopOKBBmuJ4MqazS5JgDiT7wx1dHvE,13161
7
+ ml_tools/ML_evaluation.py,sha256=4dVqe6JF1Ukmk1sAcY8E5EG1oB1_oy2HXE5OT-pZwCs,10273
8
+ ml_tools/ML_inference.py,sha256=Fh-X2UQn3AznWBjf-7iPSxwE-EzkGQm1VEIRUAkURmE,5336
9
+ ml_tools/ML_trainer.py,sha256=dJjMfCEEM07Txy9KEH-2srZ3CZUa4lFWTJhpNWQ4Ndk,14974
10
+ ml_tools/PSO_optimization.py,sha256=z8zPyoMtE-Vl5LcB24ZtNNTcEw9kVHftkV1VkV5pLD8,24662
11
+ ml_tools/RNN_forecast.py,sha256=2CyjBLSYYc3xLHxwLXUmP5Qv8AmV1OB_EndETNX1IBk,1956
12
+ ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
13
+ ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
15
+ ml_tools/_pytorch_models.py,sha256=ewPPsTHgmRPzMMWwObZOdH1vxm2Ij2VWZP38NC6zSH4,10135
16
+ ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
17
+ ml_tools/custom_logger.py,sha256=a3ywSCQT7j5ypR-usnKh2l861d_aVJ93ZRVqxrHsBBw,4112
18
+ ml_tools/data_exploration.py,sha256=rJhvxUqVbEuB_7HG-PfLH3vaA7hrZEtbVHg9QO9VS4A,22837
19
+ ml_tools/datasetmaster.py,sha256=_tNC2v98eCQGr3nMW_EFs83TRgRme8Uc7ttg1vosmQU,30106
20
+ ml_tools/ensemble_inference.py,sha256=0SNX3YAz5bpvtwYmqEwqyWeIJP2Pb-v-bemENRSO7qg,9426
21
+ ml_tools/ensemble_learning.py,sha256=Zi1oy6G2FWnTI5hBwjlexwF3JKALFS2FN6F8HAlVi_s,35391
22
+ ml_tools/handle_excel.py,sha256=J9iwIqMZemoxK49J5osSwp9Ge0h9YTKyYGbOm53hcno,13007
23
+ ml_tools/keys.py,sha256=kK9UF-hek2VcPGFILCKl5geoN6flmMOu7IzhdEA6z5Y,1068
24
+ ml_tools/path_manager.py,sha256=ElDa25bntANujTjY7xN4ZfCDiZp-9Ud3x0aJSJptZBY,13419
25
+ ml_tools/utilities.py,sha256=mz-M351DzxWxnYVcLX-7ZQ6c-RGoCV9g4VTS9Qif2Es,18348
26
+ dragon_ml_toolbox-4.0.0.dist-info/METADATA,sha256=tGUq_S7xEHszYM0vUifYUzV4dmV_TQ9W8Ja4sJOEPTs,5994
27
+ dragon_ml_toolbox-4.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ dragon_ml_toolbox-4.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
29
+ dragon_ml_toolbox-4.0.0.dist-info/RECORD,,
@@ -1,8 +1,8 @@
1
1
  import polars as pl
2
2
  import re
3
3
  from typing import Literal, Union, Optional, Any, Callable, List, Dict, Tuple
4
- from .utilities import _script_info
5
- from .logger import _LOGGER
4
+ from ._script_info import _script_info
5
+ from ._logger import _LOGGER
6
6
 
7
7
 
8
8
  __all__ = [
ml_tools/GUI_tools.py CHANGED
@@ -4,9 +4,9 @@ import traceback
4
4
  import FreeSimpleGUI as sg
5
5
  from functools import wraps
6
6
  from typing import Any, Dict, Tuple, List, Literal, Union, Optional, Callable
7
- from .utilities import _script_info
7
+ from ._script_info import _script_info
8
8
  import numpy as np
9
- from .logger import _LOGGER
9
+ from ._logger import _LOGGER
10
10
  from .keys import _OneHotOtherPlaceholder
11
11
 
12
12
 
@@ -191,13 +191,14 @@ class GUIFactory:
191
191
  }
192
192
  return sg.Button(text.title(), key=key, **style_args)
193
193
 
194
- def make_frame(self, title: str, layout: List[List[Union[sg.Element, sg.Column]]], **kwargs) -> sg.Frame:
194
+ def make_frame(self, title: str, layout: List[List[Union[sg.Element, sg.Column]]], center_layout: bool = False, **kwargs) -> sg.Frame:
195
195
  """
196
196
  Creates a styled frame around a given layout.
197
197
 
198
198
  Args:
199
199
  title (str): The title displayed on the frame's border.
200
200
  layout (list): The layout to enclose within the frame.
201
+ center_layout (bool): If True, the content within the frame will be horizontally centered.
201
202
  **kwargs: Override default styles or add other sg.Frame parameters
202
203
  (e.g., `title_color='red'`, `relief=sg.RELIEF_SUNKEN`).
203
204
  """
@@ -210,6 +211,10 @@ class GUIFactory:
210
211
  "background_color": sg.theme_background_color(),
211
212
  **kwargs
212
213
  }
214
+
215
+ if center_layout:
216
+ style_args["element_justification"] = 'center'
217
+
213
218
  return sg.Frame(title, layout, **style_args)
214
219
 
215
220
  # --- General-Purpose Layout Generators ---
@@ -218,7 +223,8 @@ class GUIFactory:
218
223
  data_dict: Dict[str, Union[Tuple[Union[int,float,None], Union[int,float,None]],List[Union[int,float,None]]]],
219
224
  is_target: bool = False,
220
225
  layout_mode: Literal["grid", "row"] = 'grid',
221
- number_columns: int = 5
226
+ number_columns: int = 5,
227
+ center_layout: bool = True
222
228
  ) -> List[List[sg.Column]]:
223
229
  """
224
230
  Generates a layout for continuous features or targets.
@@ -228,6 +234,7 @@ class GUIFactory:
228
234
  is_target (bool): If True, creates disabled inputs for displaying results.
229
235
  layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
230
236
  number_columns (int): Number of columns when `layout_mode` is 'grid'.
237
+ center_layout (bool): If True, the entire grid will be horizontally centered.
231
238
 
232
239
  Returns:
233
240
  A list of lists of sg.Column elements, ready to be used in a window layout.
@@ -264,7 +271,7 @@ class GUIFactory:
264
271
  layout = [[label], [element]]
265
272
  else:
266
273
  range_font = (cfg.fonts.font_family, cfg.fonts.range_size) # type: ignore
267
- range_text = sg.Text(f"Range: {val_min}-{val_max}", font=range_font, background_color=bg_color) # type: ignore
274
+ range_text = sg.Text(f"Range: {val_min} - {val_max}", font=range_font, background_color=bg_color) # type: ignore
268
275
  layout = [[label], [element], [range_text]]
269
276
 
270
277
  # each feature is wrapped as a column element
@@ -275,13 +282,14 @@ class GUIFactory:
275
282
  return [all_feature_layouts] # A single row containing all features
276
283
 
277
284
  # Default to 'grid' layout: delegate to the helper method
278
- return self._build_grid_layout(all_feature_layouts, number_columns, bg_color) # type: ignore
285
+ return self._build_grid_layout(all_feature_layouts, number_columns, bg_color, center_layout) # type: ignore
279
286
 
280
287
  def generate_combo_layout(
281
288
  self,
282
289
  data_dict: Dict[str, Union[List[Any],Tuple[Any,...]]],
283
290
  layout_mode: Literal["grid", "row"] = 'grid',
284
- number_columns: int = 5
291
+ number_columns: int = 5,
292
+ center_layout: bool = True
285
293
  ) -> List[List[sg.Column]]:
286
294
  """
287
295
  Generates a layout for categorical or binary features using Combo boxes.
@@ -290,6 +298,7 @@ class GUIFactory:
290
298
  data_dict (dict): Keys are feature names, values are lists of options.
291
299
  layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
292
300
  number_columns (int): Number of columns when `layout_mode` is 'grid'.
301
+ center_layout (bool): If True, the entire grid will be horizontally centered.
293
302
 
294
303
  Returns:
295
304
  A list of lists of sg.Column elements, ready to be used in a window layout.
@@ -315,13 +324,14 @@ class GUIFactory:
315
324
  return [all_feature_layouts] # A single row containing all features
316
325
 
317
326
  # Default to 'grid' layout: delegate to the helper method
318
- return self._build_grid_layout(all_feature_layouts, number_columns, bg_color) # type: ignore
327
+ return self._build_grid_layout(all_feature_layouts, number_columns, bg_color, center_layout) # type: ignore
319
328
 
320
329
  def generate_multiselect_layout(
321
330
  self,
322
331
  data_dict: Dict[str, Union[List[Any], Tuple[Any, ...]]],
323
332
  layout_mode: Literal["grid", "row"] = 'grid',
324
- number_columns: int = 5
333
+ number_columns: int = 5,
334
+ center_layout: bool = True
325
335
  ) -> List[List[sg.Column]]:
326
336
  """
327
337
  Generates a layout for features using Listbox elements for multiple selections.
@@ -333,6 +343,7 @@ class GUIFactory:
333
343
  data_dict (dict): Keys are feature names, values are lists of options.
334
344
  layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
335
345
  number_columns (int): Number of columns when `layout_mode` is 'grid'.
346
+ center_layout (bool): If True, the entire grid will be horizontally centered.
336
347
 
337
348
  Returns:
338
349
  A list of lists of sg.Column elements, ready to be used in a window layout.
@@ -366,7 +377,7 @@ class GUIFactory:
366
377
  return [all_feature_layouts] # A single row containing all features
367
378
 
368
379
  # Default to 'grid' layout: delegate to the helper method
369
- return self._build_grid_layout(all_feature_layouts, number_columns, bg_color) # type: ignore
380
+ return self._build_grid_layout(all_feature_layouts, number_columns, bg_color, center_layout) # type: ignore
370
381
 
371
382
  # --- Window Creation ---
372
383
  def create_window(self, title: str, layout: List[List[sg.Element]], **kwargs) -> sg.Window:
@@ -396,7 +407,7 @@ class GUIFactory:
396
407
 
397
408
  return window
398
409
 
399
- def _build_grid_layout(self, all_feature_layouts: List[sg.Column], num_columns: int, bg_color: str) -> List[List[sg.Column]]:
410
+ def _build_grid_layout(self, all_feature_layouts: List[sg.Column], num_columns: int, bg_color: str, center_layout: bool = True) -> List[List[sg.Column]]:
400
411
  """
401
412
  Private helper to distribute feature layouts vertically into a grid of columns.
402
413
  """
@@ -412,7 +423,12 @@ class GUIFactory:
412
423
  gui_columns = [sg.Column([[c] for c in col], background_color=bg_color) for col in final_columns]
413
424
 
414
425
  # Return a single row containing all the generated vertical columns
415
- return [gui_columns]
426
+ if center_layout:
427
+ # Return a single row containing the columns, centered with Push elements.
428
+ return [[sg.Push()] + gui_columns + [sg.Push()]] # type: ignore
429
+ else:
430
+ # Return a single row containing just the columns.
431
+ return [gui_columns]
416
432
 
417
433
 
418
434
  # --- Exception Handling Decorator ---
@@ -3,11 +3,12 @@ import miceforest as mf
3
3
  from pathlib import Path
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
- from .utilities import load_dataframe, list_csv_paths, sanitize_filename, _script_info, merge_dataframes, save_dataframe, threshold_binary_values, make_fullpath
6
+ from .utilities import load_dataframe, merge_dataframes, save_dataframe, threshold_binary_values
7
+ from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
7
8
  from plotnine import ggplot, labs, theme, element_blank # type: ignore
8
9
  from typing import Optional, Union
9
- from .logger import _LOGGER
10
-
10
+ from ._logger import _LOGGER
11
+ from ._script_info import _script_info
11
12
 
12
13
  __all__ = [
13
14
  "apply_mice",
ml_tools/ML_callbacks.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import numpy as np
2
2
  import torch
3
3
  from tqdm.auto import tqdm
4
- from .utilities import make_fullpath
4
+ from .path_manager import make_fullpath
5
5
  from .keys import LogKeys
6
- from .logger import _LOGGER
6
+ from ._logger import _LOGGER
7
7
  from typing import Optional
8
+ from ._script_info import _script_info
8
9
 
9
10
 
10
11
  __all__ = [
@@ -270,7 +271,7 @@ class ModelCheckpoint(Callback):
270
271
  self.last_best_filepath = new_filepath
271
272
 
272
273
  def _save_rolling_checkpoints(self, epoch, logs):
273
- """Saves the latest model and keeps only the last 5."""
274
+ """Saves the latest model and keeps only the most recent ones."""
274
275
  filename = f"epoch_{epoch}.pth"
275
276
  filepath = self.save_dir / filename
276
277
 
@@ -334,4 +335,7 @@ class LRScheduler(Callback):
334
335
  if current_lr != self.previous_lr:
335
336
  _LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
336
337
  self.previous_lr = current_lr
337
-
338
+
339
+
340
+ def info():
341
+ _script_info(__all__)
ml_tools/ML_evaluation.py CHANGED
@@ -14,9 +14,10 @@ from sklearn.metrics import (
14
14
  import torch
15
15
  import shap
16
16
  from pathlib import Path
17
- from .utilities import make_fullpath
18
- from .logger import _LOGGER
17
+ from .path_manager import make_fullpath
18
+ from ._logger import _LOGGER
19
19
  from typing import Union, Optional
20
+ from ._script_info import _script_info
20
21
 
21
22
 
22
23
  __all__ = [
@@ -62,7 +63,7 @@ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
62
63
  plt.tight_layout()
63
64
 
64
65
  if save_dir:
65
- save_dir_path = make_fullpath(save_dir, make=True)
66
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
66
67
  save_path = save_dir_path / "loss_plot.svg"
67
68
  plt.savefig(save_path)
68
69
  _LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
@@ -88,7 +89,7 @@ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optio
88
89
  print(report)
89
90
 
90
91
  if save_dir:
91
- save_dir_path = make_fullpath(save_dir, make=True)
92
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
92
93
  # Save text report
93
94
  report_path = save_dir_path / "classification_report.txt"
94
95
  report_path.write_text(report, encoding="utf-8")
@@ -158,7 +159,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
158
159
  print(report_string)
159
160
 
160
161
  if save_dir:
161
- save_dir_path = make_fullpath(save_dir, make=True)
162
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
162
163
  # Save text report
163
164
  report_path = save_dir_path / "regression_report.txt"
164
165
  report_path.write_text(report_string)
@@ -220,7 +221,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
220
221
  _LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
221
222
 
222
223
  if save_dir:
223
- save_dir_path = make_fullpath(save_dir, make=True)
224
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
224
225
  # Save Bar Plot
225
226
  bar_path = save_dir_path / "shap_bar_plot.svg"
226
227
  shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
@@ -253,3 +254,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
253
254
  else:
254
255
  _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
255
256
  shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
257
+
258
+
259
+ def info():
260
+ _script_info(__all__)
@@ -0,0 +1,131 @@
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from typing import Union, Literal, Dict, Any
6
+
7
+ from ._script_info import _script_info
8
+ from ._logger import _LOGGER
9
+ from .path_manager import make_fullpath
10
+ from .keys import PyTorchInferenceKeys
11
+
12
+ __all__ = [
13
+ "PyTorchInferenceHandler"
14
+ ]
15
+
16
+ class PyTorchInferenceHandler:
17
+ """
18
+ Handles loading a PyTorch model's state dictionary and performing inference
19
+ for either regression or classification tasks.
20
+ """
21
+ def __init__(self,
22
+ model: nn.Module,
23
+ state_dict: Union[str, Path],
24
+ task: Literal["classification", "regression"],
25
+ device: str = 'cpu'):
26
+ """
27
+ Initializes the handler by loading a model's state_dict.
28
+
29
+ Args:
30
+ model (nn.Module): An instantiated PyTorch model with the correct architecture.
31
+ state_dict (str | Path): The path to the saved .pth model state_dict file.
32
+ task (str): The type of task, 'regression' or 'classification'.
33
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
34
+ """
35
+ self.model = model
36
+ self.task = task
37
+ self.device = self._validate_device(device)
38
+
39
+ model_p = make_fullpath(state_dict, enforce="file")
40
+
41
+ try:
42
+ # Load the state dictionary and apply it to the model structure
43
+ self.model.load_state_dict(torch.load(model_p, map_location=self.device))
44
+ self.model.to(self.device)
45
+ self.model.eval() # Set the model to evaluation mode
46
+ _LOGGER.info(f"✅ Model state loaded from '{model_p.name}' and set to evaluation mode.")
47
+ except Exception as e:
48
+ _LOGGER.error(f"❌ Failed to load model state from '{model_p}': {e}")
49
+ raise
50
+
51
+ def _validate_device(self, device: str) -> torch.device:
52
+ """Validates the selected device and returns a torch.device object."""
53
+ device_lower = device.lower()
54
+ if "cuda" in device_lower and not torch.cuda.is_available():
55
+ _LOGGER.warning("⚠️ CUDA not available, switching to CPU.")
56
+ device_lower = "cpu"
57
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
58
+ _LOGGER.warning("⚠️ Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
59
+ device_lower = "cpu"
60
+ return torch.device(device_lower)
61
+
62
+ def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
63
+ """Converts input to a torch.Tensor and moves it to the correct device."""
64
+ if isinstance(features, np.ndarray):
65
+ features = torch.from_numpy(features).float()
66
+
67
+ # Ensure tensor is on the correct device
68
+ return features.to(self.device)
69
+
70
+ def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
71
+ """
72
+ Predicts on a single feature vector.
73
+
74
+ Args:
75
+ features (np.ndarray | torch.Tensor): A 1D or 2D array/tensor for a single sample.
76
+
77
+ Returns:
78
+ Dict[str, Any]: A dictionary containing the prediction.
79
+ - For regression: {'predictions': float}
80
+ - For classification: {'labels': int, 'probabilities': np.ndarray}
81
+ """
82
+ if features.ndim == 1:
83
+ features = features.reshape(1, -1)
84
+
85
+ if features.shape[0] != 1:
86
+ raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
87
+
88
+ results_batch = self.predict_batch(features)
89
+
90
+ # Extract the single result from the batch
91
+ if self.task == "regression":
92
+ return {PyTorchInferenceKeys.PREDICTIONS: results_batch[PyTorchInferenceKeys.PREDICTIONS].item()}
93
+ else: # classification
94
+ return {
95
+ PyTorchInferenceKeys.LABELS: results_batch[PyTorchInferenceKeys.LABELS].item(),
96
+ PyTorchInferenceKeys.PROBABILITIES: results_batch[PyTorchInferenceKeys.PROBABILITIES][0]
97
+ }
98
+
99
+ def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
100
+ """
101
+ Predicts on a batch of feature vectors.
102
+
103
+ Args:
104
+ features (np.ndarray | torch.Tensor): A 2D array/tensor where each row is a sample.
105
+
106
+ Returns:
107
+ Dict[str, Any]: A dictionary containing the predictions.
108
+ - For regression: {'predictions': np.ndarray}
109
+ - For classification: {'labels': np.ndarray, 'probabilities': np.ndarray}
110
+ """
111
+ if features.ndim != 2:
112
+ raise ValueError("Input for batch prediction must be a 2D array or tensor.")
113
+
114
+ input_tensor = self._preprocess_input(features)
115
+
116
+ with torch.no_grad():
117
+ output = self.model(input_tensor).cpu()
118
+
119
+ if self.task == "classification":
120
+ probs = nn.functional.softmax(output, dim=1)
121
+ labels = torch.argmax(probs, dim=1)
122
+ return {
123
+ PyTorchInferenceKeys.LABELS: labels.numpy(),
124
+ PyTorchInferenceKeys.PROBABILITIES: probs.numpy()
125
+ }
126
+ else: # regression
127
+ return {PyTorchInferenceKeys.PREDICTIONS: output.numpy()}
128
+
129
+
130
+ def info():
131
+ _script_info(__all__)