dragon-ml-toolbox 20.2.0__py3-none-any.whl → 20.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/METADATA +1 -1
- dragon_ml_toolbox-20.3.0.dist-info/RECORD +143 -0
- ml_tools/ETL_cleaning/__init__.py +5 -1
- ml_tools/ETL_cleaning/_basic_clean.py +1 -1
- ml_tools/ETL_engineering/__init__.py +5 -1
- ml_tools/GUI_tools/__init__.py +5 -1
- ml_tools/IO_tools/_IO_loggers.py +12 -4
- ml_tools/IO_tools/__init__.py +5 -1
- ml_tools/MICE/__init__.py +8 -2
- ml_tools/MICE/_dragon_mice.py +1 -1
- ml_tools/ML_callbacks/__init__.py +5 -1
- ml_tools/ML_chain/__init__.py +5 -1
- ml_tools/ML_configuration/__init__.py +7 -1
- ml_tools/ML_configuration/_training.py +65 -1
- ml_tools/ML_datasetmaster/__init__.py +5 -1
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +31 -20
- ml_tools/ML_datasetmaster/_datasetmaster.py +26 -9
- ml_tools/ML_datasetmaster/_sequence_datasetmaster.py +38 -23
- ml_tools/ML_evaluation/__init__.py +5 -1
- ml_tools/ML_evaluation_captum/__init__.py +5 -1
- ml_tools/ML_finalize_handler/__init__.py +5 -1
- ml_tools/ML_inference/__init__.py +5 -1
- ml_tools/ML_inference_sequence/__init__.py +5 -1
- ml_tools/ML_inference_vision/__init__.py +5 -1
- ml_tools/ML_models/__init__.py +21 -6
- ml_tools/ML_models/_dragon_autoint.py +302 -0
- ml_tools/ML_models/_dragon_gate.py +358 -0
- ml_tools/ML_models/_dragon_node.py +268 -0
- ml_tools/ML_models/_dragon_tabnet.py +255 -0
- ml_tools/ML_models_sequence/__init__.py +5 -1
- ml_tools/ML_models_vision/__init__.py +5 -1
- ml_tools/ML_optimization/__init__.py +11 -3
- ml_tools/ML_optimization/_multi_dragon.py +2 -2
- ml_tools/ML_optimization/_single_dragon.py +47 -67
- ml_tools/ML_optimization/_single_manual.py +1 -1
- ml_tools/ML_scaler/_ML_scaler.py +12 -7
- ml_tools/ML_scaler/__init__.py +5 -1
- ml_tools/ML_trainer/__init__.py +5 -1
- ml_tools/ML_trainer/_base_trainer.py +136 -13
- ml_tools/ML_trainer/_dragon_detection_trainer.py +31 -91
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +24 -74
- ml_tools/ML_trainer/_dragon_trainer.py +24 -85
- ml_tools/ML_utilities/__init__.py +5 -1
- ml_tools/ML_utilities/_inspection.py +44 -30
- ml_tools/ML_vision_transformers/__init__.py +8 -2
- ml_tools/PSO_optimization/__init__.py +5 -1
- ml_tools/SQL/__init__.py +8 -2
- ml_tools/VIF/__init__.py +5 -1
- ml_tools/data_exploration/__init__.py +4 -1
- ml_tools/data_exploration/_cleaning.py +4 -2
- ml_tools/ensemble_evaluation/__init__.py +5 -1
- ml_tools/ensemble_inference/__init__.py +5 -1
- ml_tools/ensemble_learning/__init__.py +5 -1
- ml_tools/excel_handler/__init__.py +5 -1
- ml_tools/keys/__init__.py +5 -1
- ml_tools/math_utilities/__init__.py +5 -1
- ml_tools/optimization_tools/__init__.py +5 -1
- ml_tools/path_manager/__init__.py +8 -2
- ml_tools/plot_fonts/__init__.py +8 -2
- ml_tools/schema/__init__.py +8 -2
- ml_tools/schema/_feature_schema.py +3 -3
- ml_tools/serde/__init__.py +5 -1
- ml_tools/utilities/__init__.py +5 -1
- ml_tools/utilities/_utility_save_load.py +38 -20
- dragon_ml_toolbox-20.2.0.dist-info/RECORD +0 -179
- ml_tools/ETL_cleaning/_imprimir.py +0 -13
- ml_tools/ETL_engineering/_imprimir.py +0 -24
- ml_tools/GUI_tools/_imprimir.py +0 -12
- ml_tools/IO_tools/_imprimir.py +0 -14
- ml_tools/MICE/_imprimir.py +0 -11
- ml_tools/ML_callbacks/_imprimir.py +0 -12
- ml_tools/ML_chain/_imprimir.py +0 -12
- ml_tools/ML_configuration/_imprimir.py +0 -47
- ml_tools/ML_datasetmaster/_imprimir.py +0 -15
- ml_tools/ML_evaluation/_imprimir.py +0 -25
- ml_tools/ML_evaluation_captum/_imprimir.py +0 -10
- ml_tools/ML_finalize_handler/_imprimir.py +0 -8
- ml_tools/ML_inference/_imprimir.py +0 -11
- ml_tools/ML_inference_sequence/_imprimir.py +0 -8
- ml_tools/ML_inference_vision/_imprimir.py +0 -8
- ml_tools/ML_models/_advanced_models.py +0 -1086
- ml_tools/ML_models/_imprimir.py +0 -18
- ml_tools/ML_models_sequence/_imprimir.py +0 -8
- ml_tools/ML_models_vision/_imprimir.py +0 -16
- ml_tools/ML_optimization/_imprimir.py +0 -13
- ml_tools/ML_scaler/_imprimir.py +0 -8
- ml_tools/ML_trainer/_imprimir.py +0 -10
- ml_tools/ML_utilities/_imprimir.py +0 -16
- ml_tools/ML_vision_transformers/_imprimir.py +0 -14
- ml_tools/PSO_optimization/_imprimir.py +0 -10
- ml_tools/SQL/_imprimir.py +0 -8
- ml_tools/VIF/_imprimir.py +0 -10
- ml_tools/data_exploration/_imprimir.py +0 -32
- ml_tools/ensemble_evaluation/_imprimir.py +0 -14
- ml_tools/ensemble_inference/_imprimir.py +0 -9
- ml_tools/ensemble_learning/_imprimir.py +0 -10
- ml_tools/excel_handler/_imprimir.py +0 -13
- ml_tools/keys/_imprimir.py +0 -11
- ml_tools/math_utilities/_imprimir.py +0 -11
- ml_tools/optimization_tools/_imprimir.py +0 -13
- ml_tools/path_manager/_imprimir.py +0 -15
- ml_tools/plot_fonts/_imprimir.py +0 -8
- ml_tools/schema/_imprimir.py +0 -10
- ml_tools/serde/_imprimir.py +0 -10
- ml_tools/utilities/_imprimir.py +0 -18
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from ..schema import FeatureSchema
|
|
7
|
+
from .._core import get_logger
|
|
8
|
+
from ..keys._keys import SchemaKeys
|
|
9
|
+
|
|
10
|
+
from ._base_save_load import _ArchitectureBuilder
|
|
11
|
+
from ._models_advanced_helpers import (
|
|
12
|
+
Embedding2dLayer,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_LOGGER = get_logger("DragonAutoInt")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"DragonAutoInt",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
# SOURCE CODE: Adapted and modified from:
|
|
24
|
+
# https://github.com/manujosephv/pytorch_tabular/blob/main/LICENSE
|
|
25
|
+
# https://github.com/Qwicen/node/blob/master/LICENSE.md
|
|
26
|
+
# https://github.com/jrzaurin/pytorch-widedeep?tab=readme-ov-file#license
|
|
27
|
+
# https://github.com/rixwew/pytorch-fm/blob/master/LICENSE
|
|
28
|
+
# https://arxiv.org/abs/1705.08741v2
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DragonAutoInt(_ArchitectureBuilder):
|
|
32
|
+
"""
|
|
33
|
+
Native implementation of AutoInt (Automatic Feature Interaction Learning).
|
|
34
|
+
|
|
35
|
+
Maps categorical and continuous features into a shared embedding space,
|
|
36
|
+
then uses Multi-Head Self-Attention to learn high-order feature interactions.
|
|
37
|
+
"""
|
|
38
|
+
def __init__(self, *,
|
|
39
|
+
schema: FeatureSchema,
|
|
40
|
+
out_targets: int,
|
|
41
|
+
embedding_dim: int = 32,
|
|
42
|
+
attn_embed_dim: int = 32,
|
|
43
|
+
num_heads: int = 2,
|
|
44
|
+
num_attn_blocks: int = 3,
|
|
45
|
+
attn_dropout: float = 0.1,
|
|
46
|
+
has_residuals: bool = True,
|
|
47
|
+
attention_pooling: bool = True,
|
|
48
|
+
deep_layers: bool = True,
|
|
49
|
+
layers: str = "128-64-32",
|
|
50
|
+
activation: str = "ReLU",
|
|
51
|
+
embedding_dropout: float = 0.0,
|
|
52
|
+
batch_norm_continuous: bool = False):
|
|
53
|
+
"""
|
|
54
|
+
Args:
|
|
55
|
+
schema (FeatureSchema):
|
|
56
|
+
Schema object containing feature names and types.
|
|
57
|
+
out_targets (int):
|
|
58
|
+
Number of output targets.
|
|
59
|
+
embedding_dim (int, optional):
|
|
60
|
+
Initial embedding dimension for features.
|
|
61
|
+
Suggested: 16 to 64.
|
|
62
|
+
attn_embed_dim (int, optional):
|
|
63
|
+
Projection dimension for the attention mechanism.
|
|
64
|
+
Suggested: 16 to 64.
|
|
65
|
+
num_heads (int, optional):
|
|
66
|
+
Number of attention heads.
|
|
67
|
+
Suggested: 2 to 8.
|
|
68
|
+
num_attn_blocks (int, optional):
|
|
69
|
+
Number of self-attention layers (depth of interaction learning).
|
|
70
|
+
Suggested: 2 to 5.
|
|
71
|
+
attn_dropout (float, optional):
|
|
72
|
+
Dropout rate within the attention blocks.
|
|
73
|
+
Suggested: 0.0 to 0.2.
|
|
74
|
+
has_residuals (bool, optional):
|
|
75
|
+
If True, adds residual connections (ResNet style) to attention blocks.
|
|
76
|
+
attention_pooling (bool, optional):
|
|
77
|
+
If True, concatenates outputs of all attention blocks (DenseNet style).
|
|
78
|
+
If False, uses only the output of the last block.
|
|
79
|
+
deep_layers (bool, optional):
|
|
80
|
+
If True, adds a standard MLP (Deep Layers) before the attention mechanism
|
|
81
|
+
to process features initially.
|
|
82
|
+
layers (str, optional):
|
|
83
|
+
Hyphen-separated string for MLP layer sizes if deep_layers is True.
|
|
84
|
+
activation (str, optional):
|
|
85
|
+
Activation function for the MLP layers.
|
|
86
|
+
embedding_dropout (float, optional):
|
|
87
|
+
Dropout applied to the initial feature embeddings.
|
|
88
|
+
batch_norm_continuous (bool, optional):
|
|
89
|
+
If True, applies Batch Normalization to continuous features.
|
|
90
|
+
"""
|
|
91
|
+
super().__init__()
|
|
92
|
+
self.schema = schema
|
|
93
|
+
self.out_targets = out_targets
|
|
94
|
+
|
|
95
|
+
self.model_hparams = {
|
|
96
|
+
'embedding_dim': embedding_dim,
|
|
97
|
+
'attn_embed_dim': attn_embed_dim,
|
|
98
|
+
'num_heads': num_heads,
|
|
99
|
+
'num_attn_blocks': num_attn_blocks,
|
|
100
|
+
'attn_dropout': attn_dropout,
|
|
101
|
+
'has_residuals': has_residuals,
|
|
102
|
+
'attention_pooling': attention_pooling,
|
|
103
|
+
'deep_layers': deep_layers,
|
|
104
|
+
'layers': layers,
|
|
105
|
+
'activation': activation,
|
|
106
|
+
'embedding_dropout': embedding_dropout,
|
|
107
|
+
'batch_norm_continuous': batch_norm_continuous
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# -- 1. Setup Embeddings --
|
|
111
|
+
self.categorical_indices = []
|
|
112
|
+
self.cardinalities = []
|
|
113
|
+
if schema.categorical_index_map:
|
|
114
|
+
self.categorical_indices = list(schema.categorical_index_map.keys())
|
|
115
|
+
self.cardinalities = list(schema.categorical_index_map.values())
|
|
116
|
+
|
|
117
|
+
all_indices = set(range(len(schema.feature_names)))
|
|
118
|
+
self.numerical_indices = sorted(list(all_indices - set(self.categorical_indices)))
|
|
119
|
+
n_continuous = len(self.numerical_indices)
|
|
120
|
+
|
|
121
|
+
self.embedding_layer = Embedding2dLayer(
|
|
122
|
+
continuous_dim=n_continuous,
|
|
123
|
+
categorical_cardinality=self.cardinalities,
|
|
124
|
+
embedding_dim=embedding_dim,
|
|
125
|
+
embedding_dropout=embedding_dropout,
|
|
126
|
+
batch_norm_continuous_input=batch_norm_continuous
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# -- 2. Deep Layers (Optional MLP) --
|
|
130
|
+
curr_units = embedding_dim
|
|
131
|
+
self.deep_layers_mod = None
|
|
132
|
+
|
|
133
|
+
if deep_layers:
|
|
134
|
+
layers_list = []
|
|
135
|
+
layer_sizes = [int(x) for x in layers.split("-")]
|
|
136
|
+
activation_fn = getattr(nn, activation, nn.ReLU)
|
|
137
|
+
|
|
138
|
+
for units in layer_sizes:
|
|
139
|
+
layers_list.append(nn.Linear(curr_units, units))
|
|
140
|
+
|
|
141
|
+
# Changed BatchNorm1d to LayerNorm to handle (Batch, Tokens, Embed) shape correctly
|
|
142
|
+
layers_list.append(nn.LayerNorm(units))
|
|
143
|
+
|
|
144
|
+
layers_list.append(activation_fn())
|
|
145
|
+
layers_list.append(nn.Dropout(embedding_dropout))
|
|
146
|
+
curr_units = units
|
|
147
|
+
|
|
148
|
+
self.deep_layers_mod = nn.Sequential(*layers_list)
|
|
149
|
+
|
|
150
|
+
# -- 3. Attention Backbone --
|
|
151
|
+
self.attn_proj = nn.Linear(curr_units, attn_embed_dim)
|
|
152
|
+
|
|
153
|
+
self.self_attns = nn.ModuleList([
|
|
154
|
+
nn.MultiheadAttention(
|
|
155
|
+
embed_dim=attn_embed_dim,
|
|
156
|
+
num_heads=num_heads,
|
|
157
|
+
dropout=attn_dropout
|
|
158
|
+
)
|
|
159
|
+
for _ in range(num_attn_blocks)
|
|
160
|
+
])
|
|
161
|
+
|
|
162
|
+
# Residuals
|
|
163
|
+
self.has_residuals = has_residuals
|
|
164
|
+
self.attention_pooling = attention_pooling
|
|
165
|
+
|
|
166
|
+
if has_residuals:
|
|
167
|
+
# If pooling, we project input to match the concatenated output size
|
|
168
|
+
# If not pooling, we project input to match the single block output size
|
|
169
|
+
res_dim = attn_embed_dim * num_attn_blocks if attention_pooling else attn_embed_dim
|
|
170
|
+
self.V_res_embedding = nn.Linear(curr_units, res_dim)
|
|
171
|
+
|
|
172
|
+
# -- 4. Output Dimension Calculation --
|
|
173
|
+
num_features = n_continuous + len(self.cardinalities)
|
|
174
|
+
|
|
175
|
+
# Output is flattened: (Num_Features * Attn_Dim)
|
|
176
|
+
final_dim = num_features * attn_embed_dim
|
|
177
|
+
if attention_pooling:
|
|
178
|
+
final_dim = final_dim * num_attn_blocks
|
|
179
|
+
|
|
180
|
+
self.output_dim = final_dim
|
|
181
|
+
self.head = nn.Linear(final_dim, out_targets)
|
|
182
|
+
|
|
183
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
184
|
+
x_cont = x[:, self.numerical_indices].float()
|
|
185
|
+
x_cat = x[:, self.categorical_indices].long()
|
|
186
|
+
|
|
187
|
+
# 1. Embed -> (Batch, Num_Features, Embed_Dim)
|
|
188
|
+
x = self.embedding_layer(x_cont, x_cat)
|
|
189
|
+
|
|
190
|
+
# 2. Deep Layers
|
|
191
|
+
if self.deep_layers_mod:
|
|
192
|
+
x = self.deep_layers_mod(x)
|
|
193
|
+
|
|
194
|
+
# 3. Attention Projection -> (Batch, Num_Features, Attn_Dim)
|
|
195
|
+
cross_term = self.attn_proj(x)
|
|
196
|
+
|
|
197
|
+
# Transpose for MultiheadAttention (Seq, Batch, Embed)
|
|
198
|
+
cross_term = cross_term.transpose(0, 1)
|
|
199
|
+
|
|
200
|
+
attention_ops = []
|
|
201
|
+
for self_attn in self.self_attns:
|
|
202
|
+
# Self Attention: Query=Key=Value=cross_term
|
|
203
|
+
# Output: (Seq, Batch, Embed)
|
|
204
|
+
out, _ = self_attn(cross_term, cross_term, cross_term)
|
|
205
|
+
cross_term = out # Sequential connection
|
|
206
|
+
if self.attention_pooling:
|
|
207
|
+
attention_ops.append(out)
|
|
208
|
+
|
|
209
|
+
if self.attention_pooling:
|
|
210
|
+
# Concatenate all attention outputs along the embedding dimension
|
|
211
|
+
cross_term = torch.cat(attention_ops, dim=-1)
|
|
212
|
+
|
|
213
|
+
# Transpose back -> (Batch, Num_Features, Final_Attn_Dim)
|
|
214
|
+
cross_term = cross_term.transpose(0, 1)
|
|
215
|
+
|
|
216
|
+
# 4. Residual Connection
|
|
217
|
+
if self.has_residuals:
|
|
218
|
+
V_res = self.V_res_embedding(x)
|
|
219
|
+
cross_term = cross_term + V_res
|
|
220
|
+
|
|
221
|
+
# 5. Flatten and Head
|
|
222
|
+
# ReLU before flattening as per original implementation
|
|
223
|
+
cross_term = F.relu(cross_term)
|
|
224
|
+
cross_term = cross_term.reshape(cross_term.size(0), -1)
|
|
225
|
+
|
|
226
|
+
return self.head(cross_term)
|
|
227
|
+
|
|
228
|
+
def data_aware_initialization(self, train_dataset, num_samples: int = 2000, verbose: int = 3):
|
|
229
|
+
"""
|
|
230
|
+
Performs data-aware initialization for the final head bias.
|
|
231
|
+
"""
|
|
232
|
+
# 1. Prepare Data
|
|
233
|
+
if verbose >= 2:
|
|
234
|
+
_LOGGER.info(f"Performing AutoInt data-aware initialization on up to {num_samples} samples...")
|
|
235
|
+
device = next(self.parameters()).device
|
|
236
|
+
|
|
237
|
+
# 2. Extract Targets
|
|
238
|
+
if hasattr(train_dataset, "labels") and isinstance(train_dataset.labels, torch.Tensor):
|
|
239
|
+
limit = min(len(train_dataset.labels), num_samples)
|
|
240
|
+
targets = train_dataset.labels[:limit]
|
|
241
|
+
else:
|
|
242
|
+
indices = range(min(len(train_dataset), num_samples))
|
|
243
|
+
y_accum = []
|
|
244
|
+
for i in indices:
|
|
245
|
+
sample = train_dataset[i]
|
|
246
|
+
# Handle tuple (X, y) or dict
|
|
247
|
+
if isinstance(sample, (tuple, list)) and len(sample) >= 2:
|
|
248
|
+
y_val = sample[1]
|
|
249
|
+
elif isinstance(sample, dict):
|
|
250
|
+
y_val = sample.get('target', sample.get('y', None))
|
|
251
|
+
else:
|
|
252
|
+
y_val = None
|
|
253
|
+
|
|
254
|
+
if y_val is not None:
|
|
255
|
+
if not isinstance(y_val, torch.Tensor):
|
|
256
|
+
y_val = torch.tensor(y_val)
|
|
257
|
+
y_accum.append(y_val)
|
|
258
|
+
|
|
259
|
+
if not y_accum:
|
|
260
|
+
if verbose >= 1:
|
|
261
|
+
_LOGGER.warning("Could not extract targets for AutoInt initialization. Skipping.")
|
|
262
|
+
return
|
|
263
|
+
|
|
264
|
+
targets = torch.stack(y_accum)
|
|
265
|
+
|
|
266
|
+
targets = targets.to(device).float()
|
|
267
|
+
|
|
268
|
+
# 3. Initialize Head Bias
|
|
269
|
+
with torch.no_grad():
|
|
270
|
+
mean_target = torch.mean(targets, dim=0)
|
|
271
|
+
if hasattr(self.head, 'bias') and self.head.bias is not None:
|
|
272
|
+
if self.head.bias.shape == mean_target.shape:
|
|
273
|
+
self.head.bias.data = mean_target
|
|
274
|
+
if verbose >= 2:
|
|
275
|
+
_LOGGER.info("AutoInt Initialization Complete. Ready to train.")
|
|
276
|
+
_LOGGER.debug(f"Initialized AutoInt head bias to {mean_target.cpu().numpy()}")
|
|
277
|
+
elif self.head.bias.numel() == 1 and mean_target.numel() == 1:
|
|
278
|
+
self.head.bias.data = mean_target.view(self.head.bias.shape)
|
|
279
|
+
if verbose >= 2:
|
|
280
|
+
_LOGGER.info("AutoInt Initialization Complete. Ready to train.")
|
|
281
|
+
_LOGGER.debug(f"Initialized AutoInt head bias to {mean_target.item()}")
|
|
282
|
+
else:
|
|
283
|
+
if verbose >= 1:
|
|
284
|
+
_LOGGER.warning("AutoInt Head does not have a bias parameter. Skipping initialization.")
|
|
285
|
+
|
|
286
|
+
def get_architecture_config(self) -> dict[str, Any]:
|
|
287
|
+
"""Returns the full configuration of the model."""
|
|
288
|
+
schema_dict = {
|
|
289
|
+
'feature_names': self.schema.feature_names,
|
|
290
|
+
'continuous_feature_names': self.schema.continuous_feature_names,
|
|
291
|
+
'categorical_feature_names': self.schema.categorical_feature_names,
|
|
292
|
+
'categorical_index_map': self.schema.categorical_index_map,
|
|
293
|
+
'categorical_mappings': self.schema.categorical_mappings
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
config = {
|
|
297
|
+
SchemaKeys.SCHEMA_DICT: schema_dict,
|
|
298
|
+
'out_targets': self.out_targets,
|
|
299
|
+
**self.model_hparams
|
|
300
|
+
}
|
|
301
|
+
return config
|
|
302
|
+
|
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from ..schema import FeatureSchema
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
from ..keys._keys import SchemaKeys
|
|
8
|
+
|
|
9
|
+
from ._base_save_load import _ArchitectureBuilder
|
|
10
|
+
from ._models_advanced_helpers import (
|
|
11
|
+
Embedding1dLayer,
|
|
12
|
+
GatedFeatureLearningUnit,
|
|
13
|
+
NeuralDecisionTree,
|
|
14
|
+
entmax15,
|
|
15
|
+
entmoid15,
|
|
16
|
+
sparsemax,
|
|
17
|
+
sparsemoid,
|
|
18
|
+
t_softmax,
|
|
19
|
+
SimpleLinearHead,
|
|
20
|
+
_GateHead
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
_LOGGER = get_logger("DragonGateModel")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"DragonGateModel",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
# SOURCE CODE: Adapted and modified from:
|
|
32
|
+
# https://github.com/manujosephv/pytorch_tabular/blob/main/LICENSE
|
|
33
|
+
# https://github.com/Qwicen/node/blob/master/LICENSE.md
|
|
34
|
+
# https://github.com/jrzaurin/pytorch-widedeep?tab=readme-ov-file#license
|
|
35
|
+
# https://github.com/rixwew/pytorch-fm/blob/master/LICENSE
|
|
36
|
+
# https://arxiv.org/abs/1705.08741v2
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DragonGateModel(_ArchitectureBuilder):
|
|
40
|
+
"""
|
|
41
|
+
Native implementation of the Gated Additive Tree Ensemble (GATE).
|
|
42
|
+
|
|
43
|
+
Combines Gated Feature Learning Units (GFLU) for feature interaction learning
|
|
44
|
+
with Differentiable Decision Trees for prediction.
|
|
45
|
+
"""
|
|
46
|
+
ACTIVATION_MAP = {
|
|
47
|
+
"entmax": entmax15,
|
|
48
|
+
"sparsemax": sparsemax,
|
|
49
|
+
"softmax": lambda x: nn.functional.softmax(x, dim=-1),
|
|
50
|
+
"t-softmax": t_softmax,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
BINARY_ACTIVATION_MAP = {
|
|
54
|
+
"entmoid": entmoid15,
|
|
55
|
+
"sparsemoid": sparsemoid,
|
|
56
|
+
"sigmoid": torch.sigmoid,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def __init__(self, *,
|
|
60
|
+
schema: FeatureSchema,
|
|
61
|
+
out_targets: int,
|
|
62
|
+
embedding_dim: int = 16,
|
|
63
|
+
gflu_stages: int = 6,
|
|
64
|
+
gflu_dropout: float = 0.1,
|
|
65
|
+
num_trees: int = 20,
|
|
66
|
+
tree_depth: int = 4,
|
|
67
|
+
tree_dropout: float = 0.1,
|
|
68
|
+
chain_trees: bool = False,
|
|
69
|
+
tree_wise_attention: bool = True,
|
|
70
|
+
tree_wise_attention_dropout: float = 0.1,
|
|
71
|
+
binning_activation: Literal['entmoid', 'sparsemoid', 'sigmoid'] = "entmoid",
|
|
72
|
+
feature_mask_function: Literal['entmax', 'sparsemax', 'softmax', 't-softmax'] = "entmax",
|
|
73
|
+
share_head_weights: bool = True,
|
|
74
|
+
batch_norm_continuous: bool = True):
|
|
75
|
+
"""
|
|
76
|
+
Args:
|
|
77
|
+
schema (FeatureSchema):
|
|
78
|
+
Schema object containing feature names and types.
|
|
79
|
+
out_targets (int):
|
|
80
|
+
Number of output targets (e.g., 1 for regression/binary, N for multi-class).
|
|
81
|
+
embedding_dim (int, optional):
|
|
82
|
+
Embedding dimension for categorical features.
|
|
83
|
+
Suggested: 8 to 64.
|
|
84
|
+
gflu_stages (int, optional):
|
|
85
|
+
Number of Gated Feature Learning Unit stages in the backbone.
|
|
86
|
+
Higher values allow learning deeper feature interactions.
|
|
87
|
+
Suggested: 2 to 10.
|
|
88
|
+
gflu_dropout (float, optional):
|
|
89
|
+
Dropout rate applied within GFLU stages.
|
|
90
|
+
Suggested: 0.0 to 0.3.
|
|
91
|
+
num_trees (int, optional):
|
|
92
|
+
Number of Neural Decision Trees to use in the ensemble.
|
|
93
|
+
Suggested: 10 to 50.
|
|
94
|
+
tree_depth (int, optional):
|
|
95
|
+
Depth of the decision trees. Deeper trees capture more complex logic
|
|
96
|
+
but may overfit.
|
|
97
|
+
Suggested: 3 to 6.
|
|
98
|
+
tree_dropout (float, optional):
|
|
99
|
+
Dropout rate applied to the tree leaves.
|
|
100
|
+
Suggested: 0.1 to 0.3.
|
|
101
|
+
chain_trees (bool, optional):
|
|
102
|
+
If True, feeds the output of tree T-1 into tree T (Boosting-style).
|
|
103
|
+
If False, trees run in parallel (Bagging-style).
|
|
104
|
+
tree_wise_attention (bool, optional):
|
|
105
|
+
If True, applies Self-Attention across the outputs of the trees
|
|
106
|
+
to weigh their contributions dynamically.
|
|
107
|
+
tree_wise_attention_dropout (float, optional):
|
|
108
|
+
Dropout rate for the tree-wise attention mechanism.
|
|
109
|
+
Suggested: 0.1.
|
|
110
|
+
binning_activation (str, optional):
|
|
111
|
+
Activation function for the soft binning in trees.
|
|
112
|
+
Options: 'entmoid' (sparse), 'sparsemoid', 'sigmoid'.
|
|
113
|
+
feature_mask_function (str, optional):
|
|
114
|
+
Activation function for feature selection/masking.
|
|
115
|
+
Options: 'entmax' (sparse), 'sparsemax', 'softmax', 't-softmax'.
|
|
116
|
+
share_head_weights (bool, optional):
|
|
117
|
+
If True, all trees share the same linear projection head weights.
|
|
118
|
+
batch_norm_continuous (bool, optional):
|
|
119
|
+
If True, applies Batch Normalization to continuous features before embedding.
|
|
120
|
+
"""
|
|
121
|
+
super().__init__()
|
|
122
|
+
self.schema = schema
|
|
123
|
+
self.out_targets = out_targets
|
|
124
|
+
|
|
125
|
+
# -- Configuration for saving --
|
|
126
|
+
self.model_hparams = {
|
|
127
|
+
'embedding_dim': embedding_dim,
|
|
128
|
+
'gflu_stages': gflu_stages,
|
|
129
|
+
'gflu_dropout': gflu_dropout,
|
|
130
|
+
'num_trees': num_trees,
|
|
131
|
+
'tree_depth': tree_depth,
|
|
132
|
+
'tree_dropout': tree_dropout,
|
|
133
|
+
'chain_trees': chain_trees,
|
|
134
|
+
'tree_wise_attention': tree_wise_attention,
|
|
135
|
+
'tree_wise_attention_dropout': tree_wise_attention_dropout,
|
|
136
|
+
'binning_activation': binning_activation,
|
|
137
|
+
'feature_mask_function': feature_mask_function,
|
|
138
|
+
'share_head_weights': share_head_weights,
|
|
139
|
+
'batch_norm_continuous': batch_norm_continuous
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
# -- 1. Setup Data Processing --
|
|
143
|
+
self.categorical_indices = []
|
|
144
|
+
self.cardinalities = []
|
|
145
|
+
if schema.categorical_index_map:
|
|
146
|
+
self.categorical_indices = list(schema.categorical_index_map.keys())
|
|
147
|
+
self.cardinalities = list(schema.categorical_index_map.values())
|
|
148
|
+
|
|
149
|
+
all_indices = set(range(len(schema.feature_names)))
|
|
150
|
+
self.numerical_indices = sorted(list(all_indices - set(self.categorical_indices)))
|
|
151
|
+
|
|
152
|
+
embedding_dims = [(c, embedding_dim) for c in self.cardinalities]
|
|
153
|
+
n_continuous = len(self.numerical_indices)
|
|
154
|
+
|
|
155
|
+
# -- 2. Embedding Layer --
|
|
156
|
+
self.embedding_layer = Embedding1dLayer(
|
|
157
|
+
continuous_dim=n_continuous,
|
|
158
|
+
categorical_embedding_dims=embedding_dims,
|
|
159
|
+
batch_norm_continuous_input=batch_norm_continuous
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Calculate total feature dimension after embedding
|
|
163
|
+
total_embedded_cat_dim = sum([d for _, d in embedding_dims])
|
|
164
|
+
self.n_features = n_continuous + total_embedded_cat_dim
|
|
165
|
+
|
|
166
|
+
# -- 3. GFLU Backbone --
|
|
167
|
+
self.gflu_stages = gflu_stages
|
|
168
|
+
if gflu_stages > 0:
|
|
169
|
+
self.gflus = GatedFeatureLearningUnit(
|
|
170
|
+
n_features_in=self.n_features,
|
|
171
|
+
n_stages=gflu_stages,
|
|
172
|
+
feature_mask_function=self.ACTIVATION_MAP[feature_mask_function],
|
|
173
|
+
dropout=gflu_dropout,
|
|
174
|
+
feature_sparsity=0.3, # Standard default
|
|
175
|
+
learnable_sparsity=True
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# -- 4. Neural Decision Trees --
|
|
179
|
+
self.num_trees = num_trees
|
|
180
|
+
self.chain_trees = chain_trees
|
|
181
|
+
self.tree_depth = tree_depth
|
|
182
|
+
|
|
183
|
+
if num_trees > 0:
|
|
184
|
+
# Calculate input dim for trees (chaining adds to input)
|
|
185
|
+
tree_input_dim = self.n_features
|
|
186
|
+
|
|
187
|
+
self.trees = nn.ModuleList()
|
|
188
|
+
for _ in range(num_trees):
|
|
189
|
+
tree = NeuralDecisionTree(
|
|
190
|
+
depth=tree_depth,
|
|
191
|
+
n_features=tree_input_dim,
|
|
192
|
+
dropout=tree_dropout,
|
|
193
|
+
binning_activation=self.BINARY_ACTIVATION_MAP[binning_activation],
|
|
194
|
+
feature_mask_function=self.ACTIVATION_MAP[feature_mask_function],
|
|
195
|
+
)
|
|
196
|
+
self.trees.append(tree)
|
|
197
|
+
if chain_trees:
|
|
198
|
+
# Next tree sees original features + output of this tree (2^depth leaves)
|
|
199
|
+
tree_input_dim += 2**tree_depth
|
|
200
|
+
|
|
201
|
+
self.tree_output_dim = 2**tree_depth
|
|
202
|
+
|
|
203
|
+
# Optional: Tree-wise Attention
|
|
204
|
+
self.tree_wise_attention = tree_wise_attention
|
|
205
|
+
if tree_wise_attention:
|
|
206
|
+
self.tree_attention = nn.MultiheadAttention(
|
|
207
|
+
embed_dim=self.tree_output_dim,
|
|
208
|
+
num_heads=1,
|
|
209
|
+
batch_first=False, # (Seq, Batch, Feature) standard for PyTorch Attn
|
|
210
|
+
dropout=tree_wise_attention_dropout
|
|
211
|
+
)
|
|
212
|
+
else:
|
|
213
|
+
self.tree_output_dim = self.n_features
|
|
214
|
+
|
|
215
|
+
# -- 5. Prediction Head --
|
|
216
|
+
if num_trees > 0:
|
|
217
|
+
self.head = _GateHead(
|
|
218
|
+
input_dim=self.tree_output_dim,
|
|
219
|
+
output_dim=out_targets,
|
|
220
|
+
num_trees=num_trees,
|
|
221
|
+
share_head_weights=share_head_weights
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
# Fallback if no trees (just GFLU -> Linear)
|
|
225
|
+
self.head = SimpleLinearHead(self.n_features, out_targets)
|
|
226
|
+
# Add T0 manually for consistency if needed, but SimpleLinear covers bias
|
|
227
|
+
|
|
228
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
229
|
+
# Split inputs
|
|
230
|
+
x_cont = x[:, self.numerical_indices].float()
|
|
231
|
+
x_cat = x[:, self.categorical_indices].long()
|
|
232
|
+
|
|
233
|
+
# 1. Embeddings
|
|
234
|
+
x = self.embedding_layer(x_cont, x_cat)
|
|
235
|
+
|
|
236
|
+
# 2. GFLU
|
|
237
|
+
if self.gflu_stages > 0:
|
|
238
|
+
x = self.gflus(x)
|
|
239
|
+
|
|
240
|
+
# 3. Trees
|
|
241
|
+
if self.num_trees > 0:
|
|
242
|
+
tree_outputs = []
|
|
243
|
+
tree_input = x
|
|
244
|
+
|
|
245
|
+
for i in range(self.num_trees):
|
|
246
|
+
# Tree returns (leaf_nodes, feature_masks)
|
|
247
|
+
# leaf_nodes shape: (Batch, 2^depth)
|
|
248
|
+
leaf_nodes, _ = self.trees[i](tree_input)
|
|
249
|
+
|
|
250
|
+
tree_outputs.append(leaf_nodes.unsqueeze(-1))
|
|
251
|
+
|
|
252
|
+
if self.chain_trees:
|
|
253
|
+
tree_input = torch.cat([tree_input, leaf_nodes], dim=1)
|
|
254
|
+
|
|
255
|
+
# Stack: (Batch, Output_Dim_Tree, Num_Trees)
|
|
256
|
+
tree_outputs = torch.cat(tree_outputs, dim=-1)
|
|
257
|
+
|
|
258
|
+
# 4. Attention
|
|
259
|
+
if self.tree_wise_attention:
|
|
260
|
+
# Permute for MultiheadAttention: (Num_Trees, Batch, Output_Dim_Tree)
|
|
261
|
+
# Treating 'Trees' as the sequence length
|
|
262
|
+
attn_input = tree_outputs.permute(2, 0, 1)
|
|
263
|
+
|
|
264
|
+
attn_output, _ = self.tree_attention(attn_input, attn_input, attn_input)
|
|
265
|
+
|
|
266
|
+
# Permute back: (Batch, Output_Dim_Tree, Num_Trees)
|
|
267
|
+
tree_outputs = attn_output.permute(1, 2, 0)
|
|
268
|
+
|
|
269
|
+
# 5. Head
|
|
270
|
+
return self.head(tree_outputs)
|
|
271
|
+
|
|
272
|
+
else:
|
|
273
|
+
# No trees, just linear on top of GFLU
|
|
274
|
+
return self.head(x)
|
|
275
|
+
|
|
276
|
+
def data_aware_initialization(self, train_dataset, num_samples: int = 2000, verbose: int = 3):
|
|
277
|
+
"""
|
|
278
|
+
Performs data-aware initialization for the global bias T0.
|
|
279
|
+
This often speeds up convergence significantly.
|
|
280
|
+
"""
|
|
281
|
+
# 1. Prepare Data
|
|
282
|
+
if verbose >= 2:
|
|
283
|
+
_LOGGER.info(f"Performing GATE data-aware initialization on up to {num_samples} samples...")
|
|
284
|
+
device = next(self.parameters()).device
|
|
285
|
+
|
|
286
|
+
# 2. Extract Targets
|
|
287
|
+
# Fast path: direct tensor access (Works with DragonDataset/_PytorchDataset)
|
|
288
|
+
if hasattr(train_dataset, "labels") and isinstance(train_dataset.labels, torch.Tensor):
|
|
289
|
+
limit = min(len(train_dataset.labels), num_samples)
|
|
290
|
+
targets = train_dataset.labels[:limit]
|
|
291
|
+
else:
|
|
292
|
+
# Slow path: Iterate
|
|
293
|
+
indices = range(min(len(train_dataset), num_samples))
|
|
294
|
+
y_accum = []
|
|
295
|
+
for i in indices:
|
|
296
|
+
# Expecting (features, targets) tuple
|
|
297
|
+
sample = train_dataset[i]
|
|
298
|
+
if isinstance(sample, (tuple, list)) and len(sample) >= 2:
|
|
299
|
+
# Standard (X, y) tuple
|
|
300
|
+
y_val = sample[1]
|
|
301
|
+
elif isinstance(sample, dict):
|
|
302
|
+
# Try common keys
|
|
303
|
+
y_val = sample.get('target', sample.get('y', None))
|
|
304
|
+
else:
|
|
305
|
+
y_val = None
|
|
306
|
+
|
|
307
|
+
if y_val is not None:
|
|
308
|
+
# Ensure it's a tensor
|
|
309
|
+
if not isinstance(y_val, torch.Tensor):
|
|
310
|
+
y_val = torch.tensor(y_val)
|
|
311
|
+
y_accum.append(y_val)
|
|
312
|
+
|
|
313
|
+
if not y_accum:
|
|
314
|
+
if verbose >= 1:
|
|
315
|
+
_LOGGER.warning("Could not extract targets for T0 initialization. Skipping.")
|
|
316
|
+
return
|
|
317
|
+
|
|
318
|
+
targets = torch.stack(y_accum)
|
|
319
|
+
|
|
320
|
+
targets = targets.to(device).float()
|
|
321
|
+
|
|
322
|
+
with torch.no_grad():
|
|
323
|
+
if self.num_trees > 0:
|
|
324
|
+
# Initialize T0 to mean of targets
|
|
325
|
+
mean_target = torch.mean(targets, dim=0)
|
|
326
|
+
|
|
327
|
+
# Check shapes to avoid broadcasting errors
|
|
328
|
+
if self.head.T0.shape == mean_target.shape:
|
|
329
|
+
self.head.T0.data = mean_target
|
|
330
|
+
if verbose >= 2:
|
|
331
|
+
_LOGGER.info(f"GATE Initialization Complete. Ready to train.")
|
|
332
|
+
elif self.head.T0.numel() == 1 and mean_target.numel() == 1: # type: ignore
|
|
333
|
+
# scalar case
|
|
334
|
+
self.head.T0.data = mean_target.view(self.head.T0.shape) # type: ignore
|
|
335
|
+
if verbose >= 2:
|
|
336
|
+
_LOGGER.info("GATE Initialization Complete. Ready to train.")
|
|
337
|
+
else:
|
|
338
|
+
_LOGGER.debug(f"Target shape mismatch for T0 init. Model: {self.head.T0.shape}, Data: {mean_target.shape}")
|
|
339
|
+
if verbose >= 1:
|
|
340
|
+
_LOGGER.warning(f"GATE initialization skipped due to shape mismatch:\n Model: {self.head.T0.shape}\n Data: {mean_target.shape}")
|
|
341
|
+
|
|
342
|
+
def get_architecture_config(self) -> dict[str, Any]:
|
|
343
|
+
"""Returns the full configuration of the model."""
|
|
344
|
+
schema_dict = {
|
|
345
|
+
'feature_names': self.schema.feature_names,
|
|
346
|
+
'continuous_feature_names': self.schema.continuous_feature_names,
|
|
347
|
+
'categorical_feature_names': self.schema.categorical_feature_names,
|
|
348
|
+
'categorical_index_map': self.schema.categorical_index_map,
|
|
349
|
+
'categorical_mappings': self.schema.categorical_mappings
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
config = {
|
|
353
|
+
SchemaKeys.SCHEMA_DICT: schema_dict,
|
|
354
|
+
'out_targets': self.out_targets,
|
|
355
|
+
**self.model_hparams
|
|
356
|
+
}
|
|
357
|
+
return config
|
|
358
|
+
|