dragon-ml-toolbox 10.6.0__py3-none-any.whl → 10.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 10.6.0
3
+ Version: 10.8.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,18 +1,18 @@
1
- dragon_ml_toolbox-10.6.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-10.6.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
1
+ dragon_ml_toolbox-10.8.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-10.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
3
3
  ml_tools/ETL_cleaning.py,sha256=lSP5q6-ukGhJBPV8dlsqJvPXAzj4du_0J-SbtEd0Pjg,19292
4
4
  ml_tools/ETL_engineering.py,sha256=a6KCWH6kRatZtjaFEF_o917ApPMK5_vRD-BjfCDAl-E,49400
5
5
  ml_tools/GUI_tools.py,sha256=kEQWg-bog3pB5tI22gMGKWaCGHnz9TB2Lvvfhf5F2CI,45412
6
6
  ml_tools/MICE_imputation.py,sha256=kVSythWfxJFR4-2mtcYCWQaQ1Oz5yyx_SJu5gjnS7H8,11670
7
7
  ml_tools/ML_callbacks.py,sha256=JPvEw_cW5tYNJ2rMSgnNrKLuni_UrmuhDFaOw-u2SvA,13926
8
- ml_tools/ML_datasetmaster.py,sha256=CBZFpvm0qiY-8gP89iKTkd7jvU-rGQcJwk-_mBJmRSg,29273
8
+ ml_tools/ML_datasetmaster.py,sha256=BMmdCVAZ-HSnnSPLzKla2TdZKvHkHj4t9A0V1Ba3i-I,30821
9
9
  ml_tools/ML_evaluation.py,sha256=28JJ2M71p4pxniwav2Hv3b1a5dsvaoIYNLm-UJQuXvY,16002
10
10
  ml_tools/ML_evaluation_multi.py,sha256=2jTSNFCu8cz5C05EusnrDyffs59M2Fq3UXSHxo2TR1A,12515
11
11
  ml_tools/ML_inference.py,sha256=SGDPiPxs_OYDKKRZziFMyaWcC8A37c70W9t-dMP5niI,23066
12
- ml_tools/ML_models.py,sha256=A_yeULMxT3IAuJuwIF5nXdAQwQDGsxHlbDSxtlzVG44,27699
12
+ ml_tools/ML_models.py,sha256=FliuqGhxP7AWHCweTLlfssXFOjwvFhIYJsgj_w_-EI4,27901
13
13
  ml_tools/ML_optimization.py,sha256=a2Uxe1g-y4I-gFa8ENIM8QDS-Pz3hoPRRaVXAWMbyQA,13491
14
- ml_tools/ML_scaler.py,sha256=O8JzHr2551zPpKRRReEIMvq0lNAAPau6hV59KUMAySg,7420
15
- ml_tools/ML_trainer.py,sha256=xM-o-gbPhWXm2lOVXbeaTFotgJSDRSHyE7H0-9OOij4,23712
14
+ ml_tools/ML_scaler.py,sha256=IrZsAr1xjvuLi8s5IKR-qbk2mS_awl3mn_xoXg5TJyA,7535
15
+ ml_tools/ML_trainer.py,sha256=xw1zMgYpdqwsTt604xe3GTQNvpg6z6Ze-avmitGBFeU,23539
16
16
  ml_tools/PSO_optimization.py,sha256=q0VYpssQGbPum7xdnkDXlJQKhZMYZo8acHpKhajPK3c,22954
17
17
  ml_tools/RNN_forecast.py,sha256=8rNZr-eWOBXMiDQV22e_tQTPM5LM2IFggEAa1FaoXaI,1965
18
18
  ml_tools/SQL.py,sha256=WDgdZUYuLBUpv-4Am9XjVY_Aq_jxBWdLrbcgAIEwefI,10704
@@ -26,11 +26,11 @@ ml_tools/ensemble_evaluation.py,sha256=xMEMfXJ5MjTkTfr1LkFOeD7iUtnVDCW3S9lm3zT-6
26
26
  ml_tools/ensemble_inference.py,sha256=EFHnbjbu31fcVp88NBx8lWAVdu2Gpg9MY9huVZJHFfM,9350
