dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1909
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,693 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import nn
|
|
3
|
-
from torch.utils.data import DataLoader
|
|
4
|
-
from typing import Union, Dict, Any, Literal
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
import json
|
|
7
|
-
import warnings
|
|
8
|
-
|
|
9
|
-
from ._ML_models import _ArchitectureHandlerMixin
|
|
10
|
-
from ._path_manager import make_fullpath
|
|
11
|
-
from ._keys import PytorchModelArchitectureKeys
|
|
12
|
-
from ._schema import FeatureSchema
|
|
13
|
-
from ._script_info import _script_info
|
|
14
|
-
from ._logger import get_logger
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
_LOGGER = get_logger("Pytorch Tabular")
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
# Imports from pytorch_tabular
|
|
21
|
-
try:
|
|
22
|
-
from omegaconf import DictConfig
|
|
23
|
-
from pytorch_tabular.models import (
|
|
24
|
-
GatedAdditiveTreeEnsembleModel as _GATE,
|
|
25
|
-
NODEModel as _NODE,
|
|
26
|
-
TabNetModel as _TabNet,
|
|
27
|
-
AutoIntModel as _AutoInt
|
|
28
|
-
)
|
|
29
|
-
except ImportError:
|
|
30
|
-
_LOGGER.error(f"GATE and NODE require 'pip install pytorch_tabular omegaconf' dependencies.")
|
|
31
|
-
raise ImportError()
|
|
32
|
-
else:
|
|
33
|
-
# Silence pytorch_tabular INFO logs up to error level
|
|
34
|
-
import logging
|
|
35
|
-
logging.getLogger("pytorch_tabular").setLevel(logging.ERROR)
|
|
36
|
-
logging.getLogger("pytorch_tabular.models.node.node_model").setLevel(logging.ERROR)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
__all__ = [
|
|
40
|
-
"PyTabGateModel",
|
|
41
|
-
"PyTabTabNet",
|
|
42
|
-
"PyTabAutoInt",
|
|
43
|
-
"PyTabNodeModel"
|
|
44
|
-
]
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class _BasePytabWrapper(nn.Module, _ArchitectureHandlerMixin):
|
|
48
|
-
"""
|
|
49
|
-
Internal Base Class: Do not use directly.
|
|
50
|
-
|
|
51
|
-
This is an adapter to make pytorch_tabular models compatible with the
|
|
52
|
-
dragon-ml-toolbox pipeline.
|
|
53
|
-
"""
|
|
54
|
-
def __init__(self, schema: FeatureSchema):
|
|
55
|
-
super().__init__()
|
|
56
|
-
|
|
57
|
-
self.schema = schema
|
|
58
|
-
self.model_name = "Base" # To be overridden by child
|
|
59
|
-
self.internal_model: nn.Module = None # type: ignore # To be set by child
|
|
60
|
-
self.model_hparams: Dict = dict() # To be set by child
|
|
61
|
-
|
|
62
|
-
# --- Derive indices from schema ---
|
|
63
|
-
categorical_map = schema.categorical_index_map
|
|
64
|
-
|
|
65
|
-
if categorical_map:
|
|
66
|
-
# The order of keys/values is implicitly linked and must be preserved
|
|
67
|
-
self.categorical_indices = list(categorical_map.keys())
|
|
68
|
-
self.cardinalities = list(categorical_map.values())
|
|
69
|
-
else:
|
|
70
|
-
self.categorical_indices = []
|
|
71
|
-
self.cardinalities = []
|
|
72
|
-
|
|
73
|
-
# Derive numerical indices by finding what's not categorical
|
|
74
|
-
all_indices = set(range(len(schema.feature_names)))
|
|
75
|
-
categorical_indices_set = set(self.categorical_indices)
|
|
76
|
-
self.numerical_indices = sorted(list(all_indices - categorical_indices_set))
|
|
77
|
-
|
|
78
|
-
def _build_pt_config(self, out_targets: int, **kwargs) -> DictConfig:
|
|
79
|
-
"""Helper to create the minimal config dict for a pytorch_tabular model."""
|
|
80
|
-
task = "regression"
|
|
81
|
-
|
|
82
|
-
config_dict = {
|
|
83
|
-
# --- Data / Schema Params ---
|
|
84
|
-
'task': task,
|
|
85
|
-
'continuous_cols': list(self.schema.continuous_feature_names),
|
|
86
|
-
'categorical_cols': list(self.schema.categorical_feature_names),
|
|
87
|
-
'continuous_dim': len(self.numerical_indices),
|
|
88
|
-
'categorical_dim': len(self.categorical_indices),
|
|
89
|
-
'categorical_cardinality': self.cardinalities,
|
|
90
|
-
'target': ['dummy_target'], # Required, but not used
|
|
91
|
-
|
|
92
|
-
# --- Model Params ---
|
|
93
|
-
'output_dim': out_targets,
|
|
94
|
-
'target_range': None,
|
|
95
|
-
**kwargs
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
if 'loss' not in config_dict:
|
|
99
|
-
config_dict['loss'] = 'MSELoss' # Dummy
|
|
100
|
-
if 'metrics' not in config_dict:
|
|
101
|
-
config_dict['metrics'] = []
|
|
102
|
-
|
|
103
|
-
return DictConfig(config_dict)
|
|
104
|
-
|
|
105
|
-
def _build_inferred_config(self, out_targets: int, embedding_dim: int = None) -> DictConfig:
|
|
106
|
-
"""
|
|
107
|
-
Helper to create the inferred_config required by pytorch_tabular v1.0+.
|
|
108
|
-
Includes explicit embedding_dims calculation to satisfy BaseModel assertions.
|
|
109
|
-
"""
|
|
110
|
-
# 1. Calculate embedding_dims list of tuples: [(cardinality, dim), ...]
|
|
111
|
-
if self.categorical_indices:
|
|
112
|
-
if embedding_dim is not None:
|
|
113
|
-
# Use the user-provided fixed dimension for all categorical features
|
|
114
|
-
embedding_dims = [(card, embedding_dim) for card in self.cardinalities]
|
|
115
|
-
else:
|
|
116
|
-
# Default heuristic: min(50, (card + 1) // 2)
|
|
117
|
-
embedding_dims = [(card, min(50, (card + 1) // 2)) for card in self.cardinalities]
|
|
118
|
-
else:
|
|
119
|
-
embedding_dims = []
|
|
120
|
-
|
|
121
|
-
# 2. Calculate the total dimension of concatenated embeddings
|
|
122
|
-
# This fixes the 'Missing key embedded_cat_dim' error
|
|
123
|
-
embedded_cat_dim = sum([dim for _, dim in embedding_dims])
|
|
124
|
-
|
|
125
|
-
return DictConfig({
|
|
126
|
-
"continuous_dim": len(self.numerical_indices),
|
|
127
|
-
"categorical_dim": len(self.categorical_indices),
|
|
128
|
-
"categorical_cardinality": self.cardinalities,
|
|
129
|
-
"output_dim": out_targets,
|
|
130
|
-
"embedding_dims": embedding_dims,
|
|
131
|
-
"embedded_cat_dim": embedded_cat_dim,
|
|
132
|
-
})
|
|
133
|
-
|
|
134
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
135
|
-
"""
|
|
136
|
-
Accepts a single tensor and converts it to the dict
|
|
137
|
-
that pytorch_tabular models expect.
|
|
138
|
-
"""
|
|
139
|
-
x_cont = x[:, self.numerical_indices].float()
|
|
140
|
-
x_cat = x[:, self.categorical_indices].long()
|
|
141
|
-
|
|
142
|
-
input_dict = {
|
|
143
|
-
'continuous': x_cont,
|
|
144
|
-
'categorical': x_cat
|
|
145
|
-
}
|
|
146
|
-
|
|
147
|
-
model_output_dict = self.internal_model(input_dict)
|
|
148
|
-
return model_output_dict['logits']
|
|
149
|
-
|
|
150
|
-
def get_architecture_config(self) -> Dict[str, Any]:
|
|
151
|
-
"""Returns the full configuration of the model."""
|
|
152
|
-
schema_dict = {
|
|
153
|
-
'feature_names': self.schema.feature_names,
|
|
154
|
-
'continuous_feature_names': self.schema.continuous_feature_names,
|
|
155
|
-
'categorical_feature_names': self.schema.categorical_feature_names,
|
|
156
|
-
'categorical_index_map': self.schema.categorical_index_map,
|
|
157
|
-
'categorical_mappings': self.schema.categorical_mappings
|
|
158
|
-
}
|
|
159
|
-
|
|
160
|
-
config = {
|
|
161
|
-
'schema_dict': schema_dict,
|
|
162
|
-
'out_targets': self.out_targets,
|
|
163
|
-
**self.model_hparams
|
|
164
|
-
}
|
|
165
|
-
return config
|
|
166
|
-
|
|
167
|
-
@classmethod
|
|
168
|
-
def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
|
|
169
|
-
"""Loads a model architecture from a JSON file."""
|
|
170
|
-
user_path = make_fullpath(file_or_dir)
|
|
171
|
-
|
|
172
|
-
if user_path.is_dir():
|
|
173
|
-
json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
|
|
174
|
-
target_path = make_fullpath(user_path / json_filename, enforce="file")
|
|
175
|
-
elif user_path.is_file():
|
|
176
|
-
target_path = user_path
|
|
177
|
-
else:
|
|
178
|
-
_LOGGER.error(f"Invalid path: '{file_or_dir}'")
|
|
179
|
-
raise IOError()
|
|
180
|
-
|
|
181
|
-
with open(target_path, 'r') as f:
|
|
182
|
-
saved_data = json.load(f)
|
|
183
|
-
|
|
184
|
-
saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
|
|
185
|
-
config = saved_data[PytorchModelArchitectureKeys.CONFIG]
|
|
186
|
-
|
|
187
|
-
if saved_class_name != cls.__name__:
|
|
188
|
-
_LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
|
|
189
|
-
raise ValueError()
|
|
190
|
-
|
|
191
|
-
# --- RECONSTRUCTION LOGIC ---
|
|
192
|
-
if 'schema_dict' not in config:
|
|
193
|
-
_LOGGER.error("Invalid architecture file: missing 'schema_dict'.")
|
|
194
|
-
raise ValueError("Missing 'schema_dict' in config.")
|
|
195
|
-
|
|
196
|
-
schema_data = config.pop('schema_dict')
|
|
197
|
-
|
|
198
|
-
raw_index_map = schema_data['categorical_index_map']
|
|
199
|
-
if raw_index_map is not None:
|
|
200
|
-
rehydrated_index_map = {int(k): v for k, v in raw_index_map.items()}
|
|
201
|
-
else:
|
|
202
|
-
rehydrated_index_map = None
|
|
203
|
-
|
|
204
|
-
schema = FeatureSchema(
|
|
205
|
-
feature_names=tuple(schema_data['feature_names']),
|
|
206
|
-
continuous_feature_names=tuple(schema_data['continuous_feature_names']),
|
|
207
|
-
categorical_feature_names=tuple(schema_data['categorical_feature_names']),
|
|
208
|
-
categorical_index_map=rehydrated_index_map,
|
|
209
|
-
categorical_mappings=schema_data['categorical_mappings']
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
config['schema'] = schema
|
|
213
|
-
# --- End Reconstruction ---
|
|
214
|
-
|
|
215
|
-
model = cls(**config)
|
|
216
|
-
if verbose:
|
|
217
|
-
_LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
|
|
218
|
-
return model
|
|
219
|
-
|
|
220
|
-
def __repr__(self) -> str:
|
|
221
|
-
internal_model_str = str(self.internal_model)
|
|
222
|
-
internal_repr = internal_model_str.split('\n')[0]
|
|
223
|
-
return f"{self.model_name}(internal_model={internal_repr})"
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
class PyTabGateModel(_BasePytabWrapper):
|
|
227
|
-
"""
|
|
228
|
-
Adapter for the Gated Additive Tree Ensemble (GATE) model.
|
|
229
|
-
"""
|
|
230
|
-
def __init__(self, *,
|
|
231
|
-
schema: FeatureSchema,
|
|
232
|
-
out_targets: int,
|
|
233
|
-
embedding_dim: int = 32,
|
|
234
|
-
gflu_stages: int = 4,
|
|
235
|
-
num_trees: int = 20,
|
|
236
|
-
tree_depth: int = 4,
|
|
237
|
-
dropout: float = 0.1):
|
|
238
|
-
"""
|
|
239
|
-
Args:
|
|
240
|
-
schema (FeatureSchema):
|
|
241
|
-
The definitive schema object from data_exploration.
|
|
242
|
-
out_targets (int):
|
|
243
|
-
Number of output targets.
|
|
244
|
-
embedding_dim (int):
|
|
245
|
-
Dimension of the categorical embeddings. (Recommended: 16 to 64)
|
|
246
|
-
gflu_stages (int):
|
|
247
|
-
Number of Gated Feature Learning Units (GFLU) stages. (Recommended: 2 to 6)
|
|
248
|
-
num_trees (int):
|
|
249
|
-
Number of trees in the ensemble. (Recommended: 10 to 50)
|
|
250
|
-
tree_depth (int):
|
|
251
|
-
Depth of each tree. (Recommended: 4 to 6)
|
|
252
|
-
dropout (float):
|
|
253
|
-
Dropout rate for the GFLU.
|
|
254
|
-
"""
|
|
255
|
-
super().__init__(schema)
|
|
256
|
-
|
|
257
|
-
warnings.filterwarnings("ignore", message="Implicit dimension choice for softmax")
|
|
258
|
-
warnings.filterwarnings("ignore", message="Ignoring head config")
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
self.model_name = "PyTabGateModel"
|
|
262
|
-
self.out_targets = out_targets
|
|
263
|
-
|
|
264
|
-
self.model_hparams = {
|
|
265
|
-
'embedding_dim': embedding_dim,
|
|
266
|
-
'gflu_stages': gflu_stages,
|
|
267
|
-
'num_trees': num_trees,
|
|
268
|
-
'tree_depth': tree_depth,
|
|
269
|
-
'dropout': dropout
|
|
270
|
-
}
|
|
271
|
-
|
|
272
|
-
# Build Hyperparameter Config with defaults
|
|
273
|
-
pt_config = self._build_pt_config(
|
|
274
|
-
out_targets=out_targets,
|
|
275
|
-
embedding_dim=embedding_dim,
|
|
276
|
-
|
|
277
|
-
# GATE Specific Mappings
|
|
278
|
-
gflu_stages=gflu_stages,
|
|
279
|
-
num_trees=num_trees,
|
|
280
|
-
tree_depth=tree_depth,
|
|
281
|
-
gflu_dropout=dropout,
|
|
282
|
-
tree_dropout=dropout,
|
|
283
|
-
tree_wise_attention=True,
|
|
284
|
-
tree_wise_attention_dropout=dropout,
|
|
285
|
-
|
|
286
|
-
# GATE Defaults
|
|
287
|
-
chain_trees=False,
|
|
288
|
-
binning_activation="sigmoid",
|
|
289
|
-
feature_mask_function="softmax",
|
|
290
|
-
share_head_weights=True,
|
|
291
|
-
|
|
292
|
-
# Sparsity
|
|
293
|
-
gflu_feature_init_sparsity=0.3,
|
|
294
|
-
tree_feature_init_sparsity=0.3,
|
|
295
|
-
learnable_sparsity=True,
|
|
296
|
-
|
|
297
|
-
# Head Configuration
|
|
298
|
-
head="LinearHead",
|
|
299
|
-
head_config={
|
|
300
|
-
"layers": "",
|
|
301
|
-
"activation": "ReLU",
|
|
302
|
-
"dropout": 0.0,
|
|
303
|
-
"use_batch_norm": False,
|
|
304
|
-
"initialization": "kaiming"
|
|
305
|
-
},
|
|
306
|
-
|
|
307
|
-
# General Defaults (Required to prevent initialization errors)
|
|
308
|
-
embedding_dropout=0.0,
|
|
309
|
-
batch_norm_continuous_input=False,
|
|
310
|
-
virtual_batch_size=None,
|
|
311
|
-
learning_rate=1e-3,
|
|
312
|
-
target_range=None,
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
# Build Data Inference Config (Required by PyTabular v1.0+)
|
|
316
|
-
inferred_config = self._build_inferred_config(
|
|
317
|
-
out_targets=out_targets,
|
|
318
|
-
embedding_dim=embedding_dim
|
|
319
|
-
)
|
|
320
|
-
|
|
321
|
-
# Instantiate the internal pytorch_tabular model
|
|
322
|
-
self.internal_model = _GATE(
|
|
323
|
-
config=pt_config,
|
|
324
|
-
inferred_config=inferred_config
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
def __repr__(self) -> str:
|
|
328
|
-
return (f"{self.model_name}(\n"
|
|
329
|
-
f" out_targets={self.out_targets},\n"
|
|
330
|
-
f" embedding_dim={self.model_hparams.get('embedding_dim')},\n"
|
|
331
|
-
f" gflu_stages={self.model_hparams.get('gflu_stages')},\n"
|
|
332
|
-
f" num_trees={self.model_hparams.get('num_trees')},\n"
|
|
333
|
-
f" tree_depth={self.model_hparams.get('tree_depth')},\n"
|
|
334
|
-
f" dropout={self.model_hparams.get('dropout')}\n"
|
|
335
|
-
")")
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
class PyTabTabNet(_BasePytabWrapper):
|
|
339
|
-
"""
|
|
340
|
-
Adapter for Google's TabNet (Attentive Interpretable Tabular Learning).
|
|
341
|
-
|
|
342
|
-
TabNet uses sequential attention to choose which features to reason
|
|
343
|
-
from at each decision step, enabling interpretability.
|
|
344
|
-
"""
|
|
345
|
-
def __init__(self, *,
|
|
346
|
-
schema: FeatureSchema,
|
|
347
|
-
out_targets: int,
|
|
348
|
-
n_d: int = 8,
|
|
349
|
-
n_a: int = 8,
|
|
350
|
-
n_steps: int = 3,
|
|
351
|
-
gamma: float = 1.3,
|
|
352
|
-
n_independent: int = 2,
|
|
353
|
-
n_shared: int = 2,
|
|
354
|
-
virtual_batch_size: int = 128,
|
|
355
|
-
mask_type: Literal['sparsemax', 'entmax', 'softmax'] = 'sparsemax'):
|
|
356
|
-
"""
|
|
357
|
-
Args:
|
|
358
|
-
schema (FeatureSchema): The definitive schema object.
|
|
359
|
-
out_targets (int): Number of output targets.
|
|
360
|
-
n_d (int): Dimension of the prediction layer (usually 8-64).
|
|
361
|
-
n_a (int): Dimension of the attention layer (usually equal to n_d).
|
|
362
|
-
n_steps (int): Number of sequential attention steps (usually 3-10).
|
|
363
|
-
gamma (float): Relaxation parameter for sparsity (usually 1.0-2.0).
|
|
364
|
-
n_independent (int): Number of independent GLU layers in each block.
|
|
365
|
-
n_shared (int): Number of shared GLU layers in each block.
|
|
366
|
-
virtual_batch_size (int): Batch size for Ghost Batch Normalization.
|
|
367
|
-
mask_type (str): Masking function.
|
|
368
|
-
- 'sparsemax' for sparse feature selection.
|
|
369
|
-
- 'entmax' for moderately sparse selection.
|
|
370
|
-
- 'softmax' for dense selection (safest).
|
|
371
|
-
"""
|
|
372
|
-
super().__init__(schema)
|
|
373
|
-
self.model_name = "PyTabTabNet"
|
|
374
|
-
self.out_targets = out_targets
|
|
375
|
-
|
|
376
|
-
self.model_hparams = {
|
|
377
|
-
'n_d': n_d,
|
|
378
|
-
'n_a': n_a,
|
|
379
|
-
'n_steps': n_steps,
|
|
380
|
-
'gamma': gamma,
|
|
381
|
-
'n_independent': n_independent,
|
|
382
|
-
'n_shared': n_shared,
|
|
383
|
-
'virtual_batch_size': virtual_batch_size,
|
|
384
|
-
'mask_type': mask_type
|
|
385
|
-
}
|
|
386
|
-
|
|
387
|
-
# TabNet does not use standard embeddings, so we don't pass embedding_dim
|
|
388
|
-
pt_config = self._build_pt_config(
|
|
389
|
-
out_targets=out_targets,
|
|
390
|
-
|
|
391
|
-
# TabNet Params
|
|
392
|
-
n_d=n_d,
|
|
393
|
-
n_a=n_a,
|
|
394
|
-
n_steps=n_steps,
|
|
395
|
-
gamma=gamma,
|
|
396
|
-
n_independent=n_independent,
|
|
397
|
-
n_shared=n_shared,
|
|
398
|
-
virtual_batch_size=virtual_batch_size,
|
|
399
|
-
|
|
400
|
-
# TabNet Defaults
|
|
401
|
-
mask_type=mask_type,
|
|
402
|
-
|
|
403
|
-
# Head Configuration
|
|
404
|
-
head="LinearHead",
|
|
405
|
-
head_config={
|
|
406
|
-
"layers": "",
|
|
407
|
-
"activation": "ReLU",
|
|
408
|
-
"dropout": 0.0,
|
|
409
|
-
"use_batch_norm": False,
|
|
410
|
-
"initialization": "kaiming"
|
|
411
|
-
},
|
|
412
|
-
|
|
413
|
-
# General Defaults
|
|
414
|
-
batch_norm_continuous_input=False,
|
|
415
|
-
learning_rate=1e-3
|
|
416
|
-
)
|
|
417
|
-
|
|
418
|
-
inferred_config = self._build_inferred_config(out_targets=out_targets)
|
|
419
|
-
|
|
420
|
-
self.internal_model = _TabNet(
|
|
421
|
-
config=pt_config,
|
|
422
|
-
inferred_config=inferred_config
|
|
423
|
-
)
|
|
424
|
-
|
|
425
|
-
def __repr__(self) -> str:
|
|
426
|
-
return (f"{self.model_name}(\n"
|
|
427
|
-
f" out_targets={self.out_targets},\n"
|
|
428
|
-
f" n_d={self.model_hparams.get('n_d')},\n"
|
|
429
|
-
f" n_a={self.model_hparams.get('n_a')},\n"
|
|
430
|
-
f" n_steps={self.model_hparams.get('n_steps')},\n"
|
|
431
|
-
f" gamma={self.model_hparams.get('gamma')},\n"
|
|
432
|
-
f" virtual_batch_size={self.model_hparams.get('virtual_batch_size')}\n"
|
|
433
|
-
f" mask_type='{self.model_hparams.get('mask_type')}'\n"
|
|
434
|
-
f")")
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
class PyTabAutoInt(_BasePytabWrapper):
|
|
438
|
-
"""
|
|
439
|
-
Adapter for AutoInt (Automatic Feature Interaction Learning).
|
|
440
|
-
|
|
441
|
-
Uses Multi-Head Self-Attention to automatically learn high-order
|
|
442
|
-
feature interactions.
|
|
443
|
-
"""
|
|
444
|
-
def __init__(self, *,
|
|
445
|
-
schema: FeatureSchema,
|
|
446
|
-
out_targets: int,
|
|
447
|
-
embedding_dim: int = 32,
|
|
448
|
-
num_heads: int = 2,
|
|
449
|
-
num_attn_blocks: int = 3,
|
|
450
|
-
attn_dropout: float = 0.1,
|
|
451
|
-
has_residuals: bool = True,
|
|
452
|
-
deep_layers: bool = True,
|
|
453
|
-
layers: str = "128-64-32"):
|
|
454
|
-
"""
|
|
455
|
-
Args:
|
|
456
|
-
schema (FeatureSchema): The definitive schema object.
|
|
457
|
-
out_targets (int): Number of output targets.
|
|
458
|
-
embedding_dim (int): Dimension of feature embeddings (attn_embed_dim).
|
|
459
|
-
num_heads (int): Number of attention heads.
|
|
460
|
-
num_attn_blocks (int): Number of attention layers.
|
|
461
|
-
attn_dropout (float): Dropout between attention layers.
|
|
462
|
-
has_residuals (bool): If True, adds residual connections.
|
|
463
|
-
deep_layers (bool): If True, adds a standard MLP after attention.
|
|
464
|
-
layers (str): Hyphen-separated layer sizes for the deep MLP part.
|
|
465
|
-
"""
|
|
466
|
-
super().__init__(schema)
|
|
467
|
-
self.model_name = "PyTabAutoInt"
|
|
468
|
-
self.out_targets = out_targets
|
|
469
|
-
|
|
470
|
-
self.model_hparams = {
|
|
471
|
-
'embedding_dim': embedding_dim,
|
|
472
|
-
'num_heads': num_heads,
|
|
473
|
-
'num_attn_blocks': num_attn_blocks,
|
|
474
|
-
'attn_dropout': attn_dropout,
|
|
475
|
-
'has_residuals': has_residuals,
|
|
476
|
-
'deep_layers': deep_layers,
|
|
477
|
-
'layers': layers
|
|
478
|
-
}
|
|
479
|
-
|
|
480
|
-
pt_config = self._build_pt_config(
|
|
481
|
-
out_targets=out_targets,
|
|
482
|
-
|
|
483
|
-
# AutoInt Params
|
|
484
|
-
attn_embed_dim=embedding_dim,
|
|
485
|
-
num_heads=num_heads,
|
|
486
|
-
num_attn_blocks=num_attn_blocks,
|
|
487
|
-
attn_dropouts=attn_dropout,
|
|
488
|
-
has_residuals=has_residuals,
|
|
489
|
-
|
|
490
|
-
# Deep MLP part (Optional in AutoInt, but usually good)
|
|
491
|
-
deep_layers=deep_layers,
|
|
492
|
-
layers=layers,
|
|
493
|
-
activation="ReLU",
|
|
494
|
-
|
|
495
|
-
# Head Configuration
|
|
496
|
-
head="LinearHead",
|
|
497
|
-
head_config={
|
|
498
|
-
"layers": "",
|
|
499
|
-
"activation": "ReLU",
|
|
500
|
-
"dropout": 0.0,
|
|
501
|
-
"use_batch_norm": False,
|
|
502
|
-
"initialization": "kaiming"
|
|
503
|
-
},
|
|
504
|
-
|
|
505
|
-
# General Defaults
|
|
506
|
-
embedding_dropout=0.0,
|
|
507
|
-
batch_norm_continuous_input=False,
|
|
508
|
-
learning_rate=1e-3
|
|
509
|
-
)
|
|
510
|
-
|
|
511
|
-
inferred_config = self._build_inferred_config(
|
|
512
|
-
out_targets=out_targets,
|
|
513
|
-
embedding_dim=embedding_dim
|
|
514
|
-
)
|
|
515
|
-
|
|
516
|
-
self.internal_model = _AutoInt(
|
|
517
|
-
config=pt_config,
|
|
518
|
-
inferred_config=inferred_config
|
|
519
|
-
)
|
|
520
|
-
|
|
521
|
-
def __repr__(self) -> str:
|
|
522
|
-
return (f"{self.model_name}(\n"
|
|
523
|
-
f" out_targets={self.out_targets},\n"
|
|
524
|
-
f" embedding_dim={self.model_hparams.get('embedding_dim')},\n"
|
|
525
|
-
f" num_heads={self.model_hparams.get('num_heads')},\n"
|
|
526
|
-
f" num_attn_blocks={self.model_hparams.get('num_attn_blocks')},\n"
|
|
527
|
-
f" deep_layers={self.model_hparams.get('deep_layers')}\n"
|
|
528
|
-
f")")
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
class PyTabNodeModel(_BasePytabWrapper):
|
|
532
|
-
"""
|
|
533
|
-
Adapter for the Neural Oblivious Decision Ensembles (NODE) model.
|
|
534
|
-
"""
|
|
535
|
-
def __init__(self, *,
|
|
536
|
-
schema: FeatureSchema,
|
|
537
|
-
out_targets: int,
|
|
538
|
-
embedding_dim: int = 32,
|
|
539
|
-
num_trees: int = 1024,
|
|
540
|
-
num_layers: int = 2,
|
|
541
|
-
tree_depth: int = 6,
|
|
542
|
-
dropout: float = 0.1,
|
|
543
|
-
backend_function: Literal['softmax', 'entmax15'] = 'softmax'):
|
|
544
|
-
"""
|
|
545
|
-
Args:
|
|
546
|
-
schema (FeatureSchema):
|
|
547
|
-
The definitive schema object from data_exploration.
|
|
548
|
-
out_targets (int):
|
|
549
|
-
Number of output targets.
|
|
550
|
-
embedding_dim (int):
|
|
551
|
-
Dimension of the categorical embeddings. (Recommended: 16 to 64)
|
|
552
|
-
num_trees (int):
|
|
553
|
-
Total number of trees in the ensemble. (Recommended: 256 to 2048)
|
|
554
|
-
num_layers (int):
|
|
555
|
-
Number of NODE layers (stacked ensembles). (Recommended: 2 to 4)
|
|
556
|
-
tree_depth (int):
|
|
557
|
-
Depth of each tree. (Recommended: 4 to 8)
|
|
558
|
-
dropout (float):
|
|
559
|
-
Dropout rate.
|
|
560
|
-
backend_function ('softmax' | 'entmax15'):
|
|
561
|
-
Function for feature selection. 'entmax15' (sparse) or 'softmax' (dense).
|
|
562
|
-
Use 'softmax' if dealing with convergence issues.
|
|
563
|
-
"""
|
|
564
|
-
super().__init__(schema)
|
|
565
|
-
self.model_name = "PyTabNodeModel"
|
|
566
|
-
self.out_targets = out_targets
|
|
567
|
-
|
|
568
|
-
warnings.filterwarnings("ignore", message="Ignoring head config because NODE has a specific head")
|
|
569
|
-
|
|
570
|
-
self.model_hparams = {
|
|
571
|
-
'embedding_dim': embedding_dim,
|
|
572
|
-
'num_trees': num_trees,
|
|
573
|
-
'num_layers': num_layers,
|
|
574
|
-
'tree_depth': tree_depth,
|
|
575
|
-
'dropout': dropout,
|
|
576
|
-
'backend_function': backend_function
|
|
577
|
-
}
|
|
578
|
-
|
|
579
|
-
# Build Hyperparameter Config with ALL defaults
|
|
580
|
-
pt_config = self._build_pt_config(
|
|
581
|
-
out_targets=out_targets,
|
|
582
|
-
embedding_dim=embedding_dim,
|
|
583
|
-
|
|
584
|
-
# NODE Specific Mappings
|
|
585
|
-
num_trees=num_trees,
|
|
586
|
-
depth=tree_depth, # Map tree_depth -> depth
|
|
587
|
-
num_layers=num_layers, # num_layers=1 for a single ensemble
|
|
588
|
-
total_trees=num_trees,
|
|
589
|
-
dropout_rate=dropout,
|
|
590
|
-
|
|
591
|
-
# NODE Defaults (Manually populated to satisfy backbone requirements)
|
|
592
|
-
additional_tree_output_dim=0,
|
|
593
|
-
input_dropout=0.0,
|
|
594
|
-
choice_function=backend_function,
|
|
595
|
-
bin_function=backend_function,
|
|
596
|
-
initialize_response="normal",
|
|
597
|
-
initialize_selection_logits="uniform",
|
|
598
|
-
threshold_init_beta=1.0,
|
|
599
|
-
threshold_init_cutoff=1.0,
|
|
600
|
-
max_features=None,
|
|
601
|
-
|
|
602
|
-
# General Defaults (Required to prevent initialization errors)
|
|
603
|
-
embedding_dropout=0.0,
|
|
604
|
-
batch_norm_continuous_input=False,
|
|
605
|
-
virtual_batch_size=None,
|
|
606
|
-
learning_rate=1e-3,
|
|
607
|
-
|
|
608
|
-
# NODE schema
|
|
609
|
-
data_aware_init_batch_size=2000, # Required by NodeConfig schema
|
|
610
|
-
augment_dim=0,
|
|
611
|
-
)
|
|
612
|
-
|
|
613
|
-
# Build Data Inference Config (Required by PyTabular v1.0+)
|
|
614
|
-
inferred_config = self._build_inferred_config(
|
|
615
|
-
out_targets=out_targets,
|
|
616
|
-
embedding_dim=embedding_dim
|
|
617
|
-
)
|
|
618
|
-
|
|
619
|
-
# Instantiate the internal pytorch_tabular model
|
|
620
|
-
self.internal_model = _NODE(
|
|
621
|
-
config=pt_config,
|
|
622
|
-
inferred_config=inferred_config
|
|
623
|
-
)
|
|
624
|
-
|
|
625
|
-
def perform_data_aware_initialization(self, train_dataset: Any, batch_size: int = 2000):
|
|
626
|
-
"""
|
|
627
|
-
CRITICAL: Initializes NODE decision thresholds using a batch of data.
|
|
628
|
-
|
|
629
|
-
Call this ONCE before training starts with a large batch (e.g., 2000 samples).
|
|
630
|
-
|
|
631
|
-
Use the CPU
|
|
632
|
-
|
|
633
|
-
Args:
|
|
634
|
-
train_dataset: a PyTorch Dataset.
|
|
635
|
-
batch_size: Number of samples to use for initialization.
|
|
636
|
-
"""
|
|
637
|
-
# Use a DataLoader to robustly fetch a single batch
|
|
638
|
-
loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
639
|
-
|
|
640
|
-
try:
|
|
641
|
-
batch = next(iter(loader))
|
|
642
|
-
except StopIteration:
|
|
643
|
-
_LOGGER.error("Dataset is empty. Cannot perform data-aware initialization.")
|
|
644
|
-
return
|
|
645
|
-
|
|
646
|
-
x_tensor, _ = batch
|
|
647
|
-
|
|
648
|
-
# Prepare input dict
|
|
649
|
-
# Prepare input dict matching pytorch_tabular expectations
|
|
650
|
-
# Ensure we are on the same device as the model (CPU here)
|
|
651
|
-
device = next(self.parameters()).device
|
|
652
|
-
x_cont = x_tensor[:, self.numerical_indices].float().to(device)
|
|
653
|
-
x_cat = x_tensor[:, self.categorical_indices].long().to(device)
|
|
654
|
-
|
|
655
|
-
input_dict = {
|
|
656
|
-
'continuous': x_cont,
|
|
657
|
-
'categorical': x_cat
|
|
658
|
-
}
|
|
659
|
-
|
|
660
|
-
# --- MOCK DATA MODULE ---
|
|
661
|
-
# datamodule.train_dataloader() -> yields the batch
|
|
662
|
-
class _MockDataModule:
|
|
663
|
-
def train_dataloader(self, batch_size=None):
|
|
664
|
-
# Accepts 'batch_size' argument to satisfy the caller
|
|
665
|
-
# Returns a list containing just the single pre-processed batch dictionary
|
|
666
|
-
return [input_dict]
|
|
667
|
-
|
|
668
|
-
mock_dm = _MockDataModule()
|
|
669
|
-
|
|
670
|
-
_LOGGER.info(f"Running NODE Data-Aware Initialization with {batch_size} samples...")
|
|
671
|
-
try:
|
|
672
|
-
with torch.no_grad():
|
|
673
|
-
# Call init on the BACKBONE, not the wrapper
|
|
674
|
-
self.internal_model.data_aware_initialization(mock_dm)
|
|
675
|
-
_LOGGER.info("NODE Initialization Complete. Ready to train.")
|
|
676
|
-
except Exception as e:
|
|
677
|
-
_LOGGER.error(f"Failed to initialize NODE model: {e}")
|
|
678
|
-
raise e
|
|
679
|
-
|
|
680
|
-
def __repr__(self) -> str:
|
|
681
|
-
return (f"{self.model_name}(\n"
|
|
682
|
-
f" out_targets={self.out_targets},\n"
|
|
683
|
-
f" embedding_dim={self.model_hparams.get('embedding_dim')},\n"
|
|
684
|
-
f" num_trees={self.model_hparams.get('num_trees')},\n"
|
|
685
|
-
f" num_layers={self.model_hparams.get('num_layers')},\n"
|
|
686
|
-
f" tree_depth={self.model_hparams.get('tree_depth')},\n"
|
|
687
|
-
f" dropout={self.model_hparams.get('dropout')}\n"
|
|
688
|
-
f" backend_function={self.model_hparams.get('backend_function')}\n"
|
|
689
|
-
f")")
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
def info():
|
|
693
|
-
_script_info(__all__)
|