dragon-ml-toolbox 20.2.0__py3-none-any.whl → 20.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/METADATA +1 -1
  2. dragon_ml_toolbox-20.3.0.dist-info/RECORD +143 -0
  3. ml_tools/ETL_cleaning/__init__.py +5 -1
  4. ml_tools/ETL_cleaning/_basic_clean.py +1 -1
  5. ml_tools/ETL_engineering/__init__.py +5 -1
  6. ml_tools/GUI_tools/__init__.py +5 -1
  7. ml_tools/IO_tools/_IO_loggers.py +12 -4
  8. ml_tools/IO_tools/__init__.py +5 -1
  9. ml_tools/MICE/__init__.py +8 -2
  10. ml_tools/MICE/_dragon_mice.py +1 -1
  11. ml_tools/ML_callbacks/__init__.py +5 -1
  12. ml_tools/ML_chain/__init__.py +5 -1
  13. ml_tools/ML_configuration/__init__.py +7 -1
  14. ml_tools/ML_configuration/_training.py +65 -1
  15. ml_tools/ML_datasetmaster/__init__.py +5 -1
  16. ml_tools/ML_datasetmaster/_base_datasetmaster.py +31 -20
  17. ml_tools/ML_datasetmaster/_datasetmaster.py +26 -9
  18. ml_tools/ML_datasetmaster/_sequence_datasetmaster.py +38 -23
  19. ml_tools/ML_evaluation/__init__.py +5 -1
  20. ml_tools/ML_evaluation_captum/__init__.py +5 -1
  21. ml_tools/ML_finalize_handler/__init__.py +5 -1
  22. ml_tools/ML_inference/__init__.py +5 -1
  23. ml_tools/ML_inference_sequence/__init__.py +5 -1
  24. ml_tools/ML_inference_vision/__init__.py +5 -1
  25. ml_tools/ML_models/__init__.py +21 -6
  26. ml_tools/ML_models/_dragon_autoint.py +302 -0
  27. ml_tools/ML_models/_dragon_gate.py +358 -0
  28. ml_tools/ML_models/_dragon_node.py +268 -0
  29. ml_tools/ML_models/_dragon_tabnet.py +255 -0
  30. ml_tools/ML_models_sequence/__init__.py +5 -1
  31. ml_tools/ML_models_vision/__init__.py +5 -1
  32. ml_tools/ML_optimization/__init__.py +11 -3
  33. ml_tools/ML_optimization/_multi_dragon.py +2 -2
  34. ml_tools/ML_optimization/_single_dragon.py +47 -67
  35. ml_tools/ML_optimization/_single_manual.py +1 -1
  36. ml_tools/ML_scaler/_ML_scaler.py +12 -7
  37. ml_tools/ML_scaler/__init__.py +5 -1
  38. ml_tools/ML_trainer/__init__.py +5 -1
  39. ml_tools/ML_trainer/_base_trainer.py +136 -13
  40. ml_tools/ML_trainer/_dragon_detection_trainer.py +31 -91
  41. ml_tools/ML_trainer/_dragon_sequence_trainer.py +24 -74
  42. ml_tools/ML_trainer/_dragon_trainer.py +24 -85
  43. ml_tools/ML_utilities/__init__.py +5 -1
  44. ml_tools/ML_utilities/_inspection.py +44 -30
  45. ml_tools/ML_vision_transformers/__init__.py +8 -2
  46. ml_tools/PSO_optimization/__init__.py +5 -1
  47. ml_tools/SQL/__init__.py +8 -2
  48. ml_tools/VIF/__init__.py +5 -1
  49. ml_tools/data_exploration/__init__.py +4 -1
  50. ml_tools/data_exploration/_cleaning.py +4 -2
  51. ml_tools/ensemble_evaluation/__init__.py +5 -1
  52. ml_tools/ensemble_inference/__init__.py +5 -1
  53. ml_tools/ensemble_learning/__init__.py +5 -1
  54. ml_tools/excel_handler/__init__.py +5 -1
  55. ml_tools/keys/__init__.py +5 -1
  56. ml_tools/math_utilities/__init__.py +5 -1
  57. ml_tools/optimization_tools/__init__.py +5 -1
  58. ml_tools/path_manager/__init__.py +8 -2
  59. ml_tools/plot_fonts/__init__.py +8 -2
  60. ml_tools/schema/__init__.py +8 -2
  61. ml_tools/schema/_feature_schema.py +3 -3
  62. ml_tools/serde/__init__.py +5 -1
  63. ml_tools/utilities/__init__.py +5 -1
  64. ml_tools/utilities/_utility_save_load.py +38 -20
  65. dragon_ml_toolbox-20.2.0.dist-info/RECORD +0 -179
  66. ml_tools/ETL_cleaning/_imprimir.py +0 -13
  67. ml_tools/ETL_engineering/_imprimir.py +0 -24
  68. ml_tools/GUI_tools/_imprimir.py +0 -12
  69. ml_tools/IO_tools/_imprimir.py +0 -14
  70. ml_tools/MICE/_imprimir.py +0 -11
  71. ml_tools/ML_callbacks/_imprimir.py +0 -12
  72. ml_tools/ML_chain/_imprimir.py +0 -12
  73. ml_tools/ML_configuration/_imprimir.py +0 -47
  74. ml_tools/ML_datasetmaster/_imprimir.py +0 -15
  75. ml_tools/ML_evaluation/_imprimir.py +0 -25
  76. ml_tools/ML_evaluation_captum/_imprimir.py +0 -10
  77. ml_tools/ML_finalize_handler/_imprimir.py +0 -8
  78. ml_tools/ML_inference/_imprimir.py +0 -11
  79. ml_tools/ML_inference_sequence/_imprimir.py +0 -8
  80. ml_tools/ML_inference_vision/_imprimir.py +0 -8
  81. ml_tools/ML_models/_advanced_models.py +0 -1086
  82. ml_tools/ML_models/_imprimir.py +0 -18
  83. ml_tools/ML_models_sequence/_imprimir.py +0 -8
  84. ml_tools/ML_models_vision/_imprimir.py +0 -16
  85. ml_tools/ML_optimization/_imprimir.py +0 -13
  86. ml_tools/ML_scaler/_imprimir.py +0 -8
  87. ml_tools/ML_trainer/_imprimir.py +0 -10
  88. ml_tools/ML_utilities/_imprimir.py +0 -16
  89. ml_tools/ML_vision_transformers/_imprimir.py +0 -14
  90. ml_tools/PSO_optimization/_imprimir.py +0 -10
  91. ml_tools/SQL/_imprimir.py +0 -8
  92. ml_tools/VIF/_imprimir.py +0 -10
  93. ml_tools/data_exploration/_imprimir.py +0 -32
  94. ml_tools/ensemble_evaluation/_imprimir.py +0 -14
  95. ml_tools/ensemble_inference/_imprimir.py +0 -9
  96. ml_tools/ensemble_learning/_imprimir.py +0 -10
  97. ml_tools/excel_handler/_imprimir.py +0 -13
  98. ml_tools/keys/_imprimir.py +0 -11
  99. ml_tools/math_utilities/_imprimir.py +0 -11
  100. ml_tools/optimization_tools/_imprimir.py +0 -13
  101. ml_tools/path_manager/_imprimir.py +0 -15
  102. ml_tools/plot_fonts/_imprimir.py +0 -8
  103. ml_tools/schema/_imprimir.py +0 -10
  104. ml_tools/serde/_imprimir.py +0 -10
  105. ml_tools/utilities/_imprimir.py +0 -18
  106. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/WHEEL +0 -0
  107. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/licenses/LICENSE +0 -0
  108. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  109. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,302 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Any