27
27
  ml_tools/ensemble_learning.py,sha256=3s0kH4i_naj0IVl_T4knst-Hwg4TScWjEdsXX5KAi7I,21929
28
28
  ml_tools/handle_excel.py,sha256=He4UT15sCGhaG-JKfs7uYVAubxWjrqgJ6U7OhMR2fuE,14005
29
- ml_tools/keys.py,sha256=C_P-8EBVYFlajtddWkOFHN0imuRyN-tkjECaoFSJOxg,1230
29
+ ml_tools/keys.py,sha256=sZANLHvp_93pPigviMOz7AhampGlpokcop_llzsjWBw,1689
30
30
  ml_tools/optimization_tools.py,sha256=P3I6lIpvZ8Xf2kX5FvvBKBmrK2pB6idBpkTzfUJxTeE,5073
31
- ml_tools/path_manager.py,sha256=7sRvAoNrboRY6ef9gH3_qdzoZ66iLs7Aii4P39K0kEk,13819
32
- ml_tools/utilities.py,sha256=SVMaSDigh6SUoAeig2_sXLLIj5w5mUs5KuVWpHvFDec,19816
33
- dragon_ml_toolbox-10.6.0.dist-info/METADATA,sha256=j_-cjm2w_DaUz9k9r4tlJ34zbsM9rQn7od2X_LaoSHU,6968
34
- dragon_ml_toolbox-10.6.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
35
- dragon_ml_toolbox-10.6.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
36
- dragon_ml_toolbox-10.6.0.dist-info/RECORD,,
31
+ ml_tools/path_manager.py,sha256=wLJlz3Y9_1-LB9em4B2VYDCVuTOX2eOc7D6hbbebjgM,14990
32
+ ml_tools/utilities.py,sha256=xddY0uASKQWSuUsYJEcfDUkeC-ccbYlkycqHKdkPnhk,25105
33
+ dragon_ml_toolbox-10.8.0.dist-info/METADATA,sha256=Ly11G7vOgCFbYwEYXQXa8RBgvWof9thiBxVjlk9DZu4,6968
34
+ dragon_ml_toolbox-10.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
35
+ dragon_ml_toolbox-10.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
36
+ dragon_ml_toolbox-10.8.0.dist-info/RECORD,,
@@ -15,6 +15,8 @@ from ._logger import _LOGGER
15
15
  from ._script_info import _script_info
16
16
  from .custom_logger import save_list_strings
17
17
  from .ML_scaler import PytorchScaler
18
+ from .keys import DatasetKeys
19
+
18
20
 
