dragon-ml-toolbox 13.7.0__py3-none-any.whl → 14.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,543 @@
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from typing import List, Literal, Union, Optional, Dict, Any
6
+ from pathlib import Path
7
+ import warnings
8
+
9
+ # --- Third-party imports ---
10
+ try:
11
+ from pytorch_tabular.models.common.heads import LinearHeadConfig
12
+ from pytorch_tabular.config import (
13
+ DataConfig,
14
+ ModelConfig,
15
+ OptimizerConfig,
16
+ TrainerConfig,
17
+ ExperimentConfig,
18
+ )
19
+ from pytorch_tabular.models import (
20
+ CategoryEmbeddingModelConfig,
21
+ TabNetModelConfig,
22
+ TabTransformerConfig,
23
+ FTTransformerConfig,
24
+ AutoIntConfig,
25
+ NodeConfig,
26
+ GANDALFConfig
27
+ )
28
+ from pytorch_tabular.tabular_model import TabularModel
29
+ except ImportError:
30
+ print("----------------------------------------------------------------")
31
+ print("ERROR: `pytorch-tabular` is not installed.")
32
+ print("Please install it to use the models in this script:")
33
+ print('\npip install "dragon-ml-toolbox[py-tab]"')
34
+ print("----------------------------------------------------------------")
35
+ raise
36
+
37
+ # --- Local ML-Tools imports ---
38
+ from ._logger import _LOGGER
39
+ from ._script_info import _script_info
40
+ from ._schema import FeatureSchema
41
+ from .path_manager import make_fullpath, sanitize_filename
42
+ from .keys import SHAPKeys
43
+ from .ML_datasetmaster import _PytorchDataset
44
+ from .ML_evaluation import (
45
+ classification_metrics,
46
+ regression_metrics
47
+ )
48
+ from .ML_evaluation_multi import (
49
+ multi_target_regression_metrics,
50
+ multi_label_classification_metrics
51
+ )
52
+
53
+
54
+ __all__ = [
55
+ "PyTabularTrainer"
56
+ ]
57
+
58
+
59
+ # --- Model Configuration Mapping ---
60
+ # Maps a simple string name to the required ModelConfig class
61
+ SUPPORTED_MODELS: Dict[str, Any] = {
62
+ "TabNet": TabNetModelConfig,
63
+ "TabTransformer": TabTransformerConfig,
64
+ "FTTransformer": FTTransformerConfig,
65
+ "AutoInt": AutoIntConfig,
66
+ "NODE": NodeConfig,
67
+ "GATE": GANDALFConfig, # Gated Additive Tree Ensemble
68
+ "CategoryEmbedding": CategoryEmbeddingModelConfig, # A basic MLP
69
+ }
70
+
71
+
72
+ class PyTabularTrainer:
73
+ """
74
+ A wrapper for models from the `pytorch-tabular` library, designed to be
75
+ compatible with the `dragon-ml-toolbox` ecosystem.
76
+
77
+ This class acts as a high-level trainer that adapts the `ML_datasetmaster`
78
+ datasets into the format required by `pytorch-tabular` and routes
79
+ evaluation results to the standard `ML_evaluation` functions.
80
+
81
+ It handles:
82
+ - Automatic `DataConfig` creation from a `FeatureSchema`.
83
+ - Model and Trainer configuration.
84
+ - Training and evaluation.
85
+ - SHAP explanations.
86
+ """
87
+
88
+ def __init__(self,
89
+ schema: FeatureSchema,
90
+ target_names: List[str],
91
+ kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
92
+ model_name: str,
93
+ model_config_params: Optional[Dict[str, Any]] = None,
94
+ optimizer_config_params: Optional[Dict[str, Any]] = None,
95
+ trainer_config_params: Optional[Dict[str, Any]] = None):
96
+ """
97
+ Initializes the Model, Data, and Trainer configurations.
98
+
99
+ Args:
100
+ schema (FeatureSchema):
101
+ The definitive schema object from data_exploration.
102
+ target_names (List[str]):
103
+ A list of target column names.
104
+ kind (Literal[...]):
105
+ The type of ML task. This is used to set the `pytorch-tabular`
106
+ task and to route to the correct evaluation function.
107
+ model_name (str):
108
+ The name of the model to use. Must be one of:
109
+ "TabNet", "TabTransformer", "FTTransformer", "AutoInt",
110
+ "NODE", "GATE", "CategoryEmbedding".
111
+ model_config_params (Dict, optional):
112
+ Overrides for the chosen model's `ModelConfig`.
113
+ (e.g., `{"n_d": 16, "n_a": 16}` for TabNet).
114
+ optimizer_config_params (Dict, optional):
115
+ Overrides for the `OptimizerConfig` (e.g., `{"lr": 0.005}`).
116
+ trainer_config_params (Dict, optional):
117
+ Overrides for the `TrainerConfig` (e.g., `{"max_epochs": 100}`).
118
+ """
119
+ _LOGGER.info(f"Initializing PyTabularTrainer for model: {model_name}")
120
+
121
+ # --- 1. Store key info ---
122
+ self.schema = schema
123
+ self.target_names = target_names
124
+ self.kind = kind
125
+ self.model_name = model_name
126
+ self._is_fitted = False
127
+
128
+ if model_name not in SUPPORTED_MODELS:
129
+ _LOGGER.error(f"Model '{model_name}' is not supported. Choose from: {list(SUPPORTED_MODELS.keys())}")
130
+ raise ValueError(f"Unsupported model: {model_name}")
131
+
132
+ # --- 2. Map ML-Tools 'kind' to pytorch-tabular 'task' ---
133
+ if kind == "regression":
134
+ self.task = "regression"
135
+ self._pt_target_names = target_names
136
+ elif kind == "classification":
137
+ self.task = "classification"
138
+ self._pt_target_names = target_names
139
+ elif kind == "multi_target_regression":
140
+ self.task = "multi-label-regression" # pytorch-tabular's name
141
+ self._pt_target_names = target_names
142
+ elif kind == "multi_label_classification":
143
+ self.task = "multi-label-classification"
144
+ self._pt_target_names = target_names
145
+ else:
146
+ _LOGGER.error(f"Unknown task 'kind': {kind}")
147
+ raise ValueError()
148
+
149
+ # --- 3. Create DataConfig from FeatureSchema ---
150
+ # Note: pytorch-tabular handles scaling internally
151
+ self.data_config = DataConfig(
152
+ target=self._pt_target_names,
153
+ continuous_cols=list(schema.continuous_feature_names),
154
+ categorical_cols=list(schema.categorical_feature_names),
155
+ continuous_feature_transform="quantile_normal",
156
+ )
157
+
158
+ # --- 4. Create ModelConfig ---
159
+ model_config_class = SUPPORTED_MODELS[model_name]
160
+
161
+ # Apply user overrides
162
+ if model_config_params is None:
163
+ model_config_params = {}
164
+
165
+ # Set task in params
166
+ model_config_params["task"] = self.task
167
+
168
+ # Handle multi-target output for regression
169
+ if self.task == "multi-label-regression":
170
+ # Must configure the model's output head
171
+ if "head" not in model_config_params:
172
+ _LOGGER.info("Configuring model head for multi-target regression.")
173
+ model_config_params["head"] = "LinearHead"
174
+ model_config_params["head_config"] = {
175
+ "layers": "", # No hidden layers in the head
176
+ "output_dim": len(self.target_names)
177
+ }
178
+
179
+ self.model_config = model_config_class(**model_config_params)
180
+
181
+ # --- 5. Create OptimizerConfig ---
182
+ if optimizer_config_params is None:
183
+ optimizer_config_params = {}
184
+ self.optimizer_config = OptimizerConfig(**optimizer_config_params)
185
+
186
+ # --- 6. Create TrainerConfig ---
187
+ if trainer_config_params is None:
188
+ trainer_config_params = {}
189
+
190
+ # Default to GPU if available
191
+ if "accelerator" not in trainer_config_params:
192
+ if torch.cuda.is_available():
193
+ trainer_config_params["accelerator"] = "cuda"
194
+ elif torch.backends.mps.is_available():
195
+ trainer_config_params["accelerator"] = "mps"
196
+ else:
197
+ trainer_config_params["accelerator"] = "cpu"
198
+
199
+ # Set other sensible defaults
200
+ if "checkpoints" not in trainer_config_params:
201
+ trainer_config_params["checkpoints"] = "val_loss"
202
+ trainer_config_params["load_best_at_end"] = True
203
+
204
+ if "early_stopping" not in trainer_config_params:
205
+ trainer_config_params["early_stopping"] = "val_loss"
206
+
207
+ self.trainer_config = TrainerConfig(**trainer_config_params)
208
+
209
+ # --- 7. Instantiate the TabularModel ---
210
+ self.tabular_model = TabularModel(
211
+ data_config=self.data_config,
212
+ model_config=self.model_config,
213
+ optimizer_config=self.optimizer_config,
214
+ trainer_config=self.trainer_config,
215
+ )
216
+
217
+ def _dataset_to_dataframe(self, dataset: _PytorchDataset) -> pd.DataFrame:
218
+ """Converts an _PytorchDataset back into a pandas DataFrame."""
219
+ try:
220
+ features_np = dataset.features.cpu().numpy()
221
+ labels_np = dataset.labels.cpu().numpy()
222
+ feature_names = dataset.feature_names
223
+ target_names = dataset.target_names
224
+ except Exception as e:
225
+ _LOGGER.error(f"Failed to extract data from provided dataset: {e}")
226
+ raise
227
+
228
+ # Create features DataFrame
229
+ df = pd.DataFrame(features_np, columns=feature_names)
230
+
231
+ # Add labels
232
+ if labels_np.ndim == 1:
233
+ df[target_names[0]] = labels_np
234
+ elif labels_np.ndim == 2:
235
+ for i, name in enumerate(target_names):
236
+ df[name] = labels_np[:, i]
237
+
238
+ return df
239
+
240
+ def fit(self,
241
+ train_dataset: _PytorchDataset,
242
+ test_dataset: _PytorchDataset,
243
+ epochs: int = 20,
244
+ batch_size: int = 10):
245
+ """
246
+ Trains the model using the provided datasets.
247
+
248
+ Args:
249
+ train_dataset (_PytorchDataset): The training dataset.
250
+ test_dataset (_PytorchDataset): The validation dataset.
251
+ epochs (int): The number of epochs to train for.
252
+ batch_size (int): The batch size.
253
+ """
254
+ _LOGGER.info(f"Converting datasets to pandas DataFrame for {self.model_name}...")
255
+ train_df = self._dataset_to_dataframe(train_dataset)
256
+ test_df = self._dataset_to_dataframe(test_dataset)
257
+
258
+ _LOGGER.info(f"Starting training for {epochs} epochs...")
259
+ with warnings.catch_warnings():
260
+ # Suppress abundant pytorch-lightning warnings
261
+ warnings.simplefilter("ignore")
262
+ self.tabular_model.fit(
263
+ train=train_df,
264
+ validation=test_df,
265
+ max_epochs=epochs
266
+ )
267
+
268
+ self._is_fitted = True
269
+ _LOGGER.info("Training complete.")
270
+
271
+ def evaluate(self,
272
+ save_dir: Union[str, Path],
273
+ data: _PytorchDataset,
274
+ classification_threshold: float = 0.5):
275
+ """
276
+ Evaluates the model and saves reports using the standard ML_evaluation functions.
277
+
278
+ Args:
279
+ save_dir (str | Path): Directory to save all reports and plots.
280
+ data (_PytorchDataset): The data to evaluate on.
281
+ classification_threshold (float): Threshold for multi-label tasks.
282
+ """
283
+ if not self._is_fitted:
284
+ _LOGGER.error("Model is not fitted. Call .fit() first.")
285
+ raise RuntimeError()
286
+
287
+ print("\n--- Model Evaluation (PyTorch-Tabular) ---")
288
+
289
+ eval_df = self._dataset_to_dataframe(data)
290
+
291
+ # Get raw predictions from pytorch-tabular
292
+ raw_preds_df = self.tabular_model.predict(
293
+ eval_df,
294
+ include_input_features=False
295
+ )
296
+
297
+ # Extract y_true from the dataframe
298
+ y_true = eval_df[self.target_names].to_numpy()
299
+
300
+ y_pred = None
301
+ y_prob = None
302
+
303
+ # --- Route based on task kind ---
304
+
305
+ if self.kind == "regression":
306
+ pred_col_name = f"{self.target_names[0]}_prediction"
307
+ y_pred = raw_preds_df[pred_col_name].to_numpy()
308
+ regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
309
+
310
+ elif self.kind == "classification":
311
+ y_pred = raw_preds_df["prediction"].to_numpy()
312
+ # Get class names from the model's datamodule
313
+ if self.tabular_model.datamodule is None:
314
+ _LOGGER.error("Model's datamodule is not initialized. Cannot extract class names for probabilities.")
315
+ raise RuntimeError("Datamodule not found. Was the model trained or loaded correctly?")
316
+ class_names = self.tabular_model.datamodule.data_config.target_classes[self.target_names[0]]
317
+ prob_cols = [f"{c}_probability" for c in class_names]
318
+ y_prob = raw_preds_df[prob_cols].values
319
+ classification_metrics(save_dir, y_true.flatten(), y_pred, y_prob)
320
+
321
+ elif self.kind == "multi_target_regression":
322
+ pred_cols = [f"{name}_prediction" for name in self.target_names]
323
+ y_pred = raw_preds_df[pred_cols].to_numpy()
324
+ multi_target_regression_metrics(y_true, y_pred, self.target_names, save_dir)
325
+
326
+ elif self.kind == "multi_label_classification":
327
+ prob_cols = [f"{name}_probability" for name in self.target_names]
328
+ y_prob = raw_preds_df[prob_cols].to_numpy()
329
+ # y_pred is derived from y_prob
330
+ multi_label_classification_metrics(y_true, y_prob, self.target_names, save_dir, classification_threshold)
331
+
332
+ def explain(self,
333
+ save_dir: Union[str, Path],
334
+ explain_dataset: _PytorchDataset):
335
+ """
336
+ Generates SHAP explanations and saves plots and summary CSVs.
337
+
338
+ This method uses pytorch-tabular's internal `.explain()` method
339
+ and then formats the output to match the ML_evaluation standard.
340
+
341
+ Args:
342
+ save_dir (str | Path): Directory to save all SHAP artifacts.
343
+ explain_dataset (_PytorchDataset): The dataset to explain.
344
+ """
345
+ if not self._is_fitted:
346
+ _LOGGER.error("Model is not fitted. Call .fit() first.")
347
+ raise RuntimeError()
348
+
349
+ print(f"\n--- SHAP Value Explanation ({self.model_name}) ---")
350
+
351
+ explain_df = self._dataset_to_dataframe(explain_dataset)
352
+
353
+ # We must use the dataframe *without* the target columns for explanation
354
+ feature_df: pd.DataFrame = explain_df[self.schema.feature_names] # type: ignore
355
+
356
+ # This returns a DataFrame (single-target) or Dict[str, DataFrame]
357
+ with warnings.catch_warnings():
358
+ warnings.simplefilter("ignore")
359
+ shap_output = self.tabular_model.explain(feature_df)
360
+
361
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
362
+ plt.ioff()
363
+
364
+ # --- 1. Handle single-target (regression/classification) ---
365
+ if isinstance(shap_output, pd.DataFrame):
366
+ # shap_output is (n_samples, n_features)
367
+ shap_values = shap_output.to_numpy()
368
+
369
+ # Save Bar Plot
370
+ self._save_shap_plots(
371
+ shap_values=shap_values,
372
+ instances_df=feature_df,
373
+ save_dir=save_dir_path,
374
+ suffix="" # No suffix for single target
375
+ )
376
+ # Save Summary Data
377
+ self._save_shap_csv(
378
+ shap_values=shap_values,
379
+ feature_names=list(self.schema.feature_names),
380
+ save_dir=save_dir_path,
381
+ suffix=""
382
+ )
383
+
384
+ # --- 2. Handle multi-target ---
385
+ elif isinstance(shap_output, dict):
386
+ for target_name, shap_df in shap_output.items(): # type: ignore
387
+ _LOGGER.info(f" -> Generating SHAP plots for target: '{target_name}'")
388
+ shap_values = shap_df.values
389
+ sanitized_name = sanitize_filename(target_name)
390
+
391
+ # Save Bar Plot
392
+ self._save_shap_plots(
393
+ shap_values=shap_values,
394
+ instances_df=feature_df,
395
+ save_dir=save_dir_path,
396
+ suffix=f"_{sanitized_name}",
397
+ title_suffix=f" for '{target_name}'"
398
+ )
399
+ # Save Summary Data
400
+ self._save_shap_csv(
401
+ shap_values=shap_values,
402
+ feature_names=list(self.schema.feature_names),
403
+ save_dir=save_dir_path,
404
+ suffix=f"_{sanitized_name}"
405
+ )
406
+
407
+ plt.ion()
408
+ _LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
409
+
410
+ def _save_shap_plots(self, shap_values: np.ndarray,
411
+ instances_df: pd.DataFrame,
412
+ save_dir: Path,
413
+ suffix: str = "",
414
+ title_suffix: str = ""):
415
+ """Internal helper to save standard SHAP plots."""
416
+ try:
417
+ import shap
418
+ except ImportError:
419
+ _LOGGER.error("`shap` is required for plotting. Please install it: pip install shap")
420
+ return
421
+
422
+ # Save Bar Plot
423
+ bar_path = save_dir / f"shap_bar_plot{suffix}.svg"
424
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
425
+ ax = plt.gca()
426
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
427
+ plt.title(f"SHAP Feature Importance{title_suffix}")
428
+ plt.tight_layout()
429
+ plt.savefig(bar_path)
430
+ plt.close()
431
+
432
+ # Save Dot Plot
433
+ dot_path = save_dir / f"shap_dot_plot{suffix}.svg"
434
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
435
+ ax = plt.gca()
436
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
437
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
438
+ cb = plt.gcf().axes[-1]
439
+ cb.set_ylabel("", size=1)
440
+ plt.title(f"SHAP Feature Importance{title_suffix}")
441
+ plt.tight_layout()
442
+ plt.savefig(dot_path)
443
+ plt.close()
444
+
445
+ def _save_shap_csv(self, shap_values: np.ndarray,
446
+ feature_names: List[str],
447
+ save_dir: Path,
448
+ suffix: str = ""):
449
+ """Internal helper to save standard SHAP summary CSV."""
450
+
451
+ shap_summary_filename = f"{SHAPKeys.SAVENAME}{suffix}.csv"
452
+ summary_path = save_dir / shap_summary_filename
453
+
454
+ # Handle multi-class (list of arrays) vs. regression (single array)
455
+ if isinstance(shap_values, list):
456
+ mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
457
+ else:
458
+ mean_abs_shap = np.abs(shap_values).mean(axis=0)
459
+
460
+ mean_abs_shap = mean_abs_shap.flatten()
461
+
462
+ summary_df = pd.DataFrame({
463
+ SHAPKeys.FEATURE_COLUMN: feature_names,
464
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
465
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
466
+
467
+ summary_df.to_csv(summary_path, index=False)
468
+
469
+ def save_model(self, directory: Union[str, Path]):
470
+ """
471
+ Saves the entire trained model, configuration, and datamodule
472
+ to a directory.
473
+
474
+ Args:
475
+ directory (str | Path): The directory to save the model.
476
+ The directory will be created.
477
+ """
478
+ if not self._is_fitted:
479
+ _LOGGER.error("Cannot save a model that has not been fitted.")
480
+ return
481
+
482
+ save_path = make_fullpath(directory, make=True, enforce="directory")
483
+ self.tabular_model.save_model(str(save_path))
484
+ _LOGGER.info(f"Model saved to '{save_path.name}'")
485
+
486
+ @classmethod
487
+ def load_model(cls,
488
+ directory: Union[str, Path],
489
+ schema: FeatureSchema,
490
+ target_names: List[str],
491
+ kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"]
492
+ ) -> 'PyTabularTrainer':
493
+ """
494
+ Loads a saved model and reconstructs the PyTabularTrainer wrapper.
495
+
496
+ Note: The schema, target_names, and kind must be provided again
497
+ as they are not serialized by pytorch-tabular.
498
+
499
+ Args:
500
+ directory (str | Path): The directory from which to load the model.
501
+ schema (FeatureSchema): The schema used during original training.
502
+ target_names (List[str]): The target names used during original training.
503
+ kind (Literal[...]): The task 'kind' used during original training.
504
+
505
+ Returns:
506
+ PyTabularTrainer: A new instance of the trainer with the loaded model.
507
+ """
508
+ load_path = make_fullpath(directory, enforce="directory")
509
+
510
+ _LOGGER.info(f"Loading model from '{load_path.name}'...")
511
+
512
+ # Load the internal pytorch-tabular model
513
+ loaded_tabular_model = TabularModel.load_model(str(load_path))
514
+
515
+ if loaded_tabular_model.model is None:
516
+ _LOGGER.error("Loaded model's internal '.model' attribute is None. Load failed.")
517
+ raise RuntimeError("Loaded model is incomplete.")
518
+
519
+ model_name = loaded_tabular_model.model._model_name
520
+
521
+ if model_name.startswith("GANDALF"): # Handle GANDALF's dynamic name
522
+ model_name = "GATE"
523
+
524
+ # Re-create the wrapper
525
+ wrapper = cls(
526
+ schema=schema,
527
+ target_names=target_names,
528
+ kind=kind,
529
+ model_name=model_name
530
+ # Configs are already part of the loaded_tabular_model
531
+ # We just need to pass the minimum to the __init__
532
+ )
533
+
534
+ # Overwrite the un-trained model with the loaded trained model
535
+ wrapper.tabular_model = loaded_tabular_model
536
+ wrapper._is_fitted = True
537
+
538
+ _LOGGER.info(f"Successfully loaded '{model_name}' model.")
539
+ return wrapper
540
+
541
+
542
+ def info():
543
+ _script_info(__all__)
@@ -0,0 +1,88 @@
1
+ import json
2
+ import torch
3
+ from torchvision import transforms
4
+ from typing import Dict, Any, List, Callable, Union
5
+ from pathlib import Path
6
+
7
+ from .ML_vision_transformers import TRANSFORM_REGISTRY
8
+ from ._logger import _LOGGER
9
+ from .keys import VisionTransformRecipeKeys
10
+ from .path_manager import make_fullpath
11
+
12
+
13
+ def save_recipe(recipe: Dict[str, Any], filepath: Path) -> None:
14
+ """
15
+ Saves a transform recipe dictionary to a JSON file.
16
+
17
+ Args:
18
+ recipe (Dict[str, Any]): The recipe dictionary to save.
19
+ filepath (str): The path to the output .json file.
20
+ """
21
+ final_filepath = filepath.with_suffix(".json")
22
+
23
+ try:
24
+ with open(final_filepath, 'w') as f:
25
+ json.dump(recipe, f, indent=4)
26
+ _LOGGER.info(f"Transform recipe saved as '{final_filepath.name}'.")
27
+ except Exception as e:
28
+ _LOGGER.error(f"Failed to save recipe to '{final_filepath}': {e}")
29
+ raise
30
+
31
+
32
+ def load_recipe_and_build_transform(filepath: Union[str,Path]) -> transforms.Compose:
33
+ """
34
+ Loads a transform recipe from a .json file and reconstructs the
35
+ torchvision.transforms.Compose pipeline.
36
+
37
+ Args:
38
+ filepath (str): Path to the saved transform recipe .json file.
39
+
40
+ Returns:
41
+ transforms.Compose: The reconstructed transformation pipeline.
42
+
43
+ Raises:
44
+ ValueError: If a transform name in the recipe is not found in
45
+ torchvision.transforms or the custom TRANSFORM_REGISTRY.
46
+ """
47
+ # validate filepath
48
+ final_filepath = make_fullpath(filepath, enforce="file")
49
+
50
+ try:
51
+ with open(final_filepath, 'r') as f:
52
+ recipe = json.load(f)
53
+ except Exception as e:
54
+ _LOGGER.error(f"Failed to load recipe from '{final_filepath}': {e}")
55
+ raise
56
+
57
+ pipeline_steps: List[Callable] = []
58
+
59
+ if VisionTransformRecipeKeys.PIPELINE not in recipe:
60
+ _LOGGER.error("Recipe file is invalid: missing 'pipeline' key.")
61
+ raise ValueError("Invalid recipe format.")
62
+
63
+ for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
64
+ t_name = step[VisionTransformRecipeKeys.NAME]
65
+ t_kwargs = step[VisionTransformRecipeKeys.KWARGS]
66
+
67
+ transform_class: Any = None
68
+
69
+ # 1. Check standard torchvision transforms
70
+ if hasattr(transforms, t_name):
71
+ transform_class = getattr(transforms, t_name)
72
+ # 2. Check custom transforms
73
+ elif t_name in TRANSFORM_REGISTRY:
74
+ transform_class = TRANSFORM_REGISTRY[t_name]
75
+ # 3. Not found
76
+ else:
77
+ _LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
78
+ raise ValueError(f"Unknown transform name: {t_name}")
79
+
80
+ # Instantiate the transform
81
+ try:
82
+ pipeline_steps.append(transform_class(**t_kwargs))
83
+ except Exception as e:
84
+ _LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
85
+ raise
86
+
87
+ _LOGGER.info(f"Successfully loaded and built transform pipeline from '{final_filepath.name}'.")
88
+ return transforms.Compose(pipeline_steps)