dragon-ml-toolbox 20.1.1__py3-none-any.whl → 20.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {dragon_ml_toolbox-20.1.1.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/METADATA +1 -1
  2. dragon_ml_toolbox-20.3.0.dist-info/RECORD +143 -0
  3. ml_tools/ETL_cleaning/__init__.py +5 -1
  4. ml_tools/ETL_cleaning/_basic_clean.py +1 -1
  5. ml_tools/ETL_engineering/__init__.py +5 -1
  6. ml_tools/GUI_tools/__init__.py +5 -1
  7. ml_tools/IO_tools/_IO_loggers.py +12 -4
  8. ml_tools/IO_tools/__init__.py +5 -1
  9. ml_tools/MICE/__init__.py +8 -2
  10. ml_tools/MICE/_dragon_mice.py +1 -1
  11. ml_tools/ML_callbacks/__init__.py +5 -1
  12. ml_tools/ML_chain/__init__.py +5 -1
  13. ml_tools/ML_configuration/__init__.py +7 -1
  14. ml_tools/ML_configuration/_training.py +65 -1
  15. ml_tools/ML_datasetmaster/__init__.py +5 -1
  16. ml_tools/ML_datasetmaster/_base_datasetmaster.py +37 -20
  17. ml_tools/ML_datasetmaster/_datasetmaster.py +26 -9
  18. ml_tools/ML_datasetmaster/_sequence_datasetmaster.py +38 -23
  19. ml_tools/ML_evaluation/__init__.py +5 -1
  20. ml_tools/ML_evaluation_captum/__init__.py +5 -1
  21. ml_tools/ML_finalize_handler/__init__.py +5 -1
  22. ml_tools/ML_inference/__init__.py +5 -1
  23. ml_tools/ML_inference_sequence/__init__.py +5 -1
  24. ml_tools/ML_inference_vision/__init__.py +5 -1
  25. ml_tools/ML_models/__init__.py +21 -6
  26. ml_tools/ML_models/_dragon_autoint.py +302 -0
  27. ml_tools/ML_models/_dragon_gate.py +358 -0
  28. ml_tools/ML_models/_dragon_node.py +268 -0
  29. ml_tools/ML_models/_dragon_tabnet.py +255 -0
  30. ml_tools/ML_models_sequence/__init__.py +5 -1
  31. ml_tools/ML_models_vision/__init__.py +5 -1
  32. ml_tools/ML_optimization/__init__.py +11 -3
  33. ml_tools/ML_optimization/_multi_dragon.py +2 -2
  34. ml_tools/ML_optimization/_single_dragon.py +47 -67
  35. ml_tools/ML_optimization/_single_manual.py +1 -1
  36. ml_tools/ML_scaler/_ML_scaler.py +29 -9
  37. ml_tools/ML_scaler/__init__.py +5 -1
  38. ml_tools/ML_trainer/__init__.py +5 -1
  39. ml_tools/ML_trainer/_base_trainer.py +136 -13
  40. ml_tools/ML_trainer/_dragon_detection_trainer.py +31 -91
  41. ml_tools/ML_trainer/_dragon_sequence_trainer.py +24 -74
  42. ml_tools/ML_trainer/_dragon_trainer.py +24 -85
  43. ml_tools/ML_utilities/__init__.py +5 -1
  44. ml_tools/ML_utilities/_inspection.py +44 -30
  45. ml_tools/ML_vision_transformers/__init__.py +8 -2
  46. ml_tools/PSO_optimization/__init__.py +5 -1
  47. ml_tools/SQL/__init__.py +8 -2
  48. ml_tools/VIF/__init__.py +5 -1
  49. ml_tools/data_exploration/__init__.py +4 -1
  50. ml_tools/data_exploration/_cleaning.py +4 -2
  51. ml_tools/ensemble_evaluation/__init__.py +5 -1
  52. ml_tools/ensemble_inference/__init__.py +5 -1
  53. ml_tools/ensemble_learning/__init__.py +5 -1
  54. ml_tools/excel_handler/__init__.py +5 -1
  55. ml_tools/keys/__init__.py +5 -1
  56. ml_tools/math_utilities/__init__.py +5 -1
  57. ml_tools/optimization_tools/__init__.py +5 -1
  58. ml_tools/path_manager/__init__.py +8 -2
  59. ml_tools/plot_fonts/__init__.py +8 -2
  60. ml_tools/schema/__init__.py +8 -2
  61. ml_tools/schema/_feature_schema.py +3 -3
  62. ml_tools/serde/__init__.py +5 -1
  63. ml_tools/utilities/__init__.py +5 -1
  64. ml_tools/utilities/_utility_save_load.py +38 -20
  65. dragon_ml_toolbox-20.1.1.dist-info/RECORD +0 -179
  66. ml_tools/ETL_cleaning/_imprimir.py +0 -13
  67. ml_tools/ETL_engineering/_imprimir.py +0 -24
  68. ml_tools/GUI_tools/_imprimir.py +0 -12
  69. ml_tools/IO_tools/_imprimir.py +0 -14
  70. ml_tools/MICE/_imprimir.py +0 -11
  71. ml_tools/ML_callbacks/_imprimir.py +0 -12
  72. ml_tools/ML_chain/_imprimir.py +0 -12
  73. ml_tools/ML_configuration/_imprimir.py +0 -47
  74. ml_tools/ML_datasetmaster/_imprimir.py +0 -15
  75. ml_tools/ML_evaluation/_imprimir.py +0 -25
  76. ml_tools/ML_evaluation_captum/_imprimir.py +0 -10
  77. ml_tools/ML_finalize_handler/_imprimir.py +0 -8
  78. ml_tools/ML_inference/_imprimir.py +0 -11
  79. ml_tools/ML_inference_sequence/_imprimir.py +0 -8
  80. ml_tools/ML_inference_vision/_imprimir.py +0 -8
  81. ml_tools/ML_models/_advanced_models.py +0 -1086
  82. ml_tools/ML_models/_imprimir.py +0 -18
  83. ml_tools/ML_models_sequence/_imprimir.py +0 -8
  84. ml_tools/ML_models_vision/_imprimir.py +0 -16
  85. ml_tools/ML_optimization/_imprimir.py +0 -13
  86. ml_tools/ML_scaler/_imprimir.py +0 -8
  87. ml_tools/ML_trainer/_imprimir.py +0 -10
  88. ml_tools/ML_utilities/_imprimir.py +0 -16
  89. ml_tools/ML_vision_transformers/_imprimir.py +0 -14
  90. ml_tools/PSO_optimization/_imprimir.py +0 -10
  91. ml_tools/SQL/_imprimir.py +0 -8
  92. ml_tools/VIF/_imprimir.py +0 -10
  93. ml_tools/data_exploration/_imprimir.py +0 -32
  94. ml_tools/ensemble_evaluation/_imprimir.py +0 -14
  95. ml_tools/ensemble_inference/_imprimir.py +0 -9
  96. ml_tools/ensemble_learning/_imprimir.py +0 -10
  97. ml_tools/excel_handler/_imprimir.py +0 -13
  98. ml_tools/keys/_imprimir.py +0 -11
  99. ml_tools/math_utilities/_imprimir.py +0 -11
  100. ml_tools/optimization_tools/_imprimir.py +0 -13
  101. ml_tools/path_manager/_imprimir.py +0 -15
  102. ml_tools/plot_fonts/_imprimir.py +0 -8
  103. ml_tools/schema/_imprimir.py +0 -10
  104. ml_tools/serde/_imprimir.py +0 -10
  105. ml_tools/utilities/_imprimir.py +0 -18
  106. {dragon_ml_toolbox-20.1.1.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/WHEEL +0 -0
  107. {dragon_ml_toolbox-20.1.1.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/licenses/LICENSE +0 -0
  108. {dragon_ml_toolbox-20.1.1.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  109. {dragon_ml_toolbox-20.1.1.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/top_level.txt +0 -0
@@ -1,1086 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from typing import Any, Optional, Literal
5
-
6
- from ..schema import FeatureSchema
7
- from .._core import get_logger
8
- from ..keys._keys import SchemaKeys
9
-
10
- from ._base_save_load import _ArchitectureBuilder
11
- from ._models_advanced_helpers import (
12
- Embedding1dLayer,
13
- GatedFeatureLearningUnit,
14
- NeuralDecisionTree,
15
- entmax15,
16
- entmoid15,
17
- sparsemax,
18
- sparsemoid,
19
- t_softmax,
20
- SimpleLinearHead,
21
- DenseODSTBlock,
22
- Embedding2dLayer,
23
- FeatTransformer,
24
- AttentiveTransformer,
25
- initialize_non_glu,
26
- _GateHead
27
- )
28
-
29
-
30
- _LOGGER = get_logger("DragonModel")
31
-
32
-
33
- __all__ = [
34
- "DragonGateModel",
35
- "DragonNodeModel",
36
- "DragonAutoInt",
37
- "DragonTabNet"
38
- ]
39
-
40
- # SOURCE CODE: Adapted and modified from:
41
- # https://github.com/manujosephv/pytorch_tabular/blob/main/LICENSE
42
- # https://github.com/Qwicen/node/blob/master/LICENSE.md
43
- # https://github.com/jrzaurin/pytorch-widedeep?tab=readme-ov-file#license
44
- # https://github.com/rixwew/pytorch-fm/blob/master/LICENSE
45
- # https://arxiv.org/abs/1705.08741v2
46
-
47
-
48
- class DragonGateModel(_ArchitectureBuilder):
49
- """
50
- Native implementation of the Gated Additive Tree Ensemble (GATE).
51
-
52
- Combines Gated Feature Learning Units (GFLU) for feature interaction learning
53
- with Differentiable Decision Trees for prediction.
54
- """
55
- ACTIVATION_MAP = {
56
- "entmax": entmax15,
57
- "sparsemax": sparsemax,
58
- # "softmax": nn.functional.softmax,
59
- "softmax": lambda x: nn.functional.softmax(x, dim=-1),
60
- "t-softmax": t_softmax,
61
- }
62
-
63
- BINARY_ACTIVATION_MAP = {
64
- "entmoid": entmoid15,
65
- "sparsemoid": sparsemoid,
66
- "sigmoid": torch.sigmoid,
67
- }
68
-
69
- def __init__(self, *,
70
- schema: FeatureSchema,
71
- out_targets: int,
72
- embedding_dim: int = 16,
73
- gflu_stages: int = 6,
74
- gflu_dropout: float = 0.1,
75
- num_trees: int = 20,
76
- tree_depth: int = 4,
77
- tree_dropout: float = 0.1,
78
- chain_trees: bool = False,
79
- tree_wise_attention: bool = True,
80
- tree_wise_attention_dropout: float = 0.1,
81
- binning_activation: Literal['entmoid', 'sparsemoid', 'sigmoid'] = "entmoid",
82
- feature_mask_function: Literal['entmax', 'sparsemax', 'softmax', 't-softmax'] = "entmax",
83
- share_head_weights: bool = True,
84
- batch_norm_continuous: bool = True):
85
- """
86
- Args:
87
- schema (FeatureSchema):
88
- Schema object containing feature names and types.
89
- out_targets (int):
90
- Number of output targets (e.g., 1 for regression/binary, N for multi-class).
91
- embedding_dim (int, optional):
92
- Embedding dimension for categorical features.
93
- Suggested: 8 to 64.
94
- gflu_stages (int, optional):
95
- Number of Gated Feature Learning Unit stages in the backbone.
96
- Higher values allow learning deeper feature interactions.
97
- Suggested: 2 to 10.
98
- gflu_dropout (float, optional):
99
- Dropout rate applied within GFLU stages.
100
- Suggested: 0.0 to 0.3.
101
- num_trees (int, optional):
102
- Number of Neural Decision Trees to use in the ensemble.
103
- Suggested: 10 to 50.
104
- tree_depth (int, optional):
105
- Depth of the decision trees. Deeper trees capture more complex logic
106
- but may overfit.
107
- Suggested: 3 to 6.
108
- tree_dropout (float, optional):
109
- Dropout rate applied to the tree leaves.
110
- Suggested: 0.1 to 0.3.
111
- chain_trees (bool, optional):
112
- If True, feeds the output of tree T-1 into tree T (Boosting-style).
113
- If False, trees run in parallel (Bagging-style).
114
- tree_wise_attention (bool, optional):
115
- If True, applies Self-Attention across the outputs of the trees
116
- to weigh their contributions dynamically.
117
- tree_wise_attention_dropout (float, optional):
118
- Dropout rate for the tree-wise attention mechanism.
119
- Suggested: 0.1.
120
- binning_activation (str, optional):
121
- Activation function for the soft binning in trees.
122
- Options: 'entmoid' (sparse), 'sparsemoid', 'sigmoid'.
123
- feature_mask_function (str, optional):
124
- Activation function for feature selection/masking.
125
- Options: 'entmax' (sparse), 'sparsemax', 'softmax', 't-softmax'.
126
- share_head_weights (bool, optional):
127
- If True, all trees share the same linear projection head weights.
128
- batch_norm_continuous (bool, optional):
129
- If True, applies Batch Normalization to continuous features before embedding.
130
- """
131
- super().__init__()
132
- self.schema = schema
133
- self.out_targets = out_targets
134
-
135
- # -- Configuration for saving --
136
- self.model_hparams = {
137
- 'embedding_dim': embedding_dim,
138
- 'gflu_stages': gflu_stages,
139
- 'gflu_dropout': gflu_dropout,
140
- 'num_trees': num_trees,
141
- 'tree_depth': tree_depth,
142
- 'tree_dropout': tree_dropout,
143
- 'chain_trees': chain_trees,
144
- 'tree_wise_attention': tree_wise_attention,
145
- 'tree_wise_attention_dropout': tree_wise_attention_dropout,
146
- 'binning_activation': binning_activation,
147
- 'feature_mask_function': feature_mask_function,
148
- 'share_head_weights': share_head_weights,
149
- 'batch_norm_continuous': batch_norm_continuous
150
- }
151
-
152
- # -- 1. Setup Data Processing --
153
- self.categorical_indices = []
154
- self.cardinalities = []
155
- if schema.categorical_index_map:
156
- self.categorical_indices = list(schema.categorical_index_map.keys())
157
- self.cardinalities = list(schema.categorical_index_map.values())
158
-
159
- all_indices = set(range(len(schema.feature_names)))
160
- self.numerical_indices = sorted(list(all_indices - set(self.categorical_indices)))
161
-
162
- embedding_dims = [(c, embedding_dim) for c in self.cardinalities]
163
- n_continuous = len(self.numerical_indices)
164
-
165
- # -- 2. Embedding Layer --
166
- self.embedding_layer = Embedding1dLayer(
167
- continuous_dim=n_continuous,
168
- categorical_embedding_dims=embedding_dims,
169
- batch_norm_continuous_input=batch_norm_continuous
170
- )
171
-
172
- # Calculate total feature dimension after embedding
173
- total_embedded_cat_dim = sum([d for _, d in embedding_dims])
174
- self.n_features = n_continuous + total_embedded_cat_dim
175
-
176
- # -- 3. GFLU Backbone --
177
- self.gflu_stages = gflu_stages
178
- if gflu_stages > 0:
179
- self.gflus = GatedFeatureLearningUnit(
180
- n_features_in=self.n_features,
181
- n_stages=gflu_stages,
182
- feature_mask_function=self.ACTIVATION_MAP[feature_mask_function],
183
- dropout=gflu_dropout,
184
- feature_sparsity=0.3, # Standard default
185
- learnable_sparsity=True
186
- )
187
-
188
- # -- 4. Neural Decision Trees --
189
- self.num_trees = num_trees
190
- self.chain_trees = chain_trees
191
- self.tree_depth = tree_depth
192
-
193
- if num_trees > 0:
194
- # Calculate input dim for trees (chaining adds to input)
195
- tree_input_dim = self.n_features
196
-
197
- self.trees = nn.ModuleList()
198
- for _ in range(num_trees):
199
- tree = NeuralDecisionTree(
200
- depth=tree_depth,
201
- n_features=tree_input_dim,
202
- dropout=tree_dropout,
203
- binning_activation=self.BINARY_ACTIVATION_MAP[binning_activation],
204
- feature_mask_function=self.ACTIVATION_MAP[feature_mask_function],
205
- )
206
- self.trees.append(tree)
207
- if chain_trees:
208
- # Next tree sees original features + output of this tree (2^depth leaves)
209
- tree_input_dim += 2**tree_depth
210
-
211
- self.tree_output_dim = 2**tree_depth
212
-
213
- # Optional: Tree-wise Attention
214
- self.tree_wise_attention = tree_wise_attention
215
- if tree_wise_attention:
216
- self.tree_attention = nn.MultiheadAttention(
217
- embed_dim=self.tree_output_dim,
218
- num_heads=1,
219
- batch_first=False, # (Seq, Batch, Feature) standard for PyTorch Attn
220
- dropout=tree_wise_attention_dropout
221
- )
222
- else:
223
- self.tree_output_dim = self.n_features
224
-
225
- # -- 5. Prediction Head --
226
- if num_trees > 0:
227
- self.head = _GateHead(
228
- input_dim=self.tree_output_dim,
229
- output_dim=out_targets,
230
- num_trees=num_trees,
231
- share_head_weights=share_head_weights
232
- )
233
- else:
234
- # Fallback if no trees (just GFLU -> Linear)
235
- self.head = SimpleLinearHead(self.n_features, out_targets)
236
- # Add T0 manually for consistency if needed, but SimpleLinear covers bias
237
-
238
- def forward(self, x: torch.Tensor) -> torch.Tensor:
239
- # Split inputs
240
- x_cont = x[:, self.numerical_indices].float()
241
- x_cat = x[:, self.categorical_indices].long()
242
-
243
- # 1. Embeddings
244
- x = self.embedding_layer(x_cont, x_cat)
245
-
246
- # 2. GFLU
247
- if self.gflu_stages > 0:
248
- x = self.gflus(x)
249
-
250
- # 3. Trees
251
- if self.num_trees > 0:
252
- tree_outputs = []
253
- tree_input = x
254
-
255
- for i in range(self.num_trees):
256
- # Tree returns (leaf_nodes, feature_masks)
257
- # leaf_nodes shape: (Batch, 2^depth)
258
- leaf_nodes, _ = self.trees[i](tree_input)
259
-
260
- tree_outputs.append(leaf_nodes.unsqueeze(-1))
261
-
262
- if self.chain_trees:
263
- tree_input = torch.cat([tree_input, leaf_nodes], dim=1)
264
-
265
- # Stack: (Batch, Output_Dim_Tree, Num_Trees)
266
- tree_outputs = torch.cat(tree_outputs, dim=-1)
267
-
268
- # 4. Attention
269
- if self.tree_wise_attention:
270
- # Permute for MultiheadAttention: (Num_Trees, Batch, Output_Dim_Tree)
271
- # Treating 'Trees' as the sequence length
272
- attn_input = tree_outputs.permute(2, 0, 1)
273
-
274
- attn_output, _ = self.tree_attention(attn_input, attn_input, attn_input)
275
-
276
- # Permute back: (Batch, Output_Dim_Tree, Num_Trees)
277
- tree_outputs = attn_output.permute(1, 2, 0)
278
-
279
- # 5. Head
280
- return self.head(tree_outputs)
281
-
282
- else:
283
- # No trees, just linear on top of GFLU
284
- return self.head(x)
285
-
286
- def data_aware_initialization(self, train_dataset, num_samples: int = 2000):
287
- """
288
- Performs data-aware initialization for the global bias T0.
289
- This often speeds up convergence significantly.
290
- """
291
- # 1. Prepare Data
292
- _LOGGER.info(f"Performing GATE data-aware initialization on up to {num_samples} samples...")
293
- device = next(self.parameters()).device
294
-
295
- # 2. Extract Targets
296
- # Fast path: direct tensor access (Works with DragonDataset/_PytorchDataset)
297
- if hasattr(train_dataset, "labels") and isinstance(train_dataset.labels, torch.Tensor):
298
- limit = min(len(train_dataset.labels), num_samples)
299
- targets = train_dataset.labels[:limit]
300
- else:
301
- # Slow path: Iterate
302
- indices = range(min(len(train_dataset), num_samples))
303
- y_accum = []
304
- for i in indices:
305
- # Expecting (features, targets) tuple
306
- sample = train_dataset[i]
307
- if isinstance(sample, (tuple, list)) and len(sample) >= 2:
308
- # Standard (X, y) tuple
309
- y_val = sample[1]
310
- elif isinstance(sample, dict):
311
- # Try common keys
312
- y_val = sample.get('target', sample.get('y', None))
313
- else:
314
- y_val = None
315
-
316
- if y_val is not None:
317
- # Ensure it's a tensor
318
- if not isinstance(y_val, torch.Tensor):
319
- y_val = torch.tensor(y_val)
320
- y_accum.append(y_val)
321
-
322
- if not y_accum:
323
- _LOGGER.warning("Could not extract targets for T0 initialization. Skipping.")
324
- return
325
-
326
- targets = torch.stack(y_accum)
327
-
328
- targets = targets.to(device).float()
329
-
330
- with torch.no_grad():
331
- if self.num_trees > 0:
332
- # Initialize T0 to mean of targets
333
- mean_target = torch.mean(targets, dim=0)
334
-
335
- # Check shapes to avoid broadcasting errors
336
- if self.head.T0.shape == mean_target.shape:
337
- self.head.T0.data = mean_target
338
- _LOGGER.info(f"Initialized T0 to {mean_target.cpu().numpy()}")
339
- elif self.head.T0.numel() == 1 and mean_target.numel() == 1: # type: ignore
340
- # scalar case
341
- self.head.T0.data = mean_target.view(self.head.T0.shape) # type: ignore
342
- _LOGGER.info("GATE Initialization Complete. Ready to train.")
343
- # _LOGGER.info(f"Initialized T0 to {mean_target.item()}")
344
- else:
345
- _LOGGER.debug(f"Target shape mismatch for T0 init. Model: {self.head.T0.shape}, Data: {mean_target.shape}")
346
- _LOGGER.warning(f"GATE initialization skipped due to shape mismatch:\n Model: {self.head.T0.shape}\n Data: {mean_target.shape}")
347
-
348
- def get_architecture_config(self) -> dict[str, Any]:
349
- """Returns the full configuration of the model."""
350
- schema_dict = {
351
- 'feature_names': self.schema.feature_names,
352
- 'continuous_feature_names': self.schema.continuous_feature_names,
353
- 'categorical_feature_names': self.schema.categorical_feature_names,
354
- 'categorical_index_map': self.schema.categorical_index_map,
355
- 'categorical_mappings': self.schema.categorical_mappings
356
- }
357
-
358
- config = {
359
- SchemaKeys.SCHEMA_DICT: schema_dict,
360
- 'out_targets': self.out_targets,
361
- **self.model_hparams
362
- }
363
- return config
364
-
365
-
366
- class DragonNodeModel(_ArchitectureBuilder):
367
- """
368
- Native implementation of Neural Oblivious Decision Ensembles (NODE).
369
-
370
- The 'Dense' architecture concatenates the outputs of previous layers to the
371
- features of subsequent layers, allowing for deep feature interaction learning.
372
- """
373
- ACTIVATION_MAP = {
374
- "entmax": entmax15,
375
- "sparsemax": sparsemax,
376
- "softmax": F.softmax,
377
- }
378
-
379
- BINARY_ACTIVATION_MAP = {
380
- "entmoid": entmoid15,
381
- "sparsemoid": sparsemoid,
382
- "sigmoid": torch.sigmoid,
383
- }
384
-
385
- def __init__(self, *,
386
- schema: FeatureSchema,
387
- out_targets: int,
388
- embedding_dim: int = 24,
389
- num_trees: int = 1024,
390
- num_layers: int = 2,
391
- tree_depth: int = 6,
392
- additional_tree_output_dim: int = 3,
393
- max_features: Optional[int] = None,
394
- input_dropout: float = 0.0,
395
- embedding_dropout: float = 0.0,
396
- choice_function: Literal['entmax', 'sparsemax', 'softmax'] = 'entmax',
397
- bin_function: Literal['entmoid', 'sparsemoid', 'sigmoid'] = 'entmoid',
398
- batch_norm_continuous: bool = False):
399
- """
400
- Args:
401
- schema (FeatureSchema):
402
- Schema object containing feature names and types.
403
- out_targets (int):
404
- Number of output targets.
405
- embedding_dim (int, optional):
406
- Embedding dimension for categorical features.
407
- Suggested: 16 to 64.
408
- num_trees (int, optional):
409
- Number of Oblivious Decision Trees per layer. NODE relies on a large number
410
- of trees (wider layers) compared to standard forests.
411
- Suggested: 512 to 2048.
412
- num_layers (int, optional):
413
- Number of DenseODST layers. Since layers are densely connected, deeper
414
- networks increase memory usage significantly.
415
- Suggested: 2 to 5.
416
- tree_depth (int, optional):
417
- Depth of the oblivious trees. Oblivious trees are symmetric, so
418
- parameters scale with 2^depth.
419
- Suggested: 4 to 8.
420
- additional_tree_output_dim (int, optional):
421
- Extra output channels per tree. These are used for internal representation
422
- in deeper layers but discarded for the final prediction.
423
- Suggested: 1 to 5.
424
- max_features (int, optional):
425
- Max features to keep in the dense connection to prevent explosion in
426
- feature dimension for deeper layers. If None, keeps all.
427
- input_dropout (float, optional):
428
- Dropout applied to the input of the Dense Block.
429
- Suggested: 0.0 to 0.2.
430
- embedding_dropout (float, optional):
431
- Dropout applied specifically to embeddings.
432
- Suggested: 0.0 to 0.2.
433
- choice_function (str, optional):
434
- Activation for feature selection. 'entmax' allows sparse feature selection.
435
- Options: 'entmax', 'sparsemax', 'softmax'.
436
- bin_function (str, optional):
437
- Activation for the soft binning steps.
438
- Options: 'entmoid', 'sparsemoid', 'sigmoid'.
439
- batch_norm_continuous (bool, optional):
440
- If True, applies Batch Normalization to continuous features.
441
- """
442
- super().__init__()
443
- self.schema = schema
444
- self.out_targets = out_targets
445
-
446
- # -- Configuration for saving --
447
- self.model_hparams = {
448
- 'embedding_dim': embedding_dim,
449
- 'num_trees': num_trees,
450
- 'num_layers': num_layers,
451
- 'tree_depth': tree_depth,
452
- 'additional_tree_output_dim': additional_tree_output_dim,
453
- 'max_features': max_features,
454
- 'input_dropout': input_dropout,
455
- 'embedding_dropout': embedding_dropout,
456
- 'choice_function': choice_function,
457
- 'bin_function': bin_function,
458
- 'batch_norm_continuous': batch_norm_continuous
459
- }
460
-
461
- # -- 1. Setup Embeddings --
462
- self.categorical_indices = []
463
- self.cardinalities = []
464
- if schema.categorical_index_map:
465
- self.categorical_indices = list(schema.categorical_index_map.keys())
466
- self.cardinalities = list(schema.categorical_index_map.values())
467
-
468
- all_indices = set(range(len(schema.feature_names)))
469
- self.numerical_indices = sorted(list(all_indices - set(self.categorical_indices)))
470
-
471
- embedding_dims = [(c, embedding_dim) for c in self.cardinalities]
472
- n_continuous = len(self.numerical_indices)
473
-
474
- self.embedding_layer = Embedding1dLayer(
475
- continuous_dim=n_continuous,
476
- categorical_embedding_dims=embedding_dims,
477
- embedding_dropout=embedding_dropout,
478
- batch_norm_continuous_input=batch_norm_continuous
479
- )
480
-
481
- total_embedded_dim = n_continuous + sum([d for _, d in embedding_dims])
482
-
483
- # -- 2. Backbone (Dense ODST) --
484
- # The tree output dim includes the target dim + auxiliary dims for deep learning
485
- self.tree_dim = out_targets + additional_tree_output_dim
486
-
487
- self.backbone = DenseODSTBlock(
488
- input_dim=total_embedded_dim,
489
- num_trees=num_trees,
490
- num_layers=num_layers,
491
- tree_output_dim=self.tree_dim,
492
- max_features=max_features,
493
- input_dropout=input_dropout,
494
- flatten_output=False, # We want (Batch, Num_Layers * Num_Trees, Tree_Dim)
495
- depth=tree_depth,
496
- # Activations
497
- choice_function=self.ACTIVATION_MAP[choice_function],
498
- bin_function=self.BINARY_ACTIVATION_MAP[bin_function],
499
- # Init strategies (defaults)
500
- initialize_response_=nn.init.normal_,
501
- initialize_selection_logits_=nn.init.uniform_,
502
- )
503
-
504
- # Note: NODE has a fixed Head (averaging) which is defined in forward()
505
-
506
- def forward(self, x: torch.Tensor) -> torch.Tensor:
507
- # Split inputs
508
- x_cont = x[:, self.numerical_indices].float()
509
- x_cat = x[:, self.categorical_indices].long()
510
-
511
- # 1. Embeddings
512
- x = self.embedding_layer(x_cont, x_cat)
513
-
514
- # 2. Backbone
515
- # Output shape: (Batch, Total_Trees, Tree_Dim)
516
- x = self.backbone(x)
517
-
518
- # 3. Head (Averaging)
519
- # We take the first 'out_targets' channels and average them across all trees
520
- # subset: x[..., :out_targets]
521
- # mean: .mean(dim=-2) -> average over Total_Trees dimension
522
- return x[..., :self.out_targets].mean(dim=-2)
523
-
524
- def data_aware_initialization(self, train_dataset, num_samples: int = 2000):
525
- """
526
- Performs data-aware initialization for the ODST trees using a dataset.
527
- Crucial for NODE convergence.
528
- """
529
- # 1. Prepare Data
530
- _LOGGER.info(f"Performing NODE data-aware initialization on up to {num_samples} samples...")
531
- device = next(self.parameters()).device
532
-
533
- # 2. Extract Features
534
- # Fast path: If the dataset exposes the full feature tensor (like _PytorchDataset)
535
- if hasattr(train_dataset, "features") and isinstance(train_dataset.features, torch.Tensor):
536
- # Slice directly
537
- limit = min(len(train_dataset.features), num_samples)
538
- x_input = train_dataset.features[:limit]
539
- else:
540
- # Slow path: Iterate and stack (Generic Dataset)
541
- indices = range(min(len(train_dataset), num_samples))
542
- x_accum = []
543
- for i in indices:
544
- # Expecting (features, targets) tuple from standard datasets
545
- sample = train_dataset[i]
546
- if isinstance(sample, (tuple, list)):
547
- x_accum.append(sample[0])
548
- elif isinstance(sample, dict) and 'features' in sample:
549
- x_accum.append(sample['features'])
550
- elif isinstance(sample, dict) and 'x' in sample:
551
- x_accum.append(sample['x'])
552
- else:
553
- # Fallback: assume the sample itself is the feature
554
- x_accum.append(sample)
555
-
556
- if not x_accum:
557
- _LOGGER.warning("Dataset empty or format unrecognized. Skipping NODE initialization.")
558
- return
559
-
560
- x_input = torch.stack(x_accum)
561
-
562
- x_input = x_input.to(device).float()
563
-
564
- # 3. Process features (Split -> Embed)
565
- x_cont = x_input[:, self.numerical_indices].float()
566
- x_cat = x_input[:, self.categorical_indices].long()
567
-
568
- with torch.no_grad():
569
- x_embedded = self.embedding_layer(x_cont, x_cat)
570
-
571
- # 4. Initialize Backbone
572
- if hasattr(self.backbone, 'initialize'):
573
- self.backbone.initialize(x_embedded)
574
- _LOGGER.info("NODE Initialization Complete. Ready to train.")
575
- else:
576
- _LOGGER.warning("NODE Backbone does not have an 'initialize' method. Skipping.")
577
-
578
- def get_architecture_config(self) -> dict[str, Any]:
579
- """Returns the full configuration of the model."""
580
- schema_dict = {
581
- 'feature_names': self.schema.feature_names,
582
- 'continuous_feature_names': self.schema.continuous_feature_names,
583
- 'categorical_feature_names': self.schema.categorical_feature_names,
584
- 'categorical_index_map': self.schema.categorical_index_map,
585
- 'categorical_mappings': self.schema.categorical_mappings
586
- }
587
-
588
- config = {
589
- SchemaKeys.SCHEMA_DICT: schema_dict,
590
- 'out_targets': self.out_targets,
591
- **self.model_hparams
592
- }
593
- return config
594
-
595
-
596
- class DragonAutoInt(_ArchitectureBuilder):
597
- """
598
- Native implementation of AutoInt (Automatic Feature Interaction Learning).
599
-
600
- Maps categorical and continuous features into a shared embedding space,
601
- then uses Multi-Head Self-Attention to learn high-order feature interactions.
602
- """
603
- def __init__(self, *,
604
- schema: FeatureSchema,
605
- out_targets: int,
606
- embedding_dim: int = 32,
607
- attn_embed_dim: int = 32,
608
- num_heads: int = 2,
609
- num_attn_blocks: int = 3,
610
- attn_dropout: float = 0.1,
611
- has_residuals: bool = True,
612
- attention_pooling: bool = True,
613
- deep_layers: bool = True,
614
- layers: str = "128-64-32",
615
- activation: str = "ReLU",
616
- embedding_dropout: float = 0.0,
617
- batch_norm_continuous: bool = False):
618
- """
619
- Args:
620
- schema (FeatureSchema):
621
- Schema object containing feature names and types.
622
- out_targets (int):
623
- Number of output targets.
624
- embedding_dim (int, optional):
625
- Initial embedding dimension for features.
626
- Suggested: 16 to 64.
627
- attn_embed_dim (int, optional):
628
- Projection dimension for the attention mechanism.
629
- Suggested: 16 to 64.
630
- num_heads (int, optional):
631
- Number of attention heads.
632
- Suggested: 2 to 8.
633
- num_attn_blocks (int, optional):
634
- Number of self-attention layers (depth of interaction learning).
635
- Suggested: 2 to 5.
636
- attn_dropout (float, optional):
637
- Dropout rate within the attention blocks.
638
- Suggested: 0.0 to 0.2.
639
- has_residuals (bool, optional):
640
- If True, adds residual connections (ResNet style) to attention blocks.
641
- attention_pooling (bool, optional):
642
- If True, concatenates outputs of all attention blocks (DenseNet style).
643
- If False, uses only the output of the last block.
644
- deep_layers (bool, optional):
645
- If True, adds a standard MLP (Deep Layers) before the attention mechanism
646
- to process features initially.
647
- layers (str, optional):
648
- Hyphen-separated string for MLP layer sizes if deep_layers is True.
649
- activation (str, optional):
650
- Activation function for the MLP layers.
651
- embedding_dropout (float, optional):
652
- Dropout applied to the initial feature embeddings.
653
- batch_norm_continuous (bool, optional):
654
- If True, applies Batch Normalization to continuous features.
655
- """
656
- super().__init__()
657
- self.schema = schema
658
- self.out_targets = out_targets
659
-
660
- self.model_hparams = {
661
- 'embedding_dim': embedding_dim,
662
- 'attn_embed_dim': attn_embed_dim,
663
- 'num_heads': num_heads,
664
- 'num_attn_blocks': num_attn_blocks,
665
- 'attn_dropout': attn_dropout,
666
- 'has_residuals': has_residuals,
667
- 'attention_pooling': attention_pooling,
668
- 'deep_layers': deep_layers,
669
- 'layers': layers,
670
- 'activation': activation,
671
- 'embedding_dropout': embedding_dropout,
672
- 'batch_norm_continuous': batch_norm_continuous
673
- }
674
-
675
- # -- 1. Setup Embeddings --
676
- self.categorical_indices = []
677
- self.cardinalities = []
678
- if schema.categorical_index_map:
679
- self.categorical_indices = list(schema.categorical_index_map.keys())
680
- self.cardinalities = list(schema.categorical_index_map.values())
681
-
682
- all_indices = set(range(len(schema.feature_names)))
683
- self.numerical_indices = sorted(list(all_indices - set(self.categorical_indices)))
684
- n_continuous = len(self.numerical_indices)
685
-
686
- self.embedding_layer = Embedding2dLayer(
687
- continuous_dim=n_continuous,
688
- categorical_cardinality=self.cardinalities,
689
- embedding_dim=embedding_dim,
690
- embedding_dropout=embedding_dropout,
691
- batch_norm_continuous_input=batch_norm_continuous
692
- )
693
-
694
- # -- 2. Deep Layers (Optional MLP) --
695
- curr_units = embedding_dim
696
- self.deep_layers_mod = None
697
-
698
- if deep_layers:
699
- layers_list = []
700
- layer_sizes = [int(x) for x in layers.split("-")]
701
- activation_fn = getattr(nn, activation, nn.ReLU)
702
-
703
- for units in layer_sizes:
704
- layers_list.append(nn.Linear(curr_units, units))
705
-
706
- # Changed BatchNorm1d to LayerNorm to handle (Batch, Tokens, Embed) shape correctly
707
- layers_list.append(nn.LayerNorm(units))
708
-
709
- layers_list.append(activation_fn())
710
- layers_list.append(nn.Dropout(embedding_dropout))
711
- curr_units = units
712
-
713
- self.deep_layers_mod = nn.Sequential(*layers_list)
714
-
715
- # -- 3. Attention Backbone --
716
- self.attn_proj = nn.Linear(curr_units, attn_embed_dim)
717
-
718
- self.self_attns = nn.ModuleList([
719
- nn.MultiheadAttention(
720
- embed_dim=attn_embed_dim,
721
- num_heads=num_heads,
722
- dropout=attn_dropout
723
- )
724
- for _ in range(num_attn_blocks)
725
- ])
726
-
727
- # Residuals
728
- self.has_residuals = has_residuals
729
- self.attention_pooling = attention_pooling
730
-
731
- if has_residuals:
732
- # If pooling, we project input to match the concatenated output size
733
- # If not pooling, we project input to match the single block output size
734
- res_dim = attn_embed_dim * num_attn_blocks if attention_pooling else attn_embed_dim
735
- self.V_res_embedding = nn.Linear(curr_units, res_dim)
736
-
737
- # -- 4. Output Dimension Calculation --
738
- num_features = n_continuous + len(self.cardinalities)
739
-
740
- # Output is flattened: (Num_Features * Attn_Dim)
741
- final_dim = num_features * attn_embed_dim
742
- if attention_pooling:
743
- final_dim = final_dim * num_attn_blocks
744
-
745
- self.output_dim = final_dim
746
- self.head = nn.Linear(final_dim, out_targets)
747
-
748
- def forward(self, x: torch.Tensor) -> torch.Tensor:
749
- x_cont = x[:, self.numerical_indices].float()
750
- x_cat = x[:, self.categorical_indices].long()
751
-
752
- # 1. Embed -> (Batch, Num_Features, Embed_Dim)
753
- x = self.embedding_layer(x_cont, x_cat)
754
-
755
- # 2. Deep Layers
756
- if self.deep_layers_mod:
757
- x = self.deep_layers_mod(x)
758
-
759
- # 3. Attention Projection -> (Batch, Num_Features, Attn_Dim)
760
- cross_term = self.attn_proj(x)
761
-
762
- # Transpose for MultiheadAttention (Seq, Batch, Embed)
763
- cross_term = cross_term.transpose(0, 1)
764
-
765
- attention_ops = []
766
- for self_attn in self.self_attns:
767
- # Self Attention: Query=Key=Value=cross_term
768
- # Output: (Seq, Batch, Embed)
769
- out, _ = self_attn(cross_term, cross_term, cross_term)
770
- cross_term = out # Sequential connection
771
- if self.attention_pooling:
772
- attention_ops.append(out)
773
-
774
- if self.attention_pooling:
775
- # Concatenate all attention outputs along the embedding dimension
776
- cross_term = torch.cat(attention_ops, dim=-1)
777
-
778
- # Transpose back -> (Batch, Num_Features, Final_Attn_Dim)
779
- cross_term = cross_term.transpose(0, 1)
780
-
781
- # 4. Residual Connection
782
- if self.has_residuals:
783
- V_res = self.V_res_embedding(x)
784
- cross_term = cross_term + V_res
785
-
786
- # 5. Flatten and Head
787
- # ReLU before flattening as per original implementation
788
- cross_term = F.relu(cross_term)
789
- cross_term = cross_term.reshape(cross_term.size(0), -1)
790
-
791
- return self.head(cross_term)
792
-
793
- def data_aware_initialization(self, train_dataset, num_samples: int = 2000):
794
- """
795
- Performs data-aware initialization for the final head bias.
796
- """
797
- # 1. Prepare Data
798
- _LOGGER.info(f"Performing AutoInt data-aware initialization on up to {num_samples} samples...")
799
- device = next(self.parameters()).device
800
-
801
- # 2. Extract Targets
802
- if hasattr(train_dataset, "labels") and isinstance(train_dataset.labels, torch.Tensor):
803
- limit = min(len(train_dataset.labels), num_samples)
804
- targets = train_dataset.labels[:limit]
805
- else:
806
- indices = range(min(len(train_dataset), num_samples))
807
- y_accum = []
808
- for i in indices:
809
- sample = train_dataset[i]
810
- # Handle tuple (X, y) or dict
811
- if isinstance(sample, (tuple, list)) and len(sample) >= 2:
812
- y_val = sample[1]
813
- elif isinstance(sample, dict):
814
- y_val = sample.get('target', sample.get('y', None))
815
- else:
816
- y_val = None
817
-
818
- if y_val is not None:
819
- if not isinstance(y_val, torch.Tensor):
820
- y_val = torch.tensor(y_val)
821
- y_accum.append(y_val)
822
-
823
- if not y_accum:
824
- _LOGGER.warning("Could not extract targets for AutoInt initialization. Skipping.")
825
- return
826
-
827
- targets = torch.stack(y_accum)
828
-
829
- targets = targets.to(device).float()
830
-
831
- # 3. Initialize Head Bias
832
- with torch.no_grad():
833
- mean_target = torch.mean(targets, dim=0)
834
- if hasattr(self.head, 'bias') and self.head.bias is not None:
835
- if self.head.bias.shape == mean_target.shape:
836
- self.head.bias.data = mean_target
837
- _LOGGER.info("AutoInt Initialization Complete. Ready to train.")
838
- _LOGGER.debug(f"Initialized AutoInt head bias to {mean_target.cpu().numpy()}")
839
- elif self.head.bias.numel() == 1 and mean_target.numel() == 1:
840
- self.head.bias.data = mean_target.view(self.head.bias.shape)
841
- _LOGGER.info("AutoInt Initialization Complete. Ready to train.")
842
- _LOGGER.debug(f"Initialized AutoInt head bias to {mean_target.item()}")
843
- else:
844
- _LOGGER.warning("AutoInt Head does not have a bias parameter. Skipping initialization.")
845
-
846
- def get_architecture_config(self) -> dict[str, Any]:
847
- """Returns the full configuration of the model."""
848
- schema_dict = {
849
- 'feature_names': self.schema.feature_names,
850
- 'continuous_feature_names': self.schema.continuous_feature_names,
851
- 'categorical_feature_names': self.schema.categorical_feature_names,
852
- 'categorical_index_map': self.schema.categorical_index_map,
853
- 'categorical_mappings': self.schema.categorical_mappings
854
- }
855
-
856
- config = {
857
- SchemaKeys.SCHEMA_DICT: schema_dict,
858
- 'out_targets': self.out_targets,
859
- **self.model_hparams
860
- }
861
- return config
862
-
863
-
864
- class DragonTabNet(_ArchitectureBuilder):
865
- """
866
- Native implementation of TabNet (Attentive Interpretable Tabular Learning).
867
-
868
- Includes the Initial Splitter, Ghost Batch Norm, and GLU scaling.
869
- """
870
- def __init__(self, *,
871
- schema: FeatureSchema,
872
- out_targets: int,
873
- n_d: int = 8,
874
- n_a: int = 8,
875
- n_steps: int = 3,
876
- gamma: float = 1.3,
877
- n_independent: int = 2,
878
- n_shared: int = 2,
879
- virtual_batch_size: int = 128,
880
- momentum: float = 0.02,
881
- mask_type: Literal['sparsemax', 'entmax', 'softmax'] = 'sparsemax',
882
- batch_norm_continuous: bool = False):
883
- """
884
- Args:
885
- schema (FeatureSchema):
886
- Schema object containing feature names and types.
887
- out_targets (int):
888
- Number of output targets.
889
- n_d (int, optional):
890
- Dimension of the prediction layer (decision step).
891
- Suggested: 8 to 64.
892
- n_a (int, optional):
893
- Dimension of the attention layer (masking step).
894
- Suggested: 8 to 64.
895
- n_steps (int, optional):
896
- Number of sequential attention steps (architecture depth).
897
- Suggested: 3 to 10.
898
- gamma (float, optional):
899
- Relaxation parameter for sparsity in the mask.
900
- Suggested: 1.0 to 2.0.
901
- n_independent (int, optional):
902
- Number of independent Gated Linear Unit (GLU) layers in each block.
903
- Suggested: 1 to 5.
904
- n_shared (int, optional):
905
- Number of shared GLU layers across all blocks.
906
- Suggested: 1 to 5.
907
- virtual_batch_size (int, optional):
908
- Batch size for Ghost Batch Normalization.
909
- Suggested: 128 to 1024.
910
- momentum (float, optional):
911
- Momentum for Batch Normalization.
912
- Suggested: 0.01 to 0.4.
913
- mask_type (str, optional):
914
- Masking function to use. 'sparsemax' enforces sparsity.
915
- Options: 'sparsemax', 'entmax', 'softmax'.
916
- batch_norm_continuous (bool, optional):
917
- If True, applies Batch Normalization to continuous features before processing.
918
- """
919
- super().__init__()
920
- self.schema = schema
921
- self.out_targets = out_targets
922
-
923
- # Save config
924
- self.model_hparams = {
925
- 'n_d': n_d,
926
- 'n_a': n_a,
927
- 'n_steps': n_steps,
928
- 'gamma': gamma,
929
- 'n_independent': n_independent,
930
- 'n_shared': n_shared,
931
- 'virtual_batch_size': virtual_batch_size,
932
- 'momentum': momentum,
933
- 'mask_type': mask_type,
934
- 'batch_norm_continuous': batch_norm_continuous
935
- }
936
-
937
- # -- 1. Setup Input Features --
938
- self.categorical_indices = []
939
- self.cardinalities = []
940
- if schema.categorical_index_map:
941
- self.categorical_indices = list(schema.categorical_index_map.keys())
942
- self.cardinalities = list(schema.categorical_index_map.values())
943
-
944
- all_indices = set(range(len(schema.feature_names)))
945
- self.numerical_indices = sorted(list(all_indices - set(self.categorical_indices)))
946
-
947
- # Standard TabNet Embeddings:
948
- # We use a simple embedding for each categorical feature and concat with continuous.
949
- self.cat_embeddings = nn.ModuleList([
950
- nn.Embedding(card, 1) for card in self.cardinalities
951
- ])
952
-
953
- self.n_continuous = len(self.numerical_indices)
954
- self.input_dim = self.n_continuous + len(self.cardinalities)
955
-
956
- # -- 2. TabNet Backbone Components --
957
- self.n_d = n_d
958
- self.n_a = n_a
959
- self.n_steps = n_steps
960
- self.gamma = gamma
961
- self.epsilon = 1e-15
962
-
963
- # Initial BN
964
- self.initial_bn = nn.BatchNorm1d(self.input_dim, momentum=0.01)
965
-
966
- # Shared GLU Layers
967
- if n_shared > 0:
968
- self.shared_feat_transform = nn.ModuleList()
969
- for i in range(n_shared):
970
- if i == 0:
971
- self.shared_feat_transform.append(
972
- nn.Linear(self.input_dim, 2 * (n_d + n_a), bias=False)
973
- )
974
- else:
975
- self.shared_feat_transform.append(
976
- nn.Linear(n_d + n_a, 2 * (n_d + n_a), bias=False)
977
- )
978
- else:
979
- self.shared_feat_transform = None
980
-
981
- # Initial Splitter
982
- # This processes the input BEFORE the first step to generate the initial attention vector 'a'
983
- self.initial_splitter = FeatTransformer(
984
- self.input_dim,
985
- n_d + n_a,
986
- self.shared_feat_transform,
987
- n_glu_independent=n_independent,
988
- virtual_batch_size=virtual_batch_size,
989
- momentum=momentum,
990
- )
991
-
992
- # Steps
993
- self.feat_transformers = nn.ModuleList()
994
- self.att_transformers = nn.ModuleList()
995
-
996
- for step in range(n_steps):
997
- transformer = FeatTransformer(
998
- self.input_dim,
999
- n_d + n_a,
1000
- self.shared_feat_transform,
1001
- n_glu_independent=n_independent,
1002
- virtual_batch_size=virtual_batch_size,
1003
- momentum=momentum,
1004
- )
1005
- attention = AttentiveTransformer(
1006
- n_a,
1007
- self.input_dim, # We assume group_dim = input_dim (no grouping)
1008
- virtual_batch_size=virtual_batch_size,
1009
- momentum=momentum,
1010
- mask_type=mask_type,
1011
- )
1012
- self.feat_transformers.append(transformer)
1013
- self.att_transformers.append(attention)
1014
-
1015
- # -- 3. Final Mapping Head --
1016
- self.final_mapping = nn.Linear(n_d, out_targets, bias=False)
1017
- initialize_non_glu(self.final_mapping, n_d, out_targets)
1018
-
1019
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1020
- # -- Preprocessing --
1021
- x_cont = x[:, self.numerical_indices].float()
1022
- x_cat = x[:, self.categorical_indices].long()
1023
-
1024
- cat_list = []
1025
- for i, embed in enumerate(self.cat_embeddings):
1026
- cat_list.append(embed(x_cat[:, i])) # (B, 1)
1027
-
1028
- if cat_list:
1029
- x_in = torch.cat([x_cont, *cat_list], dim=1)
1030
- else:
1031
- x_in = x_cont
1032
-
1033
-
1034
- # -- TabNet Encoder Pass --
1035
- x_bn = self.initial_bn(x_in)
1036
- # Initial Split
1037
- # The splitter produces [d, a]. We only need 'a' to start the loop.
1038
- att = self.initial_splitter(x_bn)[:, self.n_d :]
1039
- priors = torch.ones(x_bn.shape, device=x.device)
1040
- out_accumulated = 0
1041
- self.regularization_loss = 0
1042
-
1043
- for step in range(self.n_steps):
1044
- # 1. Attention
1045
- mask = self.att_transformers[step](priors, att)
1046
- # 2. Accumulate sparsity loss matching original implementation
1047
- loss = torch.sum(torch.mul(mask, torch.log(mask + self.epsilon)), dim=1)
1048
- self.regularization_loss += torch.mean(loss)
1049
- # 3. Update Prior
1050
- priors = torch.mul(self.gamma - mask, priors)
1051
- # 4. Masking
1052
- masked_x = torch.mul(mask, x_bn)
1053
- # 5. Feature Transformer
1054
- out = self.feat_transformers[step](masked_x)
1055
- # 6. Split Output
1056
- d = nn.ReLU()(out[:, :self.n_d])
1057
- att = out[:, self.n_d:]
1058
- # 7. Accumulate Decision
1059
- out_accumulated = out_accumulated + d
1060
-
1061
- self.regularization_loss /= self.n_steps
1062
- return self.final_mapping(out_accumulated)
1063
-
1064
- def data_aware_initialization(self, train_dataset, num_samples: int = 2000):
1065
- """
1066
- TabNet does not require data-aware initialization. Method Implemented for compatibility.
1067
- """
1068
- _LOGGER.info("TabNet does not require data-aware initialization. Skipping.")
1069
-
1070
- def get_architecture_config(self) -> dict[str, Any]:
1071
- """Returns the full configuration of the model."""
1072
- schema_dict = {
1073
- 'feature_names': self.schema.feature_names,
1074
- 'continuous_feature_names': self.schema.continuous_feature_names,
1075
- 'categorical_feature_names': self.schema.categorical_feature_names,
1076
- 'categorical_index_map': self.schema.categorical_index_map,
1077
- 'categorical_mappings': self.schema.categorical_mappings
1078
- }
1079
-
1080
- config = {
1081
- SchemaKeys.SCHEMA_DICT: schema_dict,
1082
- 'out_targets': self.out_targets,
1083
- **self.model_hparams
1084
- }
1085
- return config
1086
-