5
+
6
+ from ..schema import FeatureSchema
7
+ from .._core import get_logger
8
+ from ..keys._keys import SchemaKeys
9
+
10
+ from ._base_save_load import _ArchitectureBuilder
11
+ from ._models_advanced_helpers import (
12
+ Embedding2dLayer,
13
+ )
14
+
15
+
16
+ _LOGGER = get_logger("DragonAutoInt")
17
+
18
+
19
+ __all__ = [
20
+ "DragonAutoInt",
21
+ ]
22
+
23
+ # SOURCE CODE: Adapted and modified from:
24
+ # https://github.com/manujosephv/pytorch_tabular/blob/main/LICENSE
25
+ # https://github.com/Qwicen/node/blob/master/LICENSE.md
26
+ # https://github.com/jrzaurin/pytorch-widedeep?tab=readme-ov-file#license
27
+ # https://github.com/rixwew/pytorch-fm/blob/master/LICENSE
28
+ # https://arxiv.org/abs/1705.08741v2
29
+
30
+
31
+ class DragonAutoInt(_ArchitectureBuilder):
32
+ """
33
+ Native implementation of AutoInt (Automatic Feature Interaction Learning).
34
+
35
+ Maps categorical and continuous features into a shared embedding space,
36
+ then uses Multi-Head Self-Attention to learn high-order feature interactions.
37
+ """
38
+ def __init__(self, *,
39
+ schema: FeatureSchema,
40
+ out_targets: int,
41
+ embedding_dim: int = 32,
42
+ attn_embed_dim: int = 32,
43
+ num_heads: int = 2,
44
+ num_attn_blocks: int = 3,
45
+ attn_dropout: float = 0.1,
46
+ has_residuals: bool = True,
47
+ attention_pooling: bool = True,
48
+ deep_layers: bool = True,
49
+ layers: str = "128-64-32",
50
+ activation: str = "ReLU",
51
+ embedding_dropout: float = 0.0,
52
+ batch_norm_continuous: bool = False):
53
+ """
54
+ Args:
55
+ schema (FeatureSchema):
56
+ Schema object containing feature names and types.
57
+ out_targets (int):
58
+ Number of output targets.
59
+ embedding_dim (int, optional):
60
+ Initial embedding dimension for features.
61
+ Suggested: 16 to 64.
62
+ attn_embed_dim (int, optional):
63
+ Projection dimension for the attention mechanism.
64
+ Suggested: 16 to 64.
65
+ num_heads (int, optional):
66
+ Number of attention heads.
67
+ Suggested: 2 to 8.
68
+ num_attn_blocks (int, optional):
69
+ Number of self-attention layers (depth of interaction learning).
70
+ Suggested: 2 to 5.
71
+ attn_dropout (float, optional):
72
+ Dropout rate within the attention blocks.
73
+ Suggested: 0.0 to 0.2.
74
+ has_residuals (bool, optional):
75
+ If True, adds residual connections (ResNet style) to attention blocks.
76
+ attention_pooling (bool, optional):
77
+ If True, concatenates outputs of all attention blocks (DenseNet style).
78
+ If False, uses only the output of the last block.
79
+ deep_layers (bool, optional):
80
+ If True, adds a standard MLP (Deep Layers) before the attention mechanism
81
+ to process features initially.
82
+ layers (str, optional):
83
+ Hyphen-separated string for MLP layer sizes if deep_layers is True.
84
+ activation (str, optional):
85
+ Activation function for the MLP layers.
86
+ embedding_dropout (float, optional):
87
+ Dropout applied to the initial feature embeddings.
88
+ batch_norm_continuous (bool, optional):
89
+ If True, applies Batch Normalization to continuous features.
90
+ """
91
+ super().__init__()
92
+ self.schema = schema
93
+ self.out_targets = out_targets
94
+
95
+ self.model_hparams = {
96
+ 'embedding_dim': embedding_dim,
97
+ 'attn_embed_dim': attn_embed_dim,
98
+ 'num_heads': num_heads,
99
+ 'num_attn_blocks': num_attn_blocks,
100
+ 'attn_dropout': attn_dropout,
101
+ 'has_residuals': has_residuals,
102
+ 'attention_pooling': attention_pooling,
103
+ 'deep_layers': deep_layers,
104
+ 'layers': layers,
105
+ 'activation': activation,
106
+ 'embedding_dropout': embedding_dropout,
107
+ 'batch_norm_continuous': batch_norm_continuous
108
+ }
109
+
110
+ # -- 1. Setup Embeddings --
111
+ self.categorical_indices = []
112
+ self.cardinalities = []
113
+ if schema.categorical_index_map:
114
+ self.categorical_indices = list(schema.categorical_index_map.keys())
115
+ self.cardinalities = list(schema.categorical_index_map.values())
116
+
117
+ all_indices = set(range(len(schema.feature_names)))
118
+ self.numerical_indices = sorted(list(all_indices - set(self.categorical_indices)))
119
+ n_continuous = len(self.numerical_indices)
120
+
121
+ self.embedding_layer = Embedding2dLayer(
122
+ continuous_dim=n_continuous,
123
+ categorical_cardinality=self.cardinalities,
124
+ embedding_dim=embedding_dim,
125
+ embedding_dropout=embedding_dropout,
126
+ batch_norm_continuous_input=batch_norm_continuous
127
+ )
128
+
129
+ # -- 2. Deep Layers (Optional MLP) --
130
+ curr_units = embedding_dim
131
+ self.deep_layers_mod = None
132
+
133
+ if deep_layers:
134
+ layers_list = []
135
+ layer_sizes = [int(x) for x in layers.split("-")]
136
+ activation_fn = getattr(nn, activation, nn.ReLU)
137
+
138
+ for units in layer_sizes:
139
+ layers_list.append(nn.Linear(curr_units, units))
140
+
141
+ # Changed BatchNorm1d to LayerNorm to handle (Batch, Tokens, Embed) shape correctly
142
+ layers_list.append(nn.LayerNorm(units))
143
+
144
+ layers_list.append(activation_fn())
145
+ layers_list.append(nn.Dropout(embedding_dropout))
146
+ curr_units = units
147
+
148
+ self.deep_layers_mod = nn.Sequential(*layers_list)
149
+
150
+ # -- 3. Attention Backbone --
151
+ self.attn_proj = nn.Linear(curr_units, attn_embed_dim)
152
+
153
+ self.self_attns = nn.ModuleList([
154
+ nn.MultiheadAttention(
155
+ embed_dim=attn_embed_dim,
156
+ num_heads=num_heads,
157
+ dropout=attn_dropout
158
+ )
159
+ for _ in range(num_attn_blocks)
160
+ ])
161
+
162
+ # Residuals
163
+ self.has_residuals = has_residuals
164
+ self.attention_pooling = attention_pooling
165
+
166
+ if has_residuals:
167
+ # If pooling, we project input to match the concatenated output size
168
+ # If not pooling, we project input to match the single block output size
169
+ res_dim = attn_embed_dim * num_attn_blocks if attention_pooling else attn_embed_dim
170
+ self.V_res_embedding = nn.Linear(curr_units, res_dim)
171
+
172
+ # -- 4. Output Dimension Calculation --
173
+ num_features = n_continuous + len(self.cardinalities)
174
+
175
+ # Output is flattened: (Num_Features * Attn_Dim)
176
+ final_dim = num_features * attn_embed_dim
177
+ if attention_pooling:
178
+ final_dim = final_dim * num_attn_blocks
179
+
180
+ self.output_dim = final_dim
181
+ self.head = nn.Linear(final_dim, out_targets)
182
+
183
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
184
+ x_cont = x[:, self.numerical_indices].float()
185
+ x_cat = x[:, self.categorical_indices].long()
186
+
187
+ # 1. Embed -> (Batch, Num_Features, Embed_Dim)
188
+ x = self.embedding_layer(x_cont, x_cat)
189
+
190
+ # 2. Deep Layers
191
+ if self.deep_layers_mod:
192
+ x = self.deep_layers_mod(x)
193
+
194
+ # 3. Attention Projection -> (Batch, Num_Features, Attn_Dim)
195
+ cross_term = self.attn_proj(x)
196
+
197
+ # Transpose for MultiheadAttention (Seq, Batch, Embed)
198
+ cross_term = cross_term.transpose(0, 1)
199
+
200
+ attention_ops = []
201
+ for self_attn in self.self_attns:
202
+ # Self Attention: Query=Key=Value=cross_term
203
+ # Output: (Seq, Batch, Embed)
204
+ out, _ = self_attn(cross_term, cross_term, cross_term)
205
+ cross_term = out # Sequential connection
206
+ if self.attention_pooling:
207
+ attention_ops.append(out)
208
+
209
+ if self.attention_pooling:
210
+ # Concatenate all attention outputs along the embedding dimension
211
+ cross_term = torch.cat(attention_ops, dim=-1)
212
+
213
+ # Transpose back -> (Batch, Num_Features, Final_Attn_Dim)
214
+ cross_term = cross_term.transpose(0, 1)
215
+
216
+ # 4. Residual Connection
217
+ if self.has_residuals:
218
+ V_res = self.V_res_embedding(x)
219
+ cross_term = cross_term + V_res
220
+
221
+ # 5. Flatten and Head
222
+ # ReLU before flattening as per original implementation
223
+ cross_term = F.relu(cross_term)
224
+ cross_term = cross_term.reshape(cross_term.size(0), -1)
225
+
226
+ return self.head(cross_term)
227
+
228
+ def data_aware_initialization(self, train_dataset, num_samples: int = 2000, verbose: int = 3):
229
+ """
230
+ Performs data-aware initialization for the final head bias.
231
+ """
232
+ # 1. Prepare Data
233
+ if verbose >= 2:
234
+ _LOGGER.info(f"Performing AutoInt data-aware initialization on up to {num_samples} samples...")
235
+ device = next(self.parameters()).device
236
+
237
+ # 2. Extract Targets
238
+ if hasattr(train_dataset, "labels") and isinstance(train_dataset.labels, torch.Tensor):
239
+ limit = min(len(train_dataset.labels), num_samples)
240
+ targets = train_dataset.labels[:limit]
241
+ else:
242
+ indices = range(min(len(train_dataset), num_samples))
243
+ y_accum = []
244
+ for i in indices:
245
+ sample = train_dataset[i]
246
+ # Handle tuple (X, y) or dict
247
+ if isinstance(sample, (tuple, list)) and len(sample) >= 2:
248
+ y_val = sample[1]
249
+ elif isinstance(sample, dict):
250
+ y_val = sample.get('target', sample.get('y', None))
251
+ else:
252
+ y_val = None
253
+
254
+ if y_val is not None:
255
+ if not isinstance(y_val, torch.Tensor):
256
+ y_val = torch.tensor(y_val)
257
+ y_accum.append(y_val)
258
+
259
+ if not y_accum:
260
+ if verbose >= 1:
261
+ _LOGGER.warning("Could not extract targets for AutoInt initialization. Skipping.")
262
+ return
263
+
264
+ targets = torch.stack(y_accum)
265
+
266
+ targets = targets.to(device).float()
267
+
268
+ # 3. Initialize Head Bias
269
+ with torch.no_grad():
270
+ mean_target = torch.mean(targets, dim=0)
271
+ if hasattr(self.head, 'bias') and self.head.bias is not None:
272
+ if self.head.bias.shape == mean_target.shape:
273
+ self.head.bias.data = mean_target
274
+ if verbose >= 2:
275
+ _LOGGER.info("AutoInt Initialization Complete. Ready to train.")
276
+ _LOGGER.debug(f"Initialized AutoInt head bias to {mean_target.cpu().numpy()}")
277
+ elif self.head.bias.numel() == 1 and mean_target.numel() == 1:
278
+ self.head.bias.data = mean_target.view(self.head.bias.shape)
279
+ if verbose >= 2:
280
+ _LOGGER.info("AutoInt Initialization Complete. Ready to train.")
281
+ _LOGGER.debug(f"Initialized AutoInt head bias to {mean_target.item()}")
282
+ else:
283
+ if verbose >= 1:
284
+ _LOGGER.warning("AutoInt Head does not have a bias parameter. Skipping initialization.")
285
+
286
+ def get_architecture_config(self) -> dict[str, Any]:
287
+ """Returns the full configuration of the model."""
288
+ schema_dict = {
289
+ 'feature_names': self.schema.feature_names,
290
+ 'continuous_feature_names': self.schema.continuous_feature_names,
291
+ 'categorical_feature_names': self.schema.categorical_feature_names,
292
+ 'categorical_index_map': self.schema.categorical_index_map,
293
+ 'categorical_mappings': self.schema.categorical_mappings
294
+ }
295
+
296
+ config = {
297
+ SchemaKeys.SCHEMA_DICT: schema_dict,
298
+ 'out_targets': self.out_targets,
299
+ **self.model_hparams
300
+ }
301
+ return config
302
+
@@ -0,0 +1,358 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Any, Literal
4
+
5
+ from ..schema import FeatureSchema
6
+ from .._core import get_logger
7
+ from ..keys._keys import SchemaKeys
8
+
9
+ from ._base_save_load import _ArchitectureBuilder
10
+ from ._models_advanced_helpers import (
11
+ Embedding1dLayer,
12
+ GatedFeatureLearningUnit,
13
+ NeuralDecisionTree,
14
+ entmax15,
15
+ entmoid15,
16
+ sparsemax,
17
+ sparsemoid,
18
+ t_softmax,
19
+ SimpleLinearHead,
20
+ _GateHead
21
+ )
22
+
23
+
24
+ _LOGGER = get_logger("DragonGateModel")
25
+
26
+
27
+ __all__ = [
28
+ "DragonGateModel",
29
+ ]
30
+
31
+ # SOURCE CODE: Adapted and modified from:
32
+ # https://github.com/manujosephv/pytorch_tabular/blob/main/LICENSE
33
+ # https://github.com/Qwicen/node/blob/master/LICENSE.md
34
+ # https://github.com/jrzaurin/pytorch-widedeep?tab=readme-ov-file#license
35
+ # https://github.com/rixwew/pytorch-fm/blob/master/LICENSE
36
+ # https://arxiv.org/abs/1705.08741v2
37
+
38
+
39
+ class DragonGateModel(_ArchitectureBuilder):
40
+ """
41
+ Native implementation of the Gated Additive Tree Ensemble (GATE).
42
+
43
+ Combines Gated Feature Learning Units (GFLU) for feature interaction learning
44
+ with Differentiable Decision Trees for prediction.
45
+ """
46
+ ACTIVATION_MAP = {
47
+ "entmax": entmax15,
48
+ "sparsemax": sparsemax,
49
+ "softmax": lambda x: nn.functional.softmax(x, dim=-1),
50
+ "t-softmax": t_softmax,
51
+ }
52
+
53
+ BINARY_ACTIVATION_MAP = {
54
+ "entmoid": entmoid15,
55
+ "sparsemoid": sparsemoid,
56
+ "sigmoid": torch.sigmoid,
57
+ }
58
+
59
+ def __init__(self, *,
60
+ schema: FeatureSchema,
61
+ out_targets: int,
62
+ embedding_dim: int = 16,
63
+ gflu_stages: int = 6,
64
+ gflu_dropout: float = 0.1,
65
+ num_trees: int = 20,
66
+ tree_depth: int = 4,
67
+ tree_dropout: float = 0.1,
68
+ chain_trees: bool = False,
69
+ tree_wise_attention: bool = True,
70
+ tree_wise_attention_dropout: float = 0.1,
71
+ binning_activation: Literal['entmoid', 'sparsemoid', 'sigmoid'] = "entmoid",
72
+ feature_mask_function: Literal['entmax', 'sparsemax', 'softmax', 't-softmax'] = "entmax",
73
+ share_head_weights: bool = True,
74
+ batch_norm_continuous: bool = True):
75
+ """
76
+ Args:
77
+ schema (FeatureSchema):
78
+ Schema object containing feature names and types.
79
+ out_targets (int):
80
+ Number of output targets (e.g., 1 for regression/binary, N for multi-class).
81
+ embedding_dim (int, optional):
82
+ Embedding dimension for categorical features.
83
+ Suggested: 8 to 64.
84
+ gflu_stages (int, optional):
85
+ Number of Gated Feature Learning Unit stages in the backbone.
86
+ Higher values allow learning deeper feature interactions.
87
+ Suggested: 2 to 10.
88
+ gflu_dropout (float, optional):
89
+ Dropout rate applied within GFLU stages.
90
+ Suggested: 0.0 to 0.3.
91
+ num_trees (int, optional):
92
+ Number of Neural Decision Trees to use in the ensemble.
93
+ Suggested: 10 to 50.
94
+ tree_depth (int, optional):
95
+ Depth of the decision trees. Deeper trees capture more complex logic
96
+ but may overfit.
97
+ Suggested: 3 to 6.
98
+ tree_dropout (float, optional):
99
+ Dropout rate applied to the tree leaves.
100
+ Suggested: 0.1 to 0.3.
101
+ chain_trees (bool, optional):
102
+ If True, feeds the output of tree T-1 into tree T (Boosting-style).
103
+ If False, trees run in parallel (Bagging-style).
104
+ tree_wise_attention (bool, optional):
105
+ If True, applies Self-Attention across the outputs of the trees
106
+ to weigh their contributions dynamically.
107
+ tree_wise_attention_dropout (float, optional):
108
+ Dropout rate for the tree-wise attention mechanism.
109
+ Suggested: 0.1.
110
+ binning_activation (str, optional):
111
+ Activation function for the soft binning in trees.
112
+ Options: 'entmoid' (sparse), 'sparsemoid', 'sigmoid'.
113
+ feature_mask_function (str, optional):
114
+ Activation function for feature selection/masking.
115
+ Options: 'entmax' (sparse), 'sparsemax', 'softmax', 't-softmax'.
116
+ share_head_weights (bool, optional):
117
+ If True, all trees share the same linear projection head weights.
118
+ batch_norm_continuous (bool, optional):
119
+ If True, applies Batch Normalization to continuous features before embedding.
120
+ """
121
+ super().__init__()
122
+ self.schema = schema
123
+ self.out_targets = out_targets
124
+
125
+ # -- Configuration for saving --
126
+ self.model_hparams = {
127
+ 'embedding_dim': embedding_dim,
128
+ 'gflu_stages': gflu_stages,
129
+ 'gflu_dropout': gflu_dropout,
130
+ 'num_trees': num_trees,
131
+ 'tree_depth': tree_depth,
132
+ 'tree_dropout': tree_dropout,
133
+ 'chain_trees': chain_trees,
134
+ 'tree_wise_attention': tree_wise_attention,
135
+ 'tree_wise_attention_dropout': tree_wise_attention_dropout,
136
+ 'binning_activation': binning_activation,
137
+ 'feature_mask_function': feature_mask_function,
138
+ 'share_head_weights': share_head_weights,
139
+ 'batch_norm_continuous': batch_norm_continuous
140
+ }
141
+
142
+ # -- 1. Setup Data Processing --
143
+ self.categorical_indices = []
144
+ self.cardinalities = []
145
+ if schema.categorical_index_map:
146
+ self.categorical_indices = list(schema.categorical_index_map.keys())
147
+ self.cardinalities = list(schema.categorical_index_map.values())
148
+
149
+ all_indices = set(range(len(schema.feature_names)))
150
+ self.numerical_indices = sorted(list(all_indices - set(self.categorical_indices)))
151
+
152
+ embedding_dims = [(c, embedding_dim) for c in self.cardinalities]
153
+ n_continuous = len(self.numerical_indices)
154
+
155
+ # -- 2. Embedding Layer --
156
+ self.embedding_layer = Embedding1dLayer(
157
+ continuous_dim=n_continuous,
158
+ categorical_embedding_dims=embedding_dims,
159
+ batch_norm_continuous_input=batch_norm_continuous
160
+ )
161
+
162
+ # Calculate total feature dimension after embedding
163
+ total_embedded_cat_dim = sum([d for _, d in embedding_dims])
164
+ self.n_features = n_continuous + total_embedded_cat_dim
165
+
166
+ # -- 3. GFLU Backbone --
167
+ self.gflu_stages = gflu_stages
168
+ if gflu_stages > 0:
169
+ self.gflus = GatedFeatureLearningUnit(
170
+ n_features_in=self.n_features,
171
+ n_stages=gflu_stages,
172
+ feature_mask_function=self.ACTIVATION_MAP[feature_mask_function],
173
+ dropout=gflu_dropout,
174
+ feature_sparsity=0.3, # Standard default
175
+ learnable_sparsity=True
176
+ )
177
+
178
+ # -- 4. Neural Decision Trees --
179
+ self.num_trees = num_trees
180
+ self.chain_trees = chain_trees
181
+ self.tree_depth = tree_depth
182
+
183
+ if num_trees > 0:
184
+ # Calculate input dim for trees (chaining adds to input)
185
+ tree_input_dim = self.n_features
186
+
187
+ self.trees = nn.ModuleList()
188
+ for _ in range(num_trees):
189
+ tree = NeuralDecisionTree(
190
+ depth=tree_depth,
191
+ n_features=tree_input_dim,
192
+ dropout=tree_dropout,
193
+ binning_activation=self.BINARY_ACTIVATION_MAP[binning_activation],
194
+ feature_mask_function=self.ACTIVATION_MAP[feature_mask_function],
195
+ )
196
+ self.trees.append(tree)
197
+ if chain_trees:
198
+ # Next tree sees original features + output of this tree (2^depth leaves)
199
+ tree_input_dim += 2**tree_depth
200
+
201
+ self.tree_output_dim = 2**tree_depth
202
+
203
+ # Optional: Tree-wise Attention
204
+ self.tree_wise_attention = tree_wise_attention
205
+ if tree_wise_attention:
206
+ self.tree_attention = nn.MultiheadAttention(
207
+ embed_dim=self.tree_output_dim,
208
+ num_heads=1,
209
+ batch_first=False, # (Seq, Batch, Feature) standard for PyTorch Attn
210
+ dropout=tree_wise_attention_dropout
211
+ )
212
+ else:
213
+ self.tree_output_dim = self.n_features
214
+
215
+ # -- 5. Prediction Head --
216
+ if num_trees > 0:
217
+ self.head = _GateHead(
218
+ input_dim=self.tree_output_dim,
219
+ output_dim=out_targets,
220
+ num_trees=num_trees,
221
+ share_head_weights=share_head_weights
222
+ )
223
+ else:
224
+ # Fallback if no trees (just GFLU -> Linear)
225
+ self.head = SimpleLinearHead(self.n_features, out_targets)
226
+ # Add T0 manually for consistency if needed, but SimpleLinear covers bias
227
+
228
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
229
+ # Split inputs
230
+ x_cont = x[:, self.numerical_indices].float()
231
+ x_cat = x[:, self.categorical_indices].long()
232
+
233
+ # 1. Embeddings
234
+ x = self.embedding_layer(x_cont, x_cat)
235
+
236
+ # 2. GFLU
237
+ if self.gflu_stages > 0:
238
+ x = self.gflus(x)
239
+
240
+ # 3. Trees
241
+ if self.num_trees > 0:
242
+ tree_outputs = []
243
+ tree_input = x
244
+
245
+ for i in range(self.num_trees):
246
+ # Tree returns (leaf_nodes, feature_masks)
247
+ # leaf_nodes shape: (Batch, 2^depth)
248
+ leaf_nodes, _ = self.trees[i](tree_input)
249
+
250
+ tree_outputs.append(leaf_nodes.unsqueeze(-1))
251
+
252
+ if self.chain_trees:
253
+ tree_input = torch.cat([tree_input, leaf_nodes], dim=1)
254
+
255
+ # Stack: (Batch, Output_Dim_Tree, Num_Trees)
256
+ tree_outputs = torch.cat(tree_outputs, dim=-1)
257
+
258
+ # 4. Attention
259
+ if self.tree_wise_attention:
260
+ # Permute for MultiheadAttention: (Num_Trees, Batch, Output_Dim_Tree)
261
+ # Treating 'Trees' as the sequence length
262
+ attn_input = tree_outputs.permute(2, 0, 1)
263
+
264
+ attn_output, _ = self.tree_attention(attn_input, attn_input, attn_input)
265
+
266
+ # Permute back: (Batch, Output_Dim_Tree, Num_Trees)
267
+ tree_outputs = attn_output.permute(1, 2, 0)
268
+
269
+ # 5. Head
270
+ return self.head(tree_outputs)
271
+
272
+ else:
273
+ # No trees, just linear on top of GFLU
274
+ return self.head(x)
275
+
276
+ def data_aware_initialization(self, train_dataset, num_samples: int = 2000, verbose: int = 3):
277
+ """
278
+ Performs data-aware initialization for the global bias T0.
279
+ This often speeds up convergence significantly.
280
+ """
281
+ # 1. Prepare Data
282
+ if verbose >= 2:
283
+ _LOGGER.info(f"Performing GATE data-aware initialization on up to {num_samples} samples...")
284
+ device = next(self.parameters()).device
285
+
286
+ # 2. Extract Targets
287
+ # Fast path: direct tensor access (Works with DragonDataset/_PytorchDataset)
288
+ if hasattr(train_dataset, "labels") and isinstance(train_dataset.labels, torch.Tensor):
289
+ limit = min(len(train_dataset.labels), num_samples)
290
+ targets = train_dataset.labels[:limit]
291
+ else:
292
+ # Slow path: Iterate
293
+ indices = range(min(len(train_dataset), num_samples))
294
+ y_accum = []
295
+ for i in indices:
296
+ # Expecting (features, targets) tuple
297
+ sample = train_dataset[i]
298
+ if isinstance(sample, (tuple, list)) and len(sample) >= 2:
299
+ # Standard (X, y) tuple
300
+ y_val = sample[1]
301
+ elif isinstance(sample, dict):
302
+ # Try common keys
303
+ y_val = sample.get('target', sample.get('y', None))
304
+ else:
305
+ y_val = None
306
+
307
+ if y_val is not None:
308
+ # Ensure it's a tensor
309
+ if not isinstance(y_val, torch.Tensor):
310
+ y_val = torch.tensor(y_val)
311
+ y_accum.append(y_val)
312
+
313
+ if not y_accum:
314
+ if verbose >= 1:
315
+ _LOGGER.warning("Could not extract targets for T0 initialization. Skipping.")
316
+ return
317
+
318
+ targets = torch.stack(y_accum)
319
+
320
+ targets = targets.to(device).float()
321
+
322
+ with torch.no_grad():
323
+ if self.num_trees > 0:
324
+ # Initialize T0 to mean of targets
325
+ mean_target = torch.mean(targets, dim=0)
326
+
327
+ # Check shapes to avoid broadcasting errors
328
+ if self.head.T0.shape == mean_target.shape:
329
+ self.head.T0.data = mean_target
330
+ if verbose >= 2:
331
+ _LOGGER.info(f"GATE Initialization Complete. Ready to train.")
332
+ elif self.head.T0.numel() == 1 and mean_target.numel() == 1: # type: ignore
333
+ # scalar case
334
+ self.head.T0.data = mean_target.view(self.head.T0.shape) # type: ignore
335
+ if verbose >= 2:
336
+ _LOGGER.info("GATE Initialization Complete. Ready to train.")
337
+ else:
338
+ _LOGGER.debug(f"Target shape mismatch for T0 init. Model: {self.head.T0.shape}, Data: {mean_target.shape}")
339
+ if verbose >= 1:
340
+ _LOGGER.warning(f"GATE initialization skipped due to shape mismatch:\n Model: {self.head.T0.shape}\n Data: {mean_target.shape}")
341
+
342
+ def get_architecture_config(self) -> dict[str, Any]:
343
+ """Returns the full configuration of the model."""
344
+ schema_dict = {
345
+ 'feature_names': self.schema.feature_names,
346
+ 'continuous_feature_names': self.schema.continuous_feature_names,
347
+ 'categorical_feature_names': self.schema.categorical_feature_names,
348
+ 'categorical_index_map': self.schema.categorical_index_map,
349
+ 'categorical_mappings': self.schema.categorical_mappings
350
+ }
351
+
352
+ config = {
353
+ SchemaKeys.SCHEMA_DICT: schema_dict,
354
+ 'out_targets': self.out_targets,
355
+ **self.model_hparams
356
+ }
357
+ return config
358
+