dragon-ml-toolbox 10.6.0__tar.gz → 10.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.6.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-10.8.0}/PKG-INFO +1 -1
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_datasetmaster.py +50 -23
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_models.py +11 -7
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_scaler.py +7 -5
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_trainer.py +3 -7
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/keys.py +21 -5
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/path_manager.py +33 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/utilities.py +125 -2
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/pyproject.toml +1 -1
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/LICENSE +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/README.md +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ETL_cleaning.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ETL_engineering.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/GUI_tools.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/MICE_imputation.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_callbacks.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_evaluation.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_evaluation_multi.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_inference.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ML_optimization.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/PSO_optimization.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/RNN_forecast.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/SQL.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/VIF_factor.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/__init__.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/_logger.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/_script_info.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/custom_logger.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/data_exploration.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ensemble_evaluation.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ensemble_inference.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/ensemble_learning.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/handle_excel.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/ml_tools/optimization_tools.py +0 -0
- {dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/setup.cfg +0 -0
|
@@ -15,6 +15,8 @@ from ._logger import _LOGGER
|
|
|
15
15
|
from ._script_info import _script_info
|
|
16
16
|
from .custom_logger import save_list_strings
|
|
17
17
|
from .ML_scaler import PytorchScaler
|
|
18
|
+
from .keys import DatasetKeys
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
__all__ = [
|
|
20
22
|
"DatasetMaker",
|
|
@@ -34,7 +36,9 @@ class _PytorchDataset(Dataset):
|
|
|
34
36
|
def __init__(self, features: Union[numpy.ndarray, pandas.DataFrame],
|
|
35
37
|
labels: Union[numpy.ndarray, pandas.Series],
|
|
36
38
|
labels_dtype: torch.dtype,
|
|
37
|
-
features_dtype: torch.dtype = torch.float32
|
|
39
|
+
features_dtype: torch.dtype = torch.float32,
|
|
40
|
+
feature_names: Optional[List[str]] = None,
|
|
41
|
+
target_names: Optional[List[str]] = None):
|
|
38
42
|
"""
|
|
39
43
|
integer labels for classification.
|
|
40
44
|
|
|
@@ -50,12 +54,30 @@ class _PytorchDataset(Dataset):
|
|
|
50
54
|
self.labels = torch.tensor(labels, dtype=labels_dtype)
|
|
51
55
|
else:
|
|
52
56
|
self.labels = torch.tensor(labels.values, dtype=labels_dtype)
|
|
57
|
+
|
|
58
|
+
self._feature_names = feature_names
|
|
59
|
+
self._target_names = target_names
|
|
53
60
|
|
|
54
61
|
def __len__(self):
|
|
55
62
|
return len(self.features)
|
|
56
63
|
|
|
57
64
|
def __getitem__(self, index):
|
|
58
65
|
return self.features[index], self.labels[index]
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def feature_names(self):
|
|
69
|
+
if self._feature_names is not None:
|
|
70
|
+
return self._feature_names
|
|
71
|
+
else:
|
|
72
|
+
_LOGGER.error(f"Dataset {self.__class__} has not been initialized with any feature names.")
|
|
73
|
+
raise ValueError()
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def target_names(self):
|
|
77
|
+
if self._target_names is not None:
|
|
78
|
+
return self._target_names
|
|
79
|
+
else:
|
|
80
|
+
_LOGGER.error(f"Dataset {self.__class__} has not been initialized with any target names.")
|
|
59
81
|
|
|
60
82
|
|
|
61
83
|
# --- Abstract Base Class (New) ---
|
|
@@ -71,6 +93,7 @@ class _BaseDatasetMaker(ABC):
|
|
|
71
93
|
self.scaler: Optional[PytorchScaler] = None
|
|
72
94
|
self._id: Optional[str] = None
|
|
73
95
|
self._feature_names: List[str] = []
|
|
96
|
+
self._target_names: List[str] = []
|
|
74
97
|
self._X_train_shape = (0,0)
|
|
75
98
|
self._X_test_shape = (0,0)
|
|
76
99
|
self._y_train_shape = (0,)
|
|
@@ -122,6 +145,10 @@ class _BaseDatasetMaker(ABC):
|
|
|
122
145
|
@property
|
|
123
146
|
def feature_names(self) -> list[str]:
|
|
124
147
|
return self._feature_names
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def target_names(self) -> list[str]:
|
|
151
|
+
return self._target_names
|
|
125
152
|
|
|
126
153
|
@property
|
|
127
154
|
def id(self) -> Optional[str]:
|
|
@@ -142,10 +169,17 @@ class _BaseDatasetMaker(ABC):
|
|
|
142
169
|
"""Saves a list of feature names as a text file"""
|
|
143
170
|
save_list_strings(list_strings=self._feature_names,
|
|
144
171
|
directory=directory,
|
|
145
|
-
filename=
|
|
146
|
-
verbose=verbose)
|
|
172
|
+
filename=DatasetKeys.FEATURE_NAMES,
|
|
173
|
+
verbose=verbose)
|
|
174
|
+
|
|
175
|
+
def save_target_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
|
|
176
|
+
"""Saves a list of target names as a text file"""
|
|
177
|
+
save_list_strings(list_strings=self._target_names,
|
|
178
|
+
directory=directory,
|
|
179
|
+
filename=DatasetKeys.TARGET_NAMES,
|
|
180
|
+
verbose=verbose)
|
|
147
181
|
|
|
148
|
-
def save_scaler(self, save_dir: Union[str, Path]):
|
|
182
|
+
def save_scaler(self, save_dir: Union[str, Path], verbose: bool=True) -> None:
|
|
149
183
|
"""
|
|
150
184
|
Saves the fitted PytorchScaler's state to a .pth file.
|
|
151
185
|
|
|
@@ -158,14 +192,15 @@ class _BaseDatasetMaker(ABC):
|
|
|
158
192
|
_LOGGER.error("No scaler was fitted or provided.")
|
|
159
193
|
raise RuntimeError()
|
|
160
194
|
if not self.id:
|
|
161
|
-
_LOGGER.error("Must set the `id` before saving scaler.")
|
|
195
|
+
_LOGGER.error("Must set the dataset `id` before saving scaler.")
|
|
162
196
|
raise ValueError()
|
|
163
197
|
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
164
198
|
sanitized_id = sanitize_filename(self.id)
|
|
165
|
-
filename = f"
|
|
199
|
+
filename = f"{DatasetKeys.SCALER_PREFIX}{sanitized_id}.pth"
|
|
166
200
|
filepath = save_path / filename
|
|
167
|
-
self.scaler.save(filepath)
|
|
168
|
-
|
|
201
|
+
self.scaler.save(filepath, verbose=False)
|
|
202
|
+
if verbose:
|
|
203
|
+
_LOGGER.info(f"Scaler for dataset '{self.id}' saved to '{filepath.name}'.")
|
|
169
204
|
|
|
170
205
|
|
|
171
206
|
# Single target dataset
|
|
@@ -183,7 +218,7 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
183
218
|
`train_dataset` -> PyTorch Dataset
|
|
184
219
|
`test_dataset` -> PyTorch Dataset
|
|
185
220
|
`feature_names` -> list[str]
|
|
186
|
-
`
|
|
221
|
+
`target_names` -> list[str]
|
|
187
222
|
`id` -> str
|
|
188
223
|
|
|
189
224
|
The ID can be manually set to any string if needed, it is the target name by default.
|
|
@@ -211,8 +246,8 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
211
246
|
features = pandas_df.iloc[:, :-1]
|
|
212
247
|
target = pandas_df.iloc[:, -1]
|
|
213
248
|
self._feature_names = features.columns.tolist()
|
|
214
|
-
self.
|
|
215
|
-
self._id = self.
|
|
249
|
+
self._target_names = [str(target.name)]
|
|
250
|
+
self._id = self._target_names[0]
|
|
216
251
|
|
|
217
252
|
# --- 2. Split ---
|
|
218
253
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
@@ -229,12 +264,8 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
229
264
|
)
|
|
230
265
|
|
|
231
266
|
# --- 4. Create Datasets ---
|
|
232
|
-
self._train_ds = _PytorchDataset(X_train_final, y_train.values, label_dtype)
|
|
233
|
-
self._test_ds = _PytorchDataset(X_test_final, y_test.values, label_dtype)
|
|
234
|
-
|
|
235
|
-
@property
|
|
236
|
-
def target_name(self) -> str:
|
|
237
|
-
return self._target_name
|
|
267
|
+
self._train_ds = _PytorchDataset(X_train_final, y_train.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
268
|
+
self._test_ds = _PytorchDataset(X_test_final, y_test.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
238
269
|
|
|
239
270
|
|
|
240
271
|
# --- New Multi-Target Class ---
|
|
@@ -280,12 +311,8 @@ class DatasetMakerMulti(_BaseDatasetMaker):
|
|
|
280
311
|
X_train, y_train, X_test, label_dtype, continuous_feature_columns
|
|
281
312
|
)
|
|
282
313
|
|
|
283
|
-
self._train_ds = _PytorchDataset(X_train_final, y_train, label_dtype)
|
|
284
|
-
self._test_ds = _PytorchDataset(X_test_final, y_test, label_dtype)
|
|
285
|
-
|
|
286
|
-
@property
|
|
287
|
-
def target_names(self) -> list[str]:
|
|
288
|
-
return self._target_names
|
|
314
|
+
self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
315
|
+
self._test_ds = _PytorchDataset(X_test_final, y_test, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
289
316
|
|
|
290
317
|
|
|
291
318
|
# --- Private Base Class ---
|
|
@@ -6,7 +6,7 @@ import json
|
|
|
6
6
|
from ._logger import _LOGGER
|
|
7
7
|
from .path_manager import make_fullpath
|
|
8
8
|
from ._script_info import _script_info
|
|
9
|
-
from .keys import
|
|
9
|
+
from .keys import PytorchModelArchitectureKeys
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
__all__ = [
|
|
@@ -29,11 +29,14 @@ class _ArchitectureHandlerMixin:
|
|
|
29
29
|
raise AttributeError()
|
|
30
30
|
|
|
31
31
|
path_dir = make_fullpath(directory, make=True, enforce="directory")
|
|
32
|
-
|
|
32
|
+
|
|
33
|
+
json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
|
|
34
|
+
|
|
35
|
+
full_path = path_dir / json_filename
|
|
33
36
|
|
|
34
37
|
config = {
|
|
35
|
-
|
|
36
|
-
|
|
38
|
+
PytorchModelArchitectureKeys.MODEL: self.__class__.__name__,
|
|
39
|
+
PytorchModelArchitectureKeys.CONFIG: self.get_architecture_config() # type: ignore
|
|
37
40
|
}
|
|
38
41
|
|
|
39
42
|
with open(full_path, 'w') as f:
|
|
@@ -48,7 +51,8 @@ class _ArchitectureHandlerMixin:
|
|
|
48
51
|
user_path = make_fullpath(file_or_dir)
|
|
49
52
|
|
|
50
53
|
if user_path.is_dir():
|
|
51
|
-
|
|
54
|
+
json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
|
|
55
|
+
target_path = make_fullpath(user_path / json_filename, enforce="file")
|
|
52
56
|
elif user_path.is_file():
|
|
53
57
|
target_path = user_path
|
|
54
58
|
else:
|
|
@@ -58,8 +62,8 @@ class _ArchitectureHandlerMixin:
|
|
|
58
62
|
with open(target_path, 'r') as f:
|
|
59
63
|
saved_data = json.load(f)
|
|
60
64
|
|
|
61
|
-
saved_class_name = saved_data[
|
|
62
|
-
config = saved_data[
|
|
65
|
+
saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
|
|
66
|
+
config = saved_data[PytorchModelArchitectureKeys.CONFIG]
|
|
63
67
|
|
|
64
68
|
if saved_class_name != cls.__name__:
|
|
65
69
|
_LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
|
|
@@ -149,24 +149,25 @@ class PytorchScaler:
|
|
|
149
149
|
|
|
150
150
|
return data_clone
|
|
151
151
|
|
|
152
|
-
def save(self, filepath: Union[str, Path]):
|
|
152
|
+
def save(self, filepath: Union[str, Path], verbose: bool=True):
|
|
153
153
|
"""
|
|
154
154
|
Saves the scaler's state (mean, std, indices) to a .pth file.
|
|
155
155
|
|
|
156
156
|
Args:
|
|
157
157
|
filepath (str | Path): The path to save the file.
|
|
158
158
|
"""
|
|
159
|
-
path_obj = make_fullpath(filepath)
|
|
159
|
+
path_obj = make_fullpath(filepath, make=True, enforce="file")
|
|
160
160
|
state = {
|
|
161
161
|
'mean': self.mean_,
|
|
162
162
|
'std': self.std_,
|
|
163
163
|
'continuous_feature_indices': self.continuous_feature_indices
|
|
164
164
|
}
|
|
165
165
|
torch.save(state, path_obj)
|
|
166
|
-
|
|
166
|
+
if verbose:
|
|
167
|
+
_LOGGER.info(f"PytorchScaler state saved to '{path_obj.name}'.")
|
|
167
168
|
|
|
168
169
|
@staticmethod
|
|
169
|
-
def load(filepath: Union[str, Path]) -> 'PytorchScaler':
|
|
170
|
+
def load(filepath: Union[str, Path], verbose: bool=True) -> 'PytorchScaler':
|
|
170
171
|
"""
|
|
171
172
|
Loads a scaler's state from a .pth file.
|
|
172
173
|
|
|
@@ -178,7 +179,8 @@ class PytorchScaler:
|
|
|
178
179
|
"""
|
|
179
180
|
path_obj = make_fullpath(filepath, enforce="file")
|
|
180
181
|
state = torch.load(path_obj)
|
|
181
|
-
|
|
182
|
+
if verbose:
|
|
183
|
+
_LOGGER.info(f"PytorchScaler state loaded from '{path_obj.name}'.")
|
|
182
184
|
return PytorchScaler(
|
|
183
185
|
mean=state['mean'],
|
|
184
186
|
std=state['std'],
|
|
@@ -357,7 +357,7 @@ class MLTrainer:
|
|
|
357
357
|
If None, the trainer's test dataset is used.
|
|
358
358
|
n_samples (int): The number of samples to use for both background and explanation.
|
|
359
359
|
feature_names (list[str] | None): Feature names.
|
|
360
|
-
target_names (list[str] | None): Target names
|
|
360
|
+
target_names (list[str] | None): Target names for multi-target tasks.
|
|
361
361
|
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
362
362
|
"""
|
|
363
363
|
# Internal helper to create a dataloader and get a random sample
|
|
@@ -408,12 +408,8 @@ class MLTrainer:
|
|
|
408
408
|
if hasattr(target_dataset, "feature_names"):
|
|
409
409
|
feature_names = target_dataset.feature_names # type: ignore
|
|
410
410
|
else:
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
feature_names = target_dataset.dataset.feature_names # type: ignore
|
|
414
|
-
except AttributeError:
|
|
415
|
-
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
|
|
416
|
-
raise ValueError()
|
|
411
|
+
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
|
|
412
|
+
raise ValueError()
|
|
417
413
|
|
|
418
414
|
# 3. Call the plotting function
|
|
419
415
|
if self.kind in ["regression", "classification"]:
|
|
@@ -38,11 +38,27 @@ class PyTorchInferenceKeys:
|
|
|
38
38
|
PROBABILITIES = "probabilities"
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
class
|
|
42
|
-
"""Keys for saving and loading
|
|
43
|
-
MODEL = 'model_class'
|
|
44
|
-
CONFIG = "config"
|
|
45
|
-
SAVENAME = "architecture
|
|
41
|
+
class PytorchModelArchitectureKeys:
|
|
42
|
+
"""Keys for saving and loading model architecture."""
|
|
43
|
+
MODEL = 'model_class'
|
|
44
|
+
CONFIG = "config"
|
|
45
|
+
SAVENAME = "architecture"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class PytorchArtifactPathKeys:
|
|
49
|
+
"""Keys for model artifact paths."""
|
|
50
|
+
FEATURES_PATH = "feature_names_path"
|
|
51
|
+
TARGETS_PATH = "target_names_path"
|
|
52
|
+
ARCHITECTURE_PATH = "model_architecture_path"
|
|
53
|
+
WEIGHTS_PATH = "model_weights_path"
|
|
54
|
+
SCALER_PATH = "scaler_path"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DatasetKeys:
|
|
58
|
+
"""Keys for saving dataset artifacts"""
|
|
59
|
+
FEATURE_NAMES = "feature_names"
|
|
60
|
+
TARGET_NAMES = "target_names"
|
|
61
|
+
SCALER_PREFIX = "scaler_"
|
|
46
62
|
|
|
47
63
|
|
|
48
64
|
class _OneHotOtherPlaceholder:
|
|
@@ -13,6 +13,7 @@ __all__ = [
|
|
|
13
13
|
"sanitize_filename",
|
|
14
14
|
"list_csv_paths",
|
|
15
15
|
"list_files_by_extension",
|
|
16
|
+
"list_subdirectories"
|
|
16
17
|
]
|
|
17
18
|
|
|
18
19
|
|
|
@@ -385,5 +386,37 @@ def list_files_by_extension(directory: Union[str,Path], extension: str, verbose:
|
|
|
385
386
|
return name_path_dict
|
|
386
387
|
|
|
387
388
|
|
|
389
|
+
def list_subdirectories(root_dir: Union[str,Path], verbose: bool=True) -> dict[str, Path]:
|
|
390
|
+
"""
|
|
391
|
+
Scans a directory and returns a dictionary of its immediate subdirectories.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
root_dir (str | Path): The path to the directory to scan.
|
|
395
|
+
verbose (bool): If True, prints the number of directories found.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
dict[str, Path]: A dictionary mapping subdirectory names (str) to their full Path objects.
|
|
399
|
+
"""
|
|
400
|
+
root_path = make_fullpath(root_dir, enforce="directory")
|
|
401
|
+
|
|
402
|
+
directories = [p.resolve() for p in root_path.iterdir() if p.is_dir()]
|
|
403
|
+
|
|
404
|
+
if len(directories) < 1:
|
|
405
|
+
_LOGGER.error(f"No subdirectories found inside '{root_path}'")
|
|
406
|
+
raise IOError()
|
|
407
|
+
|
|
408
|
+
if verbose:
|
|
409
|
+
count = len(directories)
|
|
410
|
+
# Use pluralization for better readability
|
|
411
|
+
plural = 'ies' if count != 1 else 'y'
|
|
412
|
+
print(f"Found {count} subdirector{plural} in '{root_path.name}'.")
|
|
413
|
+
|
|
414
|
+
# Create a dictionary where the key is the directory's name (a string)
|
|
415
|
+
# and the value is the full Path object.
|
|
416
|
+
dir_map = {p.name: p for p in directories}
|
|
417
|
+
|
|
418
|
+
return dir_map
|
|
419
|
+
|
|
420
|
+
|
|
388
421
|
def info():
|
|
389
422
|
_script_info(__all__)
|
|
@@ -6,9 +6,10 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple, overload
|
|
7
7
|
import joblib
|
|
8
8
|
from joblib.externals.loky.process_executor import TerminatedWorkerError
|
|
9
|
-
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
9
|
+
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths, list_files_by_extension, list_subdirectories
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
from ._logger import _LOGGER
|
|
12
|
+
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
# Keep track of available tools
|
|
@@ -24,7 +25,8 @@ __all__ = [
|
|
|
24
25
|
"deserialize_object",
|
|
25
26
|
"distribute_dataset_by_target",
|
|
26
27
|
"train_dataset_orchestrator",
|
|
27
|
-
"train_dataset_yielder"
|
|
28
|
+
"train_dataset_yielder",
|
|
29
|
+
"find_model_artifacts"
|
|
28
30
|
]
|
|
29
31
|
|
|
30
32
|
|
|
@@ -560,5 +562,126 @@ def train_dataset_yielder(
|
|
|
560
562
|
yield (df_features, df_target, feature_names, target_col)
|
|
561
563
|
|
|
562
564
|
|
|
565
|
+
def find_model_artifacts(target_directory: Union[str,Path], load_scaler: bool, verbose: bool=False) -> list[dict[str,Any]]:
|
|
566
|
+
"""
|
|
567
|
+
Scans subdirectories to find paths to model weights, target names, feature names, and model architecture. Optionally an scaler path if `load_scaler` is True.
|
|
568
|
+
|
|
569
|
+
This function operates on a specific directory structure. It expects the
|
|
570
|
+
`target_directory` to contain one or more subdirectories, where each
|
|
571
|
+
subdirectory represents a single trained model result.
|
|
572
|
+
|
|
573
|
+
The expected directory structure for each model is as follows:
|
|
574
|
+
```
|
|
575
|
+
target_directory
|
|
576
|
+
├── model_1
|
|
577
|
+
│ ├── *.pth
|
|
578
|
+
│ ├── scaler_*.pth (Required if `load_scaler` is True)
|
|
579
|
+
│ ├── feature_names.txt
|
|
580
|
+
│ ├── target_names.txt
|
|
581
|
+
│ └── architecture.json
|
|
582
|
+
└── model_2/
|
|
583
|
+
└── ...
|
|
584
|
+
```
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
target_directory (str | Path): The path to the root directory that contains model subdirectories.
|
|
588
|
+
load_scaler (bool): If True, the function requires and searches for a scaler file (`.pth`) in each model subdirectory.
|
|
589
|
+
verbose (bool): If True, enables detailed logging during the file paths search process.
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
(list[dict[str, Path]]): A list of dictionaries, where each dictionary
|
|
593
|
+
corresponds to a model found in a subdirectory. The dictionary
|
|
594
|
+
maps standardized keys to the absolute paths of the model's
|
|
595
|
+
artifacts (weights, architecture, features, targets, and scaler).
|
|
596
|
+
The scaler path will be `None` if `load_scaler` is False.
|
|
597
|
+
"""
|
|
598
|
+
# validate directory
|
|
599
|
+
root_path = make_fullpath(target_directory, enforce="directory")
|
|
600
|
+
|
|
601
|
+
# store results
|
|
602
|
+
all_artifacts: list[dict] = list()
|
|
603
|
+
|
|
604
|
+
# find model directories
|
|
605
|
+
result_dirs_dict = list_subdirectories(root_dir=root_path, verbose=verbose)
|
|
606
|
+
for dir_name, dir_path in result_dirs_dict.items():
|
|
607
|
+
# find files
|
|
608
|
+
model_pth_dict = list_files_by_extension(directory=dir_path, extension="pth", verbose=verbose)
|
|
609
|
+
|
|
610
|
+
# restriction
|
|
611
|
+
if load_scaler:
|
|
612
|
+
if len(model_pth_dict) != 2:
|
|
613
|
+
_LOGGER.error(f"Directory {dir_path} should contain exactly 2 '.pth' files: scaler and weights.")
|
|
614
|
+
raise IOError()
|
|
615
|
+
else:
|
|
616
|
+
if len(model_pth_dict) != 1:
|
|
617
|
+
_LOGGER.error(f"Directory {dir_path} should contain exactly 1 '.pth' file: weights.")
|
|
618
|
+
raise IOError()
|
|
619
|
+
|
|
620
|
+
##### Scaler and Weights #####
|
|
621
|
+
scaler_path = None
|
|
622
|
+
weights_path = None
|
|
623
|
+
|
|
624
|
+
# load weights and scaler if present
|
|
625
|
+
for pth_filename, pth_path in model_pth_dict.items():
|
|
626
|
+
if load_scaler and pth_filename.lower().startswith(DatasetKeys.SCALER_PREFIX):
|
|
627
|
+
scaler_path = pth_path
|
|
628
|
+
else:
|
|
629
|
+
weights_path = pth_path
|
|
630
|
+
|
|
631
|
+
# validation
|
|
632
|
+
if not weights_path:
|
|
633
|
+
_LOGGER.error(f"Error parsing the model weights path from '{dir_name}'")
|
|
634
|
+
raise IOError()
|
|
635
|
+
|
|
636
|
+
if load_scaler and not scaler_path:
|
|
637
|
+
_LOGGER.error(f"Error parsing the scaler path from '{dir_name}'")
|
|
638
|
+
raise IOError()
|
|
639
|
+
|
|
640
|
+
##### Target and Feature names #####
|
|
641
|
+
target_names_path = None
|
|
642
|
+
feature_names_path = None
|
|
643
|
+
|
|
644
|
+
# load feature and target names
|
|
645
|
+
model_txt_dict = list_files_by_extension(directory=dir_path, extension="txt", verbose=verbose)
|
|
646
|
+
|
|
647
|
+
for txt_filename, txt_path in model_txt_dict.items():
|
|
648
|
+
if txt_filename == DatasetKeys.FEATURE_NAMES:
|
|
649
|
+
feature_names_path = txt_path
|
|
650
|
+
elif txt_filename == DatasetKeys.TARGET_NAMES:
|
|
651
|
+
target_names_path = txt_path
|
|
652
|
+
|
|
653
|
+
# validation
|
|
654
|
+
if not target_names_path or not feature_names_path:
|
|
655
|
+
_LOGGER.error(f"Error parsing features path or targets path from '{dir_name}'")
|
|
656
|
+
raise IOError()
|
|
657
|
+
|
|
658
|
+
##### load model architecture path #####
|
|
659
|
+
architecture_path = None
|
|
660
|
+
|
|
661
|
+
model_json_dict = list_files_by_extension(directory=dir_path, extension="json", verbose=verbose)
|
|
662
|
+
|
|
663
|
+
for json_filename, json_path in model_json_dict.items():
|
|
664
|
+
if json_filename == PytorchModelArchitectureKeys.SAVENAME:
|
|
665
|
+
architecture_path = json_path
|
|
666
|
+
|
|
667
|
+
# validation
|
|
668
|
+
if not architecture_path:
|
|
669
|
+
_LOGGER.error(f"Error parsing the model architecture path from '{dir_name}'")
|
|
670
|
+
raise IOError()
|
|
671
|
+
|
|
672
|
+
##### Paths dictionary #####
|
|
673
|
+
parsing_dict = {
|
|
674
|
+
PytorchArtifactPathKeys.WEIGHTS_PATH: weights_path,
|
|
675
|
+
PytorchArtifactPathKeys.ARCHITECTURE_PATH: architecture_path,
|
|
676
|
+
PytorchArtifactPathKeys.FEATURES_PATH: feature_names_path,
|
|
677
|
+
PytorchArtifactPathKeys.TARGETS_PATH: target_names_path,
|
|
678
|
+
PytorchArtifactPathKeys.SCALER_PATH: scaler_path
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
all_artifacts.append(parsing_dict)
|
|
682
|
+
|
|
683
|
+
return all_artifacts
|
|
684
|
+
|
|
685
|
+
|
|
563
686
|
def info():
|
|
564
687
|
_script_info(__all__)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/dragon_ml_toolbox.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/dragon_ml_toolbox.egg-info/requires.txt
RENAMED
|
File without changes
|
{dragon_ml_toolbox-10.6.0 → dragon_ml_toolbox-10.8.0}/dragon_ml_toolbox.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|