19
21
  __all__ = [
20
22
  "DatasetMaker",
@@ -34,7 +36,9 @@ class _PytorchDataset(Dataset):
34
36
  def __init__(self, features: Union[numpy.ndarray, pandas.DataFrame],
35
37
  labels: Union[numpy.ndarray, pandas.Series],
36
38
  labels_dtype: torch.dtype,
37
- features_dtype: torch.dtype = torch.float32):
39
+ features_dtype: torch.dtype = torch.float32,
40
+ feature_names: Optional[List[str]] = None,
41
+ target_names: Optional[List[str]] = None):
38
42
  """
39
43
  integer labels for classification.
40
44
 
@@ -50,12 +54,30 @@ class _PytorchDataset(Dataset):
50
54
  self.labels = torch.tensor(labels, dtype=labels_dtype)
51
55
  else:
52
56
  self.labels = torch.tensor(labels.values, dtype=labels_dtype)
57
+
58
+ self._feature_names = feature_names
59
+ self._target_names = target_names
53
60
 
54
61
  def __len__(self):
55
62
  return len(self.features)
56
63
 
57
64
  def __getitem__(self, index):
58
65
  return self.features[index], self.labels[index]
66
+
67
+ @property
68
+ def feature_names(self):
69
+ if self._feature_names is not None:
70
+ return self._feature_names
71
+ else:
72
+ _LOGGER.error(f"Dataset {self.__class__} has not been initialized with any feature names.")
73
+ raise ValueError()
74
+
75
+ @property
76
+ def target_names(self):
77
+ if self._target_names is not None:
78
+ return self._target_names
79
+ else:
80
+ _LOGGER.error(f"Dataset {self.__class__} has not been initialized with any target names.")
59
81
 
60
82
 
61
83
  # --- Abstract Base Class (New) ---
@@ -71,6 +93,7 @@ class _BaseDatasetMaker(ABC):
71
93
  self.scaler: Optional[PytorchScaler] = None
72
94
  self._id: Optional[str] = None
73
95
  self._feature_names: List[str] = []
96
+ self._target_names: List[str] = []
74
97
  self._X_train_shape = (0,0)
75
98
  self._X_test_shape = (0,0)
76
99
  self._y_train_shape = (0,)
@@ -122,6 +145,10 @@ class _BaseDatasetMaker(ABC):
122
145
  @property
123
146
  def feature_names(self) -> list[str]:
124
147
  return self._feature_names
148
+
149
+ @property
150
+ def target_names(self) -> list[str]:
151
+ return self._target_names
125
152
 
126
153
  @property
127
154
  def id(self) -> Optional[str]:
@@ -142,10 +169,17 @@ class _BaseDatasetMaker(ABC):
142
169
  """Saves a list of feature names as a text file"""
143
170
  save_list_strings(list_strings=self._feature_names,
144
171
  directory=directory,
145
- filename="feature_names",
146
- verbose=verbose)
172
+ filename=DatasetKeys.FEATURE_NAMES,
173
+ verbose=verbose)
174
+
175
+ def save_target_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
176
+ """Saves a list of target names as a text file"""
177
+ save_list_strings(list_strings=self._target_names,
178
+ directory=directory,
179
+ filename=DatasetKeys.TARGET_NAMES,
180
+ verbose=verbose)
147
181
 
148
- def save_scaler(self, save_dir: Union[str, Path]):
182
+ def save_scaler(self, save_dir: Union[str, Path], verbose: bool=True) -> None:
149
183
  """
150
184
  Saves the fitted PytorchScaler's state to a .pth file.
151
185
 
@@ -158,14 +192,15 @@ class _BaseDatasetMaker(ABC):
158
192
  _LOGGER.error("No scaler was fitted or provided.")
159
193
  raise RuntimeError()
160
194
  if not self.id:
161
- _LOGGER.error("Must set the `id` before saving scaler.")
195
+ _LOGGER.error("Must set the dataset `id` before saving scaler.")
162
196
  raise ValueError()
163
197
  save_path = make_fullpath(save_dir, make=True, enforce="directory")
164
198
  sanitized_id = sanitize_filename(self.id)
165
- filename = f"scaler_{sanitized_id}.pth"
199
+ filename = f"{DatasetKeys.SCALER_PREFIX}{sanitized_id}.pth"
166
200
  filepath = save_path / filename
167
- self.scaler.save(filepath)
168
- _LOGGER.info(f"Scaler for dataset '{self.id}' saved to '{filepath.name}'.")
201
+ self.scaler.save(filepath, verbose=False)
202
+ if verbose:
203
+ _LOGGER.info(f"Scaler for dataset '{self.id}' saved to '{filepath.name}'.")
169
204
 
170
205
 
171
206
  # Single target dataset
@@ -183,7 +218,7 @@ class DatasetMaker(_BaseDatasetMaker):
183
218
  `train_dataset` -> PyTorch Dataset
184
219
  `test_dataset` -> PyTorch Dataset
185
220
  `feature_names` -> list[str]
186
- `target_name` -> str
221
+ `target_names` -> list[str]
187
222
  `id` -> str
188
223
 
189
224
  The ID can be manually set to any string if needed, it is the target name by default.
@@ -211,8 +246,8 @@ class DatasetMaker(_BaseDatasetMaker):
211
246
  features = pandas_df.iloc[:, :-1]
212
247
  target = pandas_df.iloc[:, -1]
213
248
  self._feature_names = features.columns.tolist()
214
- self._target_name = str(target.name)
215
- self._id = self._target_name
249
+ self._target_names = [str(target.name)]
250
+ self._id = self._target_names[0]
216
251
 
217
252
  # --- 2. Split ---
