dragon-ml-toolbox 6.2.1__py3-none-any.whl → 6.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 6.2.1
3
+ Version: 6.4.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -160,6 +160,8 @@ SQL
160
160
  utilities
161
161
  ```
162
162
 
163
+ ---
164
+
163
165
  ### 🔬 MICE Imputation and Variance Inflation Factor [mice]
164
166
 
165
167
  ⚠️ Important: This group has strict version requirements. It is highly recommended to install this group in a separate virtual environment.
@@ -178,6 +180,8 @@ path_manager
178
180
  utilities
179
181
  ```
180
182
 
183
+ ---
184
+
181
185
  ### 📋 Excel File Handling [excel]
182
186
 
183
187
  Installs dependencies required to process and handle .xlsx or .xls files.
@@ -194,6 +198,8 @@ handle_excel
194
198
  path_manager
195
199
  ```
196
200
 
201
+ ---
202
+
197
203
  ### 🎰 GUI for Boosting Algorithms (XGBoost, LightGBM) [gui-boost]
198
204
 
199
205
  For GUIs that include plotting functionality, you must also install the [plot] extra.
@@ -215,6 +221,8 @@ ensemble_inference
215
221
  path_manager
216
222
  ```
217
223
 
224
+ ---
225
+
218
226
  ### 🤖 GUI for PyTorch Models [gui-torch]
219
227
 
220
228
  For GUIs that include plotting functionality, you must also install the [plot] extra.
@@ -232,10 +240,13 @@ pip install "dragon-ml-toolbox[gui-torch,plot]"
232
240
  ```Bash
233
241
  custom_logger
234
242
  GUI_tools
243
+ ML_models
235
244
  ML_inference
236
245
  path_manager
237
246
  ```
238
247
 
248
+ ---
249
+
239
250
  ### 🎫 Base Tools [base]
240
251
 
241
252
  General purpose functions and classes.
@@ -254,6 +265,8 @@ utilities
254
265
  path_manager
