dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .._core import get_logger
|
|
6
|
+
|
|
7
|
+
from ._base_save_load import _ArchitectureHandlerMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
_LOGGER = get_logger("DragonModel: MLP")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"_BaseMLP",
|
|
15
|
+
"_BaseAttention",
|
|
16
|
+
"_AttentionLayer",
|
|
17
|
+
"_MultiHeadAttentionLayer",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _BaseMLP(nn.Module, _ArchitectureHandlerMixin):
|
|
22
|
+
"""
|
|
23
|
+
A base class for Multilayer Perceptrons.
|
|
24
|
+
|
|
25
|
+
Handles validation, configuration, and the creation of the core MLP layers,
|
|
26
|
+
allowing subclasses to define their own pre-processing and forward pass.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self,
|
|
29
|
+
in_features: int,
|
|
30
|
+
out_targets: int,
|
|
31
|
+
hidden_layers: list[int],
|
|
32
|
+
drop_out: float) -> None:
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
# --- Validation ---
|
|
36
|
+
if not isinstance(in_features, int) or in_features < 1:
|
|
37
|
+
_LOGGER.error("'in_features' must be a positive integer.")
|
|
38
|
+
raise ValueError()
|
|
39
|
+
if not isinstance(out_targets, int) or out_targets < 1:
|
|
40
|
+
_LOGGER.error("'out_targets' must be a positive integer.")
|
|
41
|
+
raise ValueError()
|
|
42
|
+
if not isinstance(hidden_layers, list) or not all(isinstance(n, int) for n in hidden_layers):
|
|
43
|
+
_LOGGER.error("'hidden_layers' must be a list of integers.")
|
|
44
|
+
raise TypeError()
|
|
45
|
+
if not (0.0 <= drop_out < 1.0):
|
|
46
|
+
_LOGGER.error("'drop_out' must be a float between 0.0 and 1.0.")
|
|
47
|
+
raise ValueError()
|
|
48
|
+
|
|
49
|
+
# --- Save configuration ---
|
|
50
|
+
self.in_features = in_features
|
|
51
|
+
self.out_targets = out_targets
|
|
52
|
+
self.hidden_layers = hidden_layers
|
|
53
|
+
self.drop_out = drop_out
|
|
54
|
+
|
|
55
|
+
# --- Build the core MLP network ---
|
|
56
|
+
mlp_layers = []
|
|
57
|
+
current_features = in_features
|
|
58
|
+
for neurons in hidden_layers:
|
|
59
|
+
mlp_layers.extend([
|
|
60
|
+
nn.Linear(current_features, neurons),
|
|
61
|
+
nn.BatchNorm1d(neurons),
|
|
62
|
+
nn.ReLU(),
|
|
63
|
+
nn.Dropout(p=drop_out)
|
|
64
|
+
])
|
|
65
|
+
current_features = neurons
|
|
66
|
+
|
|
67
|
+
self.mlp = nn.Sequential(*mlp_layers)
|
|
68
|
+
# Set a customizable Prediction Head for flexibility, specially in transfer learning and fine-tuning
|
|
69
|
+
self.output_layer = nn.Linear(current_features, out_targets)
|
|
70
|
+
|
|
71
|
+
def get_architecture_config(self) -> dict[str, Any]:
|
|
72
|
+
"""Returns the base configuration of the model."""
|
|
73
|
+
return {
|
|
74
|
+
'in_features': self.in_features,
|
|
75
|
+
'out_targets': self.out_targets,
|
|
76
|
+
'hidden_layers': self.hidden_layers,
|
|
77
|
+
'drop_out': self.drop_out
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def _repr_helper(self, name: str, mlp_layers: list[str]):
|
|
81
|
+
last_layer = self.output_layer
|
|
82
|
+
if isinstance(last_layer, nn.Linear):
|
|
83
|
+
mlp_layers.append(str(last_layer.out_features))
|
|
84
|
+
else:
|
|
85
|
+
mlp_layers.append("Custom Prediction Head")
|
|
86
|
+
|
|
87
|
+
# Creates a string like: 10 -> 40 -> 80 -> 40 -> 2
|
|
88
|
+
arch_str = ' -> '.join(mlp_layers)
|
|
89
|
+
|
|
90
|
+
return f"{name}(arch: {arch_str})"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class _BaseAttention(_BaseMLP):
|
|
94
|
+
"""
|
|
95
|
+
Abstract base class for MLP models that incorporate an attention mechanism
|
|
96
|
+
before the main MLP layers.
|
|
97
|
+
"""
|
|
98
|
+
def __init__(self, *args, **kwargs):
|
|
99
|
+
super().__init__(*args, **kwargs)
|
|
100
|
+
# By default, models inheriting this do not have the flag.
|
|
101
|
+
self.attention = None
|
|
102
|
+
self.has_interpretable_attention = False
|
|
103
|
+
|
|
104
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
105
|
+
"""Defines the standard forward pass."""
|
|
106
|
+
logits, _attention_weights = self.forward_attention(x)
|
|
107
|
+
return logits
|
|
108
|
+
|
|
109
|
+
def forward_attention(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
110
|
+
"""Returns logits and attention weights."""
|
|
111
|
+
# This logic is now shared and defined in one place
|
|
112
|
+
x, attention_weights = self.attention(x) # type: ignore
|
|
113
|
+
x = self.mlp(x)
|
|
114
|
+
logits = self.output_layer(x)
|
|
115
|
+
return logits, attention_weights
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class _AttentionLayer(nn.Module):
|
|
119
|
+
"""
|
|
120
|
+
Calculates attention weights and applies them to the input features, incorporating a residual connection for improved stability and performance.
|
|
121
|
+
|
|
122
|
+
Returns both the final output and the weights for interpretability.
|
|
123
|
+
"""
|
|
124
|
+
def __init__(self, num_features: int):
|
|
125
|
+
super().__init__()
|
|
126
|
+
# The hidden layer size is a hyperparameter
|
|
127
|
+
hidden_size = max(16, num_features // 4)
|
|
128
|
+
|
|
129
|
+
# Learn to produce attention scores
|
|
130
|
+
self.attention_net = nn.Sequential(
|
|
131
|
+
nn.Linear(num_features, hidden_size),
|
|
132
|
+
nn.Tanh(),
|
|
133
|
+
nn.Linear(hidden_size, num_features) # Output one score per feature
|
|
134
|
+
)
|
|
135
|
+
self.softmax = nn.Softmax(dim=1)
|
|
136
|
+
|
|
137
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
138
|
+
# x shape: (batch_size, num_features)
|
|
139
|
+
|
|
140
|
+
# Get one raw "importance" score per feature
|
|
141
|
+
attention_scores = self.attention_net(x)
|
|
142
|
+
|
|
143
|
+
# Apply the softmax module to get weights that sum to 1
|
|
144
|
+
attention_weights = self.softmax(attention_scores)
|
|
145
|
+
|
|
146
|
+
# Weighted features (attention mechanism's output)
|
|
147
|
+
weighted_features = x * attention_weights
|
|
148
|
+
|
|
149
|
+
# Residual connection
|
|
150
|
+
residual_connection = x + weighted_features
|
|
151
|
+
|
|
152
|
+
return residual_connection, attention_weights
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class _MultiHeadAttentionLayer(nn.Module):
|
|
156
|
+
"""
|
|
157
|
+
A wrapper for the standard `torch.nn.MultiheadAttention` layer.
|
|
158
|
+
|
|
159
|
+
This layer treats the entire input feature vector as a single item in a
|
|
160
|
+
sequence and applies self-attention to it. It is followed by a residual
|
|
161
|
+
connection and layer normalization, which is a standard block in
|
|
162
|
+
Transformer-style models.
|
|
163
|
+
"""
|
|
164
|
+
def __init__(self, num_features: int, num_heads: int, dropout: float):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.attention = nn.MultiheadAttention(
|
|
167
|
+
embed_dim=num_features,
|
|
168
|
+
num_heads=num_heads,
|
|
169
|
+
dropout=dropout,
|
|
170
|
+
batch_first=True # Crucial for (batch, seq, feature) input
|
|
171
|
+
)
|
|
172
|
+
self.layer_norm = nn.LayerNorm(num_features)
|
|
173
|
+
|
|
174
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
175
|
+
# x shape: (batch_size, num_features)
|
|
176
|
+
|
|
177
|
+
# nn.MultiheadAttention expects a sequence dimension.
|
|
178
|
+
# We add a sequence dimension of length 1.
|
|
179
|
+
# x_reshaped shape: (batch_size, 1, num_features)
|
|
180
|
+
x_reshaped = x.unsqueeze(1)
|
|
181
|
+
|
|
182
|
+
# Apply self-attention. query, key, and value are all the same.
|
|
183
|
+
# attn_output shape: (batch_size, 1, num_features)
|
|
184
|
+
# attn_weights shape: (batch_size, 1, 1)
|
|
185
|
+
attn_output, attn_weights = self.attention(
|
|
186
|
+
query=x_reshaped,
|
|
187
|
+
key=x_reshaped,
|
|
188
|
+
value=x_reshaped,
|
|
189
|
+
need_weights=True,
|
|
190
|
+
average_attn_weights=True # Average weights across heads
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Add residual connection and apply layer normalization (Post-LN)
|
|
194
|
+
out = self.layer_norm(x + attn_output.squeeze(1))
|
|
195
|
+
|
|
196
|
+
# Squeeze weights for a consistent output shape
|
|
197
|
+
return out, attn_weights.squeeze()
|
|
198
|
+
|
|
@@ -1,42 +1,40 @@
|
|
|
1
1
|
from torch import nn
|
|
2
|
-
from typing import
|
|
3
|
-
import json
|
|
2
|
+
from typing import Union, Any
|
|
4
3
|
from pathlib import Path
|
|
4
|
+
import json
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from ._keys import PytorchModelArchitectureKeys
|
|
9
|
-
from ._schema import FeatureSchema
|
|
10
|
-
from ._logger import get_logger
|
|
7
|
+
from ..schema import FeatureSchema
|
|
11
8
|
|
|
9
|
+
from .._core import get_logger
|
|
10
|
+
from ..path_manager import make_fullpath
|
|
11
|
+
from ..keys._keys import PytorchModelArchitectureKeys, SchemaKeys
|
|
12
12
|
|
|
13
|
-
|
|
13
|
+
from .._core._schema_load_ops import prepare_schema_from_json
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
_LOGGER = get_logger("DragonModel: Save/Load")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"_ArchitectureHandlerMixin",
|
|
21
|
+
"_ArchitectureBuilder",
|
|
22
|
+
]
|
|
23
|
+
|
|
16
24
|
##################################
|
|
17
|
-
#
|
|
18
|
-
|
|
19
|
-
class _ArchitectureBuilder(nn.Module, ABC):
|
|
25
|
+
# Mixin class for saving and loading basic model architectures
|
|
26
|
+
class _ArchitectureHandlerMixin:
|
|
20
27
|
"""
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
Implements:
|
|
24
|
-
- JSON serialization and JSON deserialization with automatic FeatureSchema reconstruction.
|
|
25
|
-
- Standardized string representation (__repr__) showing hyperparameters.
|
|
28
|
+
A mixin class to provide save and load functionality for model architectures.
|
|
26
29
|
"""
|
|
27
|
-
def __init__(self):
|
|
28
|
-
super().__init__()
|
|
29
|
-
# Placeholder for hyperparameters, to be populated by child classes
|
|
30
|
-
self.model_hparams: Dict[str, Any] = {}
|
|
31
|
-
|
|
32
30
|
# abstract method that must be implemented by children
|
|
33
31
|
@abstractmethod
|
|
34
|
-
def get_architecture_config(self) ->
|
|
32
|
+
def get_architecture_config(self) -> dict[str, Any]:
|
|
35
33
|
"To be implemented by children"
|
|
36
34
|
pass
|
|
37
|
-
|
|
38
|
-
def
|
|
39
|
-
"""Saves the model's architecture to
|
|
35
|
+
|
|
36
|
+
def save_architecture(self: nn.Module, directory: Union[str, Path], verbose: bool = True): # type: ignore
|
|
37
|
+
"""Saves the model's architecture to an "architecture.json" file."""
|
|
40
38
|
if not hasattr(self, 'get_architecture_config'):
|
|
41
39
|
_LOGGER.error(f"Model '{self.__class__.__name__}' must have a 'get_architecture_config()' method to use this functionality.")
|
|
42
40
|
raise AttributeError()
|
|
@@ -59,9 +57,10 @@ class _ArchitectureBuilder(nn.Module, ABC):
|
|
|
59
57
|
_LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved as '{full_path.name}'")
|
|
60
58
|
|
|
61
59
|
@classmethod
|
|
62
|
-
def
|
|
60
|
+
def load_architecture(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
|
|
63
61
|
"""
|
|
64
|
-
Loads a model architecture from a JSON file.
|
|
62
|
+
Loads a model architecture from a JSON file.
|
|
63
|
+
If a directory is provided, the function will attempt to load the JSON file "architecture.json" inside.
|
|
65
64
|
"""
|
|
66
65
|
user_path = make_fullpath(file_or_dir)
|
|
67
66
|
|
|
@@ -84,35 +83,59 @@ class _ArchitectureBuilder(nn.Module, ABC):
|
|
|
84
83
|
_LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
|
|
85
84
|
raise ValueError()
|
|
86
85
|
|
|
87
|
-
#
|
|
88
|
-
|
|
89
|
-
raise ValueError("Missing 'schema_dict' in config.")
|
|
90
|
-
|
|
91
|
-
schema_data = config.pop('schema_dict')
|
|
92
|
-
|
|
93
|
-
raw_index_map = schema_data['categorical_index_map']
|
|
94
|
-
if raw_index_map is not None:
|
|
95
|
-
# JSON keys are strings, convert back to int
|
|
96
|
-
rehydrated_index_map = {int(k): v for k, v in raw_index_map.items()}
|
|
97
|
-
else:
|
|
98
|
-
rehydrated_index_map = None
|
|
99
|
-
|
|
100
|
-
schema = FeatureSchema(
|
|
101
|
-
feature_names=tuple(schema_data['feature_names']),
|
|
102
|
-
continuous_feature_names=tuple(schema_data['continuous_feature_names']),
|
|
103
|
-
categorical_feature_names=tuple(schema_data['categorical_feature_names']),
|
|
104
|
-
categorical_index_map=rehydrated_index_map,
|
|
105
|
-
categorical_mappings=schema_data['categorical_mappings']
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
config['schema'] = schema
|
|
109
|
-
# --- End Reconstruction ---
|
|
86
|
+
# Hook to allow children classes to modify config before init (reconstruction)
|
|
87
|
+
config = cls._prepare_config_for_load(config)
|
|
110
88
|
|
|
111
89
|
model = cls(**config)
|
|
112
90
|
if verbose:
|
|
113
91
|
_LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
|
|
114
92
|
return model
|
|
115
93
|
|
|
94
|
+
@classmethod
|
|
95
|
+
def _prepare_config_for_load(cls, config: dict[str, Any]) -> dict[str, Any]:
|
|
96
|
+
"""
|
|
97
|
+
Hook method to process configuration data before model instantiation.
|
|
98
|
+
Base implementation simply returns the config as-is.
|
|
99
|
+
"""
|
|
100
|
+
return config
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
##################################
|
|
104
|
+
# Base class for loading and saving advanced models
|
|
105
|
+
##################################
|
|
106
|
+
class _ArchitectureBuilder(_ArchitectureHandlerMixin, nn.Module, ABC):
|
|
107
|
+
"""
|
|
108
|
+
Base class for Dragon models that unifies architecture handling.
|
|
109
|
+
|
|
110
|
+
Implements:
|
|
111
|
+
- JSON serialization and JSON deserialization with automatic FeatureSchema reconstruction.
|
|
112
|
+
- Standardized string representation (__repr__) showing hyperparameters.
|
|
113
|
+
"""
|
|
114
|
+
def __init__(self):
|
|
115
|
+
super().__init__()
|
|
116
|
+
# Placeholder for hyperparameters, to be populated by child classes
|
|
117
|
+
self.model_hparams: dict[str, Any] = {}
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def _prepare_config_for_load(cls, config: dict[str, Any]) -> dict[str, Any]:
|
|
121
|
+
"""
|
|
122
|
+
Overrides the mixin hook to reconstruct the FeatureSchema object
|
|
123
|
+
from the raw dictionary data found in the JSON.
|
|
124
|
+
"""
|
|
125
|
+
if SchemaKeys.SCHEMA_DICT not in config:
|
|
126
|
+
_LOGGER.error(f"The model architecture is missing the '{SchemaKeys.SCHEMA_DICT}' key.")
|
|
127
|
+
raise ValueError()
|
|
128
|
+
|
|
129
|
+
schema_data = config.pop(SchemaKeys.SCHEMA_DICT)
|
|
130
|
+
|
|
131
|
+
# Use shared helper to prepare arguments (handles tuple/int conversion)
|
|
132
|
+
schema_kwargs = prepare_schema_from_json(schema_data)
|
|
133
|
+
|
|
134
|
+
schema = FeatureSchema(**schema_kwargs)
|
|
135
|
+
|
|
136
|
+
config['schema'] = schema
|
|
137
|
+
return config
|
|
138
|
+
|
|
116
139
|
def __repr__(self):
|
|
117
140
|
# 1. Format hyperparameters
|
|
118
141
|
hparams_str = ",\n ".join([f"{k}={v}" for k, v in self.model_hparams.items()])
|
|
@@ -130,3 +153,4 @@ class _ArchitectureBuilder(nn.Module, ABC):
|
|
|
130
153
|
main_str += "\n".join(child_lines) + "\n"
|
|
131
154
|
main_str += ")"
|
|
132
155
|
return main_str
|
|
156
|
+
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from ..schema import FeatureSchema
|
|
6
|
+
|
|
7
|
+
from .._core import get_logger
|
|
8
|
+
from ..keys._keys import SchemaKeys
|
|
9
|
+
|
|
10
|
+
from ._base_save_load import _ArchitectureBuilder
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_LOGGER = get_logger("DragonTabularTransformer")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"DragonTabularTransformer"
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DragonTabularTransformer(_ArchitectureBuilder):
|
|
22
|
+
"""
|
|
23
|
+
A Transformer-based model for tabular data tasks.
|
|
24
|
+
|
|
25
|
+
This model uses a Feature Tokenizer to convert all input features into a
|
|
26
|
+
sequence of embeddings, prepends a [CLS] token, and processes the
|
|
27
|
+
sequence with a standard Transformer Encoder.
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self, *,
|
|
30
|
+
schema: FeatureSchema,
|
|
31
|
+
out_targets: int,
|
|
32
|
+
embedding_dim: int = 256,
|
|
33
|
+
num_heads: int = 8,
|
|
34
|
+
num_layers: int = 6,
|
|
35
|
+
dropout: float = 0.2):
|
|
36
|
+
"""
|
|
37
|
+
Args:
|
|
38
|
+
schema (FeatureSchema):
|
|
39
|
+
The definitive FeatureSchema object.
|
|
40
|
+
out_targets (int):
|
|
41
|
+
Number of output targets.
|
|
42
|
+
embedding_dim (int):
|
|
43
|
+
The dimension for all feature embeddings. Must be divisible by num_heads. Common values: (64, 128, 192, 256, etc.)
|
|
44
|
+
num_heads (int):
|
|
45
|
+
The number of heads in the multi-head attention mechanism. Common values: (4, 8, 16)
|
|
46
|
+
num_layers (int):
|
|
47
|
+
The number of sub-encoder-layers in the transformer encoder. Common values: (4, 8, 12)
|
|
48
|
+
dropout (float):
|
|
49
|
+
The dropout value.
|
|
50
|
+
|
|
51
|
+
## Note:
|
|
52
|
+
|
|
53
|
+
**Embedding Dimension:** "Width" of the model. It's the N-dimension vector that will be used to represent each one of the features.
|
|
54
|
+
- Each continuous feature gets its own learnable N-dimension vector.
|
|
55
|
+
- Each categorical feature gets an embedding table that maps every category (e.g., "color=red", "color=blue") to a unique N-dimension vector.
|
|
56
|
+
|
|
57
|
+
**Attention Heads:** Controls the "Multi-Head Attention" mechanism. Instead of looking at all the feature interactions at once, the model splits its attention into N parallel heads.
|
|
58
|
+
- Embedding Dimensions get divided by the number of Attention Heads, resulting in the dimensions assigned per head.
|
|
59
|
+
|
|
60
|
+
**Number of Layers:** "Depth" of the model. Number of identical `TransformerEncoderLayer` blocks that are stacked on top of each other.
|
|
61
|
+
- Layer 1: The attention heads find simple, direct interactions between the features.
|
|
62
|
+
- Layer 2: Takes the output of Layer 1 and finds interactions between those interactions and so on.
|
|
63
|
+
- Trade-off: More layers are more powerful but are slower to train and more prone to overfitting. If the training loss goes down but the validation loss goes up, you might have too many layers (or need more dropout).
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
# _ArchitectureBuilder init sets up self.model_hparams
|
|
67
|
+
super().__init__()
|
|
68
|
+
|
|
69
|
+
# --- Get info from schema ---
|
|
70
|
+
in_features = len(schema.feature_names)
|
|
71
|
+
categorical_index_map = schema.categorical_index_map
|
|
72
|
+
|
|
73
|
+
# --- Validation ---
|
|
74
|
+
if categorical_index_map and (max(categorical_index_map.keys()) >= in_features):
|
|
75
|
+
_LOGGER.error(f"A categorical index ({max(categorical_index_map.keys())}) is out of bounds for the provided input features ({in_features}).")
|
|
76
|
+
raise ValueError()
|
|
77
|
+
|
|
78
|
+
# --- Save configuration ---
|
|
79
|
+
self.schema = schema # <-- Save the whole schema
|
|
80
|
+
self.out_targets = out_targets
|
|
81
|
+
self.embedding_dim = embedding_dim
|
|
82
|
+
self.num_heads = num_heads
|
|
83
|
+
self.num_layers = num_layers
|
|
84
|
+
self.dropout = dropout
|
|
85
|
+
|
|
86
|
+
# --- 1. Feature Tokenizer (now takes the schema) ---
|
|
87
|
+
self.tokenizer = _FeatureTokenizer(
|
|
88
|
+
schema=schema,
|
|
89
|
+
embedding_dim=embedding_dim
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# --- 2. CLS Token ---
|
|
93
|
+
self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
|
|
94
|
+
|
|
95
|
+
# --- 3. Transformer Encoder ---
|
|
96
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
97
|
+
d_model=embedding_dim,
|
|
98
|
+
nhead=num_heads,
|
|
99
|
+
dropout=dropout,
|
|
100
|
+
batch_first=True # Crucial for (batch, seq, feature) input
|
|
101
|
+
)
|
|
102
|
+
self.transformer_encoder = nn.TransformerEncoder(
|
|
103
|
+
encoder_layer=encoder_layer,
|
|
104
|
+
num_layers=num_layers
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# --- 4. Prediction Head ---
|
|
108
|
+
self.output_layer = nn.Linear(embedding_dim, out_targets)
|
|
109
|
+
|
|
110
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
111
|
+
"""Defines the forward pass of the model."""
|
|
112
|
+
# Get the batch size for later use
|
|
113
|
+
batch_size = x.shape[0]
|
|
114
|
+
|
|
115
|
+
# 1. Get feature tokens from the tokenizer
|
|
116
|
+
# -> tokens shape: (batch_size, num_features, embedding_dim)
|
|
117
|
+
tokens = self.tokenizer(x)
|
|
118
|
+
|
|
119
|
+
# 2. Prepend the [CLS] token to the sequence
|
|
120
|
+
# -> cls_tokens shape: (batch_size, 1, embedding_dim)
|
|
121
|
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
122
|
+
# -> full_sequence shape: (batch_size, num_features + 1, embedding_dim)
|
|
123
|
+
full_sequence = torch.cat([cls_tokens, tokens], dim=1)
|
|
124
|
+
|
|
125
|
+
# 3. Pass the full sequence through the Transformer Encoder
|
|
126
|
+
# -> transformer_out shape: (batch_size, num_features + 1, embedding_dim)
|
|
127
|
+
transformer_out = self.transformer_encoder(full_sequence)
|
|
128
|
+
|
|
129
|
+
# 4. Isolate the output of the [CLS] token (it's the first one)
|
|
130
|
+
# -> cls_output shape: (batch_size, embedding_dim)
|
|
131
|
+
cls_output = transformer_out[:, 0]
|
|
132
|
+
|
|
133
|
+
# 5. Pass the [CLS] token's output through the prediction head
|
|
134
|
+
# -> logits shape: (batch_size, out_targets)
|
|
135
|
+
logits = self.output_layer(cls_output)
|
|
136
|
+
|
|
137
|
+
return logits
|
|
138
|
+
|
|
139
|
+
def get_architecture_config(self) -> dict[str, Any]:
|
|
140
|
+
"""Returns the full configuration of the model."""
|
|
141
|
+
# Deconstruct schema into a JSON-friendly dict
|
|
142
|
+
# Tuples are saved as lists
|
|
143
|
+
schema_dict = {
|
|
144
|
+
'feature_names': self.schema.feature_names,
|
|
145
|
+
'continuous_feature_names': self.schema.continuous_feature_names,
|
|
146
|
+
'categorical_feature_names': self.schema.categorical_feature_names,
|
|
147
|
+
'categorical_index_map': self.schema.categorical_index_map,
|
|
148
|
+
'categorical_mappings': self.schema.categorical_mappings
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
return {
|
|
152
|
+
SchemaKeys.SCHEMA_DICT: schema_dict,
|
|
153
|
+
'out_targets': self.out_targets,
|
|
154
|
+
'embedding_dim': self.embedding_dim,
|
|
155
|
+
'num_heads': self.num_heads,
|
|
156
|
+
'num_layers': self.num_layers,
|
|
157
|
+
'dropout': self.dropout
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
def __repr__(self) -> str:
|
|
161
|
+
"""Returns the developer-friendly string representation of the model."""
|
|
162
|
+
# Build the architecture string part-by-part
|
|
163
|
+
parts = [
|
|
164
|
+
f"Tokenizer(features={len(self.schema.feature_names)}, dim={self.embedding_dim})",
|
|
165
|
+
"[CLS]",
|
|
166
|
+
f"TransformerEncoder(layers={self.num_layers}, heads={self.num_heads})",
|
|
167
|
+
f"PredictionHead(outputs={self.out_targets})"
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
arch_str = " -> ".join(parts)
|
|
171
|
+
|
|
172
|
+
return f"DragonTabularTransformer(arch: {arch_str})"
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class _FeatureTokenizer(nn.Module):
|
|
176
|
+
"""
|
|
177
|
+
Transforms raw numerical and categorical features from any column order
|
|
178
|
+
into a sequence of embeddings.
|
|
179
|
+
"""
|
|
180
|
+
def __init__(self,
|
|
181
|
+
schema: FeatureSchema,
|
|
182
|
+
embedding_dim: int):
|
|
183
|
+
"""
|
|
184
|
+
Args:
|
|
185
|
+
schema (FeatureSchema):
|
|
186
|
+
The definitive schema object from data_exploration.
|
|
187
|
+
embedding_dim (int):
|
|
188
|
+
The dimension for all feature embeddings.
|
|
189
|
+
"""
|
|
190
|
+
super().__init__()
|
|
191
|
+
|
|
192
|
+
# --- Get info from schema ---
|
|
193
|
+
categorical_map = schema.categorical_index_map
|
|
194
|
+
|
|
195
|
+
if categorical_map:
|
|
196
|
+
# Unpack the dictionary into separate lists
|
|
197
|
+
self.categorical_indices = list(categorical_map.keys())
|
|
198
|
+
cardinalities = list(categorical_map.values())
|
|
199
|
+
else:
|
|
200
|
+
self.categorical_indices = []
|
|
201
|
+
cardinalities = []
|
|
202
|
+
|
|
203
|
+
# Derive numerical indices by finding what's not categorical
|
|
204
|
+
all_indices = set(range(len(schema.feature_names)))
|
|
205
|
+
categorical_indices_set = set(self.categorical_indices)
|
|
206
|
+
self.numerical_indices = sorted(list(all_indices - categorical_indices_set))
|
|
207
|
+
|
|
208
|
+
self.embedding_dim = embedding_dim
|
|
209
|
+
|
|
210
|
+
# A learnable embedding for each numerical feature
|
|
211
|
+
self.numerical_embeddings = nn.Parameter(torch.randn(len(self.numerical_indices), embedding_dim))
|
|
212
|
+
|
|
213
|
+
# A standard embedding layer for each categorical feature
|
|
214
|
+
self.categorical_embeddings = nn.ModuleList(
|
|
215
|
+
[nn.Embedding(num_embeddings=c, embedding_dim=embedding_dim) for c in cardinalities]
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
219
|
+
"""
|
|
220
|
+
Processes features from a single input tensor and concatenates them
|
|
221
|
+
into a sequence of tokens.
|
|
222
|
+
"""
|
|
223
|
+
# Select the correct columns for each type using the stored indices
|
|
224
|
+
x_numerical = x[:, self.numerical_indices].float()
|
|
225
|
+
x_categorical = x[:, self.categorical_indices].long()
|
|
226
|
+
|
|
227
|
+
# Process numerical features
|
|
228
|
+
numerical_tokens = x_numerical.unsqueeze(-1) * self.numerical_embeddings
|
|
229
|
+
|
|
230
|
+
# Process categorical features
|
|
231
|
+
categorical_tokens = []
|
|
232
|
+
for i, embed_layer in enumerate(self.categorical_embeddings):
|
|
233
|
+
# x_categorical[:, i] selects the i-th categorical column
|
|
234
|
+
# (e.g., all values for the 'color' feature)
|
|
235
|
+
token = embed_layer(x_categorical[:, i]).unsqueeze(1)
|
|
236
|
+
categorical_tokens.append(token)
|
|
237
|
+
|
|
238
|
+
# Concatenate all tokens into a single sequence
|
|
239
|
+
if not self.categorical_indices:
|
|
240
|
+
all_tokens = numerical_tokens
|
|
241
|
+
elif not self.numerical_indices:
|
|
242
|
+
all_tokens = torch.cat(categorical_tokens, dim=1)
|
|
243
|
+
else:
|
|
244
|
+
all_categorical_tokens = torch.cat(categorical_tokens, dim=1)
|
|
245
|
+
all_tokens = torch.cat([numerical_tokens, all_categorical_tokens], dim=1)
|
|
246
|
+
|
|
247
|
+
return all_tokens
|
|
248
|
+
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .._core import _imprimir_disponibles
|
|
2
|
+
|
|
3
|
+
_GRUPOS = [
|
|
4
|
+
# MLP and Attention Models
|
|
5
|
+
"DragonMLP",
|
|
6
|
+
"DragonAttentionMLP",
|
|
7
|
+
"DragonMultiHeadAttentionNet",
|
|
8
|
+
# Tabular Transformer Model
|
|
9
|
+
"DragonTabularTransformer",
|
|
10
|
+
# Advanced Models
|
|
11
|
+
"DragonGateModel",
|
|
12
|
+
"DragonNodeModel",
|
|
13
|
+
"DragonAutoInt",
|
|
14
|
+
"DragonTabNet",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
def info():
|
|
18
|
+
_imprimir_disponibles(_GRUPOS)
|