dragon-ml-toolbox 11.1.1__py3-none-any.whl → 12.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-11.1.1.dist-info → dragon_ml_toolbox-12.0.1.dist-info}/METADATA +22 -36
- dragon_ml_toolbox-12.0.1.dist-info/RECORD +40 -0
- ml_tools/ETL_cleaning.py +1 -0
- ml_tools/ETL_engineering.py +17 -5
- ml_tools/GUI_tools.py +2 -1
- ml_tools/MICE_imputation.py +5 -2
- ml_tools/ML_callbacks.py +3 -3
- ml_tools/ML_datasetmaster.py +1 -0
- ml_tools/ML_evaluation.py +2 -1
- ml_tools/ML_evaluation_multi.py +1 -0
- ml_tools/ML_inference.py +1 -0
- ml_tools/ML_models.py +3 -1
- ml_tools/ML_optimization.py +2 -1
- ml_tools/ML_scaler.py +3 -0
- ml_tools/ML_utilities.py +219 -0
- ml_tools/PSO_optimization.py +5 -6
- ml_tools/RNN_forecast.py +2 -0
- ml_tools/SQL.py +1 -0
- ml_tools/VIF_factor.py +2 -1
- ml_tools/_logger.py +0 -2
- ml_tools/custom_logger.py +1 -0
- ml_tools/data_exploration.py +16 -10
- ml_tools/ensemble_inference.py +5 -6
- ml_tools/ensemble_learning.py +3 -2
- ml_tools/handle_excel.py +1 -0
- ml_tools/math_utilities.py +235 -0
- ml_tools/path_manager.py +2 -1
- ml_tools/serde.py +103 -0
- ml_tools/utilities.py +19 -453
- dragon_ml_toolbox-11.1.1.dist-info/RECORD +0 -37
- {dragon_ml_toolbox-11.1.1.dist-info → dragon_ml_toolbox-12.0.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-11.1.1.dist-info → dragon_ml_toolbox-12.0.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-11.1.1.dist-info → dragon_ml_toolbox-12.0.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-11.1.1.dist-info → dragon_ml_toolbox-12.0.1.dist-info}/top_level.txt +0 -0
ml_tools/VIF_factor.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
|
|
2
1
|
import pandas as pd
|
|
3
2
|
import numpy as np
|
|
4
3
|
import matplotlib.pyplot as plt
|
|
@@ -7,11 +6,13 @@ from statsmodels.stats.outliers_influence import variance_inflation_factor
|
|
|
7
6
|
from statsmodels.tools.tools import add_constant
|
|
8
7
|
import warnings
|
|
9
8
|
from pathlib import Path
|
|
9
|
+
|
|
10
10
|
from .utilities import yield_dataframes_from_dir, save_dataframe
|
|
11
11
|
from .path_manager import sanitize_filename, make_fullpath
|
|
12
12
|
from ._logger import _LOGGER
|
|
13
13
|
from ._script_info import _script_info
|
|
14
14
|
|
|
15
|
+
|
|
15
16
|
__all__ = [
|
|
16
17
|
"compute_vif",
|
|
17
18
|
"drop_vif_based",
|
ml_tools/_logger.py
CHANGED
ml_tools/custom_logger.py
CHANGED
ml_tools/data_exploration.py
CHANGED
|
@@ -417,7 +417,7 @@ def encode_categorical_features(
|
|
|
417
417
|
|
|
418
418
|
# Handle the dataset splitting logic
|
|
419
419
|
if split_resulting_dataset:
|
|
420
|
-
df_categorical = df_encoded[valid_columns].to_frame()
|
|
420
|
+
df_categorical = df_encoded[valid_columns].to_frame() # type: ignore
|
|
421
421
|
df_non_categorical = df.drop(columns=valid_columns)
|
|
422
422
|
return mappings, df_non_categorical, df_categorical
|
|
423
423
|
else:
|
|
@@ -493,9 +493,9 @@ def split_continuous_binary(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFram
|
|
|
493
493
|
return df_cont, df_bin # type: ignore
|
|
494
494
|
|
|
495
495
|
|
|
496
|
-
def plot_correlation_heatmap(df: pd.DataFrame,
|
|
496
|
+
def plot_correlation_heatmap(df: pd.DataFrame,
|
|
497
|
+
plot_title: str,
|
|
497
498
|
save_dir: Union[str, Path, None] = None,
|
|
498
|
-
plot_title: str="Correlation Heatmap",
|
|
499
499
|
method: Literal["pearson", "kendall", "spearman"]="pearson"):
|
|
500
500
|
"""
|
|
501
501
|
Plots a heatmap of pairwise correlations between numeric features in a DataFrame.
|
|
@@ -503,7 +503,7 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
503
503
|
Args:
|
|
504
504
|
df (pd.DataFrame): The input dataset.
|
|
505
505
|
save_dir (str | Path | None): If provided, the heatmap will be saved to this directory as a svg file.
|
|
506
|
-
plot_title:
|
|
506
|
+
plot_title: The suffix "`method` Correlation Heatmap" will be automatically appended.
|
|
507
507
|
method (str): Correlation method to use. Must be one of:
|
|
508
508
|
- 'pearson' (default): measures linear correlation (assumes normally distributed data),
|
|
509
509
|
- 'kendall': rank correlation (non-parametric),
|
|
@@ -518,6 +518,9 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
518
518
|
if numeric_df.empty:
|
|
519
519
|
_LOGGER.warning("No numeric columns found. Heatmap not generated.")
|
|
520
520
|
return
|
|
521
|
+
if method not in ["pearson", "kendall", "spearman"]:
|
|
522
|
+
_LOGGER.error(f"'method' must be pearson, kendall, or spearman.")
|
|
523
|
+
raise ValueError()
|
|
521
524
|
|
|
522
525
|
corr = numeric_df.corr(method=method)
|
|
523
526
|
|
|
@@ -538,7 +541,10 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
538
541
|
cbar_kws={"shrink": 0.8}
|
|
539
542
|
)
|
|
540
543
|
|
|
541
|
-
|
|
544
|
+
# add suffix to title
|
|
545
|
+
full_plot_title = f"{plot_title} - {method.title()} Correlation Heatmap"
|
|
546
|
+
|
|
547
|
+
plt.title(full_plot_title)
|
|
542
548
|
plt.xticks(rotation=45, ha='right')
|
|
543
549
|
plt.yticks(rotation=0)
|
|
544
550
|
|
|
@@ -547,13 +553,13 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
547
553
|
if save_dir:
|
|
548
554
|
save_path = make_fullpath(save_dir, make=True)
|
|
549
555
|
# sanitize the plot title to save the file
|
|
550
|
-
|
|
551
|
-
|
|
556
|
+
sanitized_plot_title = sanitize_filename(plot_title)
|
|
557
|
+
plot_filename = sanitized_plot_title + ".svg"
|
|
552
558
|
|
|
553
|
-
full_path = save_path /
|
|
559
|
+
full_path = save_path / plot_filename
|
|
554
560
|
|
|
555
561
|
plt.savefig(full_path, bbox_inches="tight", format='svg')
|
|
556
|
-
_LOGGER.info(f"Saved correlation heatmap: '{
|
|
562
|
+
_LOGGER.info(f"Saved correlation heatmap: '{plot_filename}'")
|
|
557
563
|
|
|
558
564
|
plt.show()
|
|
559
565
|
plt.close()
|
|
@@ -968,7 +974,7 @@ def reconstruct_one_hot(
|
|
|
968
974
|
# Handle rows where all OHE columns were 0 (e.g., original value was NaN).
|
|
969
975
|
# In these cases, idxmax returns the first column name, but the sum of values is 0.
|
|
970
976
|
all_zero_mask = new_df[ohe_cols].sum(axis=1) == 0
|
|
971
|
-
new_column_values.loc[all_zero_mask] = np.nan
|
|
977
|
+
new_column_values.loc[all_zero_mask] = np.nan # type: ignore
|
|
972
978
|
|
|
973
979
|
# Assign the new reconstructed column to the DataFrame
|
|
974
980
|
new_df[base_name] = new_column_values
|
ml_tools/ensemble_inference.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
from ._script_info import _script_info
|
|
2
|
-
from ._logger import _LOGGER
|
|
3
|
-
from .path_manager import make_fullpath, list_files_by_extension
|
|
4
|
-
from .keys import EnsembleKeys
|
|
5
|
-
|
|
6
1
|
from typing import Union, Literal, Dict, Any, Optional, List
|
|
7
2
|
from pathlib import Path
|
|
8
3
|
import json
|
|
9
|
-
|
|
10
4
|
import joblib
|
|
11
5
|
import numpy as np
|
|
12
6
|
# Inference models
|
|
13
7
|
import xgboost
|
|
14
8
|
import lightgbm
|
|
15
9
|
|
|
10
|
+
from ._script_info import _script_info
|
|
11
|
+
from ._logger import _LOGGER
|
|
12
|
+
from .path_manager import make_fullpath, list_files_by_extension
|
|
13
|
+
from .keys import EnsembleKeys
|
|
14
|
+
|
|
16
15
|
|
|
17
16
|
__all__ = [
|
|
18
17
|
"InferenceHandler",
|
ml_tools/ensemble_learning.py
CHANGED
|
@@ -13,7 +13,8 @@ import lightgbm as lgb
|
|
|
13
13
|
from sklearn.model_selection import train_test_split
|
|
14
14
|
from sklearn.base import clone
|
|
15
15
|
|
|
16
|
-
from .utilities import yield_dataframes_from_dir,
|
|
16
|
+
from .utilities import yield_dataframes_from_dir, train_dataset_yielder
|
|
17
|
+
from .serde import serialize_object
|
|
17
18
|
from .path_manager import sanitize_filename, make_fullpath
|
|
18
19
|
from ._script_info import _script_info
|
|
19
20
|
from .keys import EnsembleKeys
|
|
@@ -481,7 +482,7 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["
|
|
|
481
482
|
|
|
482
483
|
###### 4. Execution ######
|
|
483
484
|
def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Path], target_columns: list[str], model_object: Union[RegressionTreeModels, ClassificationTreeModels],
|
|
484
|
-
handle_classification_imbalance: HandleImbalanceStrategy=None, save_model: bool=
|
|
485
|
+
handle_classification_imbalance: HandleImbalanceStrategy=None, save_model: bool=True,
|
|
485
486
|
test_size: float=0.2, debug:bool=False, generate_learning_curves: bool = False):
|
|
486
487
|
#Check models
|
|
487
488
|
if isinstance(model_object, RegressionTreeModels):
|
ml_tools/handle_excel.py
CHANGED
|
@@ -2,6 +2,7 @@ from pathlib import Path
|
|
|
2
2
|
from openpyxl import load_workbook, Workbook
|
|
3
3
|
import pandas as pd
|
|
4
4
|
from typing import List, Optional, Union
|
|
5
|
+
|
|
5
6
|
from .path_manager import sanitize_filename, make_fullpath
|
|
6
7
|
from ._script_info import _script_info
|
|
7
8
|
from ._logger import _LOGGER
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import math
|
|
4
|
+
from typing import Union, Sequence, Optional
|
|
5
|
+
|
|
6
|
+
from ._script_info import _script_info
|
|
7
|
+
from ._logger import _LOGGER
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"normalize_mixed_list",
|
|
12
|
+
"threshold_binary_values",
|
|
13
|
+
"threshold_binary_values_batch",
|
|
14
|
+
"discretize_categorical_values",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def normalize_mixed_list(data: list, threshold: int = 2) -> list[float]:
|
|
19
|
+
"""
|
|
20
|
+
Normalize a mixed list of numeric values and strings casted to floats so that the sum of the values equals 1.0,
|
|
21
|
+
applying heuristic adjustments to correct for potential data entry scale mismatches.
|
|
22
|
+
|
|
23
|
+
Parameters:
|
|
24
|
+
data (list):
|
|
25
|
+
A list of values that may include strings, floats, integers, or None.
|
|
26
|
+
None values are treated as 0.0.
|
|
27
|
+
|
|
28
|
+
threshold (int, optional):
|
|
29
|
+
The number of log10 orders of magnitude below the median scale
|
|
30
|
+
at which a value is considered suspect and is scaled upward accordingly.
|
|
31
|
+
Default is 2.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
List[float]: A list of normalized float values summing to 1.0.
|
|
35
|
+
|
|
36
|
+
Notes:
|
|
37
|
+
- Zeros and None values remain zero.
|
|
38
|
+
- Input strings are automatically cast to floats if possible.
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
>>> normalize_mixed_list([1, "0.01", 4, None])
|
|
42
|
+
[0.2, 0.2, 0.6, 0.0]
|
|
43
|
+
"""
|
|
44
|
+
# Step 1: Convert all values to float, treat None as 0.0
|
|
45
|
+
float_list = [float(x) if x is not None else 0.0 for x in data]
|
|
46
|
+
|
|
47
|
+
# Raise for negative values
|
|
48
|
+
if any(x < 0 for x in float_list):
|
|
49
|
+
_LOGGER.error("Negative values are not allowed in the input list.")
|
|
50
|
+
raise ValueError()
|
|
51
|
+
|
|
52
|
+
# Step 2: Compute log10 of non-zero values
|
|
53
|
+
nonzero = [x for x in float_list if x > 0]
|
|
54
|
+
if not nonzero:
|
|
55
|
+
return [0.0 for _ in float_list]
|
|
56
|
+
|
|
57
|
+
log_scales = [math.log10(x) for x in nonzero]
|
|
58
|
+
log_median = np.median(log_scales)
|
|
59
|
+
|
|
60
|
+
# Step 3: Adjust values that are much smaller than median
|
|
61
|
+
adjusted = []
|
|
62
|
+
for x in float_list:
|
|
63
|
+
if x == 0.0:
|
|
64
|
+
adjusted.append(0.0)
|
|
65
|
+
else:
|
|
66
|
+
log_x = math.log10(x)
|
|
67
|
+
if log_median - log_x > threshold:
|
|
68
|
+
scale_diff = round(log_median - log_x)
|
|
69
|
+
adjusted.append(x * (10 ** scale_diff))
|
|
70
|
+
else:
|
|
71
|
+
adjusted.append(x)
|
|
72
|
+
|
|
73
|
+
# Step 4: Normalize to sum to 1.0
|
|
74
|
+
total = sum(adjusted)
|
|
75
|
+
if total == 0:
|
|
76
|
+
return [0.0 for _ in adjusted]
|
|
77
|
+
|
|
78
|
+
return [x / total for x in adjusted]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def threshold_binary_values(
|
|
82
|
+
input_array: Union[Sequence[float], np.ndarray, pd.Series],
|
|
83
|
+
binary_values: Optional[int] = None
|
|
84
|
+
) -> Union[np.ndarray, pd.Series, list[float], tuple[float]]:
|
|
85
|
+
"""
|
|
86
|
+
Thresholds binary features in a 1D input. The number of binary features are counted starting from the end.
|
|
87
|
+
|
|
88
|
+
Binary elements are converted to 0 or 1 using a 0.5 threshold.
|
|
89
|
+
|
|
90
|
+
Parameters:
|
|
91
|
+
input_array: 1D sequence, NumPy array, or pandas Series.
|
|
92
|
+
binary_values (Optional[int]) :
|
|
93
|
+
- If `None`, all values are treated as binary.
|
|
94
|
+
- If `int`, only this many last `binary_values` are thresholded.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Any:
|
|
98
|
+
Same type as input
|
|
99
|
+
"""
|
|
100
|
+
original_type = type(input_array)
|
|
101
|
+
|
|
102
|
+
if isinstance(input_array, (pd.Series, np.ndarray)):
|
|
103
|
+
array = np.asarray(input_array)
|
|
104
|
+
elif isinstance(input_array, (list, tuple)):
|
|
105
|
+
array = np.array(input_array)
|
|
106
|
+
else:
|
|
107
|
+
_LOGGER.error("Unsupported input type")
|
|
108
|
+
raise TypeError()
|
|
109
|
+
|
|
110
|
+
array = array.flatten()
|
|
111
|
+
total = array.shape[0]
|
|
112
|
+
|
|
113
|
+
bin_count = total if binary_values is None else binary_values
|
|
114
|
+
if not (0 <= bin_count <= total):
|
|
115
|
+
_LOGGER.error("'binary_values' must be between 0 and the total number of elements")
|
|
116
|
+
raise ValueError()
|
|
117
|
+
|
|
118
|
+
if bin_count == 0:
|
|
119
|
+
result = array
|
|
120
|
+
else:
|
|
121
|
+
cont_part = array[:-bin_count] if bin_count < total else np.array([])
|
|
122
|
+
bin_part = (array[-bin_count:] > 0.5).astype(int)
|
|
123
|
+
result = np.concatenate([cont_part, bin_part])
|
|
124
|
+
|
|
125
|
+
if original_type is pd.Series:
|
|
126
|
+
return pd.Series(result, index=input_array.index if hasattr(input_array, 'index') else None) # type: ignore
|
|
127
|
+
elif original_type is list:
|
|
128
|
+
return result.tolist()
|
|
129
|
+
elif original_type is tuple:
|
|
130
|
+
return tuple(result)
|
|
131
|
+
else:
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def threshold_binary_values_batch(
|
|
136
|
+
input_array: np.ndarray,
|
|
137
|
+
binary_values: int
|
|
138
|
+
) -> np.ndarray:
|
|
139
|
+
"""
|
|
140
|
+
Threshold the last `binary_values` columns of a 2D NumPy array to binary {0,1} using 0.5 cutoff.
|
|
141
|
+
|
|
142
|
+
Parameters
|
|
143
|
+
----------
|
|
144
|
+
input_array : np.ndarray
|
|
145
|
+
2D array with shape (batch_size, n_features).
|
|
146
|
+
binary_values : int
|
|
147
|
+
Number of binary features located at the END of each row.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
np.ndarray
|
|
152
|
+
Thresholded array, same shape as input.
|
|
153
|
+
"""
|
|
154
|
+
if input_array.ndim != 2:
|
|
155
|
+
_LOGGER.error(f"Expected 2D array, got {input_array.ndim}D array.")
|
|
156
|
+
raise AssertionError()
|
|
157
|
+
|
|
158
|
+
batch_size, total_features = input_array.shape
|
|
159
|
+
|
|
160
|
+
if not (0 <= binary_values <= total_features):
|
|
161
|
+
_LOGGER.error("'binary_values' out of valid range.")
|
|
162
|
+
raise AssertionError()
|
|
163
|
+
|
|
164
|
+
if binary_values == 0:
|
|
165
|
+
return input_array.copy()
|
|
166
|
+
|
|
167
|
+
cont_part = input_array[:, :-binary_values] if binary_values < total_features else np.empty((batch_size, 0))
|
|
168
|
+
bin_part = input_array[:, -binary_values:] > 0.5
|
|
169
|
+
bin_part = bin_part.astype(np.int32)
|
|
170
|
+
|
|
171
|
+
return np.hstack([cont_part, bin_part])
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def discretize_categorical_values(
|
|
175
|
+
input_array: np.ndarray,
|
|
176
|
+
categorical_info: dict[int, int],
|
|
177
|
+
start_at_zero: bool = False
|
|
178
|
+
) -> np.ndarray:
|
|
179
|
+
"""
|
|
180
|
+
Rounds specified columns of a 2D NumPy array to the nearest integer and
|
|
181
|
+
clamps the result to a valid categorical range.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
input_array : np.ndarray
|
|
186
|
+
2D array with shape (batch_size, n_features) containing continuous values.
|
|
187
|
+
categorical_info : dict[int, int]
|
|
188
|
+
A dictionary mapping column indices to their cardinality (number of categories).
|
|
189
|
+
Example: {3: 4} means column 3 will be clamped to its 4 valid categories.
|
|
190
|
+
start_at_zero : bool
|
|
191
|
+
If True, categories range from 0 to k-1.
|
|
192
|
+
If False, categories range from 1 to k.
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
np.ndarray
|
|
197
|
+
A new array with the specified columns converted to integer categories.
|
|
198
|
+
"""
|
|
199
|
+
# --- Input Validation ---
|
|
200
|
+
if input_array.ndim != 2:
|
|
201
|
+
_LOGGER.error(f"Expected 2D array, got {input_array.ndim}D array.")
|
|
202
|
+
raise ValueError()
|
|
203
|
+
|
|
204
|
+
if not isinstance(categorical_info, dict) or not categorical_info:
|
|
205
|
+
_LOGGER.error(f"'categorical_info' is not a dictionary, or is empty.")
|
|
206
|
+
raise ValueError()
|
|
207
|
+
|
|
208
|
+
_, total_features = input_array.shape
|
|
209
|
+
for col_idx, cardinality in categorical_info.items():
|
|
210
|
+
if not (0 <= col_idx < total_features):
|
|
211
|
+
_LOGGER.error(f"Column index {col_idx} is out of bounds for an array with {total_features} features.")
|
|
212
|
+
raise ValueError()
|
|
213
|
+
if not isinstance(cardinality, int) or cardinality < 2:
|
|
214
|
+
_LOGGER.error(f"Cardinality for column {col_idx} must be an integer >= 2, but got {cardinality}.")
|
|
215
|
+
raise ValueError()
|
|
216
|
+
|
|
217
|
+
# --- Core Logic ---
|
|
218
|
+
output_array = input_array.copy()
|
|
219
|
+
|
|
220
|
+
for col_idx, cardinality in categorical_info.items():
|
|
221
|
+
# 1. Round the column values using "round half up"
|
|
222
|
+
rounded_col = np.floor(output_array[:, col_idx] + 0.5)
|
|
223
|
+
|
|
224
|
+
# 2. Determine clamping bounds
|
|
225
|
+
min_bound = 0 if start_at_zero else 1
|
|
226
|
+
max_bound = cardinality - 1 if start_at_zero else cardinality
|
|
227
|
+
|
|
228
|
+
# 3. Clamp the values and update the output array
|
|
229
|
+
output_array[:, col_idx] = np.clip(rounded_col, min_bound, max_bound)
|
|
230
|
+
|
|
231
|
+
return output_array.astype(np.int32)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def info():
|
|
235
|
+
_script_info(__all__)
|
ml_tools/path_manager.py
CHANGED
ml_tools/serde.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import joblib
|
|
2
|
+
from joblib.externals.loky.process_executor import TerminatedWorkerError
|
|
3
|
+
from typing import Any, Union, TypeVar, get_origin, Type, Optional
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
7
|
+
from ._script_info import _script_info
|
|
8
|
+
from ._logger import _LOGGER
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"serialize_object",
|
|
13
|
+
"deserialize_object",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def serialize_object(obj: Any, save_dir: Union[str,Path], filename: str, verbose: bool=True, raise_on_error: bool=False) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Serializes a Python object using joblib; suitable for Python built-ins, numpy, and pandas.
|
|
20
|
+
|
|
21
|
+
Parameters:
|
|
22
|
+
obj (Any) : The Python object to serialize.
|
|
23
|
+
save_dir (str | Path) : Directory path where the serialized object will be saved.
|
|
24
|
+
filename (str) : Name for the output file, extension will be appended if needed.
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
28
|
+
sanitized_name = sanitize_filename(filename)
|
|
29
|
+
if not sanitized_name.endswith('.joblib'):
|
|
30
|
+
sanitized_name = sanitized_name + ".joblib"
|
|
31
|
+
full_path = save_path / sanitized_name
|
|
32
|
+
joblib.dump(obj, full_path)
|
|
33
|
+
except (IOError, OSError, TypeError, TerminatedWorkerError) as e:
|
|
34
|
+
_LOGGER.error(f"Failed to serialize object of type '{type(obj)}'.")
|
|
35
|
+
if raise_on_error:
|
|
36
|
+
raise e
|
|
37
|
+
return None
|
|
38
|
+
else:
|
|
39
|
+
if verbose:
|
|
40
|
+
_LOGGER.info(f"Object of type '{type(obj)}' saved to '{full_path}'")
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
# Define a TypeVar to link the expected type to the return type of deserialization
|
|
44
|
+
T = TypeVar('T')
|
|
45
|
+
|
|
46
|
+
def deserialize_object(
|
|
47
|
+
filepath: Union[str, Path],
|
|
48
|
+
expected_type: Optional[Type[T]] = None,
|
|
49
|
+
verbose: bool = True,
|
|
50
|
+
raise_on_error: bool = True
|
|
51
|
+
) -> Optional[T]:
|
|
52
|
+
"""
|
|
53
|
+
Loads a serialized object from a .joblib file.
|
|
54
|
+
|
|
55
|
+
Parameters:
|
|
56
|
+
filepath (str | Path): Full path to the serialized .joblib file.
|
|
57
|
+
expected_type (Type[T] | None): The expected type of the object.
|
|
58
|
+
If provided, the function raises a TypeError if the loaded object
|
|
59
|
+
is not an instance of this type. It correctly handles generics
|
|
60
|
+
like `list[str]` by checking the base type (e.g., `list`).
|
|
61
|
+
Defaults to None, which skips the type check.
|
|
62
|
+
verbose (bool): If True, logs success messages.
|
|
63
|
+
raise_on_error (bool): If True, raises exceptions on errors. If False, returns None instead.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
(Any | None): The deserialized Python object, which will match the
|
|
67
|
+
`expected_type` if provided. Returns None if an error
|
|
68
|
+
occurs and `raise_on_error` is False.
|
|
69
|
+
"""
|
|
70
|
+
true_filepath = make_fullpath(filepath)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
obj = joblib.load(true_filepath)
|
|
74
|
+
except (IOError, OSError, EOFError, TypeError, ValueError) as e:
|
|
75
|
+
_LOGGER.error(f"Failed to deserialize object from '{true_filepath}'.")
|
|
76
|
+
if raise_on_error:
|
|
77
|
+
raise e
|
|
78
|
+
return None
|
|
79
|
+
else:
|
|
80
|
+
# --- Type Validation Step ---
|
|
81
|
+
if expected_type:
|
|
82
|
+
# get_origin handles generics (e.g., list[str] -> list)
|
|
83
|
+
# If it's not a generic, get_origin returns None, so we use the type itself.
|
|
84
|
+
type_to_check = get_origin(expected_type) or expected_type
|
|
85
|
+
|
|
86
|
+
# Can't do an isinstance check on 'Any', skip it.
|
|
87
|
+
if type_to_check is not Any and not isinstance(obj, type_to_check):
|
|
88
|
+
error_msg = (
|
|
89
|
+
f"Type mismatch: Expected an instance of '{expected_type}', "
|
|
90
|
+
f"but found '{type(obj)}' in '{true_filepath}'."
|
|
91
|
+
)
|
|
92
|
+
_LOGGER.error(error_msg)
|
|
93
|
+
if raise_on_error:
|
|
94
|
+
raise TypeError()
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
if verbose:
|
|
98
|
+
_LOGGER.info(f"Loaded object of type '{type(obj)}' from '{true_filepath}'.")
|
|
99
|
+
|
|
100
|
+
return obj
|
|
101
|
+
|
|
102
|
+
def info():
|
|
103
|
+
_script_info(__all__)
|