218
253
  X_train, X_test, y_train, y_test = train_test_split(
@@ -229,12 +264,8 @@ class DatasetMaker(_BaseDatasetMaker):
229
264
  )
230
265
 
231
266
  # --- 4. Create Datasets ---
232
- self._train_ds = _PytorchDataset(X_train_final, y_train.values, label_dtype)
233
- self._test_ds = _PytorchDataset(X_test_final, y_test.values, label_dtype)
234
-
235
- @property
236
- def target_name(self) -> str:
237
- return self._target_name
267
+ self._train_ds = _PytorchDataset(X_train_final, y_train.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
268
+ self._test_ds = _PytorchDataset(X_test_final, y_test.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
238
269
 
239
270
 
240
271
  # --- New Multi-Target Class ---
@@ -280,12 +311,8 @@ class DatasetMakerMulti(_BaseDatasetMaker):
280
311
  X_train, y_train, X_test, label_dtype, continuous_feature_columns
281
312
  )
282
313
 
283
- self._train_ds = _PytorchDataset(X_train_final, y_train, label_dtype)
284
- self._test_ds = _PytorchDataset(X_test_final, y_test, label_dtype)
285
-
286
- @property
287
- def target_names(self) -> list[str]:
288
- return self._target_names
314
+ self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
315
+ self._test_ds = _PytorchDataset(X_test_final, y_test, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
289
316
 
290
317
 
291
318
  # --- Private Base Class ---
ml_tools/ML_models.py CHANGED
@@ -6,7 +6,7 @@ import json
6
6
  from ._logger import _LOGGER
7
7
  from .path_manager import make_fullpath
8
8
  from ._script_info import _script_info
9
- from .keys import PytorchModelKeys
9
+ from .keys import PytorchModelArchitectureKeys
10
10
 
11
11
 
12
12
  __all__ = [
@@ -29,11 +29,14 @@ class _ArchitectureHandlerMixin:
29
29
  raise AttributeError()
30
30
 
31
31
  path_dir = make_fullpath(directory, make=True, enforce="directory")
32
- full_path = path_dir / PytorchModelKeys.SAVENAME
32
+
33
+ json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
34
+
35
+ full_path = path_dir / json_filename
33
36
 
34
37
  config = {
35
- PytorchModelKeys.MODEL: self.__class__.__name__,
36
- PytorchModelKeys.CONFIG: self.get_architecture_config() # type: ignore
38
+ PytorchModelArchitectureKeys.MODEL: self.__class__.__name__,
39
+ PytorchModelArchitectureKeys.CONFIG: self.get_architecture_config() # type: ignore
37
40
  }
38
41
 
39
42
  with open(full_path, 'w') as f:
@@ -48,7 +51,8 @@ class _ArchitectureHandlerMixin:
48
51
  user_path = make_fullpath(file_or_dir)
49
52
 
50
53
  if user_path.is_dir():
51
- target_path = make_fullpath(user_path / PytorchModelKeys.SAVENAME, enforce="file")
54
+ json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
55
+ target_path = make_fullpath(user_path / json_filename, enforce="file")
52
56
  elif user_path.is_file():
53
57
  target_path = user_path
54
58
  else:
@@ -58,8 +62,8 @@ class _ArchitectureHandlerMixin:
58
62
  with open(target_path, 'r') as f:
59
63
  saved_data = json.load(f)
60
64
 
61
- saved_class_name = saved_data[PytorchModelKeys.MODEL]
62
- config = saved_data[PytorchModelKeys.CONFIG]
65
+ saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
66
+ config = saved_data[PytorchModelArchitectureKeys.CONFIG]
63
67
 
64
68
  if saved_class_name != cls.__name__:
65
69
  _LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
ml_tools/ML_scaler.py CHANGED
@@ -149,24 +149,25 @@ class PytorchScaler:
149
149
 
150
150
  return data_clone
151
151
 
152
- def save(self, filepath: Union[str, Path]):
152
+ def save(self, filepath: Union[str, Path], verbose: bool=True):
153
153
  """
154
154
  Saves the scaler's state (mean, std, indices) to a .pth file.
155
155
 
156
156
  Args:
157
157
  filepath (str | Path): The path to save the file.
158
158
  """
159
- path_obj = make_fullpath(filepath)
159
+ path_obj = make_fullpath(filepath, make=True, enforce="file")
160
160
  state = {
161
161
  'mean': self.mean_,
162
162
  'std': self.std_,
163
163
  'continuous_feature_indices': self.continuous_feature_indices
164
164
  }
165
165
  torch.save(state, path_obj)
166
- _LOGGER.info(f"PytorchScaler state saved to '{path_obj.name}'.")
166
+ if verbose:
167
+ _LOGGER.info(f"PytorchScaler state saved to '{path_obj.name}'.")
167
168
 
168
169
  @staticmethod
169
- def load(filepath: Union[str, Path]) -> 'PytorchScaler':
170
+ def load(filepath: Union[str, Path], verbose: bool=True) -> 'PytorchScaler':
170
171
  """
171
172
  Loads a scaler's state from a .pth file.
172
173
 
@@ -178,7 +179,8 @@ class PytorchScaler:
178
179
  """
179
180
  path_obj = make_fullpath(filepath, enforce="file")
180
181
  state = torch.load(path_obj)
181
- _LOGGER.info(f"PytorchScaler state loaded from '{path_obj.name}'.")
182
+ if verbose:
183
+ _LOGGER.info(f"PytorchScaler state loaded from '{path_obj.name}'.")
182
184
  return PytorchScaler(
183
185
  mean=state['mean'],
184
186
  std=state['std'],
ml_tools/ML_trainer.py CHANGED
@@ -357,7 +357,7 @@ class MLTrainer:
357
357
  If None, the trainer's test dataset is used.
358
358
  n_samples (int): The number of samples to use for both background and explanation.
359
359
  feature_names (list[str] | None): Feature names.
360
- target_names (list[str] | None): Target names
360
+ target_names (list[str] | None): Target names for multi-target tasks.
361
361
  save_dir (str | Path): Directory to save all SHAP artifacts.
362
362
  """
363
363
  # Internal helper to create a dataloader and get a random sample
@@ -408,12 +408,8 @@ class MLTrainer:
408
408
  if hasattr(target_dataset, "feature_names"):
409
409
  feature_names = target_dataset.feature_names # type: ignore
410
410
  else:
411
- try:
412
- # Handle PyTorch Subset
413
- feature_names = target_dataset.dataset.feature_names # type: ignore
414
- except AttributeError:
415
- _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
416
- raise ValueError()
411
+ _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
412
+ raise ValueError()
417
413
 
418
414
  # 3. Call the plotting function
419
415
  if self.kind in ["regression", "classification"]:
ml_tools/keys.py CHANGED
@@ -38,11 +38,27 @@ class PyTorchInferenceKeys:
38
38
  PROBABILITIES = "probabilities"
39
39
 
40
40
 
41
- class PytorchModelKeys:
42
- """Keys for saving and loading models"""
43
- MODEL = 'model_class',
44
- CONFIG = "config",
45
- SAVENAME = "architecture.json"
41
+ class PytorchModelArchitectureKeys:
42
+ """Keys for saving and loading model architecture."""
43
+ MODEL = 'model_class'
44
+ CONFIG = "config"
45
+ SAVENAME = "architecture"
46
+
47
+
48
+ class PytorchArtifactPathKeys:
49
+ """Keys for model artifact paths."""
50
+ FEATURES_PATH = "feature_names_path"
51
+ TARGETS_PATH = "target_names_path"
52
+ ARCHITECTURE_PATH = "model_architecture_path"
53
+ WEIGHTS_PATH = "model_weights_path"
54
+ SCALER_PATH = "scaler_path"
55
+
56
+
57
+ class DatasetKeys:
58
+ """Keys for saving dataset artifacts"""
59
+ FEATURE_NAMES = "feature_names"
60
+ TARGET_NAMES = "target_names"
61
+ SCALER_PREFIX = "scaler_"
46
62
 
47
63
 
48
64
  class _OneHotOtherPlaceholder:
ml_tools/path_manager.py CHANGED
@@ -13,6 +13,7 @@ __all__ = [
13
13
  "sanitize_filename",
14
14
  "list_csv_paths",
15
15
  "list_files_by_extension",
16
+ "list_subdirectories"
16
17
  ]
17
18
 
18
19
 
@@ -385,5 +386,37 @@ def list_files_by_extension(directory: Union[str,Path], extension: str, verbose:
385
386
  return name_path_dict
386
387
 
387
388
 
389
+ def list_subdirectories(root_dir: Union[str,Path], verbose: bool=True) -> dict[str, Path]:
390
+ """
391
+ Scans a directory and returns a dictionary of its immediate subdirectories.
392
+
393
+ Args:
394
+ root_dir (str | Path): The path to the directory to scan.
395
+ verbose (bool): If True, prints the number of directories found.
396
+
397
+ Returns:
398
+ dict[str, Path]: A dictionary mapping subdirectory names (str) to their full Path objects.
399
+ """
400
+ root_path = make_fullpath(root_dir, enforce="directory")
401
+
402
+ directories = [p.resolve() for p in root_path.iterdir() if p.is_dir()]
403
+
404
+ if len(directories) < 1:
405
+ _LOGGER.error(f"No subdirectories found inside '{root_path}'")
406
+ raise IOError()
407
+
408
+ if verbose:
409
+ count = len(directories)
410
+ # Use pluralization for better readability
411
+ plural = 'ies' if count != 1 else 'y'
412
+ print(f"Found {count} subdirector{plural} in '{root_path.name}'.")
413
+
414
+ # Create a dictionary where the key is the directory's name (a string)
415
+ # and the value is the full Path object.
416
+ dir_map = {p.name: p for p in directories}
417
+
418
+ return dir_map
419
+
420
+
388
421
  def info():
389
422
  _script_info(__all__)
ml_tools/utilities.py CHANGED
@@ -6,9 +6,10 @@ from pathlib import Path
6
6
  from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple, overload
7
7
  import joblib
8
8
  from joblib.externals.loky.process_executor import TerminatedWorkerError
9
- from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
9
+ from .path_manager import sanitize_filename, make_fullpath, list_csv_paths, list_files_by_extension, list_subdirectories
10
10
  from ._script_info import _script_info
11
11
  from ._logger import _LOGGER
12
+ from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys
12
13
 
13
14
 
14
15
  # Keep track of available tools
@@ -24,7 +25,8 @@ __all__ = [
24
25
  "deserialize_object",
25
26
  "distribute_dataset_by_target",
26
27
  "train_dataset_orchestrator",
27
- "train_dataset_yielder"
28
+ "train_dataset_yielder",
29
+ "find_model_artifacts"
28
30
  ]
29
31
 
30
32
 
@@ -560,5 +562,126 @@ def train_dataset_yielder(
560
562
  yield (df_features, df_target, feature_names, target_col)
561
563
 
562
564
 
565
+ def find_model_artifacts(target_directory: Union[str,Path], load_scaler: bool, verbose: bool=False) -> list[dict[str,Any]]:
566
+ """
567
+ Scans subdirectories to find paths to model weights, target names, feature names, and model architecture. Optionally an scaler path if `load_scaler` is True.
568
+
569
+ This function operates on a specific directory structure. It expects the
570
+ `target_directory` to contain one or more subdirectories, where each
571
+ subdirectory represents a single trained model result.
572
+
573
+ The expected directory structure for each model is as follows:
574
+ ```
575
+ target_directory
576
+ ├── model_1
577
+ │ ├── *.pth
578
+ │ ├── scaler_*.pth (Required if `load_scaler` is True)
579
+ │ ├── feature_names.txt
580
+ │ ├── target_names.txt
581
+ │ └── architecture.json
582
+ └── model_2/
583
+ └── ...
584
+ ```
585
+
586
+ Args:
587
+ target_directory (str | Path): The path to the root directory that contains model subdirectories.
588
+ load_scaler (bool): If True, the function requires and searches for a scaler file (`.pth`) in each model subdirectory.
589
+ verbose (bool): If True, enables detailed logging during the file paths search process.
590
+
591
+ Returns:
592
+ (list[dict[str, Path]]): A list of dictionaries, where each dictionary
593
+ corresponds to a model found in a subdirectory. The dictionary
594
+ maps standardized keys to the absolute paths of the model's
595
+ artifacts (weights, architecture, features, targets, and scaler).
596
+ The scaler path will be `None` if `load_scaler` is False.
597
+ """
598
+ # validate directory
599
+ root_path = make_fullpath(target_directory, enforce="directory")
600
+
601
+ # store results
602
+ all_artifacts: list[dict] = list()
603
+
604
+ # find model directories
605
+ result_dirs_dict = list_subdirectories(root_dir=root_path, verbose=verbose)
606
+ for dir_name, dir_path in result_dirs_dict.items():
607
+ # find files
608
+ model_pth_dict = list_files_by_extension(directory=dir_path, extension="pth", verbose=verbose)
609
+
610
+ # restriction
611
+ if load_scaler:
612
+ if len(model_pth_dict) != 2:
613
+ _LOGGER.error(f"Directory {dir_path} should contain exactly 2 '.pth' files: scaler and weights.")
614
+ raise IOError()
615
+ else:
616
+ if len(model_pth_dict) != 1:
617
+ _LOGGER.error(f"Directory {dir_path} should contain exactly 1 '.pth' file: weights.")
618
+ raise IOError()
619
+
620
+ ##### Scaler and Weights #####
621
+ scaler_path = None
622
+ weights_path = None
623
+
624
+ # load weights and scaler if present
625
+ for pth_filename, pth_path in model_pth_dict.items():
626
+ if load_scaler and pth_filename.lower().startswith(DatasetKeys.SCALER_PREFIX):
627
+ scaler_path = pth_path
628
+ else:
629
+ weights_path = pth_path
630
+
631
+ # validation
632
+ if not weights_path:
633
+ _LOGGER.error(f"Error parsing the model weights path from '{dir_name}'")
634
+ raise IOError()
635
+
636
+ if load_scaler and not scaler_path:
637
+ _LOGGER.error(f"Error parsing the scaler path from '{dir_name}'")
638
+ raise IOError()
639
+
640
+ ##### Target and Feature names #####
641
+ target_names_path = None
642
+ feature_names_path = None
643
+
644
+ # load feature and target names
645
+ model_txt_dict = list_files_by_extension(directory=dir_path, extension="txt", verbose=verbose)
646
+
647
+ for txt_filename, txt_path in model_txt_dict.items():
648
+ if txt_filename == DatasetKeys.FEATURE_NAMES:
649
+ feature_names_path = txt_path
650
+ elif txt_filename == DatasetKeys.TARGET_NAMES:
651
+ target_names_path = txt_path
652
+
653
+ # validation
654
+ if not target_names_path or not feature_names_path:
655
+ _LOGGER.error(f"Error parsing features path or targets path from '{dir_name}'")
656
+ raise IOError()
657
+
658
+ ##### load model architecture path #####
659
+ architecture_path = None
660
+
661
+ model_json_dict = list_files_by_extension(directory=dir_path, extension="json", verbose=verbose)
662
+
663
+ for json_filename, json_path in model_json_dict.items():
664
+ if json_filename == PytorchModelArchitectureKeys.SAVENAME:
665
+ architecture_path = json_path
666
+
667
+ # validation
668
+ if not architecture_path:
669
+ _LOGGER.error(f"Error parsing the model architecture path from '{dir_name}'")
670
+ raise IOError()
671
+
672
+ ##### Paths dictionary #####
673
+ parsing_dict = {
674
+ PytorchArtifactPathKeys.WEIGHTS_PATH: weights_path,
675
+ PytorchArtifactPathKeys.ARCHITECTURE_PATH: architecture_path,
676
+ PytorchArtifactPathKeys.FEATURES_PATH: feature_names_path,
677
+ PytorchArtifactPathKeys.TARGETS_PATH: target_names_path,
678
+ PytorchArtifactPathKeys.SCALER_PATH: scaler_path
679
+ }
680
+
681
+ all_artifacts.append(parsing_dict)
682
+
683
+ return all_artifacts
684
+
685
+
563
686
  def info():
564
687
  _script_info(__all__)