255
266
  ```
256
267
 
268
+ ---
269
+
257
270
  ### ⚒️ APP bundlers
258
271
 
259
272
  Choose one if needed.
@@ -1,13 +1,13 @@
1
- dragon_ml_toolbox-6.2.1.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-6.2.1.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
1
+ dragon_ml_toolbox-6.4.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-6.4.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
3
3
  ml_tools/ETL_engineering.py,sha256=4wwZXi9_U7xfCY70jGBaKniOeZ0m75ppxWpQBd_DmLc,39369
4
4
  ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
5
5
  ml_tools/MICE_imputation.py,sha256=oFHg-OytOzPYTzBR_wIRHhP71cMn3aupDeT59ABsXlQ,11576
6
6
  ml_tools/ML_callbacks.py,sha256=noedVMmHZ72Odbg28zqx5wkhhvX2v-jXicKE_NCAiqU,13838
7
- ml_tools/ML_datasetmaster.py,sha256=bbKCNA_b_uDIfxP9YIYKZm-VSfUSD15LvegFxpE9DIQ,34315
7
+ ml_tools/ML_datasetmaster.py,sha256=98dAfP-i7BjhGpmGSaxtuUOZeiUN_8KpjBwZEmPCpgk,35485
8
8
  ml_tools/ML_evaluation.py,sha256=-Z5fXQi2ou6l5Oyir06bO90SZIZVrjQfgoVAqKgSjks,13800
9
- ml_tools/ML_inference.py,sha256=blEDgzvDqatxbfloBKsyNPacRwoq9g6WTpIKQ3zoTak,5758
10
- ml_tools/ML_models.py,sha256=SJhKHGAN2VTBqzcHUOpFWuVZ2Y7U1M4P_axG_LNYWcI,6460
9
+ ml_tools/ML_inference.py,sha256=62F5RPC19bTHXUMTjnj2KMMg-wJdhLdVZDw--xJyiwM,12715
10
+ ml_tools/ML_models.py,sha256=QBPlu5d6QCKh-rlUJOAR3qVdFgOFqEzRPv1jXvRdOsw,10380
11
11
  ml_tools/ML_optimization.py,sha256=GX-qZ2mCI3gWRCTP5w7lXrZpfGle3J_mE0O68seIoio,13475
12
12
  ml_tools/ML_trainer.py,sha256=1q_CDXuMfndRsPuNofUn2mg2TlhG6MYuGqjWxTDgN9c,15112
13
13
  ml_tools/PSO_optimization.py,sha256=9Y074d-B5h4Wvp9YPiy6KAeXM-Yv6Il3gWalKvOLVgo,22705
@@ -17,7 +17,7 @@ ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
17
17
  ml_tools/__init__.py,sha256=q0y9faQ6e17XCQ7eUiCZ1FJ4Bg5EQqLjZ9f_l5REUUY,41
18
18
  ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
19
19
  ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
20
- ml_tools/custom_logger.py,sha256=njM_0XPbQ1S-x5LeSQAaTo2if-XVOR_pQSGg4EDeiTU,4603
20
+ ml_tools/custom_logger.py,sha256=nyLRxaRxkqYOFdSjI0X2BWXB8C2IU18QfmqIFKqSedI,5820
21
21
  ml_tools/data_exploration.py,sha256=P4f8OpRa7Q4i-11nkppxXw5Lx2lwlpn20GwWBbN_xbM,23901
22
22
  ml_tools/ensemble_evaluation.py,sha256=wnqoTPg4WYWf2A8z5XT0eSlW4snEuLCXQVj88sZKzQ4,24683
23
23
  ml_tools/ensemble_inference.py,sha256=rtU7eUaQne615n2g7IHZCJI-OvrBCcjxbTkEIvtCGFQ,9414
@@ -27,7 +27,7 @@ ml_tools/keys.py,sha256=HtPG8-MWh89C32A7eIlfuuA-DLwkxGkoDfwR2TGN9CQ,1074
27
27
  ml_tools/optimization_tools.py,sha256=EL5tgNFwRo-82pbRE1CFVy9noNhULD7wprWuKadPheg,5090
28
28
  ml_tools/path_manager.py,sha256=Z8e7w3MPqQaN8xmTnKuXZS6CIW59BFwwqGhGc00sdp4,13692
29
29
  ml_tools/utilities.py,sha256=LqXXTovaHbA5AOKRk6Ru6DgAPAM0wPfYU70kUjYBryo,19231
30
- dragon_ml_toolbox-6.2.1.dist-info/METADATA,sha256=acX_886jWy_IUIO9-HsSalgpD9HAh8a5Q22rbUlBzbU,6698
31
- dragon_ml_toolbox-6.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
32
- dragon_ml_toolbox-6.2.1.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
33
- dragon_ml_toolbox-6.2.1.dist-info/RECORD,,
30
+ dragon_ml_toolbox-6.4.0.dist-info/METADATA,sha256=jzs_BIaUzjLYIMUOVOgDl2qkyO-Z7Q00rZLZDzkxBkQ,6738
31
+ dragon_ml_toolbox-6.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
32
+ dragon_ml_toolbox-6.4.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
33
+ dragon_ml_toolbox-6.4.0.dist-info/RECORD,,
@@ -16,6 +16,7 @@ from pathlib import Path
16
16
  from .path_manager import make_fullpath
17
17
  from ._logger import _LOGGER
18
18
  from ._script_info import _script_info
19
+ from .custom_logger import save_list_strings
19
20
 
20
21
 
21
22
  # --- public-facing API ---
@@ -144,6 +145,9 @@ class DatasetMaker(_BaseMaker):
144
145
  self.features = pandas_df.drop(columns=label_col)
145
146
  self.labels_map = None
146
147
  self.scaler = None
148
+
149
+ self._feature_names = self.features.columns.tolist()
150
+ self._target_name = str(self.labels.name)
147
151
 
148
152
  self._is_split = False
149
153
  self._is_balanced = False
@@ -347,6 +351,23 @@ class DatasetMaker(_BaseMaker):
347
351
  if not self._is_split:
348
352
  raise RuntimeError("Data has not been split yet. Call .split_data() or .process() first.")
349
353
  return self.features_train, self.features_test, self.labels_train, self.labels_test # type: ignore
354
+
355
+ @property
356
+ def feature_names(self) -> list[str]:
357
+ """Returns the list of feature column names."""
358
+ return self._feature_names
359
+
360
+ @property
361
+ def target_name(self) -> str:
362
+ """Returns the name of the target column."""
363
+ return self._target_name
364
+
365
+ def save_feature_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
366
+ """Saves a list of feature names as a text file"""
367
+ save_list_strings(list_strings=self._feature_names,
368
+ directory=directory,
369
+ filename="feature_names",
370
+ verbose=verbose)
350
371
 
351
372
  @staticmethod
352
373
  def _embed_categorical(cat_df: pandas.DataFrame, random_state: Optional[int] = None, **kwargs) -> pandas.DataFrame:
@@ -413,7 +434,7 @@ class SimpleDatasetMaker:
413
434
  target = pandas_df.iloc[:, -1]
414
435
 
415
436
  self._feature_names = features.columns.tolist()
416
- self._target_name = target.name
437
+ self._target_name = str(target.name)
417
438
 
418
439
  #set id
419
440
  self._id: Optional[str] = None
@@ -452,7 +473,7 @@ class SimpleDatasetMaker:
452
473
  @property
453
474
  def target_name(self) -> str:
454
475
  """Returns the name of the target column."""
455
- return str(self._target_name)
476
+ return self._target_name
456
477
 
457
478
  @property
458
479
  def id(self) -> Optional[str]:
@@ -474,6 +495,13 @@ class SimpleDatasetMaker:
474
495
  print(f" X_test shape: {self._X_test_shape}")
475
496
  print(f" y_test shape: {self._y_test_shape}")
476
497
  print("-------------------------------------------")
498
+
499
+ def save_feature_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
500
+ """Saves a list of feature names as a text file"""
501
+ save_list_strings(list_strings=self._feature_names,
502
+ directory=directory,
503
+ filename="feature_names",
504
+ verbose=verbose)
477
505
 
478
506
 
479
507
  # --- VisionDatasetMaker ---
ml_tools/ML_inference.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
2
  from torch import nn
3
3
  import numpy as np
4
4
  from pathlib import Path
5
- from typing import Union, Literal, Dict, Any
5
+ from typing import Union, Literal, Dict, Any, Optional
6
6
 
7
7
  from ._script_info import _script_info
8
8
  from ._logger import _LOGGER
@@ -10,7 +10,9 @@ from .path_manager import make_fullpath
10
10
  from .keys import PyTorchInferenceKeys
11
11
 
12
12
  __all__ = [
13
- "PyTorchInferenceHandler"
13
+ "PyTorchInferenceHandler",
14
+ "multi_inference_regression",
15
+ "multi_inference_classification"
14
16
  ]
15
17
 
16
18
  class PyTorchInferenceHandler:
@@ -22,7 +24,8 @@ class PyTorchInferenceHandler:
22
24
  model: nn.Module,
23
25
  state_dict: Union[str, Path],
24
26
  task: Literal["classification", "regression"],
25
- device: str = 'cpu'):
27
+ device: str = 'cpu',
28
+ target_id: Optional[str]=None):
26
29
  """
27
30
  Initializes the handler by loading a model's state_dict.
28
31
 
@@ -31,10 +34,12 @@ class PyTorchInferenceHandler:
31
34
  state_dict (str | Path): The path to the saved .pth model state_dict file.
32
35
  task (str): The type of task, 'regression' or 'classification'.
33
36
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
37
+ target_id (str | None): Target name as used in the training set.
34
38
  """
35
39
  self.model = model
36
40
  self.task = task
37
41
  self.device = self._validate_device(device)
42
+ self.target_id = target_id
38
43
 
39
44
  model_p = make_fullpath(state_dict, enforce="file")
40
45
 
@@ -128,10 +133,155 @@ class PyTorchInferenceHandler:
128
133
  else: # classification
129
134
  return {
130
135
  PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].item(),
131
- # Move tensor to CPU before converting to NumPy
136
+ # Move tensor to CPU before converting to NumPy
132
137
  PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
133
138
  }
134
-
139
+
140
+
141
+ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
142
+ feature_vector: Union[np.ndarray, torch.Tensor],
143
+ output: Literal["numpy","torch"]="numpy") -> dict[str,Any]:
144
+ """
145
+ Performs regression inference using multiple models on a single feature vector.
146
+
147
+ This function iterates through a list of PyTorchInferenceHandler objects,
148
+ each configured for a different regression target. It runs a prediction for
149
+ each handler using the same input feature vector and returns the results
150
+ in a dictionary.
151
+
152
+ The function adapts its behavior based on the input dimensions:
153
+ - 1D input: Returns a dictionary mapping target ID to a single value.
154
+ - 2D input: Returns a dictionary mapping target ID to a list of values.
155
+
156
+ Args:
157
+ handlers (list[PyTorchInferenceHandler]): A list of initialized inference
158
+ handlers. Each handler must have a unique `target_id` and be configured with `task="regression"`.
159
+ feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D) or a batch of samples (2D) to be fed into each regression model.
160
+ output (Literal["numpy", "torch"], optional): The desired format for the output predictions.
161
+ - "numpy": Returns predictions as Python scalars or NumPy arrays.
162
+ - "torch": Returns predictions as PyTorch tensors.
163
+
164
+ Returns:
165
+ (dict[str, Any]): A dictionary mapping each handler's `target_id` to its
166
+ predicted regression values.
167
+
168
+ Raises:
169
+ AttributeError: If any handler in the list is missing a `target_id`.
170
+ ValueError: If any handler's `task` is not 'regression' or if the input `feature_vector` is not 1D or 2D.
171
+ """
172
+ # check batch dimension
173
+ is_single_sample = feature_vector.ndim == 1
174
+
175
+ # Reshape a 1D vector to a 2D batch of one for uniform processing.
176
+ if is_single_sample:
177
+ feature_vector = feature_vector.reshape(1, -1)
178
+
179
+ # Validate that the input is a 2D tensor.
180
+ if feature_vector.ndim != 2:
181
+ raise ValueError("Input feature_vector must be a 1D or 2D array/tensor.")
182
+
183
+ results: dict[str,Any] = dict()
184
+ for handler in handlers:
185
+ # validation
186
+ if handler.target_id is None:
187
+ raise AttributeError("All inference handlers must have a 'target_id' attribute.")
188
+ if handler.task != "regression":
189
+ raise ValueError(
190
+ f"Invalid task type: The handler for target_id '{handler.target_id}' "
191
+ f"is for '{handler.task}', but only 'regression' tasks are supported."
192
+ )
193
+ # inference
194
+ if output == "numpy":
195
+ result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
196
+ else: # torch
197
+ result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
198
+
199
+ # Unpack single results and update result dictionary
200
+ # If the original input was 1D, extract the single prediction from the array.
201
+ if is_single_sample:
202
+ results[handler.target_id] = result[0]
203
+ else:
204
+ results[handler.target_id] = result
205
+
206
+ return results
207
+
208
+
209
+ def multi_inference_classification(
210
+ handlers: list[PyTorchInferenceHandler],
211
+ feature_vector: Union[np.ndarray, torch.Tensor],
212
+ output: Literal["numpy","torch"]="numpy"
213
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
214
+ """
215
+ Performs classification inference on a single sample or a batch.
216
+
217
+ This function iterates through a list of PyTorchInferenceHandler objects,
218
+ each configured for a different classification target. It returns two
219
+ dictionaries: one for the predicted labels and one for the probabilities.
220
+
221
+ The function adapts its behavior based on the input dimensions:
222
+ - 1D input: The dictionaries map target ID to a single label and a single probability array.
223
+ - 2D input: The dictionaries map target ID to an array of labels and an array of probability arrays.
224
+
225
+ Args:
226
+ handlers (list[PyTorchInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
227
+ with `task="classification"`.
228
+ feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D)
229
+ or a batch of samples (2D) for prediction.
230
+ output (Literal["numpy", "torch"], optional): The desired format for the
231
+ output predictions.
232
+
233
+ Returns:
234
+ (tuple[dict[str, Any], dict[str, Any]]): A tuple containing two dictionaries:
235
+ 1. A dictionary mapping `target_id` to the predicted label(s).
236
+ 2. A dictionary mapping `target_id` to the prediction probabilities.
237
+
238
+ Raises:
239
+ AttributeError: If any handler in the list is missing a `target_id`.
240
+ ValueError: If any handler's `task` is not 'classification' or if the input `feature_vector` is not 1D or 2D.
241
+ """
242
+ # Store if the original input was a single sample
243
+ is_single_sample = feature_vector.ndim == 1
244
+
245
+ # Reshape a 1D vector to a 2D batch of one for uniform processing
246
+ if is_single_sample:
247
+ feature_vector = feature_vector.reshape(1, -1)
248
+
249
+ if feature_vector.ndim != 2:
250
+ raise ValueError("Input feature_vector must be a 1D or 2D array/tensor.")
251
+
252
+ # Initialize two dictionaries for results
253
+ labels_results: dict[str, Any] = dict()
254
+ probs_results: dict[str, Any] = dict()
255
+
256
+ for handler in handlers:
257
+ # Validation
258
+ if handler.target_id is None:
259
+ raise AttributeError("All inference handlers must have a 'target_id' attribute.")
260
+ if handler.task != "classification":
261
+ raise ValueError(
262
+ f"Invalid task type: The handler for target_id '{handler.target_id}' "
263
+ f"is for '{handler.task}', but this function only supports 'classification'."
264
+ )
265
+
266
+ # Always use the batch method to get both labels and probabilities
267
+ if output == "numpy":
268
+ result = handler.predict_batch_numpy(feature_vector)
269
+ else: # torch
270
+ result = handler.predict_batch(feature_vector)
271
+
272
+ labels = result[PyTorchInferenceKeys.LABELS]
273
+ probabilities = result[PyTorchInferenceKeys.PROBABILITIES]
274
+
275
+ # If the original input was 1D, unpack the single result from the batch array
276
+ if is_single_sample:
277
+ labels_results[handler.target_id] = labels[0]
278
+ probs_results[handler.target_id] = probabilities[0]
279
+ else:
280
+ labels_results[handler.target_id] = labels
281
+ probs_results[handler.target_id] = probabilities
282
+
283
+ return labels_results, probs_results
284
+
135
285
 
136
286
  def info():
137
287
  _script_info(__all__)
ml_tools/ML_models.py CHANGED
@@ -1,12 +1,18 @@
1
1
  import torch
2
2
  from torch import nn
3
3
  from ._script_info import _script_info
4
- from typing import List
4
+ from typing import List, Union
5
+ from pathlib import Path
6
+ import json
7
+ from ._logger import _LOGGER
8
+ from .path_manager import make_fullpath
5
9
 
6
10
 
7
11
  __all__ = [
8
12
  "MultilayerPerceptron",
9
- "SequencePredictorLSTM"
13
+ "SequencePredictorLSTM",
14
+ "save_architecture",
15
+ "load_architecture"
10
16
  ]
11
17
 
12
18
 
@@ -45,6 +51,12 @@ class MultilayerPerceptron(nn.Module):
45
51
  raise TypeError("hidden_layers must be a list of integers.")
46
52
  if not (0.0 <= drop_out < 1.0):
47
53
  raise ValueError("drop_out must be a float between 0.0 and 1.0.")
54
+
55
+ # --- Save configuration ---
56
+ self.in_features = in_features
57
+ self.out_targets = out_targets
58
+ self.hidden_layers = hidden_layers
59
+ self.drop_out = drop_out
48
60
 
49
61
  # --- Build network layers ---
50
62
  layers = []
@@ -67,6 +79,15 @@ class MultilayerPerceptron(nn.Module):
67
79
  """Defines the forward pass of the model."""
68
80
  return self._layers(x)
69
81
 
82
+ def get_config(self) -> dict:
83
+ """Returns the configuration of the model."""
84
+ return {
85
+ 'in_features': self.in_features,
86
+ 'out_targets': self.out_targets,
87
+ 'hidden_layers': self.hidden_layers,
88
+ 'drop_out': self.drop_out
89
+ }
90
+
70
91
  def __repr__(self) -> str:
71
92
  """Returns the developer-friendly string representation of the model."""
72
93
  # Extracts the number of neurons from each nn.Linear layer
@@ -114,7 +135,14 @@ class SequencePredictorLSTM(nn.Module):
114
135
  raise ValueError("recurrent_layers must be a positive integer.")
115
136
  if not (0.0 <= dropout < 1.0):
116
137
  raise ValueError("dropout must be a float between 0.0 and 1.0.")
117
-
138
+
139
+ # --- Save configuration ---
140
+ self.features = features
141
+ self.hidden_size = hidden_size
142
+ self.recurrent_layers = recurrent_layers
143
+ self.dropout = dropout
144
+
145
+ # Build model
118
146
  self.lstm = nn.LSTM(
119
147
  input_size=features,
120
148
  hidden_size=hidden_size,
@@ -144,6 +172,15 @@ class SequencePredictorLSTM(nn.Module):
144
172
 
145
173
  return predictions
146
174
 
175
+ def get_config(self) -> dict:
176
+ """Returns the configuration of the model."""
177
+ return {
178
+ 'features': self.features,
179
+ 'hidden_size': self.hidden_size,
180
+ 'recurrent_layers': self.recurrent_layers,
181
+ 'dropout': self.dropout
182
+ }
183
+
147
184
  def __repr__(self) -> str:
148
185
  """Returns the developer-friendly string representation of the model."""
149
186
  return (
@@ -153,5 +190,80 @@ class SequencePredictorLSTM(nn.Module):
153
190
  )
154
191
 
155
192
 
193
+ def save_architecture(model: nn.Module, directory: Union[str, Path], verbose: bool=True):
194
+ """
195
+ Saves a model's architecture to a 'architecture.json' file.
196
+
197
+ This function relies on the model having a `get_config()` method that
198
+ returns a dictionary of the arguments needed to initialize it.
199
+
200
+ Args:
201
+ model (nn.Module): The PyTorch model instance to save.
202
+ directory (str | Path): The directory to save the JSON file.
203
+
204
+ Raises:
205
+ AttributeError: If the model does not have a `get_config()` method.
206
+ """
207
+ if not hasattr(model, 'get_config'):
208
+ raise AttributeError(
209
+ f"Model '{model.__class__.__name__}' does not have a 'get_config()' method. "
210
+ "Please implement it to return the model's constructor arguments."
211
+ )
212
+
213
+ # Ensure the target directory exists
214
+ path_dir = make_fullpath(directory, make=True, enforce="directory")
215
+ full_path = path_dir / "architecture.json"
216
+
217
+ config = {
218
+ 'model_class': model.__class__.__name__,
219
+ 'config': model.get_config() # type: ignore
220
+ }
221
+
222
+ with open(full_path, 'w') as f:
223
+ json.dump(config, f, indent=4)
224
+
225
+ if verbose:
226
+ _LOGGER.info(f"✅ Architecture for '{model.__class__.__name__}' saved to '{path_dir.name}'")
227
+
228
+
229
+ def load_architecture(filepath: Union[str, Path], expected_model_class: type, verbose: bool=True) -> nn.Module:
230
+ """
231
+ Loads a model architecture from a JSON file.
232
+
233
+ This function instantiates a model by providing an explicit class to use
234
+ and checking that it matches the class name specified in the file.
235
+
236
+ Args:
237
+ filepath (Union[str, Path]): The path of the JSON architecture file.
238
+ expected_model_class (type): The model class expected to load (e.g., MultilayerPerceptron).
239
+
240
+ Returns:
241
+ nn.Module: An instance of the model with a freshly initialized state.
242
+
243
+ Raises:
244
+ FileNotFoundError: If the filepath does not exist.
245
+ ValueError: If the class name in the file does not match the `expected_model_class`.
246
+ """
247
+ path_obj = make_fullpath(filepath, enforce="file")
248
+
249
+ with open(path_obj, 'r') as f:
250
+ saved_data = json.load(f)
251
+
252
+ saved_class_name = saved_data['model_class']
253
+ config = saved_data['config']
254
+
255
+ if saved_class_name != expected_model_class.__name__:
256
+ raise ValueError(
257
+ f"Model class mismatch. File specifies '{saved_class_name}', "
258
+ f"but you expected '{expected_model_class.__name__}'."
259
+ )
260
+
261
+ # Create an instance of the model using the provided class and config
262
+ model = expected_model_class(**config)
263
+ if verbose:
264
+ _LOGGER.info(f"✅ Successfully loaded architecture for '{saved_class_name}'")
265
+ return model
266
+
267
+
156
268
  def info():
157
269
  _script_info(__all__)
ml_tools/custom_logger.py CHANGED
@@ -10,7 +10,9 @@ from ._logger import _LOGGER
10
10
 
11
11
 
12
12
  __all__ = [
13
- "custom_logger"
13
+ "custom_logger",
14
+ "save_list_strings",
15
+ "load_list_strings"
14
16
  ]
15
17
 
16
18
 
@@ -136,5 +138,39 @@ def _log_dict_to_json(data: Dict[Any, Any], path: Path) -> None:
136
138
  json.dump(data, f, indent=4, ensure_ascii=False)
137
139
 
138
140
 
141
+ def save_list_strings(list_strings: list[str], directory: Union[str,Path], filename: str, verbose: bool=True):
142
+ """Saves a list of strings as a text file."""
143
+ target_dir = make_fullpath(directory, make=True, enforce="directory")
144
+ sanitized_name = sanitize_filename(filename)
145
+
146
+ if not sanitized_name.endswith(".txt"):
147
+ sanitized_name = sanitized_name + ".txt"
148
+
149
+ full_path = target_dir / sanitized_name
150
+ with open(full_path, 'w') as f:
151
+ for string_data in list_strings:
152
+ f.write(f"{string_data}\n")
153
+
154
+ if verbose:
155
+ _LOGGER.info(f"✅ Text file saved as '{full_path.name}'.")
156
+
157
+
158
+ def load_list_strings(text_file: Union[str,Path], verbose: bool=True) -> list[str]:
159
+ """Loads a text file as a list of strings."""
160
+ target_path = make_fullpath(text_file, enforce="file")
161
+ loaded_strings = []
162
+
163
+ with open(target_path, 'r') as f:
164
+ loaded_strings = [line.strip() for line in f]
165
+
166
+ if len(loaded_strings) == 0:
167
+ raise ValueError("❌ The text file is empty.")
168
+
169
+ if verbose:
170
+ _LOGGER.info(f"✅ Text file loaded as list of strings.")
171
+
172
+ return loaded_strings
173
+
174
+
139
175
  def info():
140
176
  _script_info(__all__)