dragon-ml-toolbox 14.1.0__tar.gz → 14.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/LICENSE-THIRD-PARTY.md +10 -0
  2. {dragon_ml_toolbox-14.1.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-14.2.0}/PKG-INFO +10 -2
  3. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/README.md +6 -1
  4. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0/dragon_ml_toolbox.egg-info}/PKG-INFO +10 -2
  5. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +1 -1
  6. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/dragon_ml_toolbox.egg-info/requires.txt +4 -0
  7. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_models.py +1 -1
  8. dragon_ml_toolbox-14.2.0/ml_tools/ML_models_advanced.py +323 -0
  9. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/pyproject.toml +6 -9
  10. dragon_ml_toolbox-14.1.0/ml_tools/_ML_pytorch_tabular.py +0 -543
  11. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/LICENSE +0 -0
  12. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
  13. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
  14. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ETL_cleaning.py +0 -0
  15. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ETL_engineering.py +0 -0
  16. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/GUI_tools.py +0 -0
  17. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/MICE_imputation.py +0 -0
  18. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_callbacks.py +0 -0
  19. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_datasetmaster.py +0 -0
  20. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_evaluation.py +0 -0
  21. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_evaluation_multi.py +0 -0
  22. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_inference.py +0 -0
  23. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_optimization.py +0 -0
  24. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_scaler.py +0 -0
  25. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_trainer.py +0 -0
  26. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_utilities.py +0 -0
  27. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_vision_datasetmaster.py +0 -0
  28. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_vision_evaluation.py +0 -0
  29. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_vision_inference.py +0 -0
  30. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_vision_models.py +0 -0
  31. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ML_vision_transformers.py +0 -0
  32. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/PSO_optimization.py +0 -0
  33. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/RNN_forecast.py +0 -0
  34. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/SQL.py +0 -0
  35. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/VIF_factor.py +0 -0
  36. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/_ML_vision_recipe.py +0 -0
  37. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/__init__.py +0 -0
  38. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/_logger.py +0 -0
  39. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/_schema.py +0 -0
  40. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/_script_info.py +0 -0
  41. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/constants.py +0 -0
  42. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/custom_logger.py +0 -0
  43. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/data_exploration.py +0 -0
  44. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ensemble_evaluation.py +0 -0
  45. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ensemble_inference.py +0 -0
  46. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/ensemble_learning.py +0 -0
  47. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/handle_excel.py +0 -0
  48. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/keys.py +0 -0
  49. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/math_utilities.py +0 -0
  50. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/optimization_tools.py +0 -0
  51. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/path_manager.py +0 -0
  52. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/serde.py +0 -0
  53. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/ml_tools/utilities.py +0 -0
  54. {dragon_ml_toolbox-14.1.0 → dragon_ml_toolbox-14.2.0}/setup.cfg +0 -0
