dragon-ml-toolbox 3.12.6__py3-none-any.whl → 4.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- dragon_ml_toolbox-4.1.0.dist-info/METADATA +253 -0
- dragon_ml_toolbox-4.1.0.dist-info/RECORD +30 -0
- ml_tools/ETL_engineering.py +2 -2
- ml_tools/GUI_tools.py +2 -2
- ml_tools/MICE_imputation.py +4 -3
- ml_tools/ML_callbacks.py +8 -4
- ml_tools/ML_evaluation.py +11 -6
- ml_tools/ML_inference.py +131 -0
- ml_tools/ML_trainer.py +17 -8
- ml_tools/PSO_optimization.py +116 -62
- ml_tools/RNN_forecast.py +5 -0
- ml_tools/SQL.py +272 -0
- ml_tools/VIF_factor.py +4 -3
- ml_tools/_logger.py +36 -0
- ml_tools/_pytorch_models.py +1 -1
- ml_tools/_script_info.py +8 -0
- ml_tools/{logger.py → custom_logger.py} +4 -66
- ml_tools/data_exploration.py +2 -66
- ml_tools/datasetmaster.py +3 -2
- ml_tools/ensemble_inference.py +249 -0
- ml_tools/ensemble_learning.py +40 -294
- ml_tools/handle_excel.py +3 -2
- ml_tools/keys.py +13 -2
- ml_tools/path_manager.py +194 -31
- ml_tools/utilities.py +2 -180
- dragon_ml_toolbox-3.12.6.dist-info/METADATA +0 -137
- dragon_ml_toolbox-3.12.6.dist-info/RECORD +0 -26
- ml_tools/ML_tutorial.py +0 -300
- {dragon_ml_toolbox-3.12.6.dist-info → dragon_ml_toolbox-4.1.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-3.12.6.dist-info → dragon_ml_toolbox-4.1.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-3.12.6.dist-info → dragon_ml_toolbox-4.1.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-3.12.6.dist-info → dragon_ml_toolbox-4.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dragon-ml-toolbox
|
|
3
|
+
Version: 4.1.0
|
|
4
|
+
Summary: A collection of tools for data science and machine learning projects.
|
|
5
|
+
Author-email: Karl Loza <luigiloza@gmail.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/DrAg0n-BoRn/ML_tools
|
|
8
|
+
Project-URL: Changelog, https://github.com/DrAg0n-BoRn/ML_tools/blob/master/CHANGELOG.md
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Python: >=3.10
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
License-File: LICENSE-THIRD-PARTY.md
|
|
15
|
+
Provides-Extra: base
|
|
16
|
+
Requires-Dist: pandas; extra == "base"
|
|
17
|
+
Requires-Dist: numpy; extra == "base"
|
|
18
|
+
Requires-Dist: polars; extra == "base"
|
|
19
|
+
Requires-Dist: joblib; extra == "base"
|
|
20
|
+
Provides-Extra: ml
|
|
21
|
+
Requires-Dist: numpy; extra == "ml"
|
|
22
|
+
Requires-Dist: pandas; extra == "ml"
|
|
23
|
+
Requires-Dist: polars; extra == "ml"
|
|
24
|
+
Requires-Dist: joblib; extra == "ml"
|
|
25
|
+
Requires-Dist: scikit-learn; extra == "ml"
|
|
26
|
+
Requires-Dist: matplotlib; extra == "ml"
|
|
27
|
+
Requires-Dist: seaborn; extra == "ml"
|
|
28
|
+
Requires-Dist: imbalanced-learn; extra == "ml"
|
|
29
|
+
Requires-Dist: ipython; extra == "ml"
|
|
30
|
+
Requires-Dist: ipykernel; extra == "ml"
|
|
31
|
+
Requires-Dist: notebook; extra == "ml"
|
|
32
|
+
Requires-Dist: jupyterlab; extra == "ml"
|
|
33
|
+
Requires-Dist: ipywidgets; extra == "ml"
|
|
34
|
+
Requires-Dist: xgboost; extra == "ml"
|
|
35
|
+
Requires-Dist: lightgbm; extra == "ml"
|
|
36
|
+
Requires-Dist: shap; extra == "ml"
|
|
37
|
+
Requires-Dist: tqdm; extra == "ml"
|
|
38
|
+
Requires-Dist: Pillow; extra == "ml"
|
|
39
|
+
Provides-Extra: mice
|
|
40
|
+
Requires-Dist: numpy<2.0; extra == "mice"
|
|
41
|
+
Requires-Dist: pandas; extra == "mice"
|
|
42
|
+
Requires-Dist: polars; extra == "mice"
|
|
43
|
+
Requires-Dist: joblib; extra == "mice"
|
|
44
|
+
Requires-Dist: miceforest>=6.0.0; extra == "mice"
|
|
45
|
+
Requires-Dist: plotnine>=0.12; extra == "mice"
|
|
46
|
+
Requires-Dist: matplotlib; extra == "mice"
|
|
47
|
+
Requires-Dist: statsmodels; extra == "mice"
|
|
48
|
+
Requires-Dist: lightgbm<=4.5.0; extra == "mice"
|
|
49
|
+
Requires-Dist: shap; extra == "mice"
|
|
50
|
+
Provides-Extra: pytorch
|
|
51
|
+
Requires-Dist: torch; extra == "pytorch"
|
|
52
|
+
Requires-Dist: torchvision; extra == "pytorch"
|
|
53
|
+
Provides-Extra: excel
|
|
54
|
+
Requires-Dist: pandas; extra == "excel"
|
|
55
|
+
Requires-Dist: openpyxl; extra == "excel"
|
|
56
|
+
Requires-Dist: ipython; extra == "excel"
|
|
57
|
+
Requires-Dist: ipykernel; extra == "excel"
|
|
58
|
+
Requires-Dist: notebook; extra == "excel"
|
|
59
|
+
Requires-Dist: jupyterlab; extra == "excel"
|
|
60
|
+
Requires-Dist: ipywidgets; extra == "excel"
|
|
61
|
+
Provides-Extra: gui-boost
|
|
62
|
+
Requires-Dist: numpy; extra == "gui-boost"
|
|
63
|
+
Requires-Dist: joblib; extra == "gui-boost"
|
|
64
|
+
Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-boost"
|
|
65
|
+
Requires-Dist: pyinstaller; extra == "gui-boost"
|
|
66
|
+
Requires-Dist: xgboost; extra == "gui-boost"
|
|
67
|
+
Requires-Dist: lightgbm; extra == "gui-boost"
|
|
68
|
+
Provides-Extra: gui-torch
|
|
69
|
+
Requires-Dist: numpy; extra == "gui-torch"
|
|
70
|
+
Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui-torch"
|
|
71
|
+
Requires-Dist: pyinstaller; extra == "gui-torch"
|
|
72
|
+
Provides-Extra: plot
|
|
73
|
+
Requires-Dist: matplotlib; extra == "plot"
|
|
74
|
+
Requires-Dist: seaborn; extra == "plot"
|
|
75
|
+
Dynamic: license-file
|
|
76
|
+
|
|
77
|
+
# dragon-ml-toolbox
|
|
78
|
+
|
|
79
|
+
A collection of Python utilities for data science and machine learning, structured as a modular package for easy reuse and installation. This package has no base dependencies, allowing for lightweight and customized virtual environments.
|
|
80
|
+
|
|
81
|
+
### Features:
|
|
82
|
+
|
|
83
|
+
- Modular scripts for data exploration, logging, machine learning, and more.
|
|
84
|
+
- Designed for seamless integration as a Git submodule or installable Python package.
|
|
85
|
+
|
|
86
|
+
## Installation
|
|
87
|
+
|
|
88
|
+
**Python 3.10+**
|
|
89
|
+
|
|
90
|
+
### Via PyPI
|
|
91
|
+
|
|
92
|
+
Install the latest stable release from PyPI:
|
|
93
|
+
|
|
94
|
+
```bash
|
|
95
|
+
pip install dragon-ml-toolbox
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
### Via GitHub (Editable)
|
|
99
|
+
|
|
100
|
+
Clone the repository and install in editable mode with optional dependencies:
|
|
101
|
+
|
|
102
|
+
```bash
|
|
103
|
+
git clone https://github.com/DrAg0n-BoRn/ML_tools.git
|
|
104
|
+
cd ML_tools
|
|
105
|
+
pip install -e .
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### Via conda-forge
|
|
109
|
+
|
|
110
|
+
Install from the conda-forge channel:
|
|
111
|
+
|
|
112
|
+
```bash
|
|
113
|
+
conda install -c conda-forge dragon-ml-toolbox
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
## Modular Installation
|
|
117
|
+
|
|
118
|
+
### 📦 Core Machine Learning Toolbox [ML]
|
|
119
|
+
|
|
120
|
+
Installs a comprehensive set of tools for typical data science workflows, including data manipulation, modeling, and evaluation. PyTorch is required.
|
|
121
|
+
|
|
122
|
+
```Bash
|
|
123
|
+
pip install "dragon-ml-toolbox[ML]"
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
To install the standard CPU-only versions of Torch and Torchvision:
|
|
127
|
+
|
|
128
|
+
```Bash
|
|
129
|
+
pip install "dragon-ml-toolbox[pytorch]"
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
⚠️ To make use of GPU acceleration (highly recommended), follow the official instructions: [PyTorch website](https://pytorch.org/get-started/locally/)
|
|
133
|
+
|
|
134
|
+
#### Modules:
|
|
135
|
+
|
|
136
|
+
```bash
|
|
137
|
+
custom_logger
|
|
138
|
+
data_exploration
|
|
139
|
+
datasetmaster
|
|
140
|
+
ensemble_learning
|
|
141
|
+
ensemble_inference
|
|
142
|
+
ETL_engineering
|
|
143
|
+
ML_callbacks
|
|
144
|
+
ML_evaluation
|
|
145
|
+
ML_trainer
|
|
146
|
+
ML_inference
|
|
147
|
+
path_manager
|
|
148
|
+
PSO_optimization
|
|
149
|
+
SQL
|
|
150
|
+
RNN_forecast
|
|
151
|
+
utilities
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
### 🔬 MICE Imputation and Variance Inflation Factor [mice]
|
|
155
|
+
|
|
156
|
+
⚠️ Important: This group has strict version requirements. It is highly recommended to install this group in a separate virtual environment.
|
|
157
|
+
|
|
158
|
+
```Bash
|
|
159
|
+
pip install "dragon-ml-toolbox[mice]"
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
#### Modules:
|
|
163
|
+
|
|
164
|
+
```bash
|
|
165
|
+
custom_logger
|
|
166
|
+
MICE_imputation
|
|
167
|
+
VIF_factor
|
|
168
|
+
path_manager
|
|
169
|
+
utilities
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
### 📋 Excel File Handling [excel]
|
|
173
|
+
|
|
174
|
+
Installs dependencies required to process and handle .xlsx or .xls files.
|
|
175
|
+
|
|
176
|
+
```Bash
|
|
177
|
+
pip install "dragon-ml-toolbox[excel]"
|
|
178
|
+
```
|
|
179
|
+
|
|
180
|
+
#### Modules:
|
|
181
|
+
|
|
182
|
+
```bash
|
|
183
|
+
custom_logger
|
|
184
|
+
handle_excel
|
|
185
|
+
path_manager
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
|
|
189
|
+
|
|
190
|
+
For GUIs that include plotting functionality, you must also install the [plot] extra.
|
|
191
|
+
|
|
192
|
+
```Bash
|
|
193
|
+
pip install "dragon-ml-toolbox[gui-boost]"
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
```Bash
|
|
197
|
+
pip install "dragon-ml-toolbox[gui-boost,plot]"
|
|
198
|
+
```
|
|
199
|
+
|
|
200
|
+
#### Modules:
|
|
201
|
+
|
|
202
|
+
```bash
|
|
203
|
+
GUI_tools
|
|
204
|
+
ensemble_inference
|
|
205
|
+
path_manager
|
|
206
|
+
```
|
|
207
|
+
|
|
208
|
+
### 🤖 GUI for PyTorch Models [gui-torch]
|
|
209
|
+
|
|
210
|
+
For GUIs that include plotting functionality, you must also install the [plot] extra.
|
|
211
|
+
|
|
212
|
+
```Bash
|
|
213
|
+
pip install "dragon-ml-toolbox[gui-torch]"
|
|
214
|
+
```
|
|
215
|
+
|
|
216
|
+
```Bash
|
|
217
|
+
pip install "dragon-ml-toolbox[gui-torch,plot]"
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
#### Modules:
|
|
221
|
+
|
|
222
|
+
```bash
|
|
223
|
+
GUI_tools
|
|
224
|
+
ML_inference
|
|
225
|
+
path_manager
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
### 🎫 Base Tools [base]
|
|
229
|
+
|
|
230
|
+
General purpose functions and classes.
|
|
231
|
+
|
|
232
|
+
```Bash
|
|
233
|
+
pip install "dragon-ml-toolbox[base]"
|
|
234
|
+
```
|
|
235
|
+
|
|
236
|
+
#### Modules:
|
|
237
|
+
|
|
238
|
+
```bash
|
|
239
|
+
ETL_Engineering
|
|
240
|
+
custom_logger
|
|
241
|
+
SQL
|
|
242
|
+
utilities
|
|
243
|
+
path_manager
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
## Usage
|
|
247
|
+
|
|
248
|
+
After installation, import modules like this:
|
|
249
|
+
|
|
250
|
+
```python
|
|
251
|
+
from ml_tools.utilities import serialize_object, deserialize_object
|
|
252
|
+
from ml_tools.custom_logger import custom_logger
|
|
253
|
+
```
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
dragon_ml_toolbox-4.1.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-4.1.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
|
|
3
|
+
ml_tools/ETL_engineering.py,sha256=m_IY-4hSp5X5TfJbWQ-MJNRxkxl4fcsxOnsivMs8tiM,39506
|
|
4
|
+
ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
|
|
5
|
+
ml_tools/MICE_imputation.py,sha256=b6ZTs8RedXFifOpuMCzr68xM16mCBVh1Ua6kcGfiVtg,11462
|
|
6
|
+
ml_tools/ML_callbacks.py,sha256=0a-Rbr0Xp_B1FNopOKBBmuJ4MqazS5JgDiT7wx1dHvE,13161
|
|
7
|
+
ml_tools/ML_evaluation.py,sha256=4dVqe6JF1Ukmk1sAcY8E5EG1oB1_oy2HXE5OT-pZwCs,10273
|
|
8
|
+
ml_tools/ML_inference.py,sha256=Fh-X2UQn3AznWBjf-7iPSxwE-EzkGQm1VEIRUAkURmE,5336
|
|
9
|
+
ml_tools/ML_trainer.py,sha256=dJjMfCEEM07Txy9KEH-2srZ3CZUa4lFWTJhpNWQ4Ndk,14974
|
|
10
|
+
ml_tools/PSO_optimization.py,sha256=xtnPute5pkS_w-VvqOBgRLgke09mjfacGC2m9DiipHE,27626
|
|
11
|
+
ml_tools/RNN_forecast.py,sha256=2CyjBLSYYc3xLHxwLXUmP5Qv8AmV1OB_EndETNX1IBk,1956
|
|
12
|
+
ml_tools/SQL.py,sha256=9zzS6AFEJM9aj6nE31hDe8S9TqLonk-J1amwZoiHNbk,10468
|
|
13
|
+
ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
|
|
14
|
+
ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
+
ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
|
|
16
|
+
ml_tools/_pytorch_models.py,sha256=ewPPsTHgmRPzMMWwObZOdH1vxm2Ij2VWZP38NC6zSH4,10135
|
|
17
|
+
ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
18
|
+
ml_tools/custom_logger.py,sha256=a3ywSCQT7j5ypR-usnKh2l861d_aVJ93ZRVqxrHsBBw,4112
|
|
19
|
+
ml_tools/data_exploration.py,sha256=rJhvxUqVbEuB_7HG-PfLH3vaA7hrZEtbVHg9QO9VS4A,22837
|
|
20
|
+
ml_tools/datasetmaster.py,sha256=_tNC2v98eCQGr3nMW_EFs83TRgRme8Uc7ttg1vosmQU,30106
|
|
21
|
+
ml_tools/ensemble_inference.py,sha256=0SNX3YAz5bpvtwYmqEwqyWeIJP2Pb-v-bemENRSO7qg,9426
|
|
22
|
+
ml_tools/ensemble_learning.py,sha256=Zi1oy6G2FWnTI5hBwjlexwF3JKALFS2FN6F8HAlVi_s,35391
|
|
23
|
+
ml_tools/handle_excel.py,sha256=J9iwIqMZemoxK49J5osSwp9Ge0h9YTKyYGbOm53hcno,13007
|
|
24
|
+
ml_tools/keys.py,sha256=kK9UF-hek2VcPGFILCKl5geoN6flmMOu7IzhdEA6z5Y,1068
|
|
25
|
+
ml_tools/path_manager.py,sha256=ElDa25bntANujTjY7xN4ZfCDiZp-9Ud3x0aJSJptZBY,13419
|
|
26
|
+
ml_tools/utilities.py,sha256=mz-M351DzxWxnYVcLX-7ZQ6c-RGoCV9g4VTS9Qif2Es,18348
|
|
27
|
+
dragon_ml_toolbox-4.1.0.dist-info/METADATA,sha256=eJQwYS8B7RMy4H8DveKsDVmj4ikBSJb_hkuTSzmObz4,6278
|
|
28
|
+
dragon_ml_toolbox-4.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
29
|
+
dragon_ml_toolbox-4.1.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
30
|
+
dragon_ml_toolbox-4.1.0.dist-info/RECORD,,
|
ml_tools/ETL_engineering.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import polars as pl
|
|
2
2
|
import re
|
|
3
3
|
from typing import Literal, Union, Optional, Any, Callable, List, Dict, Tuple
|
|
4
|
-
from .
|
|
5
|
-
from .
|
|
4
|
+
from ._script_info import _script_info
|
|
5
|
+
from ._logger import _LOGGER
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
ml_tools/GUI_tools.py
CHANGED
|
@@ -4,9 +4,9 @@ import traceback
|
|
|
4
4
|
import FreeSimpleGUI as sg
|
|
5
5
|
from functools import wraps
|
|
6
6
|
from typing import Any, Dict, Tuple, List, Literal, Union, Optional, Callable
|
|
7
|
-
from .
|
|
7
|
+
from ._script_info import _script_info
|
|
8
8
|
import numpy as np
|
|
9
|
-
from .
|
|
9
|
+
from ._logger import _LOGGER
|
|
10
10
|
from .keys import _OneHotOtherPlaceholder
|
|
11
11
|
|
|
12
12
|
|
ml_tools/MICE_imputation.py
CHANGED
|
@@ -3,11 +3,12 @@ import miceforest as mf
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
|
-
from .utilities import load_dataframe,
|
|
6
|
+
from .utilities import load_dataframe, merge_dataframes, save_dataframe, threshold_binary_values
|
|
7
|
+
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
7
8
|
from plotnine import ggplot, labs, theme, element_blank # type: ignore
|
|
8
9
|
from typing import Optional, Union
|
|
9
|
-
from .
|
|
10
|
-
|
|
10
|
+
from ._logger import _LOGGER
|
|
11
|
+
from ._script_info import _script_info
|
|
11
12
|
|
|
12
13
|
__all__ = [
|
|
13
14
|
"apply_mice",
|
ml_tools/ML_callbacks.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import torch
|
|
3
3
|
from tqdm.auto import tqdm
|
|
4
|
-
from .
|
|
4
|
+
from .path_manager import make_fullpath
|
|
5
5
|
from .keys import LogKeys
|
|
6
|
-
from .
|
|
6
|
+
from ._logger import _LOGGER
|
|
7
7
|
from typing import Optional
|
|
8
|
+
from ._script_info import _script_info
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
__all__ = [
|
|
@@ -270,7 +271,7 @@ class ModelCheckpoint(Callback):
|
|
|
270
271
|
self.last_best_filepath = new_filepath
|
|
271
272
|
|
|
272
273
|
def _save_rolling_checkpoints(self, epoch, logs):
|
|
273
|
-
"""Saves the latest model and keeps only the
|
|
274
|
+
"""Saves the latest model and keeps only the most recent ones."""
|
|
274
275
|
filename = f"epoch_{epoch}.pth"
|
|
275
276
|
filepath = self.save_dir / filename
|
|
276
277
|
|
|
@@ -334,4 +335,7 @@ class LRScheduler(Callback):
|
|
|
334
335
|
if current_lr != self.previous_lr:
|
|
335
336
|
_LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
336
337
|
self.previous_lr = current_lr
|
|
337
|
-
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def info():
|
|
341
|
+
_script_info(__all__)
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -14,9 +14,10 @@ from sklearn.metrics import (
|
|
|
14
14
|
import torch
|
|
15
15
|
import shap
|
|
16
16
|
from pathlib import Path
|
|
17
|
-
from .
|
|
18
|
-
from .
|
|
17
|
+
from .path_manager import make_fullpath
|
|
18
|
+
from ._logger import _LOGGER
|
|
19
19
|
from typing import Union, Optional
|
|
20
|
+
from ._script_info import _script_info
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
__all__ = [
|
|
@@ -62,7 +63,7 @@ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
|
|
|
62
63
|
plt.tight_layout()
|
|
63
64
|
|
|
64
65
|
if save_dir:
|
|
65
|
-
save_dir_path = make_fullpath(save_dir, make=True)
|
|
66
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
66
67
|
save_path = save_dir_path / "loss_plot.svg"
|
|
67
68
|
plt.savefig(save_path)
|
|
68
69
|
_LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
|
|
@@ -88,7 +89,7 @@ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optio
|
|
|
88
89
|
print(report)
|
|
89
90
|
|
|
90
91
|
if save_dir:
|
|
91
|
-
save_dir_path = make_fullpath(save_dir, make=True)
|
|
92
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
92
93
|
# Save text report
|
|
93
94
|
report_path = save_dir_path / "classification_report.txt"
|
|
94
95
|
report_path.write_text(report, encoding="utf-8")
|
|
@@ -158,7 +159,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
|
|
|
158
159
|
print(report_string)
|
|
159
160
|
|
|
160
161
|
if save_dir:
|
|
161
|
-
save_dir_path = make_fullpath(save_dir, make=True)
|
|
162
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
162
163
|
# Save text report
|
|
163
164
|
report_path = save_dir_path / "regression_report.txt"
|
|
164
165
|
report_path.write_text(report_string)
|
|
@@ -220,7 +221,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
|
|
|
220
221
|
_LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
|
|
221
222
|
|
|
222
223
|
if save_dir:
|
|
223
|
-
save_dir_path = make_fullpath(save_dir, make=True)
|
|
224
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
224
225
|
# Save Bar Plot
|
|
225
226
|
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
226
227
|
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
|
|
@@ -253,3 +254,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
|
|
|
253
254
|
else:
|
|
254
255
|
_LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
|
|
255
256
|
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def info():
|
|
260
|
+
_script_info(__all__)
|
ml_tools/ML_inference.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union, Literal, Dict, Any
|
|
6
|
+
|
|
7
|
+
from ._script_info import _script_info
|
|
8
|
+
from ._logger import _LOGGER
|
|
9
|
+
from .path_manager import make_fullpath
|
|
10
|
+
from .keys import PyTorchInferenceKeys
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"PyTorchInferenceHandler"
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
class PyTorchInferenceHandler:
|
|
17
|
+
"""
|
|
18
|
+
Handles loading a PyTorch model's state dictionary and performing inference
|
|
19
|
+
for either regression or classification tasks.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self,
|
|
22
|
+
model: nn.Module,
|
|
23
|
+
state_dict: Union[str, Path],
|
|
24
|
+
task: Literal["classification", "regression"],
|
|
25
|
+
device: str = 'cpu'):
|
|
26
|
+
"""
|
|
27
|
+
Initializes the handler by loading a model's state_dict.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model (nn.Module): An instantiated PyTorch model with the correct architecture.
|
|
31
|
+
state_dict (str | Path): The path to the saved .pth model state_dict file.
|
|
32
|
+
task (str): The type of task, 'regression' or 'classification'.
|
|
33
|
+
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
34
|
+
"""
|
|
35
|
+
self.model = model
|
|
36
|
+
self.task = task
|
|
37
|
+
self.device = self._validate_device(device)
|
|
38
|
+
|
|
39
|
+
model_p = make_fullpath(state_dict, enforce="file")
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
# Load the state dictionary and apply it to the model structure
|
|
43
|
+
self.model.load_state_dict(torch.load(model_p, map_location=self.device))
|
|
44
|
+
self.model.to(self.device)
|
|
45
|
+
self.model.eval() # Set the model to evaluation mode
|
|
46
|
+
_LOGGER.info(f"✅ Model state loaded from '{model_p.name}' and set to evaluation mode.")
|
|
47
|
+
except Exception as e:
|
|
48
|
+
_LOGGER.error(f"❌ Failed to load model state from '{model_p}': {e}")
|
|
49
|
+
raise
|
|
50
|
+
|
|
51
|
+
def _validate_device(self, device: str) -> torch.device:
|
|
52
|
+
"""Validates the selected device and returns a torch.device object."""
|
|
53
|
+
device_lower = device.lower()
|
|
54
|
+
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
55
|
+
_LOGGER.warning("⚠️ CUDA not available, switching to CPU.")
|
|
56
|
+
device_lower = "cpu"
|
|
57
|
+
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
58
|
+
_LOGGER.warning("⚠️ Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
59
|
+
device_lower = "cpu"
|
|
60
|
+
return torch.device(device_lower)
|
|
61
|
+
|
|
62
|
+
def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
63
|
+
"""Converts input to a torch.Tensor and moves it to the correct device."""
|
|
64
|
+
if isinstance(features, np.ndarray):
|
|
65
|
+
features = torch.from_numpy(features).float()
|
|
66
|
+
|
|
67
|
+
# Ensure tensor is on the correct device
|
|
68
|
+
return features.to(self.device)
|
|
69
|
+
|
|
70
|
+
def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
71
|
+
"""
|
|
72
|
+
Predicts on a single feature vector.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
features (np.ndarray | torch.Tensor): A 1D or 2D array/tensor for a single sample.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Dict[str, Any]: A dictionary containing the prediction.
|
|
79
|
+
- For regression: {'predictions': float}
|
|
80
|
+
- For classification: {'labels': int, 'probabilities': np.ndarray}
|
|
81
|
+
"""
|
|
82
|
+
if features.ndim == 1:
|
|
83
|
+
features = features.reshape(1, -1)
|
|
84
|
+
|
|
85
|
+
if features.shape[0] != 1:
|
|
86
|
+
raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
|
|
87
|
+
|
|
88
|
+
results_batch = self.predict_batch(features)
|
|
89
|
+
|
|
90
|
+
# Extract the single result from the batch
|
|
91
|
+
if self.task == "regression":
|
|
92
|
+
return {PyTorchInferenceKeys.PREDICTIONS: results_batch[PyTorchInferenceKeys.PREDICTIONS].item()}
|
|
93
|
+
else: # classification
|
|
94
|
+
return {
|
|
95
|
+
PyTorchInferenceKeys.LABELS: results_batch[PyTorchInferenceKeys.LABELS].item(),
|
|
96
|
+
PyTorchInferenceKeys.PROBABILITIES: results_batch[PyTorchInferenceKeys.PROBABILITIES][0]
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
100
|
+
"""
|
|
101
|
+
Predicts on a batch of feature vectors.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
features (np.ndarray | torch.Tensor): A 2D array/tensor where each row is a sample.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Dict[str, Any]: A dictionary containing the predictions.
|
|
108
|
+
- For regression: {'predictions': np.ndarray}
|
|
109
|
+
- For classification: {'labels': np.ndarray, 'probabilities': np.ndarray}
|
|
110
|
+
"""
|
|
111
|
+
if features.ndim != 2:
|
|
112
|
+
raise ValueError("Input for batch prediction must be a 2D array or tensor.")
|
|
113
|
+
|
|
114
|
+
input_tensor = self._preprocess_input(features)
|
|
115
|
+
|
|
116
|
+
with torch.no_grad():
|
|
117
|
+
output = self.model(input_tensor).cpu()
|
|
118
|
+
|
|
119
|
+
if self.task == "classification":
|
|
120
|
+
probs = nn.functional.softmax(output, dim=1)
|
|
121
|
+
labels = torch.argmax(probs, dim=1)
|
|
122
|
+
return {
|
|
123
|
+
PyTorchInferenceKeys.LABELS: labels.numpy(),
|
|
124
|
+
PyTorchInferenceKeys.PROBABILITIES: probs.numpy()
|
|
125
|
+
}
|
|
126
|
+
else: # regression
|
|
127
|
+
return {PyTorchInferenceKeys.PREDICTIONS: output.numpy()}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def info():
|
|
131
|
+
_script_info(__all__)
|
ml_tools/ML_trainer.py
CHANGED
|
@@ -7,9 +7,9 @@ import numpy as np
|
|
|
7
7
|
|
|
8
8
|
from .ML_callbacks import Callback, History, TqdmProgressBar
|
|
9
9
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
|
|
10
|
-
from .
|
|
10
|
+
from ._script_info import _script_info
|
|
11
11
|
from .keys import LogKeys
|
|
12
|
-
from .
|
|
12
|
+
from ._logger import _LOGGER
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
@@ -105,7 +105,7 @@ class MyTrainer:
|
|
|
105
105
|
pin_memory=(self.device.type == "cuda")
|
|
106
106
|
)
|
|
107
107
|
|
|
108
|
-
def fit(self, epochs: int = 10, batch_size: int =
|
|
108
|
+
def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
|
|
109
109
|
"""
|
|
110
110
|
Starts the training-validation process of the model.
|
|
111
111
|
|
|
@@ -113,6 +113,13 @@ class MyTrainer:
|
|
|
113
113
|
epochs (int): The total number of epochs to train for.
|
|
114
114
|
batch_size (int): The number of samples per batch.
|
|
115
115
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
116
|
+
|
|
117
|
+
Note:
|
|
118
|
+
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
119
|
+
automatically aligns the model's output tensor with the target tensor's
|
|
120
|
+
shape using `output.view_as(target)`. This handles the common case
|
|
121
|
+
where a model outputs a shape of `[batch_size, 1]` and the target has a
|
|
122
|
+
shape of `[batch_size]`.
|
|
116
123
|
"""
|
|
117
124
|
self.epochs = epochs
|
|
118
125
|
self._create_dataloaders(batch_size, shuffle)
|
|
@@ -189,9 +196,10 @@ class MyTrainer:
|
|
|
189
196
|
logs = {LogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
|
|
190
197
|
return logs
|
|
191
198
|
|
|
192
|
-
def
|
|
199
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
193
200
|
"""
|
|
194
|
-
|
|
201
|
+
Private method to yield model predictions batch by batch for evaluation.
|
|
202
|
+
This is used internally by the `evaluate` method.
|
|
195
203
|
|
|
196
204
|
Args:
|
|
197
205
|
dataloader (DataLoader): The dataloader to predict on.
|
|
@@ -213,13 +221,14 @@ class MyTrainer:
|
|
|
213
221
|
preds = torch.argmax(probs, dim=1)
|
|
214
222
|
y_pred_batch = preds.numpy()
|
|
215
223
|
y_prob_batch = probs.numpy()
|
|
224
|
+
# regression
|
|
216
225
|
else:
|
|
217
226
|
y_pred_batch = output.numpy()
|
|
218
227
|
y_prob_batch = None
|
|
219
228
|
|
|
220
229
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
221
230
|
|
|
222
|
-
def evaluate(self,
|
|
231
|
+
def evaluate(self, save_dir: Optional[Union[str,Path]], data: Optional[Union[DataLoader, Dataset]] = None):
|
|
223
232
|
"""
|
|
224
233
|
Evaluates the model on the given data.
|
|
225
234
|
|
|
@@ -251,7 +260,7 @@ class MyTrainer:
|
|
|
251
260
|
|
|
252
261
|
# Collect results from the predict generator
|
|
253
262
|
all_preds, all_probs, all_true = [], [], []
|
|
254
|
-
for y_pred_b, y_prob_b, y_true_b in self.
|
|
263
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
255
264
|
all_preds.append(y_pred_b)
|
|
256
265
|
if y_prob_b is not None:
|
|
257
266
|
all_probs.append(y_prob_b)
|
|
@@ -270,7 +279,7 @@ class MyTrainer:
|
|
|
270
279
|
plot_losses(self.history, save_dir=save_dir)
|
|
271
280
|
|
|
272
281
|
def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 100,
|
|
273
|
-
feature_names: Optional[List[str]] = None, save_dir: Optional[str] = None):
|
|
282
|
+
feature_names: Optional[List[str]] = None, save_dir: Optional[Union[str,Path]] = None):
|
|
274
283
|
"""
|
|
275
284
|
Explains model predictions using SHAP and saves all artifacts.
|
|
276
285
|
|