dragon-ml-toolbox 6.2.0__tar.gz → 6.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-6.2.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-6.3.0}/PKG-INFO +13 -1
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/README.md +12 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0/dragon_ml_toolbox.egg-info}/PKG-INFO +13 -1
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ML_callbacks.py +3 -1
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ML_datasetmaster.py +30 -2
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ML_inference.py +5 -2
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ML_models.py +115 -3
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/custom_logger.py +37 -1
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/pyproject.toml +1 -1
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/LICENSE +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ETL_engineering.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/GUI_tools.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/MICE_imputation.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ML_evaluation.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ML_optimization.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ML_trainer.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/PSO_optimization.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/RNN_forecast.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/SQL.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/VIF_factor.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/__init__.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/_logger.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/_script_info.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/data_exploration.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ensemble_evaluation.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ensemble_inference.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/ensemble_learning.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/handle_excel.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/keys.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/optimization_tools.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/path_manager.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/ml_tools/utilities.py +0 -0
- {dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.3.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: Karl Loza <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -160,6 +160,8 @@ SQL
|
|
|
160
160
|
utilities
|
|
161
161
|
```
|
|
162
162
|
|
|
163
|
+
---
|
|
164
|
+
|
|
163
165
|
### 🔬 MICE Imputation and Variance Inflation Factor [mice]
|
|
164
166
|
|
|
165
167
|
⚠️ Important: This group has strict version requirements. It is highly recommended to install this group in a separate virtual environment.
|
|
@@ -178,6 +180,8 @@ path_manager
|
|
|
178
180
|
utilities
|
|
179
181
|
```
|
|
180
182
|
|
|
183
|
+
---
|
|
184
|
+
|
|
181
185
|
### 📋 Excel File Handling [excel]
|
|
182
186
|
|
|
183
187
|
Installs dependencies required to process and handle .xlsx or .xls files.
|
|
@@ -194,6 +198,8 @@ handle_excel
|
|
|
194
198
|
path_manager
|
|
195
199
|
```
|
|
196
200
|
|
|
201
|
+
---
|
|
202
|
+
|
|
197
203
|
### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
|
|
198
204
|
|
|
199
205
|
For GUIs that include plotting functionality, you must also install the [plot] extra.
|
|
@@ -215,6 +221,8 @@ ensemble_inference
|
|
|
215
221
|
path_manager
|
|
216
222
|
```
|
|
217
223
|
|
|
224
|
+
---
|
|
225
|
+
|
|
218
226
|
### 🤖 GUI for PyTorch Models [gui-torch]
|
|
219
227
|
|
|
220
228
|
For GUIs that include plotting functionality, you must also install the [plot] extra.
|
|
@@ -236,6 +244,8 @@ ML_inference
|
|
|
236
244
|
path_manager
|
|
237
245
|
```
|
|
238
246
|
|
|
247
|
+
---
|
|
248
|
+
|
|
239
249
|
### 🎫 Base Tools [base]
|
|
240
250
|
|
|
241
251
|
General purpose functions and classes.
|
|
@@ -254,6 +264,8 @@ utilities
|
|
|
254
264
|
path_manager
|
|
255
265
|
```
|
|
256
266
|
|
|
267
|
+
---
|
|
268
|
+
|
|
257
269
|
### ⚒️ APP bundlers
|
|
258
270
|
|
|
259
271
|
Choose one if needed.
|
|
@@ -79,6 +79,8 @@ SQL
|
|
|
79
79
|
utilities
|
|
80
80
|
```
|
|
81
81
|
|
|
82
|
+
---
|
|
83
|
+
|
|
82
84
|
### 🔬 MICE Imputation and Variance Inflation Factor [mice]
|
|
83
85
|
|
|
84
86
|
⚠️ Important: This group has strict version requirements. It is highly recommended to install this group in a separate virtual environment.
|
|
@@ -97,6 +99,8 @@ path_manager
|
|
|
97
99
|
utilities
|
|
98
100
|
```
|
|
99
101
|
|
|
102
|
+
---
|
|
103
|
+
|
|
100
104
|
### 📋 Excel File Handling [excel]
|
|
101
105
|
|
|
102
106
|
Installs dependencies required to process and handle .xlsx or .xls files.
|
|
@@ -113,6 +117,8 @@ handle_excel
|
|
|
113
117
|
path_manager
|
|
114
118
|
```
|
|
115
119
|
|
|
120
|
+
---
|
|
121
|
+
|
|
116
122
|
### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
|
|
117
123
|
|
|
118
124
|
For GUIs that include plotting functionality, you must also install the [plot] extra.
|
|
@@ -134,6 +140,8 @@ ensemble_inference
|
|
|
134
140
|
path_manager
|
|
135
141
|
```
|
|
136
142
|
|
|
143
|
+
---
|
|
144
|
+
|
|
137
145
|
### 🤖 GUI for PyTorch Models [gui-torch]
|
|
138
146
|
|
|
139
147
|
For GUIs that include plotting functionality, you must also install the [plot] extra.
|
|
@@ -155,6 +163,8 @@ ML_inference
|
|
|
155
163
|
path_manager
|
|
156
164
|
```
|
|
157
165
|
|
|
166
|
+
---
|
|
167
|
+
|
|
158
168
|
### 🎫 Base Tools [base]
|
|
159
169
|
|
|
160
170
|
General purpose functions and classes.
|
|
@@ -173,6 +183,8 @@ utilities
|
|
|
173
183
|
path_manager
|
|
174
184
|
```
|
|
175
185
|
|
|
186
|
+
---
|
|
187
|
+
|
|
176
188
|
### ⚒️ APP bundlers
|
|
177
189
|
|
|
178
190
|
Choose one if needed.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.3.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: Karl Loza <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -160,6 +160,8 @@ SQL
|
|
|
160
160
|
utilities
|
|
161
161
|
```
|
|
162
162
|
|
|
163
|
+
---
|
|
164
|
+
|
|
163
165
|
### 🔬 MICE Imputation and Variance Inflation Factor [mice]
|
|
164
166
|
|
|
165
167
|
⚠️ Important: This group has strict version requirements. It is highly recommended to install this group in a separate virtual environment.
|
|
@@ -178,6 +180,8 @@ path_manager
|
|
|
178
180
|
utilities
|
|
179
181
|
```
|
|
180
182
|
|
|
183
|
+
---
|
|
184
|
+
|
|
181
185
|
### 📋 Excel File Handling [excel]
|
|
182
186
|
|
|
183
187
|
Installs dependencies required to process and handle .xlsx or .xls files.
|
|
@@ -194,6 +198,8 @@ handle_excel
|
|
|
194
198
|
path_manager
|
|
195
199
|
```
|
|
196
200
|
|
|
201
|
+
---
|
|
202
|
+
|
|
197
203
|
### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
|
|
198
204
|
|
|
199
205
|
For GUIs that include plotting functionality, you must also install the [plot] extra.
|
|
@@ -215,6 +221,8 @@ ensemble_inference
|
|
|
215
221
|
path_manager
|
|
216
222
|
```
|
|
217
223
|
|
|
224
|
+
---
|
|
225
|
+
|
|
218
226
|
### 🤖 GUI for PyTorch Models [gui-torch]
|
|
219
227
|
|
|
220
228
|
For GUIs that include plotting functionality, you must also install the [plot] extra.
|
|
@@ -236,6 +244,8 @@ ML_inference
|
|
|
236
244
|
path_manager
|
|
237
245
|
```
|
|
238
246
|
|
|
247
|
+
---
|
|
248
|
+
|
|
239
249
|
### 🎫 Base Tools [base]
|
|
240
250
|
|
|
241
251
|
General purpose functions and classes.
|
|
@@ -254,6 +264,8 @@ utilities
|
|
|
254
264
|
path_manager
|
|
255
265
|
```
|
|
256
266
|
|
|
267
|
+
---
|
|
268
|
+
|
|
257
269
|
### ⚒️ APP bundlers
|
|
258
270
|
|
|
259
271
|
Choose one if needed.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import torch
|
|
3
3
|
from tqdm.auto import tqdm
|
|
4
|
-
from .path_manager import make_fullpath
|
|
4
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
5
5
|
from .keys import PyTorchLogKeys
|
|
6
6
|
from ._logger import _LOGGER
|
|
7
7
|
from typing import Optional
|
|
@@ -212,6 +212,8 @@ class ModelCheckpoint(Callback):
|
|
|
212
212
|
self.monitor = monitor
|
|
213
213
|
self.save_best_only = save_best_only
|
|
214
214
|
self.verbose = verbose
|
|
215
|
+
if checkpoint_name:
|
|
216
|
+
checkpoint_name = sanitize_filename(checkpoint_name)
|
|
215
217
|
self.checkpoint_name = checkpoint_name
|
|
216
218
|
|
|
217
219
|
# State variables to be managed during training
|
|
@@ -16,6 +16,7 @@ from pathlib import Path
|
|
|
16
16
|
from .path_manager import make_fullpath
|
|
17
17
|
from ._logger import _LOGGER
|
|
18
18
|
from ._script_info import _script_info
|
|
19
|
+
from .custom_logger import save_list_strings
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
# --- public-facing API ---
|
|
@@ -144,6 +145,9 @@ class DatasetMaker(_BaseMaker):
|
|
|
144
145
|
self.features = pandas_df.drop(columns=label_col)
|
|
145
146
|
self.labels_map = None
|
|
146
147
|
self.scaler = None
|
|
148
|
+
|
|
149
|
+
self._feature_names = self.features.columns.tolist()
|
|
150
|
+
self._target_name = str(self.labels.name)
|
|
147
151
|
|
|
148
152
|
self._is_split = False
|
|
149
153
|
self._is_balanced = False
|
|
@@ -347,6 +351,23 @@ class DatasetMaker(_BaseMaker):
|
|
|
347
351
|
if not self._is_split:
|
|
348
352
|
raise RuntimeError("Data has not been split yet. Call .split_data() or .process() first.")
|
|
349
353
|
return self.features_train, self.features_test, self.labels_train, self.labels_test # type: ignore
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def feature_names(self) -> list[str]:
|
|
357
|
+
"""Returns the list of feature column names."""
|
|
358
|
+
return self._feature_names
|
|
359
|
+
|
|
360
|
+
@property
|
|
361
|
+
def target_name(self) -> str:
|
|
362
|
+
"""Returns the name of the target column."""
|
|
363
|
+
return self._target_name
|
|
364
|
+
|
|
365
|
+
def save_feature_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
|
|
366
|
+
"""Saves a list of feature names as a text file"""
|
|
367
|
+
save_list_strings(list_strings=self._feature_names,
|
|
368
|
+
directory=directory,
|
|
369
|
+
filename="feature_names",
|
|
370
|
+
verbose=verbose)
|
|
350
371
|
|
|
351
372
|
@staticmethod
|
|
352
373
|
def _embed_categorical(cat_df: pandas.DataFrame, random_state: Optional[int] = None, **kwargs) -> pandas.DataFrame:
|
|
@@ -413,7 +434,7 @@ class SimpleDatasetMaker:
|
|
|
413
434
|
target = pandas_df.iloc[:, -1]
|
|
414
435
|
|
|
415
436
|
self._feature_names = features.columns.tolist()
|
|
416
|
-
self._target_name = target.name
|
|
437
|
+
self._target_name = str(target.name)
|
|
417
438
|
|
|
418
439
|
#set id
|
|
419
440
|
self._id: Optional[str] = None
|
|
@@ -452,7 +473,7 @@ class SimpleDatasetMaker:
|
|
|
452
473
|
@property
|
|
453
474
|
def target_name(self) -> str:
|
|
454
475
|
"""Returns the name of the target column."""
|
|
455
|
-
return
|
|
476
|
+
return self._target_name
|
|
456
477
|
|
|
457
478
|
@property
|
|
458
479
|
def id(self) -> Optional[str]:
|
|
@@ -474,6 +495,13 @@ class SimpleDatasetMaker:
|
|
|
474
495
|
print(f" X_test shape: {self._X_test_shape}")
|
|
475
496
|
print(f" y_test shape: {self._y_test_shape}")
|
|
476
497
|
print("-------------------------------------------")
|
|
498
|
+
|
|
499
|
+
def save_feature_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
|
|
500
|
+
"""Saves a list of feature names as a text file"""
|
|
501
|
+
save_list_strings(list_strings=self._feature_names,
|
|
502
|
+
directory=directory,
|
|
503
|
+
filename="feature_names",
|
|
504
|
+
verbose=verbose)
|
|
477
505
|
|
|
478
506
|
|
|
479
507
|
# --- VisionDatasetMaker ---
|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
from torch import nn
|
|
3
3
|
import numpy as np
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Union, Literal, Dict, Any
|
|
5
|
+
from typing import Union, Literal, Dict, Any, Optional
|
|
6
6
|
|
|
7
7
|
from ._script_info import _script_info
|
|
8
8
|
from ._logger import _LOGGER
|
|
@@ -22,7 +22,8 @@ class PyTorchInferenceHandler:
|
|
|
22
22
|
model: nn.Module,
|
|
23
23
|
state_dict: Union[str, Path],
|
|
24
24
|
task: Literal["classification", "regression"],
|
|
25
|
-
device: str = 'cpu'
|
|
25
|
+
device: str = 'cpu',
|
|
26
|
+
target_id: Optional[str]=None):
|
|
26
27
|
"""
|
|
27
28
|
Initializes the handler by loading a model's state_dict.
|
|
28
29
|
|
|
@@ -31,10 +32,12 @@ class PyTorchInferenceHandler:
|
|
|
31
32
|
state_dict (str | Path): The path to the saved .pth model state_dict file.
|
|
32
33
|
task (str): The type of task, 'regression' or 'classification'.
|
|
33
34
|
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
35
|
+
target_id (str | None): Target name as used in the training set.
|
|
34
36
|
"""
|
|
35
37
|
self.model = model
|
|
36
38
|
self.task = task
|
|
37
39
|
self.device = self._validate_device(device)
|
|
40
|
+
self.target_id = target_id
|
|
38
41
|
|
|
39
42
|
model_p = make_fullpath(state_dict, enforce="file")
|
|
40
43
|
|
|
@@ -1,12 +1,18 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import nn
|
|
3
3
|
from ._script_info import _script_info
|
|
4
|
-
from typing import List
|
|
4
|
+
from typing import List, Union
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import json
|
|
7
|
+
from ._logger import _LOGGER
|
|
8
|
+
from .path_manager import make_fullpath
|
|
5
9
|
|
|
6
10
|
|
|
7
11
|
__all__ = [
|
|
8
12
|
"MultilayerPerceptron",
|
|
9
|
-
"SequencePredictorLSTM"
|
|
13
|
+
"SequencePredictorLSTM",
|
|
14
|
+
"save_architecture",
|
|
15
|
+
"load_architecture"
|
|
10
16
|
]
|
|
11
17
|
|
|
12
18
|
|
|
@@ -45,6 +51,12 @@ class MultilayerPerceptron(nn.Module):
|
|
|
45
51
|
raise TypeError("hidden_layers must be a list of integers.")
|
|
46
52
|
if not (0.0 <= drop_out < 1.0):
|
|
47
53
|
raise ValueError("drop_out must be a float between 0.0 and 1.0.")
|
|
54
|
+
|
|
55
|
+
# --- Save configuration ---
|
|
56
|
+
self.in_features = in_features
|
|
57
|
+
self.out_targets = out_targets
|
|
58
|
+
self.hidden_layers = hidden_layers
|
|
59
|
+
self.drop_out = drop_out
|
|
48
60
|
|
|
49
61
|
# --- Build network layers ---
|
|
50
62
|
layers = []
|
|
@@ -67,6 +79,15 @@ class MultilayerPerceptron(nn.Module):
|
|
|
67
79
|
"""Defines the forward pass of the model."""
|
|
68
80
|
return self._layers(x)
|
|
69
81
|
|
|
82
|
+
def get_config(self) -> dict:
|
|
83
|
+
"""Returns the configuration of the model."""
|
|
84
|
+
return {
|
|
85
|
+
'in_features': self.in_features,
|
|
86
|
+
'out_targets': self.out_targets,
|
|
87
|
+
'hidden_layers': self.hidden_layers,
|
|
88
|
+
'drop_out': self.drop_out
|
|
89
|
+
}
|
|
90
|
+
|
|
70
91
|
def __repr__(self) -> str:
|
|
71
92
|
"""Returns the developer-friendly string representation of the model."""
|
|
72
93
|
# Extracts the number of neurons from each nn.Linear layer
|
|
@@ -114,7 +135,14 @@ class SequencePredictorLSTM(nn.Module):
|
|
|
114
135
|
raise ValueError("recurrent_layers must be a positive integer.")
|
|
115
136
|
if not (0.0 <= dropout < 1.0):
|
|
116
137
|
raise ValueError("dropout must be a float between 0.0 and 1.0.")
|
|
117
|
-
|
|
138
|
+
|
|
139
|
+
# --- Save configuration ---
|
|
140
|
+
self.features = features
|
|
141
|
+
self.hidden_size = hidden_size
|
|
142
|
+
self.recurrent_layers = recurrent_layers
|
|
143
|
+
self.dropout = dropout
|
|
144
|
+
|
|
145
|
+
# Build model
|
|
118
146
|
self.lstm = nn.LSTM(
|
|
119
147
|
input_size=features,
|
|
120
148
|
hidden_size=hidden_size,
|
|
@@ -144,6 +172,15 @@ class SequencePredictorLSTM(nn.Module):
|
|
|
144
172
|
|
|
145
173
|
return predictions
|
|
146
174
|
|
|
175
|
+
def get_config(self) -> dict:
|
|
176
|
+
"""Returns the configuration of the model."""
|
|
177
|
+
return {
|
|
178
|
+
'features': self.features,
|
|
179
|
+
'hidden_size': self.hidden_size,
|
|
180
|
+
'recurrent_layers': self.recurrent_layers,
|
|
181
|
+
'dropout': self.dropout
|
|
182
|
+
}
|
|
183
|
+
|
|
147
184
|
def __repr__(self) -> str:
|
|
148
185
|
"""Returns the developer-friendly string representation of the model."""
|
|
149
186
|
return (
|
|
@@ -153,5 +190,80 @@ class SequencePredictorLSTM(nn.Module):
|
|
|
153
190
|
)
|
|
154
191
|
|
|
155
192
|
|
|
193
|
+
def save_architecture(model: nn.Module, directory: Union[str, Path], verbose: bool=True):
|
|
194
|
+
"""
|
|
195
|
+
Saves a model's architecture to a 'architecture.json' file.
|
|
196
|
+
|
|
197
|
+
This function relies on the model having a `get_config()` method that
|
|
198
|
+
returns a dictionary of the arguments needed to initialize it.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
model (nn.Module): The PyTorch model instance to save.
|
|
202
|
+
directory (str | Path): The directory to save the JSON file.
|
|
203
|
+
|
|
204
|
+
Raises:
|
|
205
|
+
AttributeError: If the model does not have a `get_config()` method.
|
|
206
|
+
"""
|
|
207
|
+
if not hasattr(model, 'get_config'):
|
|
208
|
+
raise AttributeError(
|
|
209
|
+
f"Model '{model.__class__.__name__}' does not have a 'get_config()' method. "
|
|
210
|
+
"Please implement it to return the model's constructor arguments."
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Ensure the target directory exists
|
|
214
|
+
path_dir = make_fullpath(directory, make=True, enforce="directory")
|
|
215
|
+
full_path = path_dir / "architecture.json"
|
|
216
|
+
|
|
217
|
+
config = {
|
|
218
|
+
'model_class': model.__class__.__name__,
|
|
219
|
+
'config': model.get_config() # type: ignore
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
with open(full_path, 'w') as f:
|
|
223
|
+
json.dump(config, f, indent=4)
|
|
224
|
+
|
|
225
|
+
if verbose:
|
|
226
|
+
_LOGGER.info(f"✅ Architecture for '{model.__class__.__name__}' saved to '{path_dir}'")
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def load_architecture(filepath: Union[str, Path], expected_model_class: type, verbose: bool=True) -> nn.Module:
|
|
230
|
+
"""
|
|
231
|
+
Loads a model architecture from a JSON file.
|
|
232
|
+
|
|
233
|
+
This function instantiates a model by providing an explicit class to use
|
|
234
|
+
and checking that it matches the class name specified in the file.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
filepath (Union[str, Path]): The path of the JSON architecture file.
|
|
238
|
+
expected_model_class (type): The model class expected to load (e.g., MultilayerPerceptron).
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
nn.Module: An instance of the model with a freshly initialized state.
|
|
242
|
+
|
|
243
|
+
Raises:
|
|
244
|
+
FileNotFoundError: If the filepath does not exist.
|
|
245
|
+
ValueError: If the class name in the file does not match the `expected_model_class`.
|
|
246
|
+
"""
|
|
247
|
+
path_obj = make_fullpath(filepath, enforce="file")
|
|
248
|
+
|
|
249
|
+
with open(path_obj, 'r') as f:
|
|
250
|
+
saved_data = json.load(f)
|
|
251
|
+
|
|
252
|
+
saved_class_name = saved_data['model_class']
|
|
253
|
+
config = saved_data['config']
|
|
254
|
+
|
|
255
|
+
if saved_class_name != expected_model_class.__name__:
|
|
256
|
+
raise ValueError(
|
|
257
|
+
f"Model class mismatch. File specifies '{saved_class_name}', "
|
|
258
|
+
f"but you expected '{expected_model_class.__name__}'."
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Create an instance of the model using the provided class and config
|
|
262
|
+
model = expected_model_class(**config)
|
|
263
|
+
if verbose:
|
|
264
|
+
_LOGGER.info(f"✅ Successfully loaded architecture for '{saved_class_name}'")
|
|
265
|
+
return model
|
|
266
|
+
|
|
267
|
+
|
|
156
268
|
def info():
|
|
157
269
|
_script_info(__all__)
|
|
@@ -10,7 +10,9 @@ from ._logger import _LOGGER
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
__all__ = [
|
|
13
|
-
"custom_logger"
|
|
13
|
+
"custom_logger",
|
|
14
|
+
"save_list_strings",
|
|
15
|
+
"load_list_strings"
|
|
14
16
|
]
|
|
15
17
|
|
|
16
18
|
|
|
@@ -136,5 +138,39 @@ def _log_dict_to_json(data: Dict[Any, Any], path: Path) -> None:
|
|
|
136
138
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
|
137
139
|
|
|
138
140
|
|
|
141
|
+
def save_list_strings(list_strings: list[str], directory: Union[str,Path], filename: str, verbose: bool=True):
|
|
142
|
+
"""Saves a list of strings as a text file."""
|
|
143
|
+
target_dir = make_fullpath(directory, make=True, enforce="directory")
|
|
144
|
+
sanitized_name = sanitize_filename(filename)
|
|
145
|
+
|
|
146
|
+
if not sanitized_name.endswith(".txt"):
|
|
147
|
+
sanitized_name = sanitized_name + ".txt"
|
|
148
|
+
|
|
149
|
+
full_path = target_dir / sanitized_name
|
|
150
|
+
with open(full_path, 'w') as f:
|
|
151
|
+
for string_data in list_strings:
|
|
152
|
+
f.write(f"{string_data}\n")
|
|
153
|
+
|
|
154
|
+
if verbose:
|
|
155
|
+
_LOGGER.info(f"✅ Text file saved as '{full_path.name}'.")
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def load_list_strings(text_file: Union[str,Path], verbose: bool=True) -> list[str]:
|
|
159
|
+
"""Loads a text file as a list of strings."""
|
|
160
|
+
target_path = make_fullpath(text_file, enforce="file")
|
|
161
|
+
loaded_strings = []
|
|
162
|
+
|
|
163
|
+
with open(target_path, 'r') as f:
|
|
164
|
+
loaded_strings = [line.strip() for line in f]
|
|
165
|
+
|
|
166
|
+
if len(loaded_strings) == 0:
|
|
167
|
+
raise ValueError("❌ The text file is empty.")
|
|
168
|
+
|
|
169
|
+
if verbose:
|
|
170
|
+
_LOGGER.info(f"✅ Text file loaded as list of strings.")
|
|
171
|
+
|
|
172
|
+
return loaded_strings
|
|
173
|
+
|
|
174
|
+
|
|
139
175
|
def info():
|
|
140
176
|
_script_info(__all__)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/dragon_ml_toolbox.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-6.2.0 → dragon_ml_toolbox-6.3.0}/dragon_ml_toolbox.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|