dragon-ml-toolbox 13.0.0__py3-none-any.whl → 14.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/METADATA +12 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +49 -0
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_configuration.py +108 -0
- ml_tools/ML_datasetmaster.py +241 -260
- ml_tools/ML_evaluation.py +229 -76
- ml_tools/ML_evaluation_multi.py +45 -16
- ml_tools/ML_inference.py +0 -1
- ml_tools/ML_models.py +135 -55
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +49 -36
- ml_tools/ML_trainer.py +498 -29
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1492 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +203 -0
- ml_tools/PSO_optimization.py +5 -1
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_schema.py +96 -0
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +576 -138
- ml_tools/ensemble_evaluation.py +53 -10
- ml_tools/keys.py +43 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +65 -86
- ml_tools/serde.py +78 -17
- ml_tools/utilities.py +192 -3
- dragon_ml_toolbox-13.0.0.dist-info/RECORD +0 -41
- ml_tools/ML_simple_optimization.py +0 -413
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/top_level.txt +0 -0
ml_tools/ensemble_evaluation.py
CHANGED
|
@@ -112,7 +112,7 @@ def evaluate_model_classification(
|
|
|
112
112
|
report_df = pd.DataFrame(report_dict).iloc[:-1, :].T
|
|
113
113
|
plt.figure(figsize=figsize)
|
|
114
114
|
sns.heatmap(report_df, annot=True, cmap=heatmap_cmap, fmt='.2f',
|
|
115
|
-
annot_kws={"size": base_fontsize - 4})
|
|
115
|
+
annot_kws={"size": base_fontsize - 4}, vmin=0.0, vmax=1.0)
|
|
116
116
|
plt.title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
117
117
|
plt.xticks(fontsize=base_fontsize - 2)
|
|
118
118
|
plt.yticks(fontsize=base_fontsize - 2)
|
|
@@ -133,6 +133,7 @@ def evaluate_model_classification(
|
|
|
133
133
|
normalize="true",
|
|
134
134
|
ax=ax
|
|
135
135
|
)
|
|
136
|
+
disp.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
136
137
|
|
|
137
138
|
ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
138
139
|
ax.tick_params(axis='both', labelsize=base_fontsize)
|
|
@@ -327,7 +328,8 @@ def plot_calibration_curve(
|
|
|
327
328
|
target_name: str,
|
|
328
329
|
figure_size: tuple = (10, 10),
|
|
329
330
|
base_fontsize: int = 24,
|
|
330
|
-
n_bins: int = 15
|
|
331
|
+
n_bins: int = 15,
|
|
332
|
+
line_color: str = 'darkorange'
|
|
331
333
|
) -> plt.Figure: # type: ignore
|
|
332
334
|
"""
|
|
333
335
|
Plots the calibration curve (reliability diagram) for a classifier.
|
|
@@ -348,22 +350,63 @@ def plot_calibration_curve(
|
|
|
348
350
|
"""
|
|
349
351
|
fig, ax = plt.subplots(figsize=figure_size)
|
|
350
352
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
353
|
+
# --- Step 1: Get probabilities from the estimator ---
|
|
354
|
+
# We do this manually so we can pass them to from_predictions
|
|
355
|
+
try:
|
|
356
|
+
y_prob = model.predict_proba(x_test)
|
|
357
|
+
# Use probabilities for the positive class (assuming binary)
|
|
358
|
+
y_score = y_prob[:, 1]
|
|
359
|
+
except Exception as e:
|
|
360
|
+
_LOGGER.error(f"Could not get probabilities from model: {e}")
|
|
361
|
+
plt.close(fig)
|
|
362
|
+
return fig # Return empty figure
|
|
363
|
+
|
|
364
|
+
# --- Step 2: Get binned data *without* plotting ---
|
|
365
|
+
with plt.ioff():
|
|
366
|
+
fig_temp, ax_temp = plt.subplots()
|
|
367
|
+
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
368
|
+
y_test,
|
|
369
|
+
y_score,
|
|
370
|
+
n_bins=n_bins,
|
|
371
|
+
ax=ax_temp,
|
|
372
|
+
name="temp"
|
|
373
|
+
)
|
|
374
|
+
line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
|
|
375
|
+
plt.close(fig_temp)
|
|
376
|
+
|
|
377
|
+
# --- Step 3: Build the plot from scratch on ax ---
|
|
378
|
+
|
|
379
|
+
# 3a. Plot the ideal diagonal line
|
|
380
|
+
ax.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
381
|
+
|
|
382
|
+
# 3b. Use regplot for the regression line and its CI
|
|
383
|
+
sns.regplot(
|
|
384
|
+
x=line_x,
|
|
385
|
+
y=line_y,
|
|
386
|
+
ax=ax,
|
|
387
|
+
scatter=False, # No scatter dots
|
|
388
|
+
label=f"Calibration Curve ({n_bins} bins)",
|
|
389
|
+
line_kws={
|
|
390
|
+
'color': line_color,
|
|
391
|
+
'linestyle': '--',
|
|
392
|
+
'linewidth': 2
|
|
393
|
+
}
|
|
357
394
|
)
|
|
358
395
|
|
|
396
|
+
# --- Step 4: Apply original formatting ---
|
|
359
397
|
ax.set_title(f"{model_name} - Reliability Curve for {target_name}", fontsize=base_fontsize)
|
|
360
398
|
ax.tick_params(axis='both', labelsize=base_fontsize - 2)
|
|
361
399
|
ax.set_xlabel("Mean Predicted Probability", fontsize=base_fontsize)
|
|
362
400
|
ax.set_ylabel("Fraction of Positives", fontsize=base_fontsize)
|
|
363
|
-
|
|
401
|
+
|
|
402
|
+
# Set limits
|
|
403
|
+
ax.set_ylim(0.0, 1.0)
|
|
404
|
+
ax.set_xlim(0.0, 1.0)
|
|
405
|
+
|
|
406
|
+
ax.legend(fontsize=base_fontsize - 4, loc='lower right')
|
|
364
407
|
fig.tight_layout()
|
|
365
408
|
|
|
366
|
-
# Save figure
|
|
409
|
+
# --- Step 5: Save figure (using original logic) ---
|
|
367
410
|
save_path = make_fullpath(save_dir, make=True)
|
|
368
411
|
sanitized_target_name = sanitize_filename(target_name)
|
|
369
412
|
full_save_path = save_path / f"Calibration_Plot_{sanitized_target_name}.svg"
|
ml_tools/keys.py
CHANGED
|
@@ -36,6 +36,7 @@ class PyTorchInferenceKeys:
|
|
|
36
36
|
# For classification tasks
|
|
37
37
|
LABELS = "labels"
|
|
38
38
|
PROBABILITIES = "probabilities"
|
|
39
|
+
LABEL_NAMES = "label_names"
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
class PytorchModelArchitectureKeys:
|
|
@@ -55,10 +56,13 @@ class PytorchArtifactPathKeys:
|
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
class DatasetKeys:
|
|
58
|
-
"""Keys for saving dataset artifacts"""
|
|
59
|
+
"""Keys for saving dataset artifacts. Also used by FeatureSchema"""
|
|
59
60
|
FEATURE_NAMES = "feature_names"
|
|
60
61
|
TARGET_NAMES = "target_names"
|
|
61
62
|
SCALER_PREFIX = "scaler_"
|
|
63
|
+
# Feature Schema
|
|
64
|
+
CONTINUOUS_NAMES = "continuous_feature_names"
|
|
65
|
+
CATEGORICAL_NAMES = "categorical_feature_names"
|
|
62
66
|
|
|
63
67
|
|
|
64
68
|
class SHAPKeys:
|
|
@@ -77,6 +81,44 @@ class PyTorchCheckpointKeys:
|
|
|
77
81
|
BEST_SCORE = "best_score"
|
|
78
82
|
|
|
79
83
|
|
|
84
|
+
class UtilityKeys:
|
|
85
|
+
"""Keys used for utility modules"""
|
|
86
|
+
MODEL_PARAMS_FILE = "model_parameters"
|
|
87
|
+
TOTAL_PARAMS = "Total Parameters"
|
|
88
|
+
TRAINABLE_PARAMS = "Trainable Parameters"
|
|
89
|
+
PTH_FILE = "pth report "
|
|
90
|
+
MODEL_ARCHITECTURE_FILE = "model_architecture_summary"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class VisionKeys:
|
|
94
|
+
"""For vision ML metrics"""
|
|
95
|
+
SEGMENTATION_REPORT = "segmentation_report"
|
|
96
|
+
SEGMENTATION_HEATMAP = "segmentation_metrics_heatmap"
|
|
97
|
+
SEGMENTATION_CONFUSION_MATRIX = "segmentation_confusion_matrix"
|
|
98
|
+
# Object detection
|
|
99
|
+
OBJECT_DETECTION_REPORT = "object_detection_report"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class VisionTransformRecipeKeys:
|
|
103
|
+
"""Defines the key names for the transform recipe JSON file."""
|
|
104
|
+
TASK = "task"
|
|
105
|
+
PIPELINE = "pipeline"
|
|
106
|
+
NAME = "name"
|
|
107
|
+
KWARGS = "kwargs"
|
|
108
|
+
PRE_TRANSFORMS = "pre_transforms"
|
|
109
|
+
|
|
110
|
+
RESIZE_SIZE = "resize_size"
|
|
111
|
+
CROP_SIZE = "crop_size"
|
|
112
|
+
MEAN = "mean"
|
|
113
|
+
STD = "std"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ObjectDetectionKeys:
|
|
117
|
+
"""Used by the object detection dataset"""
|
|
118
|
+
BOXES = "boxes"
|
|
119
|
+
LABELS = "labels"
|
|
120
|
+
|
|
121
|
+
|
|
80
122
|
class _OneHotOtherPlaceholder:
|
|
81
123
|
"""Used internally by GUI_tools."""
|
|
82
124
|
OTHER_GUI = "OTHER"
|
ml_tools/math_utilities.py
CHANGED
|
@@ -219,7 +219,7 @@ def discretize_categorical_values(
|
|
|
219
219
|
_LOGGER.error(f"'categorical_info' is not a dictionary, or is empty.")
|
|
220
220
|
raise ValueError()
|
|
221
221
|
|
|
222
|
-
_, total_features =
|
|
222
|
+
_, total_features = working_array.shape
|
|
223
223
|
for col_idx, cardinality in categorical_info.items():
|
|
224
224
|
if not isinstance(col_idx, int):
|
|
225
225
|
_LOGGER.error(f"Column index key {col_idx} is not an integer.")
|
ml_tools/optimization_tools.py
CHANGED
|
@@ -9,6 +9,7 @@ from .utilities import yield_dataframes_from_dir
|
|
|
9
9
|
from ._logger import _LOGGER
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
from .SQL import DatabaseManager
|
|
12
|
+
from ._schema import FeatureSchema
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
__all__ = [
|
|
@@ -19,35 +20,25 @@ __all__ = [
|
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
def create_optimization_bounds(
|
|
22
|
-
|
|
23
|
+
schema: FeatureSchema,
|
|
23
24
|
continuous_bounds_map: Dict[str, Tuple[float, float]],
|
|
24
|
-
categorical_map: Dict[int, int],
|
|
25
|
-
target_column: Optional[str] = None,
|
|
26
25
|
start_at_zero: bool = True
|
|
27
26
|
) -> Tuple[List[float], List[float]]:
|
|
28
27
|
"""
|
|
29
|
-
Generates the lower and upper bounds lists for the optimizer from a
|
|
28
|
+
Generates the lower and upper bounds lists for the optimizer from a FeatureSchema.
|
|
30
29
|
|
|
31
30
|
This helper function automates the creation of unbiased bounds for
|
|
32
31
|
categorical features and combines them with user-defined bounds for
|
|
33
|
-
continuous features
|
|
34
|
-
|
|
35
|
-
It reads *only* the header of the provided CSV to determine the full
|
|
36
|
-
list of feature columns and their order, excluding the specified target.
|
|
37
|
-
This is memory-efficient as the full dataset is not loaded.
|
|
32
|
+
continuous features, using the schema as the single source of truth
|
|
33
|
+
for feature order and type.
|
|
38
34
|
|
|
39
35
|
Args:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
36
|
+
schema (FeatureSchema):
|
|
37
|
+
The definitive schema object created by
|
|
38
|
+
`data_exploration.finalize_feature_schema()`.
|
|
43
39
|
continuous_bounds_map (Dict[str, Tuple[float, float]]):
|
|
44
40
|
A dictionary mapping the *name* of each **continuous** feature
|
|
45
41
|
to its (min_bound, max_bound) tuple.
|
|
46
|
-
categorical_map (Dict[int, int]):
|
|
47
|
-
The map from the *index* of each **categorical** feature to its cardinality.
|
|
48
|
-
(e.g., {2: 4} for a feature at index 2 with 4 categories).
|
|
49
|
-
target_column (Optional[str], optional):
|
|
50
|
-
The name of the target column to exclude. If None (default), the *last column* in the CSV is assumed to be the target.
|
|
51
42
|
start_at_zero (bool):
|
|
52
43
|
- If True, assumes categorical encoding is [0, 1, ..., k-1].
|
|
53
44
|
Bounds will be set as [-0.5, k - 0.5].
|
|
@@ -59,98 +50,86 @@ def create_optimization_bounds(
|
|
|
59
50
|
A tuple containing two lists: (lower_bounds, upper_bounds).
|
|
60
51
|
|
|
61
52
|
Raises:
|
|
62
|
-
ValueError: If a feature is
|
|
63
|
-
|
|
64
|
-
|
|
53
|
+
ValueError: If a feature is missing from `continuous_bounds_map`
|
|
54
|
+
or if a feature name in the map is not a
|
|
55
|
+
continuous feature according to the schema.
|
|
65
56
|
"""
|
|
66
|
-
# 1.
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
df_header = pd.read_csv(full_csv_path, nrows=0, encoding="utf-8")
|
|
70
|
-
except Exception as e:
|
|
71
|
-
_LOGGER.error(f"Failed to read header from CSV: {e}")
|
|
72
|
-
raise
|
|
73
|
-
|
|
74
|
-
all_column_names = df_header.columns.to_list()
|
|
75
|
-
feature_names: List[str] = []
|
|
76
|
-
|
|
77
|
-
if target_column is None:
|
|
78
|
-
feature_names = all_column_names[:-1]
|
|
79
|
-
excluded_target = all_column_names[-1]
|
|
80
|
-
_LOGGER.info(f"No target_column provided. Assuming last column '{excluded_target}' is the target.")
|
|
81
|
-
else:
|
|
82
|
-
if target_column not in all_column_names:
|
|
83
|
-
_LOGGER.error(f"Target column '{target_column}' not found in CSV header.")
|
|
84
|
-
raise ValueError()
|
|
85
|
-
feature_names = [name for name in all_column_names if name != target_column]
|
|
86
|
-
_LOGGER.info(f"Excluding target column '{target_column}'.")
|
|
87
|
-
|
|
88
|
-
# 2. Initialize bound lists
|
|
57
|
+
# 1. Get feature names and map from schema
|
|
58
|
+
feature_names = schema.feature_names
|
|
59
|
+
categorical_index_map = schema.categorical_index_map
|
|
89
60
|
total_features = len(feature_names)
|
|
61
|
+
|
|
90
62
|
if total_features <= 0:
|
|
91
|
-
_LOGGER.error("
|
|
63
|
+
_LOGGER.error("Schema contains no features.")
|
|
92
64
|
raise ValueError()
|
|
65
|
+
|
|
66
|
+
_LOGGER.info(f"Generating bounds for {total_features} total features...")
|
|
93
67
|
|
|
68
|
+
# 2. Initialize bound lists
|
|
94
69
|
lower_bounds: List[Optional[float]] = [None] * total_features
|
|
95
70
|
upper_bounds: List[Optional[float]] = [None] * total_features
|
|
96
|
-
|
|
97
|
-
_LOGGER.info(f"Generating bounds for {total_features} total features...")
|
|
98
71
|
|
|
99
72
|
# 3. Populate categorical bounds (Index-based)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
upper_bounds[index] = high
|
|
73
|
+
if categorical_index_map:
|
|
74
|
+
for index, cardinality in categorical_index_map.items():
|
|
75
|
+
if not (0 <= index < total_features):
|
|
76
|
+
_LOGGER.error(f"Categorical index {index} is out of range for the {total_features} features.")
|
|
77
|
+
raise ValueError()
|
|
78
|
+
|
|
79
|
+
if start_at_zero:
|
|
80
|
+
# Rule for [0, k-1]: bounds are [-0.5, k - 0.5]
|
|
81
|
+
low = -0.5
|
|
82
|
+
high = float(cardinality) - 0.5
|
|
83
|
+
else:
|
|
84
|
+
# Rule for [1, k]: bounds are [0.5, k + 0.5]
|
|
85
|
+
low = 0.5
|
|
86
|
+
high = float(cardinality) + 0.5
|
|
87
|
+
|
|
88
|
+
lower_bounds[index] = low
|
|
89
|
+
upper_bounds[index] = high
|
|
118
90
|
|
|
119
|
-
|
|
91
|
+
_LOGGER.info(f"Automatically set bounds for {len(categorical_index_map)} categorical features.")
|
|
92
|
+
else:
|
|
93
|
+
_LOGGER.info("No categorical features found in schema.")
|
|
120
94
|
|
|
121
95
|
# 4. Populate continuous bounds (Name-based)
|
|
96
|
+
# Use schema.continuous_feature_names for robust checking
|
|
97
|
+
continuous_names_set = set(schema.continuous_feature_names)
|
|
98
|
+
|
|
99
|
+
if continuous_names_set != set(continuous_bounds_map.keys()):
|
|
100
|
+
missing_in_map = continuous_names_set - set(continuous_bounds_map.keys())
|
|
101
|
+
if missing_in_map:
|
|
102
|
+
_LOGGER.error(f"The following continuous features are missing from 'continuous_bounds_map': {list(missing_in_map)}")
|
|
103
|
+
|
|
104
|
+
extra_in_map = set(continuous_bounds_map.keys()) - continuous_names_set
|
|
105
|
+
if extra_in_map:
|
|
106
|
+
_LOGGER.error(f"The following features in 'continuous_bounds_map' are not defined as continuous in the schema: {list(extra_in_map)}")
|
|
107
|
+
|
|
108
|
+
raise ValueError("Mismatch between 'continuous_bounds_map' and schema's continuous features.")
|
|
109
|
+
|
|
122
110
|
count_continuous = 0
|
|
123
111
|
for name, (low, high) in continuous_bounds_map.items():
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
_LOGGER.warning(f"Feature name '{name}' from 'continuous_bounds_map' not found in the CSV's feature columns.")
|
|
129
|
-
continue
|
|
130
|
-
|
|
112
|
+
# Map name to its index in the *feature-only* list
|
|
113
|
+
# This is guaranteed to be correct by the schema
|
|
114
|
+
index = feature_names.index(name)
|
|
115
|
+
|
|
131
116
|
if lower_bounds[index] is not None:
|
|
132
|
-
# This
|
|
133
|
-
_LOGGER.error(f"Feature '{name}' (at index {index}) is defined
|
|
117
|
+
# This should be impossible if schema is correct, but good to check
|
|
118
|
+
_LOGGER.error(f"Schema conflict: Feature '{name}' (at index {index}) is defined as both continuous and categorical.")
|
|
134
119
|
raise ValueError()
|
|
135
|
-
|
|
120
|
+
|
|
136
121
|
lower_bounds[index] = float(low)
|
|
137
122
|
upper_bounds[index] = float(high)
|
|
138
123
|
count_continuous += 1
|
|
139
124
|
|
|
140
125
|
_LOGGER.info(f"Manually set bounds for {count_continuous} continuous features.")
|
|
141
126
|
|
|
142
|
-
# 5. Validation
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
if lower_bounds[i] is None:
|
|
146
|
-
missing_indices.append(i)
|
|
147
|
-
|
|
148
|
-
if missing_indices:
|
|
127
|
+
# 5. Final Validation (all Nones should be filled)
|
|
128
|
+
if None in lower_bounds:
|
|
129
|
+
missing_indices = [i for i, b in enumerate(lower_bounds) if b is None]
|
|
149
130
|
missing_names = [feature_names[i] for i in missing_indices]
|
|
150
|
-
_LOGGER.error(f"
|
|
151
|
-
raise
|
|
152
|
-
|
|
153
|
-
# _LOGGER.info("All bounds successfully created.")
|
|
131
|
+
_LOGGER.error(f"Failed to create all bounds. This indicates an internal logic error. Missing: {missing_names}")
|
|
132
|
+
raise RuntimeError("Internal error: Not all bounds were populated.")
|
|
154
133
|
|
|
155
134
|
# Cast to float lists, as 'None' sentinels are gone
|
|
156
135
|
return (
|
ml_tools/serde.py
CHANGED
|
@@ -6,15 +6,22 @@ from pathlib import Path
|
|
|
6
6
|
from .path_manager import make_fullpath, sanitize_filename
|
|
7
7
|
from ._script_info import _script_info
|
|
8
8
|
from ._logger import _LOGGER
|
|
9
|
+
from ._schema import FeatureSchema
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
__all__ = [
|
|
12
13
|
"serialize_object_filename",
|
|
13
14
|
"serialize_object",
|
|
14
15
|
"deserialize_object",
|
|
16
|
+
"serialize_schema",
|
|
17
|
+
"deserialize_schema"
|
|
15
18
|
]
|
|
16
19
|
|
|
17
20
|
|
|
21
|
+
# Base types that have a generic `type()` log.
|
|
22
|
+
_SIMPLE_TYPES = (list, dict, tuple, set, str, int, float, bool)
|
|
23
|
+
|
|
24
|
+
|
|
18
25
|
def serialize_object_filename(obj: Any, save_dir: Union[str,Path], filename: str, verbose: bool=True, raise_on_error: bool=False) -> None:
|
|
19
26
|
"""
|
|
20
27
|
Serializes a Python object using joblib; suitable for Python built-ins, numpy, and pandas.
|
|
@@ -24,22 +31,25 @@ def serialize_object_filename(obj: Any, save_dir: Union[str,Path], filename: str
|
|
|
24
31
|
save_dir (str | Path) : Directory path where the serialized object will be saved.
|
|
25
32
|
filename (str) : Name for the output file, extension will be appended if needed.
|
|
26
33
|
"""
|
|
34
|
+
if obj is None:
|
|
35
|
+
_LOGGER.warning(f"Attempted to serialize a None object. Skipping save for '{filename}'.")
|
|
36
|
+
return
|
|
37
|
+
|
|
27
38
|
try:
|
|
28
|
-
save_path = make_fullpath(save_dir, make=True)
|
|
39
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
29
40
|
sanitized_name = sanitize_filename(filename)
|
|
30
|
-
if not sanitized_name.endswith('.joblib'):
|
|
31
|
-
sanitized_name = sanitized_name + ".joblib"
|
|
32
41
|
full_path = save_path / sanitized_name
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
_LOGGER.error(f"Failed to serialize object of type '{type(obj)}'.")
|
|
42
|
+
except (IOError, OSError, TypeError) as e:
|
|
43
|
+
_LOGGER.error(f"Failed to construct save path from dir='{save_dir}' and filename='{filename}'. Error: {e}")
|
|
36
44
|
if raise_on_error:
|
|
37
45
|
raise e
|
|
38
46
|
return None
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
47
|
+
|
|
48
|
+
# call serialize_object with the fully constructed path.
|
|
49
|
+
serialize_object(obj=obj,
|
|
50
|
+
file_path=full_path,
|
|
51
|
+
verbose=verbose,
|
|
52
|
+
raise_on_error=raise_on_error)
|
|
43
53
|
|
|
44
54
|
|
|
45
55
|
def serialize_object(obj: Any, file_path: Path, verbose: bool = True, raise_on_error: bool = False) -> None:
|
|
@@ -54,10 +64,13 @@ def serialize_object(obj: Any, file_path: Path, verbose: bool = True, raise_on_e
|
|
|
54
64
|
'.joblib' extension will be appended if missing.
|
|
55
65
|
raise_on_error (bool) : If True, raises exceptions on failure.
|
|
56
66
|
"""
|
|
67
|
+
if obj is None:
|
|
68
|
+
_LOGGER.warning(f"Attempted to serialize a None object. Skipping save for '{file_path}'.")
|
|
69
|
+
return
|
|
70
|
+
|
|
57
71
|
try:
|
|
58
72
|
# Ensure the extension is correct
|
|
59
|
-
|
|
60
|
-
file_path = file_path.with_suffix(file_path.suffix + '.joblib')
|
|
73
|
+
file_path = file_path.with_suffix('.joblib')
|
|
61
74
|
|
|
62
75
|
# Ensure the parent directory exists
|
|
63
76
|
_save_dir = make_fullpath(file_path.parent, make=True, enforce="directory")
|
|
@@ -72,7 +85,11 @@ def serialize_object(obj: Any, file_path: Path, verbose: bool = True, raise_on_e
|
|
|
72
85
|
return None
|
|
73
86
|
else:
|
|
74
87
|
if verbose:
|
|
75
|
-
|
|
88
|
+
if type(obj) in _SIMPLE_TYPES:
|
|
89
|
+
_LOGGER.info(f"Object of type '{type(obj)}' saved to '{file_path}'")
|
|
90
|
+
else:
|
|
91
|
+
_LOGGER.info(f"Object '{obj}' saved to '{file_path}'")
|
|
92
|
+
|
|
76
93
|
return None
|
|
77
94
|
|
|
78
95
|
|
|
@@ -116,16 +133,60 @@ def deserialize_object(
|
|
|
116
133
|
# Can't do an isinstance check on 'Any', skip it.
|
|
117
134
|
if type_to_check is not Any and not isinstance(obj, type_to_check):
|
|
118
135
|
error_msg = (
|
|
119
|
-
f"Type mismatch: Expected an instance of '{expected_type}', "
|
|
120
|
-
f"but found '{type(obj)}' in '{true_filepath}'."
|
|
136
|
+
f"Type mismatch: Expected an instance of '{expected_type}', but found '{type(obj)}' in '{true_filepath}'."
|
|
121
137
|
)
|
|
122
138
|
_LOGGER.error(error_msg)
|
|
123
139
|
raise TypeError()
|
|
124
140
|
|
|
125
141
|
if verbose:
|
|
126
|
-
|
|
142
|
+
# log special objects
|
|
143
|
+
if type(obj) in _SIMPLE_TYPES:
|
|
144
|
+
_LOGGER.info(f"Loaded object of type '{type(obj)}' from '{true_filepath}'.")
|
|
145
|
+
else:
|
|
146
|
+
_LOGGER.info(f"Loaded object '{obj}' from '{true_filepath}'.")
|
|
127
147
|
|
|
128
|
-
return obj
|
|
148
|
+
return obj # type: ignore
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def serialize_schema(schema: FeatureSchema, file_path: Path):
|
|
152
|
+
"""
|
|
153
|
+
Serializes a FeatureSchema object to a .joblib file.
|
|
154
|
+
|
|
155
|
+
This is a high-level wrapper around `serialize_object` that
|
|
156
|
+
specifically handles `FeatureSchema` instances and ensures
|
|
157
|
+
errors are raised on failure.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
schema (FeatureSchema): The schema object to serialize.
|
|
161
|
+
file_path (Path): The full file path to save the schema to.
|
|
162
|
+
"""
|
|
163
|
+
serialize_object(obj=schema,
|
|
164
|
+
file_path=file_path,
|
|
165
|
+
verbose=True,
|
|
166
|
+
raise_on_error=True)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def deserialize_schema(file_path: Path):
|
|
170
|
+
"""
|
|
171
|
+
Deserializes a FeatureSchema object from a .joblib file.
|
|
172
|
+
|
|
173
|
+
This is a high-level wrapper around `deserialize_object` that
|
|
174
|
+
validates the loaded object is an instance of `FeatureSchema`.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
file_path (Path): The full file path of the serialized schema.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
FeatureSchema: The deserialized schema object.
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
TypeError: If the deserialized object is not an instance of `FeatureSchema`.
|
|
184
|
+
"""
|
|
185
|
+
schema = deserialize_object(filepath=file_path,
|
|
186
|
+
expected_type=FeatureSchema,
|
|
187
|
+
verbose=True)
|
|
188
|
+
return schema
|
|
189
|
+
|
|
129
190
|
|
|
130
191
|
def info():
|
|
131
192
|
_script_info(__all__)
|