dragon-ml-toolbox 13.0.0__py3-none-any.whl → 14.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/METADATA +12 -2
  2. dragon_ml_toolbox-14.7.0.dist-info/RECORD +49 -0
  3. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/MICE_imputation.py +207 -5
  5. ml_tools/ML_configuration.py +108 -0
  6. ml_tools/ML_datasetmaster.py +241 -260
  7. ml_tools/ML_evaluation.py +229 -76
  8. ml_tools/ML_evaluation_multi.py +45 -16
  9. ml_tools/ML_inference.py +0 -1
  10. ml_tools/ML_models.py +135 -55
  11. ml_tools/ML_models_advanced.py +323 -0
  12. ml_tools/ML_optimization.py +49 -36
  13. ml_tools/ML_trainer.py +498 -29
  14. ml_tools/ML_utilities.py +351 -4
  15. ml_tools/ML_vision_datasetmaster.py +1492 -0
  16. ml_tools/ML_vision_evaluation.py +260 -0
  17. ml_tools/ML_vision_inference.py +428 -0
  18. ml_tools/ML_vision_models.py +641 -0
  19. ml_tools/ML_vision_transformers.py +203 -0
  20. ml_tools/PSO_optimization.py +5 -1
  21. ml_tools/_ML_vision_recipe.py +88 -0
  22. ml_tools/__init__.py +1 -0
  23. ml_tools/_schema.py +96 -0
  24. ml_tools/custom_logger.py +37 -14
  25. ml_tools/data_exploration.py +576 -138
  26. ml_tools/ensemble_evaluation.py +53 -10
  27. ml_tools/keys.py +43 -1
  28. ml_tools/math_utilities.py +1 -1
  29. ml_tools/optimization_tools.py +65 -86
  30. ml_tools/serde.py +78 -17
  31. ml_tools/utilities.py +192 -3
  32. dragon_ml_toolbox-13.0.0.dist-info/RECORD +0 -41
  33. ml_tools/ML_simple_optimization.py +0 -413
  34. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/WHEEL +0 -0
  35. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE +0 -0
  36. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/top_level.txt +0 -0
@@ -112,7 +112,7 @@ def evaluate_model_classification(
112
112
  report_df = pd.DataFrame(report_dict).iloc[:-1, :].T
113
113
  plt.figure(figsize=figsize)
114
114
  sns.heatmap(report_df, annot=True, cmap=heatmap_cmap, fmt='.2f',
115
- annot_kws={"size": base_fontsize - 4})
115
+ annot_kws={"size": base_fontsize - 4}, vmin=0.0, vmax=1.0)
116
116
  plt.title(f"{model_name} - {target_name}", fontsize=base_fontsize)
117
117
  plt.xticks(fontsize=base_fontsize - 2)
118
118
  plt.yticks(fontsize=base_fontsize - 2)
@@ -133,6 +133,7 @@ def evaluate_model_classification(
133
133
  normalize="true",
134
134
  ax=ax
135
135
  )
136
+ disp.im_.set_clim(vmin=0.0, vmax=1.0)
136
137
 
137
138
  ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
138
139
  ax.tick_params(axis='both', labelsize=base_fontsize)
@@ -327,7 +328,8 @@ def plot_calibration_curve(
327
328
  target_name: str,
328
329
  figure_size: tuple = (10, 10),
329
330
  base_fontsize: int = 24,
330
- n_bins: int = 15
331
+ n_bins: int = 15,
332
+ line_color: str = 'darkorange'
331
333
  ) -> plt.Figure: # type: ignore
332
334
  """
333
335
  Plots the calibration curve (reliability diagram) for a classifier.
@@ -348,22 +350,63 @@ def plot_calibration_curve(
348
350
  """
349
351
  fig, ax = plt.subplots(figsize=figure_size)
350
352
 
