dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,693 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.utils.data import DataLoader
4
- from typing import Union, Dict, Any, Literal
5
- from pathlib import Path
6
- import json
7
- import warnings
8
-
9
- from ._ML_models import _ArchitectureHandlerMixin
10
- from ._path_manager import make_fullpath
11
- from ._keys import PytorchModelArchitectureKeys
12
- from ._schema import FeatureSchema
13
- from ._script_info import _script_info
14
- from ._logger import get_logger
15
-
16
-
17
- _LOGGER = get_logger("Pytorch Tabular")
18
-
19
-
20
- # Imports from pytorch_tabular
21
- try:
22
- from omegaconf import DictConfig
23
- from pytorch_tabular.models import (
24
- GatedAdditiveTreeEnsembleModel as _GATE,
25
- NODEModel as _NODE,
26
- TabNetModel as _TabNet,
27
- AutoIntModel as _AutoInt
28
- )
29
- except ImportError:
30
- _LOGGER.error(f"GATE and NODE require 'pip install pytorch_tabular omegaconf' dependencies.")
31
- raise ImportError()
32
- else:
33
- # Silence pytorch_tabular INFO logs up to error level
34
- import logging
35
- logging.getLogger("pytorch_tabular").setLevel(logging.ERROR)
36
- logging.getLogger("pytorch_tabular.models.node.node_model").setLevel(logging.ERROR)
37
-
38
-
39
- __all__ = [
40
- "PyTabGateModel",
41
- "PyTabTabNet",
42
- "PyTabAutoInt",
43
- "PyTabNodeModel"
44
- ]
45
-
46
-
47
- class _BasePytabWrapper(nn.Module, _ArchitectureHandlerMixin):
48
- """
49
- Internal Base Class: Do not use directly.
50
-
51
- This is an adapter to make pytorch_tabular models compatible with the
52
- dragon-ml-toolbox pipeline.
53
- """
54
- def __init__(self, schema: FeatureSchema):
55
- super().__init__()
56
-
57
- self.schema = schema
58
- self.model_name = "Base" # To be overridden by child
59
- self.internal_model: nn.Module = None # type: ignore # To be set by child
60
- self.model_hparams: Dict = dict() # To be set by child
61
-
62
- # --- Derive indices from schema ---
63
- categorical_map = schema.categorical_index_map
64
-
65
- if categorical_map:
66
- # The order of keys/values is implicitly linked and must be preserved
67
- self.categorical_indices = list(categorical_map.keys())
68
- self.cardinalities = list(categorical_map.values())
69
- else:
70
- self.categorical_indices = []
71
- self.cardinalities = []
72
-
73
- # Derive numerical indices by finding what's not categorical
74
- all_indices = set(range(len(schema.feature_names)))
75
- categorical_indices_set = set(self.categorical_indices)
76
- self.numerical_indices = sorted(list(all_indices - categorical_indices_set))
77
-
78
- def _build_pt_config(self, out_targets: int, **kwargs) -> DictConfig:
79
- """Helper to create the minimal config dict for a pytorch_tabular model."""
80
- task = "regression"
81
-
82
- config_dict = {
83
- # --- Data / Schema Params ---
84
- 'task': task,
85
- 'continuous_cols': list(self.schema.continuous_feature_names),
86
- 'categorical_cols': list(self.schema.categorical_feature_names),
87
- 'continuous_dim': len(self.numerical_indices),
88
- 'categorical_dim': len(self.categorical_indices),
89
- 'categorical_cardinality': self.cardinalities,
90
- 'target': ['dummy_target'], # Required, but not used
91
-
92
- # --- Model Params ---
93
- 'output_dim': out_targets,
94
- 'target_range': None,
95
- **kwargs
96
- }
97
-
98
- if 'loss' not in config_dict:
99
- config_dict['loss'] = 'MSELoss' # Dummy
100
- if 'metrics' not in config_dict:
101
- config_dict['metrics'] = []
102
-
103
- return DictConfig(config_dict)
104
-
105
- def _build_inferred_config(self, out_targets: int, embedding_dim: int = None) -> DictConfig:
106
- """
107
- Helper to create the inferred_config required by pytorch_tabular v1.0+.
108
- Includes explicit embedding_dims calculation to satisfy BaseModel assertions.
109
- """
110
- # 1. Calculate embedding_dims list of tuples: [(cardinality, dim), ...]
111
- if self.categorical_indices:
112
- if embedding_dim is not None:
113
- # Use the user-provided fixed dimension for all categorical features
114
- embedding_dims = [(card, embedding_dim) for card in self.cardinalities]
115
- else:
116
- # Default heuristic: min(50, (card + 1) // 2)
117
- embedding_dims = [(card, min(50, (card + 1) // 2)) for card in self.cardinalities]
118
- else:
119
- embedding_dims = []
120
-
121
- # 2. Calculate the total dimension of concatenated embeddings
122
- # This fixes the 'Missing key embedded_cat_dim' error
123
- embedded_cat_dim = sum([dim for _, dim in embedding_dims])
124
-
125
- return DictConfig({
126
- "continuous_dim": len(self.numerical_indices),
127
- "categorical_dim": len(self.categorical_indices),
128
- "categorical_cardinality": self.cardinalities,
129
- "output_dim": out_targets,
130
- "embedding_dims": embedding_dims,
131
- "embedded_cat_dim": embedded_cat_dim,
132
- })
133
-
134
- def forward(self, x: torch.Tensor) -> torch.Tensor:
135
- """
136
- Accepts a single tensor and converts it to the dict
137
- that pytorch_tabular models expect.
138
- """
139
- x_cont = x[:, self.numerical_indices].float()
140
- x_cat = x[:, self.categorical_indices].long()
141
-
142
- input_dict = {
143
- 'continuous': x_cont,
144
- 'categorical': x_cat
145
- }
146
-
147
- model_output_dict = self.internal_model(input_dict)
148
- return model_output_dict['logits']
149
-
150
- def get_architecture_config(self) -> Dict[str, Any]:
151
- """Returns the full configuration of the model."""
152
- schema_dict = {
153
- 'feature_names': self.schema.feature_names,
154
- 'continuous_feature_names': self.schema.continuous_feature_names,
155
- 'categorical_feature_names': self.schema.categorical_feature_names,
156
- 'categorical_index_map': self.schema.categorical_index_map,
157
- 'categorical_mappings': self.schema.categorical_mappings
158
- }
159
-
160
- config = {
161
- 'schema_dict': schema_dict,
162
- 'out_targets': self.out_targets,
163
- **self.model_hparams
164
- }
165
- return config
166
-
167
- @classmethod
168
- def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
169
- """Loads a model architecture from a JSON file."""
170
- user_path = make_fullpath(file_or_dir)
171
-
172
- if user_path.is_dir():
173
- json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
174
- target_path = make_fullpath(user_path / json_filename, enforce="file")
175
- elif user_path.is_file():
176
- target_path = user_path
177
- else:
178
- _LOGGER.error(f"Invalid path: '{file_or_dir}'")
179
- raise IOError()
180
-
181
- with open(target_path, 'r') as f:
182
- saved_data = json.load(f)
183
-
184
- saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
185
- config = saved_data[PytorchModelArchitectureKeys.CONFIG]
186
-
187
- if saved_class_name != cls.__name__:
188
- _LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
189
- raise ValueError()
190
-
191
- # --- RECONSTRUCTION LOGIC ---
192
- if 'schema_dict' not in config:
193
- _LOGGER.error("Invalid architecture file: missing 'schema_dict'.")
194
- raise ValueError("Missing 'schema_dict' in config.")
195
-
196
- schema_data = config.pop('schema_dict')
197
-
198
- raw_index_map = schema_data['categorical_index_map']
199
- if raw_index_map is not None:
200
- rehydrated_index_map = {int(k): v for k, v in raw_index_map.items()}
201
- else:
202
- rehydrated_index_map = None
203
-
204
- schema = FeatureSchema(
205
- feature_names=tuple(schema_data['feature_names']),
206
- continuous_feature_names=tuple(schema_data['continuous_feature_names']),
207
- categorical_feature_names=tuple(schema_data['categorical_feature_names']),
208
- categorical_index_map=rehydrated_index_map,
209
- categorical_mappings=schema_data['categorical_mappings']
210
- )
211
-
212
- config['schema'] = schema
213
- # --- End Reconstruction ---
214
-
215
- model = cls(**config)
216
- if verbose:
217
- _LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
218
- return model
219
-
220
- def __repr__(self) -> str:
221
- internal_model_str = str(self.internal_model)
222
- internal_repr = internal_model_str.split('\n')[0]
223
- return f"{self.model_name}(internal_model={internal_repr})"
224
-
225
-
226
- class PyTabGateModel(_BasePytabWrapper):
227
- """
228
- Adapter for the Gated Additive Tree Ensemble (GATE) model.
229
- """
230
- def __init__(self, *,
231
- schema: FeatureSchema,
232
- out_targets: int,
233
- embedding_dim: int = 32,
234
- gflu_stages: int = 4,
235
- num_trees: int = 20,
236
- tree_depth: int = 4,
237
- dropout: float = 0.1):
238
- """
239
- Args:
240
- schema (FeatureSchema):
241
- The definitive schema object from data_exploration.
242
- out_targets (int):
243
- Number of output targets.
244
- embedding_dim (int):
245
- Dimension of the categorical embeddings. (Recommended: 16 to 64)
246
- gflu_stages (int):
247
- Number of Gated Feature Learning Units (GFLU) stages. (Recommended: 2 to 6)
248
- num_trees (int):
249
- Number of trees in the ensemble. (Recommended: 10 to 50)
250
- tree_depth (int):
251
- Depth of each tree. (Recommended: 4 to 6)
252
- dropout (float):
253
- Dropout rate for the GFLU.
254
- """
255
- super().__init__(schema)
256
-
257
- warnings.filterwarnings("ignore", message="Implicit dimension choice for softmax")
258
- warnings.filterwarnings("ignore", message="Ignoring head config")
259
-
260
-
261
- self.model_name = "PyTabGateModel"
262
- self.out_targets = out_targets
263
-
264
- self.model_hparams = {
265
- 'embedding_dim': embedding_dim,
266
- 'gflu_stages': gflu_stages,
267
- 'num_trees': num_trees,
268
- 'tree_depth': tree_depth,
269
- 'dropout': dropout
270
- }
271
-
272
- # Build Hyperparameter Config with defaults
273
- pt_config = self._build_pt_config(
274
- out_targets=out_targets,
275
- embedding_dim=embedding_dim,
276
-
277
- # GATE Specific Mappings
278
- gflu_stages=gflu_stages,
279
- num_trees=num_trees,
280
- tree_depth=tree_depth,
281
- gflu_dropout=dropout,
282
- tree_dropout=dropout,
283
- tree_wise_attention=True,
284
- tree_wise_attention_dropout=dropout,
285
-
286
- # GATE Defaults
287
- chain_trees=False,
288
- binning_activation="sigmoid",
289
- feature_mask_function="softmax",
290
- share_head_weights=True,
291
-
292
- # Sparsity
293
- gflu_feature_init_sparsity=0.3,
294
- tree_feature_init_sparsity=0.3,
295
- learnable_sparsity=True,
296
-
297
- # Head Configuration
298
- head="LinearHead",
299
- head_config={
300
- "layers": "",
301
- "activation": "ReLU",
302
- "dropout": 0.0,
303
- "use_batch_norm": False,
304
- "initialization": "kaiming"
305
- },
306
-
307
- # General Defaults (Required to prevent initialization errors)
308
- embedding_dropout=0.0,
309
- batch_norm_continuous_input=False,
310
- virtual_batch_size=None,
311
- learning_rate=1e-3,
312
- target_range=None,
313
- )
314
-
315
- # Build Data Inference Config (Required by PyTabular v1.0+)
316
- inferred_config = self._build_inferred_config(
317
- out_targets=out_targets,
318
- embedding_dim=embedding_dim
319
- )
320
-
321
- # Instantiate the internal pytorch_tabular model
322
- self.internal_model = _GATE(
323
- config=pt_config,
324
- inferred_config=inferred_config
325
- )
326
-
327
- def __repr__(self) -> str:
328
- return (f"{self.model_name}(\n"
329
- f" out_targets={self.out_targets},\n"
330
- f" embedding_dim={self.model_hparams.get('embedding_dim')},\n"
331
- f" gflu_stages={self.model_hparams.get('gflu_stages')},\n"
332
- f" num_trees={self.model_hparams.get('num_trees')},\n"
333
- f" tree_depth={self.model_hparams.get('tree_depth')},\n"
334
- f" dropout={self.model_hparams.get('dropout')}\n"
335
- ")")
336
-
337
-
338
- class PyTabTabNet(_BasePytabWrapper):
339
- """
340
- Adapter for Google's TabNet (Attentive Interpretable Tabular Learning).
341
-
342
- TabNet uses sequential attention to choose which features to reason
343
- from at each decision step, enabling interpretability.
344
- """
345
- def __init__(self, *,
346
- schema: FeatureSchema,
347
- out_targets: int,
348
- n_d: int = 8,
349
- n_a: int = 8,
350
- n_steps: int = 3,
351
- gamma: float = 1.3,
352
- n_independent: int = 2,
353
- n_shared: int = 2,
354
- virtual_batch_size: int = 128,
355
- mask_type: Literal['sparsemax', 'entmax', 'softmax'] = 'sparsemax'):
356
- """
357
- Args:
358
- schema (FeatureSchema): The definitive schema object.
359
- out_targets (int): Number of output targets.
360
- n_d (int): Dimension of the prediction layer (usually 8-64).
361
- n_a (int): Dimension of the attention layer (usually equal to n_d).
362
- n_steps (int): Number of sequential attention steps (usually 3-10).
363
- gamma (float): Relaxation parameter for sparsity (usually 1.0-2.0).
364
- n_independent (int): Number of independent GLU layers in each block.
365
- n_shared (int): Number of shared GLU layers in each block.
366
- virtual_batch_size (int): Batch size for Ghost Batch Normalization.
367
- mask_type (str): Masking function.
368
- - 'sparsemax' for sparse feature selection.
369
- - 'entmax' for moderately sparse selection.
370
- - 'softmax' for dense selection (safest).
371
- """
372
- super().__init__(schema)
373
- self.model_name = "PyTabTabNet"
374
- self.out_targets = out_targets
375
-
376
- self.model_hparams = {
377
- 'n_d': n_d,
378
- 'n_a': n_a,
379
- 'n_steps': n_steps,
380
- 'gamma': gamma,
381
- 'n_independent': n_independent,
382
- 'n_shared': n_shared,
383
- 'virtual_batch_size': virtual_batch_size,
384
- 'mask_type': mask_type
385
- }
386
-
387
- # TabNet does not use standard embeddings, so we don't pass embedding_dim
388
- pt_config = self._build_pt_config(
389
- out_targets=out_targets,
390
-
391
- # TabNet Params
392
- n_d=n_d,
393
- n_a=n_a,
394
- n_steps=n_steps,
395
- gamma=gamma,
396
- n_independent=n_independent,
397
- n_shared=n_shared,
398
- virtual_batch_size=virtual_batch_size,
399
-
400
- # TabNet Defaults
401
- mask_type=mask_type,
402
-
403
- # Head Configuration
404
- head="LinearHead",
405
- head_config={
406
- "layers": "",
407
- "activation": "ReLU",
408
- "dropout": 0.0,
409
- "use_batch_norm": False,
410
- "initialization": "kaiming"
411
- },
412
-
413
- # General Defaults
414
- batch_norm_continuous_input=False,
415
- learning_rate=1e-3
416
- )
417
-
418
- inferred_config = self._build_inferred_config(out_targets=out_targets)
419
-
420
- self.internal_model = _TabNet(
421
- config=pt_config,
422
- inferred_config=inferred_config
423
- )
424
-
425
- def __repr__(self) -> str:
426
- return (f"{self.model_name}(\n"
427
- f" out_targets={self.out_targets},\n"
428
- f" n_d={self.model_hparams.get('n_d')},\n"
429
- f" n_a={self.model_hparams.get('n_a')},\n"
430
- f" n_steps={self.model_hparams.get('n_steps')},\n"
431
- f" gamma={self.model_hparams.get('gamma')},\n"
432
- f" virtual_batch_size={self.model_hparams.get('virtual_batch_size')}\n"
433
- f" mask_type='{self.model_hparams.get('mask_type')}'\n"
434
- f")")
435
-
436
-
437
- class PyTabAutoInt(_BasePytabWrapper):
438
- """
439
- Adapter for AutoInt (Automatic Feature Interaction Learning).
440
-
441
- Uses Multi-Head Self-Attention to automatically learn high-order
442
- feature interactions.
443
- """
444
- def __init__(self, *,
445
- schema: FeatureSchema,
446
- out_targets: int,
447
- embedding_dim: int = 32,
448
- num_heads: int = 2,
449
- num_attn_blocks: int = 3,
450
- attn_dropout: float = 0.1,
451
- has_residuals: bool = True,
452
- deep_layers: bool = True,
453
- layers: str = "128-64-32"):
454
- """
455
- Args:
456
- schema (FeatureSchema): The definitive schema object.
457
- out_targets (int): Number of output targets.
458
- embedding_dim (int): Dimension of feature embeddings (attn_embed_dim).
459
- num_heads (int): Number of attention heads.
460
- num_attn_blocks (int): Number of attention layers.
461
- attn_dropout (float): Dropout between attention layers.
462
- has_residuals (bool): If True, adds residual connections.
463
- deep_layers (bool): If True, adds a standard MLP after attention.
464
- layers (str): Hyphen-separated layer sizes for the deep MLP part.
465
- """
466
- super().__init__(schema)
467
- self.model_name = "PyTabAutoInt"
468
- self.out_targets = out_targets
469
-
470
- self.model_hparams = {
471
- 'embedding_dim': embedding_dim,
472
- 'num_heads': num_heads,
473
- 'num_attn_blocks': num_attn_blocks,
474
- 'attn_dropout': attn_dropout,
475
- 'has_residuals': has_residuals,
476
- 'deep_layers': deep_layers,
477
- 'layers': layers
478
- }
479
-
480
- pt_config = self._build_pt_config(
481
- out_targets=out_targets,
482
-
483
- # AutoInt Params
484
- attn_embed_dim=embedding_dim,
485
- num_heads=num_heads,
486
- num_attn_blocks=num_attn_blocks,
487
- attn_dropouts=attn_dropout,
488
- has_residuals=has_residuals,
489
-
490
- # Deep MLP part (Optional in AutoInt, but usually good)
491
- deep_layers=deep_layers,
492
- layers=layers,
493
- activation="ReLU",
494
-
495
- # Head Configuration
496
- head="LinearHead",
497
- head_config={
498
- "layers": "",
499
- "activation": "ReLU",
500
- "dropout": 0.0,
501
- "use_batch_norm": False,
502
- "initialization": "kaiming"
503
- },
504
-
505
- # General Defaults
506
- embedding_dropout=0.0,
507
- batch_norm_continuous_input=False,
508
- learning_rate=1e-3
509
- )
510
-
511
- inferred_config = self._build_inferred_config(
512
- out_targets=out_targets,
513
- embedding_dim=embedding_dim
514
- )
515
-
516
- self.internal_model = _AutoInt(
517
- config=pt_config,
518
- inferred_config=inferred_config
519
- )
520
-
521
- def __repr__(self) -> str:
522
- return (f"{self.model_name}(\n"
523
- f" out_targets={self.out_targets},\n"
524
- f" embedding_dim={self.model_hparams.get('embedding_dim')},\n"
525
- f" num_heads={self.model_hparams.get('num_heads')},\n"
526
- f" num_attn_blocks={self.model_hparams.get('num_attn_blocks')},\n"
527
- f" deep_layers={self.model_hparams.get('deep_layers')}\n"
528
- f")")
529
-
530
-
531
- class PyTabNodeModel(_BasePytabWrapper):
532
- """
533
- Adapter for the Neural Oblivious Decision Ensembles (NODE) model.
534
- """
535
- def __init__(self, *,
536
- schema: FeatureSchema,
537
- out_targets: int,
538
- embedding_dim: int = 32,
539
- num_trees: int = 1024,
540
- num_layers: int = 2,
541
- tree_depth: int = 6,
542
- dropout: float = 0.1,
543
- backend_function: Literal['softmax', 'entmax15'] = 'softmax'):
544
- """
545
- Args:
546
- schema (FeatureSchema):
547
- The definitive schema object from data_exploration.
548
- out_targets (int):
549
- Number of output targets.
550
- embedding_dim (int):
551
- Dimension of the categorical embeddings. (Recommended: 16 to 64)
552
- num_trees (int):
553
- Total number of trees in the ensemble. (Recommended: 256 to 2048)
554
- num_layers (int):
555
- Number of NODE layers (stacked ensembles). (Recommended: 2 to 4)
556
- tree_depth (int):
557
- Depth of each tree. (Recommended: 4 to 8)
558
- dropout (float):
559
- Dropout rate.
560
- backend_function ('softmax' | 'entmax15'):
561
- Function for feature selection. 'entmax15' (sparse) or 'softmax' (dense).
562
- Use 'softmax' if dealing with convergence issues.
563
- """
564
- super().__init__(schema)
565
- self.model_name = "PyTabNodeModel"
566
- self.out_targets = out_targets
567
-
568
- warnings.filterwarnings("ignore", message="Ignoring head config because NODE has a specific head")
569
-
570
- self.model_hparams = {
571
- 'embedding_dim': embedding_dim,
572
- 'num_trees': num_trees,
573
- 'num_layers': num_layers,
574
- 'tree_depth': tree_depth,
575
- 'dropout': dropout,
576
- 'backend_function': backend_function
577
- }
578
-
579
- # Build Hyperparameter Config with ALL defaults
580
- pt_config = self._build_pt_config(
581
- out_targets=out_targets,
582
- embedding_dim=embedding_dim,
583
-
584
- # NODE Specific Mappings
585
- num_trees=num_trees,
586
- depth=tree_depth, # Map tree_depth -> depth
587
- num_layers=num_layers, # num_layers=1 for a single ensemble
588
- total_trees=num_trees,
589
- dropout_rate=dropout,
590
-
591
- # NODE Defaults (Manually populated to satisfy backbone requirements)
592
- additional_tree_output_dim=0,
593
- input_dropout=0.0,
594
- choice_function=backend_function,
595
- bin_function=backend_function,
596
- initialize_response="normal",
597
- initialize_selection_logits="uniform",
598
- threshold_init_beta=1.0,
599
- threshold_init_cutoff=1.0,
600
- max_features=None,
601
-
602
- # General Defaults (Required to prevent initialization errors)
603
- embedding_dropout=0.0,
604
- batch_norm_continuous_input=False,
605
- virtual_batch_size=None,
606
- learning_rate=1e-3,
607
-
608
- # NODE schema
609
- data_aware_init_batch_size=2000, # Required by NodeConfig schema
610
- augment_dim=0,
611
- )
612
-
613
- # Build Data Inference Config (Required by PyTabular v1.0+)
614
- inferred_config = self._build_inferred_config(
615
- out_targets=out_targets,
616
- embedding_dim=embedding_dim
617
- )
618
-
619
- # Instantiate the internal pytorch_tabular model
620
- self.internal_model = _NODE(
621
- config=pt_config,
622
- inferred_config=inferred_config
623
- )
624
-
625
- def perform_data_aware_initialization(self, train_dataset: Any, batch_size: int = 2000):
626
- """
627
- CRITICAL: Initializes NODE decision thresholds using a batch of data.
628
-
629
- Call this ONCE before training starts with a large batch (e.g., 2000 samples).
630
-
631
- Use the CPU
632
-
633
- Args:
634
- train_dataset: a PyTorch Dataset.
635
- batch_size: Number of samples to use for initialization.
636
- """
637
- # Use a DataLoader to robustly fetch a single batch
638
- loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
639
-
640
- try:
641
- batch = next(iter(loader))
642
- except StopIteration:
643
- _LOGGER.error("Dataset is empty. Cannot perform data-aware initialization.")
644
- return
645
-
646
- x_tensor, _ = batch
647
-
648
- # Prepare input dict
649
- # Prepare input dict matching pytorch_tabular expectations
650
- # Ensure we are on the same device as the model (CPU here)
651
- device = next(self.parameters()).device
652
- x_cont = x_tensor[:, self.numerical_indices].float().to(device)
653
- x_cat = x_tensor[:, self.categorical_indices].long().to(device)
654
-
655
- input_dict = {
656
- 'continuous': x_cont,
657
- 'categorical': x_cat
658
- }
659
-
660
- # --- MOCK DATA MODULE ---
661
- # datamodule.train_dataloader() -> yields the batch
662
- class _MockDataModule:
663
- def train_dataloader(self, batch_size=None):
664
- # Accepts 'batch_size' argument to satisfy the caller
665
- # Returns a list containing just the single pre-processed batch dictionary
666
- return [input_dict]
667
-
668
- mock_dm = _MockDataModule()
669
-
670
- _LOGGER.info(f"Running NODE Data-Aware Initialization with {batch_size} samples...")
671
- try:
672
- with torch.no_grad():
673
- # Call init on the BACKBONE, not the wrapper
674
- self.internal_model.data_aware_initialization(mock_dm)
675
- _LOGGER.info("NODE Initialization Complete. Ready to train.")
676
- except Exception as e:
677
- _LOGGER.error(f"Failed to initialize NODE model: {e}")
678
- raise e
679
-
680
- def __repr__(self) -> str:
681
- return (f"{self.model_name}(\n"
682
- f" out_targets={self.out_targets},\n"
683
- f" embedding_dim={self.model_hparams.get('embedding_dim')},\n"
684
- f" num_trees={self.model_hparams.get('num_trees')},\n"
685
- f" num_layers={self.model_hparams.get('num_layers')},\n"
686
- f" tree_depth={self.model_hparams.get('tree_depth')},\n"
687
- f" dropout={self.model_hparams.get('dropout')}\n"
688
- f" backend_function={self.model_hparams.get('backend_function')}\n"
689
- f")")
690
-
691
-
692
- def info():
693
- _script_info(__all__)