@@ -27,3 +27,13 @@ This project depends on the following third-party packages. Each is governed by
27
27
  - [plotnine](https://github.com/has2k1/plotnine/blob/main/LICENSE)
28
28
  - [tqdm](https://github.com/tqdm/tqdm/blob/master/LICENSE)
29
29
  - [pyarrow](https://github.com/apache/arrow/blob/main/LICENSE.txt)
30
+ - [colorlog](https://github.com/borntyping/python-colorlog/blob/main/LICENSE)
31
+ - [evotorch](https://github.com/nnaisense/evotorch/blob/master/LICENSE)
32
+ - [FreeSimpleGUI](https://github.com/spyoungtech/FreeSimpleGUI/blob/main/license.txt)
33
+ - [nuitka](https://github.com/Nuitka/Nuitka/blob/main/LICENSE.txt)
34
+ - [omegaconf](https://github.com/omry/omegaconf/blob/master/LICENSE)
35
+ - [ordered-set](https://github.com/rspeer/ordered-set/blob/master/MIT-LICENSE)
36
+ - [pyinstaller](https://github.com/pyinstaller/pyinstaller/blob/develop/COPYING.txt)
37
+ - [pytorch_tabular](https://github.com/manujosephv/pytorch_tabular/blob/main/LICENSE)
38
+ - [torchmetrics](https://github.com/Lightning-AI/torchmetrics/blob/master/LICENSE)
39
+ - [zstandard](https://github.com/indygreg/python-zstandard/blob/main/LICENSE)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 14.1.0
3
+ Version: 14.2.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -35,6 +35,9 @@ Requires-Dist: evotorch; extra == "ml"
35
35
  Requires-Dist: pyarrow; extra == "ml"
36
36
  Requires-Dist: colorlog; extra == "ml"
37
37
  Requires-Dist: torchmetrics; extra == "ml"
38
+ Provides-Extra: py-tab
39
+ Requires-Dist: pytorch_tabular; extra == "py-tab"
40
+ Requires-Dist: omegaconf; extra == "py-tab"
38
41
  Provides-Extra: mice
39
42
  Requires-Dist: numpy<2.0; extra == "mice"
40
43
  Requires-Dist: pandas; extra == "mice"
@@ -143,10 +146,16 @@ ML_evaluation_multi
143
146
  ML_evaluation
144
147
  ML_inference
145
148
  ML_models
149
+ ML_models_advanced # Requires the extra flag [py-tab]
146
150
  ML_optimization
147
151
  ML_scaler
148
152
  ML_trainer
149
153
  ML_utilities
154
+ ML_vision_datasetmaster
155
+ ML_vision_evaluation
156
+ ML_vision_inference
157
+ ML_vision_models
158
+ ML_vision_transformers
150
159
  optimization_tools
151
160
  path_manager
152
161
  PSO_optimization
@@ -192,7 +201,6 @@ pip install "dragon-ml-toolbox[excel]"
192
201
  #### Modules:
193
202
 
194
203
  ```Bash
195
- constants
196
204
  custom_logger
197
205
  handle_excel
198
206
  path_manager
@@ -67,10 +67,16 @@ ML_evaluation_multi
67
67
  ML_evaluation
68
68
  ML_inference
69
69
  ML_models
70
+ ML_models_advanced # Requires the extra flag [py-tab]
70
71
  ML_optimization
71
72
  ML_scaler
72
73
  ML_trainer
73
74
  ML_utilities
75
+ ML_vision_datasetmaster
76
+ ML_vision_evaluation
77
+ ML_vision_inference
78
+ ML_vision_models
79
+ ML_vision_transformers
74
80
  optimization_tools
75
81
  path_manager
76
82
  PSO_optimization
@@ -116,7 +122,6 @@ pip install "dragon-ml-toolbox[excel]"
116
122
  #### Modules:
117
123
 
118
124
  ```Bash
119
- constants
120
125
  custom_logger
121
126
  handle_excel
122
127
  path_manager
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 14.1.0
3
+ Version: 14.2.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -35,6 +35,9 @@ Requires-Dist: evotorch; extra == "ml"
35
35
  Requires-Dist: pyarrow; extra == "ml"
36
36
  Requires-Dist: colorlog; extra == "ml"
37
37
  Requires-Dist: torchmetrics; extra == "ml"
38
+ Provides-Extra: py-tab
39
+ Requires-Dist: pytorch_tabular; extra == "py-tab"
40
+ Requires-Dist: omegaconf; extra == "py-tab"
38
41
  Provides-Extra: mice
39
42
  Requires-Dist: numpy<2.0; extra == "mice"
40
43
  Requires-Dist: pandas; extra == "mice"
@@ -143,10 +146,16 @@ ML_evaluation_multi
143
146
  ML_evaluation
144
147
  ML_inference
145
148
  ML_models
149
+ ML_models_advanced # Requires the extra flag [py-tab]
146
150
  ML_optimization
147
151
  ML_scaler
148
152
  ML_trainer
149
153
  ML_utilities
154
+ ML_vision_datasetmaster
155
+ ML_vision_evaluation
156
+ ML_vision_inference
157
+ ML_vision_models
158
+ ML_vision_transformers
150
159
  optimization_tools
151
160
  path_manager
152
161
  PSO_optimization
@@ -192,7 +201,6 @@ pip install "dragon-ml-toolbox[excel]"
192
201
  #### Modules:
193
202
 
194
203
  ```Bash
195
- constants
196
204
  custom_logger
197
205
  handle_excel
198
206
  path_manager
@@ -17,6 +17,7 @@ ml_tools/ML_evaluation.py
17
17
  ml_tools/ML_evaluation_multi.py
18
18
  ml_tools/ML_inference.py
19
19
  ml_tools/ML_models.py
20
+ ml_tools/ML_models_advanced.py
20
21
  ml_tools/ML_optimization.py
21
22
  ml_tools/ML_scaler.py
22
23
  ml_tools/ML_trainer.py
@@ -30,7 +31,6 @@ ml_tools/PSO_optimization.py
30
31
  ml_tools/RNN_forecast.py
31
32
  ml_tools/SQL.py
32
33
  ml_tools/VIF_factor.py
33
- ml_tools/_ML_pytorch_tabular.py
34
34
  ml_tools/_ML_vision_recipe.py
35
35
  ml_tools/__init__.py
36
36
  ml_tools/_logger.py
@@ -63,5 +63,9 @@ nuitka
63
63
  zstandard
64
64
  ordered-set
65
65
 
66
+ [py-tab]
67
+ pytorch_tabular
68
+ omegaconf
69
+
66
70
  [pyinstaller]
67
71
  pyinstaller
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  from torch import nn
3
- from typing import List, Union, Tuple, Dict, Any, Literal, Optional
3
+ from typing import List, Union, Tuple, Dict, Any
4
4
  from pathlib import Path
5
5
  import json
6
6
 
@@ -0,0 +1,323 @@
1
+ import torch
2
+ from torch import nn
3
+ from typing import Union, Dict, Any
4
+ from pathlib import Path
5
+ import json
6
+
7
+ from ._logger import _LOGGER
8
+ from .path_manager import make_fullpath
9
+ from .keys import PytorchModelArchitectureKeys
10
+ from ._schema import FeatureSchema
11
+ from ._script_info import _script_info
12
+ from .ML_models import _ArchitectureHandlerMixin
13
+
14
+ # Imports from pytorch_tabular
15
+ try:
16
+ from omegaconf import DictConfig
17
+ from pytorch_tabular.models import GatedAdditiveTreeEnsembleModel, NODEModel
18
+ except ImportError:
19
+ _LOGGER.error(f"GATE and NODE require 'pip install pytorch_tabular omegaconf' dependencies.")
20
+ raise ImportError()
21
+
22
+
23
+ __all__ = [
24
+ "DragonGateModel",
25
+ "DragonNodeModel",
26
+ ]
27
+
28
+
29
+ class _BasePytabWrapper(nn.Module, _ArchitectureHandlerMixin):
30
+ """
31
+ Internal Base Class: Do not use directly.
32
+
33
+ This is an adapter to make pytorch_tabular models compatible with the
34
+ dragon-ml-toolbox pipeline.
35
+
36
+ It handles:
37
+ 1. Schema-based initialization.
38
+ 2. Single-tensor forward pass, which is then split into the
39
+ dict {'continuous': ..., 'categorical': ...} that pytorch_tabular expects.
40
+ 3. Saving/Loading architecture using the pipeline's _ArchitectureHandlerMixin.
41
+ """
42
+ def __init__(self, schema: FeatureSchema):
43
+ super().__init__()
44
+
45
+ self.schema = schema
46
+ self.model_name = "Base" # To be overridden by child
47
+ self.internal_model: nn.Module = None # type: ignore # To be set by child
48
+ self.model_hparams: Dict = dict() # To be set by child
49
+
50
+ # --- Derive indices from schema ---
51
+ categorical_map = schema.categorical_index_map
52
+
53
+ if categorical_map:
54
+ # The order of keys/values is implicitly linked and must be preserved
55
+ self.categorical_indices = list(categorical_map.keys())
56
+ self.cardinalities = list(categorical_map.values())
57
+ else:
58
+ self.categorical_indices = []
59
+ self.cardinalities = []
60
+
61
+ # Derive numerical indices by finding what's not categorical
62
+ all_indices = set(range(len(schema.feature_names)))
63
+ categorical_indices_set = set(self.categorical_indices)
64
+ self.numerical_indices = sorted(list(all_indices - categorical_indices_set))
65
+
66
+ def _build_pt_config(self, out_targets: int, **kwargs) -> DictConfig:
67
+ """Helper to create the minimal config dict for a pytorch_tabular model."""
68
+ # 'regression' is the most neutral for model architecture. The final output_dim is what truly matters.
69
+ task = "regression"
70
+
71
+ config_dict = {
72
+ # --- Data / Schema Params ---
73
+ 'task': task,
74
+ 'continuous_cols': list(self.schema.continuous_feature_names),
75
+ 'categorical_cols': list(self.schema.categorical_feature_names),
76
+ 'continuous_dim': len(self.numerical_indices),
77
+ 'categorical_dim': len(self.categorical_indices),
78
+ 'categorical_cardinality': self.cardinalities,
79
+ 'target': ['dummy_target'], # Required, but not used
80
+
81
+ # --- Model Params ---
82
+ 'output_dim': out_targets,
83
+ **kwargs
84
+ }
85
+
86
+ # Add common params that most models need
87
+ if 'loss' not in config_dict:
88
+ config_dict['loss'] = 'NotUsed'
89
+ if 'metrics' not in config_dict:
90
+ config_dict['metrics'] = []
91
+
92
+ return DictConfig(config_dict)
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ """
96
+ Accepts a single tensor and converts it to the dict
97
+ that pytorch_tabular models expect.
98
+ """
99
+ # 1. Split the single tensor input
100
+ x_cont = x[:, self.numerical_indices].float()
101
+ x_cat = x[:, self.categorical_indices].long()
102
+
103
+ # 2. Create the input dict
104
+ input_dict = {
105
+ 'continuous': x_cont,
106
+ 'categorical': x_cat
107
+ }
108
+
109
+ # 3. Pass to the internal pytorch_tabular model
110
+ # The model returns a dict, we extract the logits
111
+ model_output_dict = self.internal_model(input_dict)
112
+
113
+ # 4. Return the logits tensor
114
+ return model_output_dict['logits']
115
+
116
+ def get_architecture_config(self) -> Dict[str, Any]:
117
+ """Returns the full configuration of the model."""
118
+ # Deconstruct schema into a JSON-friendly dict
119
+ schema_dict = {
120
+ 'feature_names': self.schema.feature_names,
121
+ 'continuous_feature_names': self.schema.continuous_feature_names,
122
+ 'categorical_feature_names': self.schema.categorical_feature_names,
123
+ 'categorical_index_map': self.schema.categorical_index_map,
124
+ 'categorical_mappings': self.schema.categorical_mappings
125
+ }
126
+
127
+ config = {
128
+ 'schema_dict': schema_dict,
129
+ 'out_targets': self.out_targets,
130
+ **self.model_hparams
131
+ }
132
+ return config
133
+
134
+ @classmethod
135
+ def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
136
+ """Loads a model architecture from a JSON file."""
137
+ user_path = make_fullpath(file_or_dir)
138
+
139
+ if user_path.is_dir():
140
+ json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
141
+ target_path = make_fullpath(user_path / json_filename, enforce="file")
142
+ elif user_path.is_file():
143
+ target_path = user_path
144
+ else:
145
+ _LOGGER.error(f"Invalid path: '{file_or_dir}'")
146
+ raise IOError()
147
+
148
+ with open(target_path, 'r') as f:
149
+ saved_data = json.load(f)
150
+
151
+ saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
152
+ config = saved_data[PytorchModelArchitectureKeys.CONFIG]
153
+
154
+ if saved_class_name != cls.__name__:
155
+ _LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
156
+ raise ValueError()
157
+
158
+ # --- RECONSTRUCTION LOGIC ---
159
+ if 'schema_dict' not in config:
160
+ _LOGGER.error("Invalid architecture file: missing 'schema_dict'. This file may be from an older version.")
161
+ raise ValueError("Missing 'schema_dict' in config.")
162
+
163
+ schema_data = config.pop('schema_dict')
164
+
165
+ # JSON saves all dict keys as strings, convert them back to int.
166
+ raw_index_map = schema_data['categorical_index_map']
167
+ if raw_index_map is not None:
168
+ rehydrated_index_map = {int(k): v for k, v in raw_index_map.items()}
169
+ else:
170
+ rehydrated_index_map = None
171
+
172
+ # JSON deserializes tuples as lists, convert them back.
173
+ schema = FeatureSchema(
174
+ feature_names=tuple(schema_data['feature_names']),
175
+ continuous_feature_names=tuple(schema_data['continuous_feature_names']),
176
+ categorical_feature_names=tuple(schema_data['categorical_feature_names']),
177
+ categorical_index_map=rehydrated_index_map,
178
+ categorical_mappings=schema_data['categorical_mappings']
179
+ )
180
+
181
+ config['schema'] = schema
182
+ # --- End Reconstruction ---
183
+
184
+ model = cls(**config)
185
+ if verbose:
186
+ _LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
187
+ return model
188
+
189
+ def __repr__(self) -> str:
190
+ internal_model_str = str(self.internal_model)
191
+ # Grab the first line of the internal model's repr
192
+ internal_repr = internal_model_str.split('\n')[0]
193
+ return f"{self.model_name}(internal_model={internal_repr})"
194
+
195
+
196
+ class DragonGateModel(_BasePytabWrapper):
197
+ """
198
+ Adapter for the Gated Additive Tree Ensemble (GATE) model from the 'pytorch_tabular' library.
199
+
200
+ GATE is a hybrid model that uses Gated Feature Learning Units (GFLUs) to
201
+ learn powerful feature representations. These learned features are then
202
+ fed into an additive ensemble of differentiable decision trees, combining
203
+ the representation learning of deep networks with the structured
204
+ decision-making of tree ensembles.
205
+ """
206
+ def __init__(self, *,
207
+ schema: FeatureSchema,
208
+ out_targets: int,
209
+ embedding_dim: int = 32,
210
+ gflu_stages: int = 6,
211
+ num_trees: int = 20,
212
+ tree_depth: int = 5,
213
+ dropout: float = 0.1):
214
+ """
215
+ Args:
216
+ schema (FeatureSchema):
217
+ The definitive schema object from data_exploration.
218
+ out_targets (int):
219
+ Number of output targets.
220
+ embedding_dim (int):
221
+ Dimension of the categorical embeddings. (Recommended: 16 to 64)
222
+ gflu_stages (int):
223
+ Number of Gated Feature Learning Units (GFLU) stages. (Recommended: 2 to 6)
224
+ num_trees (int):
225
+ Number of trees in the ensemble. (Recommended: 10 to 50)
226
+ tree_depth (int):
227
+ Depth of each tree. (Recommended: 4 to 8)
228
+ dropout (float):
229
+ Dropout rate for the GFLU.
230
+ """
231
+ super().__init__(schema)
232
+ self.model_name = "DragonGateModel"
233
+ self.out_targets = out_targets
234
+
235
+ # Store hparams for saving/loading
236
+ self.model_hparams = {
237
+ 'embedding_dim': embedding_dim,
238
+ 'gflu_stages': gflu_stages,
239
+ 'num_trees': num_trees,
240
+ 'tree_depth': tree_depth,
241
+ 'dropout': dropout
242
+ }
243
+
244
+ # Build the minimal config for the GateModel
245
+ pt_config = self._build_pt_config(
246
+ out_targets=out_targets,
247
+ embedding_dim=embedding_dim,
248
+ gflu_stages=gflu_stages,
249
+ num_trees=num_trees,
250
+ tree_depth=tree_depth,
251
+ dropout=dropout,
252
+ # GATE-specific params
253
+ gflu_dropout=dropout,
254
+ chain_trees=False,
255
+ )
256
+
257
+ # Instantiate the internal pytorch_tabular model
258
+ self.internal_model = GatedAdditiveTreeEnsembleModel(config=pt_config)
259
+
260
+
261
+ class DragonNodeModel(_BasePytabWrapper):
262
+ """
263
+ Adapter for the Neural Oblivious Decision Ensembles (NODE) model from the 'pytorch_tabular' library.
264
+
265
+ NODE is a model based on an ensemble of differentiable 'oblivious'
266
+ decision trees. An oblivious tree uses the same splitting feature and
267
+ threshold across all nodes at the same depth. This structure, combined
268
+ with a differentiable formulation, allows the model to be trained
269
+ end-to-end with gradient descent, learning feature interactions and
270
+ splitting thresholds simultaneously.
271
+ """
272
+ def __init__(self, *,
273
+ schema: FeatureSchema,
274
+ out_targets: int,
275
+ embedding_dim: int = 32,
276
+ num_trees: int = 1024,
277
+ tree_depth: int = 6,
278
+ dropout: float = 0.1):
279
+ """
280
+ Args:
281
+ schema (FeatureSchema):
282
+ The definitive schema object from data_exploration.
283
+ out_targets (int):
284
+ Number of output targets.
285
+ embedding_dim (int):
286
+ Dimension of the categorical embeddings. (Recommended: 16 to 64)
287
+ num_trees (int):
288
+ Total number of trees in the ensemble. (Recommended: 256 to 2048)
289
+ tree_depth (int):
290
+ Depth of each tree. (Recommended: 4 to 8)
291
+ dropout (float):
292
+ Dropout rate.
293
+ """
294
+ super().__init__(schema)
295
+ self.model_name = "DragonNodeModel"
296
+ self.out_targets = out_targets
297
+
298
+ # Store hparams for saving/loading
299
+ self.model_hparams = {
300
+ 'embedding_dim': embedding_dim,
301
+ 'num_trees': num_trees,
302
+ 'tree_depth': tree_depth,
303
+ 'dropout': dropout
304
+ }
305
+
306
+ # Build the minimal config for the NodeModel
307
+ pt_config = self._build_pt_config(
308
+ out_targets=out_targets,
309
+ embedding_dim=embedding_dim,
310
+ num_trees=num_trees,
311
+ tree_depth=tree_depth,
312
+ # NODE-specific params
313
+ num_layers=1, # NODE uses num_layers=1 for a single ensemble
314
+ total_trees=num_trees,
315
+ dropout_rate=dropout,
316
+ )
317
+
318
+ # Instantiate the internal pytorch_tabular model
319
+ self.internal_model = NODEModel(config=pt_config)
320
+
321
+
322
+ def info():
323
+ _script_info(__all__)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dragon-ml-toolbox"
3
- version = "14.1.0"
3
+ version = "14.2.0"
4
4
  description = "A collection of tools for data science and machine learning projects."
5
5
  authors = [
6
6
  { name = "Karl L. Loza Vidaurre", email = "luigiloza@gmail.com" }
@@ -45,14 +45,11 @@ ML = [
45
45
  "torchmetrics",
46
46
  ]
47
47
 
48
- # pytorch-tabular API. Additionally Requires PyTorch with CUDA / MPS support
49
- # py-tab = [
50
- # "pytorch_tabular",
51
- # "pytorch-lightning",
52
- # "wandb",
53
- # "plotly",
54
- # "captum",
55
- # ]
48
+ # pytorch-tabular API for advanced models
49
+ py-tab = [
50
+ "pytorch_tabular",
51
+ "omegaconf"
52
+ ]
56
53
 
57
54
  # MICE and VIF - Requires a new virtual-env due to dependency version conflicts
58
55
  mice = [
@@ -1,543 +0,0 @@
1
- import torch
2
- import pandas as pd
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- from typing import List, Literal, Union, Optional, Dict, Any
6
- from pathlib import Path
7
- import warnings
8
-
9
- # --- Third-party imports ---
10
- try:
11
- from pytorch_tabular.models.common.heads import LinearHeadConfig
12
- from pytorch_tabular.config import (
13
- DataConfig,
14
- ModelConfig,
15
- OptimizerConfig,
16
- TrainerConfig,
17
- ExperimentConfig,
18
- )
19
- from pytorch_tabular.models import (
20
- CategoryEmbeddingModelConfig,
21
- TabNetModelConfig,
22
- TabTransformerConfig,
23
- FTTransformerConfig,
24
- AutoIntConfig,
25
- NodeConfig,
26
- GANDALFConfig
27
- )
28
- from pytorch_tabular.tabular_model import TabularModel
29
- except ImportError:
30
- print("----------------------------------------------------------------")
31
- print("ERROR: `pytorch-tabular` is not installed.")
32
- print("Please install it to use the models in this script:")
33
- print('\npip install "dragon-ml-toolbox[py-tab]"')
34
- print("----------------------------------------------------------------")
35
- raise
36
-
37
- # --- Local ML-Tools imports ---
38
- from ._logger import _LOGGER
39
- from ._script_info import _script_info
40
- from ._schema import FeatureSchema
41
- from .path_manager import make_fullpath, sanitize_filename
42
- from .keys import SHAPKeys
43
- from .ML_datasetmaster import _PytorchDataset
44
- from .ML_evaluation import (
45
- classification_metrics,
46
- regression_metrics
47
- )
48
- from .ML_evaluation_multi import (
49
- multi_target_regression_metrics,
50
- multi_label_classification_metrics
51
- )
52
-
53
-
54
- __all__ = [
55
- "PyTabularTrainer"
56
- ]
57
-
58
-
59
- # --- Model Configuration Mapping ---
60
- # Maps a simple string name to the required ModelConfig class
61
- SUPPORTED_MODELS: Dict[str, Any] = {
62
- "TabNet": TabNetModelConfig,
63
- "TabTransformer": TabTransformerConfig,
64
- "FTTransformer": FTTransformerConfig,
65
- "AutoInt": AutoIntConfig,
66
- "NODE": NodeConfig,
67
- "GATE": GANDALFConfig, # Gated Additive Tree Ensemble
68
- "CategoryEmbedding": CategoryEmbeddingModelConfig, # A basic MLP
69
- }
70
-
71
-
72
- class PyTabularTrainer:
73
- """
74
- A wrapper for models from the `pytorch-tabular` library, designed to be
75
- compatible with the `dragon-ml-toolbox` ecosystem.
76
-
77
- This class acts as a high-level trainer that adapts the `ML_datasetmaster`
78
- datasets into the format required by `pytorch-tabular` and routes
79
- evaluation results to the standard `ML_evaluation` functions.
80
-
81
- It handles:
82
- - Automatic `DataConfig` creation from a `FeatureSchema`.
83
- - Model and Trainer configuration.
84
- - Training and evaluation.
85
- - SHAP explanations.
86
- """
87
-
88
- def __init__(self,
89
- schema: FeatureSchema,
90
- target_names: List[str],
91
- kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
92
- model_name: str,
93
- model_config_params: Optional[Dict[str, Any]] = None,
94
- optimizer_config_params: Optional[Dict[str, Any]] = None,
95
- trainer_config_params: Optional[Dict[str, Any]] = None):
96
- """
97
- Initializes the Model, Data, and Trainer configurations.
98
-
99
- Args:
100
- schema (FeatureSchema):
101
- The definitive schema object from data_exploration.
102
- target_names (List[str]):
103
- A list of target column names.
104
- kind (Literal[...]):
105
- The type of ML task. This is used to set the `pytorch-tabular`
106
- task and to route to the correct evaluation function.
107
- model_name (str):
108
- The name of the model to use. Must be one of:
109
- "TabNet", "TabTransformer", "FTTransformer", "AutoInt",
110
- "NODE", "GATE", "CategoryEmbedding".
111
- model_config_params (Dict, optional):
112
- Overrides for the chosen model's `ModelConfig`.
113
- (e.g., `{"n_d": 16, "n_a": 16}` for TabNet).
114
- optimizer_config_params (Dict, optional):
115
- Overrides for the `OptimizerConfig` (e.g., `{"lr": 0.005}`).
116
- trainer_config_params (Dict, optional):
117
- Overrides for the `TrainerConfig` (e.g., `{"max_epochs": 100}`).
118
- """
119
- _LOGGER.info(f"Initializing PyTabularTrainer for model: {model_name}")
120
-
121
- # --- 1. Store key info ---
122
- self.schema = schema
123
- self.target_names = target_names
124
- self.kind = kind
125
- self.model_name = model_name
126
- self._is_fitted = False
127
-
128
- if model_name not in SUPPORTED_MODELS:
129
- _LOGGER.error(f"Model '{model_name}' is not supported. Choose from: {list(SUPPORTED_MODELS.keys())}")
130
- raise ValueError(f"Unsupported model: {model_name}")
131
-
132
- # --- 2. Map ML-Tools 'kind' to pytorch-tabular 'task' ---
133
- if kind == "regression":
134
- self.task = "regression"
135
- self._pt_target_names = target_names
136
- elif kind == "classification":
137
- self.task = "classification"
138
- self._pt_target_names = target_names
139
- elif kind == "multi_target_regression":
140
- self.task = "multi-label-regression" # pytorch-tabular's name
141
- self._pt_target_names = target_names
142
- elif kind == "multi_label_classification":
143
- self.task = "multi-label-classification"
144
- self._pt_target_names = target_names
145
- else:
146
- _LOGGER.error(f"Unknown task 'kind': {kind}")
147
- raise ValueError()
148
-
149
- # --- 3. Create DataConfig from FeatureSchema ---
150
- # Note: pytorch-tabular handles scaling internally
151
- self.data_config = DataConfig(
152
- target=self._pt_target_names,
153
- continuous_cols=list(schema.continuous_feature_names),
154
- categorical_cols=list(schema.categorical_feature_names),
155
- continuous_feature_transform="quantile_normal",
156
- )
157
-
158
- # --- 4. Create ModelConfig ---
159
- model_config_class = SUPPORTED_MODELS[model_name]
160
-
161
- # Apply user overrides
162
- if model_config_params is None:
163
- model_config_params = {}
164
-
165
- # Set task in params
166
- model_config_params["task"] = self.task
167
-
168
- # Handle multi-target output for regression
169
- if self.task == "multi-label-regression":
170
- # Must configure the model's output head
171
- if "head" not in model_config_params:
172
- _LOGGER.info("Configuring model head for multi-target regression.")
173
- model_config_params["head"] = "LinearHead"
174
- model_config_params["head_config"] = {
175
- "layers": "", # No hidden layers in the head
176
- "output_dim": len(self.target_names)
177
- }
178
-
179
- self.model_config = model_config_class(**model_config_params)
180
-
181
- # --- 5. Create OptimizerConfig ---
182
- if optimizer_config_params is None:
183
- optimizer_config_params = {}
184
- self.optimizer_config = OptimizerConfig(**optimizer_config_params)
185
-
186
- # --- 6. Create TrainerConfig ---
187
- if trainer_config_params is None:
188
- trainer_config_params = {}
189
-
190
- # Default to GPU if available
191
- if "accelerator" not in trainer_config_params:
192
- if torch.cuda.is_available():
193
- trainer_config_params["accelerator"] = "cuda"
194
- elif torch.backends.mps.is_available():
195
- trainer_config_params["accelerator"] = "mps"
196
- else:
197
- trainer_config_params["accelerator"] = "cpu"
198
-
199
- # Set other sensible defaults
200
- if "checkpoints" not in trainer_config_params:
201
- trainer_config_params["checkpoints"] = "val_loss"
202
- trainer_config_params["load_best_at_end"] = True
203
-
204
- if "early_stopping" not in trainer_config_params:
205
- trainer_config_params["early_stopping"] = "val_loss"
206
-
207
- self.trainer_config = TrainerConfig(**trainer_config_params)
208
-
209
- # --- 7. Instantiate the TabularModel ---
210
- self.tabular_model = TabularModel(
211
- data_config=self.data_config,
212
- model_config=self.model_config,
213
- optimizer_config=self.optimizer_config,
214
- trainer_config=self.trainer_config,
215
- )
216
-
217
- def _dataset_to_dataframe(self, dataset: _PytorchDataset) -> pd.DataFrame:
218
- """Converts an _PytorchDataset back into a pandas DataFrame."""
219
- try:
220
- features_np = dataset.features.cpu().numpy()
221
- labels_np = dataset.labels.cpu().numpy()
222
- feature_names = dataset.feature_names
223
- target_names = dataset.target_names
224
- except Exception as e:
225
- _LOGGER.error(f"Failed to extract data from provided dataset: {e}")
226
- raise
227
-
228
- # Create features DataFrame
229
- df = pd.DataFrame(features_np, columns=feature_names)
230
-
231
- # Add labels
232
- if labels_np.ndim == 1:
233
- df[target_names[0]] = labels_np
234
- elif labels_np.ndim == 2:
235
- for i, name in enumerate(target_names):
236
- df[name] = labels_np[:, i]
237
-
238
- return df
239
-
240
- def fit(self,
241
- train_dataset: _PytorchDataset,
242
- test_dataset: _PytorchDataset,
243
- epochs: int = 20,
244
- batch_size: int = 10):
245
- """
246
- Trains the model using the provided datasets.
247
-
248
- Args:
249
- train_dataset (_PytorchDataset): The training dataset.
250
- test_dataset (_PytorchDataset): The validation dataset.
251
- epochs (int): The number of epochs to train for.
252
- batch_size (int): The batch size.
253
- """
254
- _LOGGER.info(f"Converting datasets to pandas DataFrame for {self.model_name}...")
255
- train_df = self._dataset_to_dataframe(train_dataset)
256
- test_df = self._dataset_to_dataframe(test_dataset)
257
-
258
- _LOGGER.info(f"Starting training for {epochs} epochs...")
259
- with warnings.catch_warnings():
260
- # Suppress abundant pytorch-lightning warnings
261
- warnings.simplefilter("ignore")
262
- self.tabular_model.fit(
263
- train=train_df,
264
- validation=test_df,
265
- max_epochs=epochs
266
- )
267
-
268
- self._is_fitted = True
269
- _LOGGER.info("Training complete.")
270
-
271
- def evaluate(self,
272
- save_dir: Union[str, Path],
273
- data: _PytorchDataset,
274
- classification_threshold: float = 0.5):
275
- """
276
- Evaluates the model and saves reports using the standard ML_evaluation functions.
277
-
278
- Args:
279
- save_dir (str | Path): Directory to save all reports and plots.
280
- data (_PytorchDataset): The data to evaluate on.
281
- classification_threshold (float): Threshold for multi-label tasks.
282
- """
283
- if not self._is_fitted:
284
- _LOGGER.error("Model is not fitted. Call .fit() first.")
285
- raise RuntimeError()
286
-
287
- print("\n--- Model Evaluation (PyTorch-Tabular) ---")
288
-
289
- eval_df = self._dataset_to_dataframe(data)
290
-
291
- # Get raw predictions from pytorch-tabular
292
- raw_preds_df = self.tabular_model.predict(
293
- eval_df,
294
- include_input_features=False
295
- )
296
-
297
- # Extract y_true from the dataframe
298
- y_true = eval_df[self.target_names].to_numpy()
299
-
300
- y_pred = None
301
- y_prob = None
302
-
303
- # --- Route based on task kind ---
304
-
305
- if self.kind == "regression":
306
- pred_col_name = f"{self.target_names[0]}_prediction"
307
- y_pred = raw_preds_df[pred_col_name].to_numpy()
308
- regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
309
-
310
- elif self.kind == "classification":
311
- y_pred = raw_preds_df["prediction"].to_numpy()
312
- # Get class names from the model's datamodule
313
- if self.tabular_model.datamodule is None:
314
- _LOGGER.error("Model's datamodule is not initialized. Cannot extract class names for probabilities.")
315
- raise RuntimeError("Datamodule not found. Was the model trained or loaded correctly?")
316
- class_names = self.tabular_model.datamodule.data_config.target_classes[self.target_names[0]]
317
- prob_cols = [f"{c}_probability" for c in class_names]
318
- y_prob = raw_preds_df[prob_cols].values
319
- classification_metrics(save_dir, y_true.flatten(), y_pred, y_prob)
320
-
321
- elif self.kind == "multi_target_regression":
322
- pred_cols = [f"{name}_prediction" for name in self.target_names]
323
- y_pred = raw_preds_df[pred_cols].to_numpy()
324
- multi_target_regression_metrics(y_true, y_pred, self.target_names, save_dir)
325
-
326
- elif self.kind == "multi_label_classification":
327
- prob_cols = [f"{name}_probability" for name in self.target_names]
328
- y_prob = raw_preds_df[prob_cols].to_numpy()
329
- # y_pred is derived from y_prob
330
- multi_label_classification_metrics(y_true, y_prob, self.target_names, save_dir, classification_threshold)
331
-
332
- def explain(self,
333
- save_dir: Union[str, Path],
334
- explain_dataset: _PytorchDataset):
335
- """
336
- Generates SHAP explanations and saves plots and summary CSVs.
337
-
338
- This method uses pytorch-tabular's internal `.explain()` method
339
- and then formats the output to match the ML_evaluation standard.
340
-
341
- Args:
342
- save_dir (str | Path): Directory to save all SHAP artifacts.
343
- explain_dataset (_PytorchDataset): The dataset to explain.
344
- """
345
- if not self._is_fitted:
346
- _LOGGER.error("Model is not fitted. Call .fit() first.")
347
- raise RuntimeError()
348
-
349
- print(f"\n--- SHAP Value Explanation ({self.model_name}) ---")
350
-
351
- explain_df = self._dataset_to_dataframe(explain_dataset)
352
-
353
- # We must use the dataframe *without* the target columns for explanation
354
- feature_df: pd.DataFrame = explain_df[self.schema.feature_names] # type: ignore
355
-
356
- # This returns a DataFrame (single-target) or Dict[str, DataFrame]
357
- with warnings.catch_warnings():
358
- warnings.simplefilter("ignore")
359
- shap_output = self.tabular_model.explain(feature_df)
360
-
361
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
362
- plt.ioff()
363
-
364
- # --- 1. Handle single-target (regression/classification) ---
365
- if isinstance(shap_output, pd.DataFrame):
366
- # shap_output is (n_samples, n_features)
367
- shap_values = shap_output.to_numpy()
368
-
369
- # Save Bar Plot
370
- self._save_shap_plots(
371
- shap_values=shap_values,
372
- instances_df=feature_df,
373
- save_dir=save_dir_path,
374
- suffix="" # No suffix for single target
375
- )
376
- # Save Summary Data
377
- self._save_shap_csv(
378
- shap_values=shap_values,
379
- feature_names=list(self.schema.feature_names),
380
- save_dir=save_dir_path,
381
- suffix=""
382
- )
383
-
384
- # --- 2. Handle multi-target ---
385
- elif isinstance(shap_output, dict):
386
- for target_name, shap_df in shap_output.items(): # type: ignore
387
- _LOGGER.info(f" -> Generating SHAP plots for target: '{target_name}'")
388
- shap_values = shap_df.values
389
- sanitized_name = sanitize_filename(target_name)
390
-
391
- # Save Bar Plot
392
- self._save_shap_plots(
393
- shap_values=shap_values,
394
- instances_df=feature_df,
395
- save_dir=save_dir_path,
396
- suffix=f"_{sanitized_name}",
397
- title_suffix=f" for '{target_name}'"
398
- )
399
- # Save Summary Data
400
- self._save_shap_csv(
401
- shap_values=shap_values,
402
- feature_names=list(self.schema.feature_names),
403
- save_dir=save_dir_path,
404
- suffix=f"_{sanitized_name}"
405
- )
406
-
407
- plt.ion()
408
- _LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
409
-
410
- def _save_shap_plots(self, shap_values: np.ndarray,
411
- instances_df: pd.DataFrame,
412
- save_dir: Path,
413
- suffix: str = "",
414
- title_suffix: str = ""):
415
- """Internal helper to save standard SHAP plots."""
416
- try:
417
- import shap
418
- except ImportError:
419
- _LOGGER.error("`shap` is required for plotting. Please install it: pip install shap")
420
- return
421
-
422
- # Save Bar Plot
423
- bar_path = save_dir / f"shap_bar_plot{suffix}.svg"
424
- shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
425
- ax = plt.gca()
426
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
427
- plt.title(f"SHAP Feature Importance{title_suffix}")
428
- plt.tight_layout()
429
- plt.savefig(bar_path)
430
- plt.close()
431
-
432
- # Save Dot Plot
433
- dot_path = save_dir / f"shap_dot_plot{suffix}.svg"
434
- shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
435
- ax = plt.gca()
436
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
437
- if plt.gcf().axes and len(plt.gcf().axes) > 1:
438
- cb = plt.gcf().axes[-1]
439
- cb.set_ylabel("", size=1)
440
- plt.title(f"SHAP Feature Importance{title_suffix}")
441
- plt.tight_layout()
442
- plt.savefig(dot_path)
443
- plt.close()
444
-
445
- def _save_shap_csv(self, shap_values: np.ndarray,
446
- feature_names: List[str],
447
- save_dir: Path,
448
- suffix: str = ""):
449
- """Internal helper to save standard SHAP summary CSV."""
450
-
451
- shap_summary_filename = f"{SHAPKeys.SAVENAME}{suffix}.csv"
452
- summary_path = save_dir / shap_summary_filename
453
-
454
- # Handle multi-class (list of arrays) vs. regression (single array)
455
- if isinstance(shap_values, list):
456
- mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
457
- else:
458
- mean_abs_shap = np.abs(shap_values).mean(axis=0)
459
-
460
- mean_abs_shap = mean_abs_shap.flatten()
461
-
462
- summary_df = pd.DataFrame({
463
- SHAPKeys.FEATURE_COLUMN: feature_names,
464
- SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
465
- }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
466
-
467
- summary_df.to_csv(summary_path, index=False)
468
-
469
- def save_model(self, directory: Union[str, Path]):
470
- """
471
- Saves the entire trained model, configuration, and datamodule
472
- to a directory.
473
-
474
- Args:
475
- directory (str | Path): The directory to save the model.
476
- The directory will be created.
477
- """
478
- if not self._is_fitted:
479
- _LOGGER.error("Cannot save a model that has not been fitted.")
480
- return
481
-
482
- save_path = make_fullpath(directory, make=True, enforce="directory")
483
- self.tabular_model.save_model(str(save_path))
484
- _LOGGER.info(f"Model saved to '{save_path.name}'")
485
-
486
- @classmethod
487
- def load_model(cls,
488
- directory: Union[str, Path],
489
- schema: FeatureSchema,
490
- target_names: List[str],
491
- kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"]
492
- ) -> 'PyTabularTrainer':
493
- """
494
- Loads a saved model and reconstructs the PyTabularTrainer wrapper.
495
-
496
- Note: The schema, target_names, and kind must be provided again
497
- as they are not serialized by pytorch-tabular.
498
-
499
- Args:
500
- directory (str | Path): The directory from which to load the model.
501
- schema (FeatureSchema): The schema used during original training.
502
- target_names (List[str]): The target names used during original training.
503
- kind (Literal[...]): The task 'kind' used during original training.
504
-
505
- Returns:
506
- PyTabularTrainer: A new instance of the trainer with the loaded model.
507
- """
508
- load_path = make_fullpath(directory, enforce="directory")
509
-
510
- _LOGGER.info(f"Loading model from '{load_path.name}'...")
511
-
512
- # Load the internal pytorch-tabular model
513
- loaded_tabular_model = TabularModel.load_model(str(load_path))
514
-
515
- if loaded_tabular_model.model is None:
516
- _LOGGER.error("Loaded model's internal '.model' attribute is None. Load failed.")
517
- raise RuntimeError("Loaded model is incomplete.")
518
-
519
- model_name = loaded_tabular_model.model._model_name
520
-
521
- if model_name.startswith("GANDALF"): # Handle GANDALF's dynamic name
522
- model_name = "GATE"
523
-
524
- # Re-create the wrapper
525
- wrapper = cls(
526
- schema=schema,
527
- target_names=target_names,
528
- kind=kind,
529
- model_name=model_name
530
- # Configs are already part of the loaded_tabular_model
531
- # We just need to pass the minimum to the __init__
532
- )
533
-
534
- # Overwrite the un-trained model with the loaded trained model
535
- wrapper.tabular_model = loaded_tabular_model
536
- wrapper._is_fitted = True
537
-
538
- _LOGGER.info(f"Successfully loaded '{model_name}' model.")
539
- return wrapper
540
-
541
-
542
- def info():
543
- _script_info(__all__)