351
- disp = CalibrationDisplay.from_estimator(
352
- model,
353
- x_test,
354
- y_test,
355
- n_bins=n_bins,
356
- ax=ax
353
+ # --- Step 1: Get probabilities from the estimator ---
354
+ # We do this manually so we can pass them to from_predictions
355
+ try:
356
+ y_prob = model.predict_proba(x_test)
357
+ # Use probabilities for the positive class (assuming binary)
358
+ y_score = y_prob[:, 1]
359
+ except Exception as e:
360
+ _LOGGER.error(f"Could not get probabilities from model: {e}")
361
+ plt.close(fig)
362
+ return fig # Return empty figure
363
+
364
+ # --- Step 2: Get binned data *without* plotting ---
365
+ with plt.ioff():
366
+ fig_temp, ax_temp = plt.subplots()
367
+ cal_display_temp = CalibrationDisplay.from_predictions(
368
+ y_test,
369
+ y_score,
370
+ n_bins=n_bins,
371
+ ax=ax_temp,
372
+ name="temp"
373
+ )
374
+ line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
375
+ plt.close(fig_temp)
376
+
377
+ # --- Step 3: Build the plot from scratch on ax ---
378
+
379
+ # 3a. Plot the ideal diagonal line
380
+ ax.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
381
+
382
+ # 3b. Use regplot for the regression line and its CI
383
+ sns.regplot(
384
+ x=line_x,
385
+ y=line_y,
386
+ ax=ax,
387
+ scatter=False, # No scatter dots
388
+ label=f"Calibration Curve ({n_bins} bins)",
389
+ line_kws={
390
+ 'color': line_color,
391
+ 'linestyle': '--',
392
+ 'linewidth': 2
393
+ }
357
394
  )
358
395
 
396
+ # --- Step 4: Apply original formatting ---
359
397
  ax.set_title(f"{model_name} - Reliability Curve for {target_name}", fontsize=base_fontsize)
360
398
  ax.tick_params(axis='both', labelsize=base_fontsize - 2)
361
399
  ax.set_xlabel("Mean Predicted Probability", fontsize=base_fontsize)
362
400
  ax.set_ylabel("Fraction of Positives", fontsize=base_fontsize)
363
- ax.legend(fontsize=base_fontsize - 4)
401
+
402
+ # Set limits
403
+ ax.set_ylim(0.0, 1.0)
404
+ ax.set_xlim(0.0, 1.0)
405
+
406
+ ax.legend(fontsize=base_fontsize - 4, loc='lower right')
364
407
  fig.tight_layout()
365
408
 
366
- # Save figure
409
+ # --- Step 5: Save figure (using original logic) ---
367
410
  save_path = make_fullpath(save_dir, make=True)
368
411
  sanitized_target_name = sanitize_filename(target_name)
369
412
  full_save_path = save_path / f"Calibration_Plot_{sanitized_target_name}.svg"
ml_tools/keys.py CHANGED
@@ -36,6 +36,7 @@ class PyTorchInferenceKeys:
36
36
  # For classification tasks
37
37
  LABELS = "labels"
38
38
  PROBABILITIES = "probabilities"
39
+ LABEL_NAMES = "label_names"
39
40
 
40
41
 
41
42
  class PytorchModelArchitectureKeys:
@@ -55,10 +56,13 @@ class PytorchArtifactPathKeys:
55
56
 
56
57
 
57
58
  class DatasetKeys:
58
- """Keys for saving dataset artifacts"""
59
+ """Keys for saving dataset artifacts. Also used by FeatureSchema"""
59
60
  FEATURE_NAMES = "feature_names"
60
61
  TARGET_NAMES = "target_names"
61
62
  SCALER_PREFIX = "scaler_"
63
+ # Feature Schema
64
+ CONTINUOUS_NAMES = "continuous_feature_names"
65
+ CATEGORICAL_NAMES = "categorical_feature_names"
62
66
 
63
67
 
64
68
  class SHAPKeys:
@@ -77,6 +81,44 @@ class PyTorchCheckpointKeys:
77
81
  BEST_SCORE = "best_score"
78
82
 
79
83
 
84
+ class UtilityKeys:
85
+ """Keys used for utility modules"""
86
+ MODEL_PARAMS_FILE = "model_parameters"
87
+ TOTAL_PARAMS = "Total Parameters"
88
+ TRAINABLE_PARAMS = "Trainable Parameters"
89
+ PTH_FILE = "pth report "
90
+ MODEL_ARCHITECTURE_FILE = "model_architecture_summary"
91
+
92
+
93
+ class VisionKeys:
94
+ """For vision ML metrics"""
95
+ SEGMENTATION_REPORT = "segmentation_report"
96
+ SEGMENTATION_HEATMAP = "segmentation_metrics_heatmap"
97
+ SEGMENTATION_CONFUSION_MATRIX = "segmentation_confusion_matrix"
98
+ # Object detection
99
+ OBJECT_DETECTION_REPORT = "object_detection_report"
100
+
101
+
102
+ class VisionTransformRecipeKeys:
103
+ """Defines the key names for the transform recipe JSON file."""
104
+ TASK = "task"
105
+ PIPELINE = "pipeline"
106
+ NAME = "name"
107
+ KWARGS = "kwargs"
108
+ PRE_TRANSFORMS = "pre_transforms"
109
+
110
+ RESIZE_SIZE = "resize_size"
111
+ CROP_SIZE = "crop_size"
112
+ MEAN = "mean"
113
+ STD = "std"
114
+
115
+
116
+ class ObjectDetectionKeys:
117
+ """Used by the object detection dataset"""
118
+ BOXES = "boxes"
119
+ LABELS = "labels"
120
+
121
+
80
122
  class _OneHotOtherPlaceholder:
81
123
  """Used internally by GUI_tools."""
82
124
  OTHER_GUI = "OTHER"
@@ -219,7 +219,7 @@ def discretize_categorical_values(
219
219
  _LOGGER.error(f"'categorical_info' is not a dictionary, or is empty.")
220
220
  raise ValueError()
221
221
 
222
- _, total_features = input_array.shape
222
+ _, total_features = working_array.shape
223
223
  for col_idx, cardinality in categorical_info.items():
224
224
  if not isinstance(col_idx, int):
225
225
  _LOGGER.error(f"Column index key {col_idx} is not an integer.")
@@ -9,6 +9,7 @@ from .utilities import yield_dataframes_from_dir
9
9
  from ._logger import _LOGGER
10
10
  from ._script_info import _script_info
11
11
  from .SQL import DatabaseManager
12
+ from ._schema import FeatureSchema
12
13
 
13
14
 
14
15
  __all__ = [
@@ -19,35 +20,25 @@ __all__ = [
19
20
 
20
21
 
21
22
  def create_optimization_bounds(
22
- csv_path: Union[str, Path],
23
+ schema: FeatureSchema,
23
24
  continuous_bounds_map: Dict[str, Tuple[float, float]],
24
- categorical_map: Dict[int, int],
25
- target_column: Optional[str] = None,
26
25
  start_at_zero: bool = True
27
26
  ) -> Tuple[List[float], List[float]]:
28
27
  """
29
- Generates the lower and upper bounds lists for the optimizer from a CSV header.
28
+ Generates the lower and upper bounds lists for the optimizer from a FeatureSchema.
30
29
 
31
30
  This helper function automates the creation of unbiased bounds for
32
31
  categorical features and combines them with user-defined bounds for
33
- continuous features.
34
-
35
- It reads *only* the header of the provided CSV to determine the full
36
- list of feature columns and their order, excluding the specified target.
37
- This is memory-efficient as the full dataset is not loaded.
32
+ continuous features, using the schema as the single source of truth
33
+ for feature order and type.
38
34
 
39
35
  Args:
40
- csv_path (Union[str, Path]):
41
- Path to the final, preprocessed CSV file. The column order in
42
- this file must match the order expected by the model.
36
+ schema (FeatureSchema):
37
+ The definitive schema object created by
38
+ `data_exploration.finalize_feature_schema()`.
43
39
  continuous_bounds_map (Dict[str, Tuple[float, float]]):
44
40
  A dictionary mapping the *name* of each **continuous** feature
45
41
  to its (min_bound, max_bound) tuple.
46
- categorical_map (Dict[int, int]):
47
- The map from the *index* of each **categorical** feature to its cardinality.
48
- (e.g., {2: 4} for a feature at index 2 with 4 categories).
49
- target_column (Optional[str], optional):
50
- The name of the target column to exclude. If None (default), the *last column* in the CSV is assumed to be the target.
51
42
  start_at_zero (bool):
52
43
  - If True, assumes categorical encoding is [0, 1, ..., k-1].
53
44
  Bounds will be set as [-0.5, k - 0.5].
@@ -59,98 +50,86 @@ def create_optimization_bounds(
59
50
  A tuple containing two lists: (lower_bounds, upper_bounds).
60
51
 
61
52
  Raises:
62
- ValueError: If a feature is defined in both maps, is missing from
63
- both maps, or if a name in `continuous_bounds_map`
64
- or `target_column` is not found in the CSV columns.
53
+ ValueError: If a feature is missing from `continuous_bounds_map`
54
+ or if a feature name in the map is not a
55
+ continuous feature according to the schema.
65
56
  """
66
- # 1. Read header and determine feature names
67
- full_csv_path = make_fullpath(csv_path, enforce="file")
68
- try:
69
- df_header = pd.read_csv(full_csv_path, nrows=0, encoding="utf-8")
70
- except Exception as e:
71
- _LOGGER.error(f"Failed to read header from CSV: {e}")
72
- raise
73
-
74
- all_column_names = df_header.columns.to_list()
75
- feature_names: List[str] = []
76
-
77
- if target_column is None:
78
- feature_names = all_column_names[:-1]
79
- excluded_target = all_column_names[-1]
80
- _LOGGER.info(f"No target_column provided. Assuming last column '{excluded_target}' is the target.")
81
- else:
82
- if target_column not in all_column_names:
83
- _LOGGER.error(f"Target column '{target_column}' not found in CSV header.")
84
- raise ValueError()
85
- feature_names = [name for name in all_column_names if name != target_column]
86
- _LOGGER.info(f"Excluding target column '{target_column}'.")
87
-
88
- # 2. Initialize bound lists
57
+ # 1. Get feature names and map from schema
58
+ feature_names = schema.feature_names
59
+ categorical_index_map = schema.categorical_index_map
89
60
  total_features = len(feature_names)
61
+
90
62
  if total_features <= 0:
91
- _LOGGER.error("No feature columns remain after excluding the target.")
63
+ _LOGGER.error("Schema contains no features.")
92
64
  raise ValueError()
65
+
66
+ _LOGGER.info(f"Generating bounds for {total_features} total features...")
93
67
 
68
+ # 2. Initialize bound lists
94
69
  lower_bounds: List[Optional[float]] = [None] * total_features
95
70
  upper_bounds: List[Optional[float]] = [None] * total_features
96
-
97
- _LOGGER.info(f"Generating bounds for {total_features} total features...")
98
71
 
99
72
  # 3. Populate categorical bounds (Index-based)
100
- # The indices in categorical_map (e.g., {2: 4}) directly correspond
101
- # to the indices in the `feature_names` list.
102
- for index, cardinality in categorical_map.items():
103
- if not (0 <= index < total_features):
104
- _LOGGER.error(f"Categorical index {index} is out of range for the {total_features} features.")
105
- raise ValueError()
106
-
107
- if start_at_zero:
108
- # Rule for [0, k-1]: bounds are [-0.5, k - 0.5]
109
- low = -0.5
110
- high = float(cardinality) - 0.5
111
- else:
112
- # Rule for [1, k]: bounds are [0.5, k + 0.5]
113
- low = 0.5
114
- high = float(cardinality) + 0.5
115
-
116
- lower_bounds[index] = low
117
- upper_bounds[index] = high
73
+ if categorical_index_map:
74
+ for index, cardinality in categorical_index_map.items():
75
+ if not (0 <= index < total_features):
76
+ _LOGGER.error(f"Categorical index {index} is out of range for the {total_features} features.")
77
+ raise ValueError()
78
+
79
+ if start_at_zero:
80
+ # Rule for [0, k-1]: bounds are [-0.5, k - 0.5]
81
+ low = -0.5
82
+ high = float(cardinality) - 0.5
83
+ else:
84
+ # Rule for [1, k]: bounds are [0.5, k + 0.5]
85
+ low = 0.5
86
+ high = float(cardinality) + 0.5
87
+
88
+ lower_bounds[index] = low
89
+ upper_bounds[index] = high
118
90
 
119
- _LOGGER.info(f"Automatically set bounds for {len(categorical_map)} categorical features.")
91
+ _LOGGER.info(f"Automatically set bounds for {len(categorical_index_map)} categorical features.")
92
+ else:
93
+ _LOGGER.info("No categorical features found in schema.")
120
94
 
121
95
  # 4. Populate continuous bounds (Name-based)
96
+ # Use schema.continuous_feature_names for robust checking
97
+ continuous_names_set = set(schema.continuous_feature_names)
98
+
99
+ if continuous_names_set != set(continuous_bounds_map.keys()):
100
+ missing_in_map = continuous_names_set - set(continuous_bounds_map.keys())
101
+ if missing_in_map:
102
+ _LOGGER.error(f"The following continuous features are missing from 'continuous_bounds_map': {list(missing_in_map)}")
103
+
104
+ extra_in_map = set(continuous_bounds_map.keys()) - continuous_names_set
105
+ if extra_in_map:
106
+ _LOGGER.error(f"The following features in 'continuous_bounds_map' are not defined as continuous in the schema: {list(extra_in_map)}")
107
+
108
+ raise ValueError("Mismatch between 'continuous_bounds_map' and schema's continuous features.")
109
+
122
110
  count_continuous = 0
123
111
  for name, (low, high) in continuous_bounds_map.items():
124
- try:
125
- # Map name to its index in the *feature-only* list
126
- index = feature_names.index(name)
127
- except ValueError:
128
- _LOGGER.warning(f"Feature name '{name}' from 'continuous_bounds_map' not found in the CSV's feature columns.")
129
- continue
130
-
112
+ # Map name to its index in the *feature-only* list
113
+ # This is guaranteed to be correct by the schema
114
+ index = feature_names.index(name)
115
+
131
116
  if lower_bounds[index] is not None:
132
- # This index was already set by the categorical map
133
- _LOGGER.error(f"Feature '{name}' (at index {index}) is defined in both 'categorical_map' and 'continuous_bounds_map'.")
117
+ # This should be impossible if schema is correct, but good to check
118
+ _LOGGER.error(f"Schema conflict: Feature '{name}' (at index {index}) is defined as both continuous and categorical.")
134
119
  raise ValueError()
135
-
120
+
136
121
  lower_bounds[index] = float(low)
137
122
  upper_bounds[index] = float(high)
138
123
  count_continuous += 1
139
124
 
140
125
  _LOGGER.info(f"Manually set bounds for {count_continuous} continuous features.")
141
126
 
142
- # 5. Validation: Check for any remaining None values
143
- missing_indices = []
144
- for i in range(total_features):
145
- if lower_bounds[i] is None:
146
- missing_indices.append(i)
147
-
148
- if missing_indices:
127
+ # 5. Final Validation (all Nones should be filled)
128
+ if None in lower_bounds:
129
+ missing_indices = [i for i, b in enumerate(lower_bounds) if b is None]
149
130
  missing_names = [feature_names[i] for i in missing_indices]
150
- _LOGGER.error(f"Bounds not defined for all features. Missing: {missing_names}")
151
- raise ValueError()
152
-
153
- # _LOGGER.info("All bounds successfully created.")
131
+ _LOGGER.error(f"Failed to create all bounds. This indicates an internal logic error. Missing: {missing_names}")
132
+ raise RuntimeError("Internal error: Not all bounds were populated.")
154
133
 
155
134
  # Cast to float lists, as 'None' sentinels are gone
156
135
  return (
ml_tools/serde.py CHANGED
@@ -6,15 +6,22 @@ from pathlib import Path
6
6
  from .path_manager import make_fullpath, sanitize_filename
7
7
  from ._script_info import _script_info
8
8
  from ._logger import _LOGGER
9
+ from ._schema import FeatureSchema
9
10
 
10
11
 
11
12
  __all__ = [
12
13
  "serialize_object_filename",
13
14
  "serialize_object",
14
15
  "deserialize_object",
16
+ "serialize_schema",
17
+ "deserialize_schema"
15
18
  ]
16
19
 
17
20
 
21
+ # Base types that have a generic `type()` log.
22
+ _SIMPLE_TYPES = (list, dict, tuple, set, str, int, float, bool)
23
+
24
+
18
25
  def serialize_object_filename(obj: Any, save_dir: Union[str,Path], filename: str, verbose: bool=True, raise_on_error: bool=False) -> None:
19
26
  """
20
27
  Serializes a Python object using joblib; suitable for Python built-ins, numpy, and pandas.
@@ -24,22 +31,25 @@ def serialize_object_filename(obj: Any, save_dir: Union[str,Path], filename: str
24
31
  save_dir (str | Path) : Directory path where the serialized object will be saved.
25
32
  filename (str) : Name for the output file, extension will be appended if needed.
26
33
  """
34
+ if obj is None:
35
+ _LOGGER.warning(f"Attempted to serialize a None object. Skipping save for '{filename}'.")
36
+ return
37
+
27
38
  try:
28
- save_path = make_fullpath(save_dir, make=True)
39
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
29
40
  sanitized_name = sanitize_filename(filename)
30
- if not sanitized_name.endswith('.joblib'):
31
- sanitized_name = sanitized_name + ".joblib"
32
41
  full_path = save_path / sanitized_name
33
- joblib.dump(obj, full_path)
34
- except (IOError, OSError, TypeError, TerminatedWorkerError) as e:
35
- _LOGGER.error(f"Failed to serialize object of type '{type(obj)}'.")
42
+ except (IOError, OSError, TypeError) as e:
43
+ _LOGGER.error(f"Failed to construct save path from dir='{save_dir}' and filename='{filename}'. Error: {e}")
36
44
  if raise_on_error:
37
45
  raise e
38
46
  return None
39
- else:
40
- if verbose:
41
- _LOGGER.info(f"Object of type '{type(obj)}' saved to '{full_path}'")
42
- return None
47
+
48
+ # call serialize_object with the fully constructed path.
49
+ serialize_object(obj=obj,
50
+ file_path=full_path,
51
+ verbose=verbose,
52
+ raise_on_error=raise_on_error)
43
53
 
44
54
 
45
55
  def serialize_object(obj: Any, file_path: Path, verbose: bool = True, raise_on_error: bool = False) -> None:
@@ -54,10 +64,13 @@ def serialize_object(obj: Any, file_path: Path, verbose: bool = True, raise_on_e
54
64
  '.joblib' extension will be appended if missing.
55
65
  raise_on_error (bool) : If True, raises exceptions on failure.
56
66
  """
67
+ if obj is None:
68
+ _LOGGER.warning(f"Attempted to serialize a None object. Skipping save for '{file_path}'.")
69
+ return
70
+
57
71
  try:
58
72
  # Ensure the extension is correct
59
- if file_path.suffix != '.joblib':
60
- file_path = file_path.with_suffix(file_path.suffix + '.joblib')
73
+ file_path = file_path.with_suffix('.joblib')
61
74
 
62
75
  # Ensure the parent directory exists
63
76
  _save_dir = make_fullpath(file_path.parent, make=True, enforce="directory")
@@ -72,7 +85,11 @@ def serialize_object(obj: Any, file_path: Path, verbose: bool = True, raise_on_e
72
85
  return None
73
86
  else:
74
87
  if verbose:
75
- _LOGGER.info(f"Object of type '{type(obj)}' saved to '{file_path}'")
88
+ if type(obj) in _SIMPLE_TYPES:
89
+ _LOGGER.info(f"Object of type '{type(obj)}' saved to '{file_path}'")
90
+ else:
91
+ _LOGGER.info(f"Object '{obj}' saved to '{file_path}'")
92
+
76
93
  return None
77
94
 
78
95
 
@@ -116,16 +133,60 @@ def deserialize_object(
116
133
  # Can't do an isinstance check on 'Any', skip it.
117
134
  if type_to_check is not Any and not isinstance(obj, type_to_check):
118
135
  error_msg = (
119
- f"Type mismatch: Expected an instance of '{expected_type}', "
120
- f"but found '{type(obj)}' in '{true_filepath}'."
136
+ f"Type mismatch: Expected an instance of '{expected_type}', but found '{type(obj)}' in '{true_filepath}'."
121
137
  )
122
138
  _LOGGER.error(error_msg)
123
139
  raise TypeError()
124
140
 
125
141
  if verbose:
126
- _LOGGER.info(f"Loaded object of type '{type(obj)}' from '{true_filepath}'.")
142
+ # log special objects
143
+ if type(obj) in _SIMPLE_TYPES:
144
+ _LOGGER.info(f"Loaded object of type '{type(obj)}' from '{true_filepath}'.")
145
+ else:
146
+ _LOGGER.info(f"Loaded object '{obj}' from '{true_filepath}'.")
127
147
 
128
- return obj
148
+ return obj # type: ignore
149
+
150
+
151
+ def serialize_schema(schema: FeatureSchema, file_path: Path):
152
+ """
153
+ Serializes a FeatureSchema object to a .joblib file.
154
+
155
+ This is a high-level wrapper around `serialize_object` that
156
+ specifically handles `FeatureSchema` instances and ensures
157
+ errors are raised on failure.
158
+
159
+ Args:
160
+ schema (FeatureSchema): The schema object to serialize.
161
+ file_path (Path): The full file path to save the schema to.
162
+ """
163
+ serialize_object(obj=schema,
164
+ file_path=file_path,
165
+ verbose=True,
166
+ raise_on_error=True)
167
+
168
+
169
+ def deserialize_schema(file_path: Path):
170
+ """
171
+ Deserializes a FeatureSchema object from a .joblib file.
172
+
173
+ This is a high-level wrapper around `deserialize_object` that
174
+ validates the loaded object is an instance of `FeatureSchema`.
175
+
176
+ Args:
177
+ file_path (Path): The full file path of the serialized schema.
178
+
179
+ Returns:
180
+ FeatureSchema: The deserialized schema object.
181
+
182
+ Raises:
183
+ TypeError: If the deserialized object is not an instance of `FeatureSchema`.
184
+ """
185
+ schema = deserialize_object(filepath=file_path,
186
+ expected_type=FeatureSchema,
187
+ verbose=True)
188
+ return schema
189
+
129
190
 
130
191
  def info():
131
192
  _script_info(__all__)