ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,342 +1,349 @@
1
- from __future__ import annotations
2
-
3
- import math
4
- from typing import List, Optional, Tuple
5
-
6
- import torch
7
- import torch.nn as nn
8
- from torch.utils.data import Dataset
9
-
10
-
11
- # =============================================================================
12
- # FT-Transformer model and sklearn-style wrapper.
13
- # =============================================================================
14
- # Define FT-Transformer model structure.
15
-
16
-
17
- class FeatureTokenizer(nn.Module):
18
- """Map numeric/categorical/geo tokens into transformer input tokens."""
19
-
20
- def __init__(
21
- self,
22
- num_numeric: int,
23
- cat_cardinalities,
24
- d_model: int,
25
- num_geo: int = 0,
26
- num_numeric_tokens: int = 1,
27
- ):
28
- super().__init__()
29
-
30
- self.num_numeric = num_numeric
31
- self.num_geo = num_geo
32
- self.has_geo = num_geo > 0
33
-
34
- if num_numeric > 0:
35
- if int(num_numeric_tokens) <= 0:
36
- raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
37
- self.num_numeric_tokens = int(num_numeric_tokens)
38
- self.has_numeric = True
39
- self.num_linear = nn.Linear(num_numeric, d_model * self.num_numeric_tokens)
40
- else:
41
- self.num_numeric_tokens = 0
42
- self.has_numeric = False
43
-
44
- self.embeddings = nn.ModuleList([
45
- nn.Embedding(card, d_model) for card in cat_cardinalities
46
- ])
47
-
48
- if self.has_geo:
49
- # Map geo tokens with a linear layer to avoid one-hot on raw strings; upstream is encoded/normalized.
50
- self.geo_linear = nn.Linear(num_geo, d_model)
51
-
52
- def forward(self, X_num, X_cat, X_geo=None):
53
- tokens = []
54
-
55
- if self.has_numeric:
56
- batch_size = X_num.shape[0]
57
- num_token = self.num_linear(X_num)
58
- num_token = num_token.view(batch_size, self.num_numeric_tokens, -1)
59
- tokens.append(num_token)
60
-
61
- for i, emb in enumerate(self.embeddings):
62
- tok = emb(X_cat[:, i])
63
- tokens.append(tok.unsqueeze(1))
64
-
65
- if self.has_geo:
66
- if X_geo is None:
67
- raise RuntimeError("Geo tokens are enabled but X_geo was not provided.")
68
- geo_token = self.geo_linear(X_geo)
69
- tokens.append(geo_token.unsqueeze(1))
70
-
71
- x = torch.cat(tokens, dim=1)
72
- return x
73
-
74
- # Encoder layer with residual scaling.
75
-
76
-
77
- class ScaledTransformerEncoderLayer(nn.Module):
78
- def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
79
- dropout: float = 0.1, residual_scale_attn: float = 1.0,
80
- residual_scale_ffn: float = 1.0, norm_first: bool = True,
81
- ):
82
- super().__init__()
83
- self.self_attn = nn.MultiheadAttention(
84
- embed_dim=d_model,
85
- num_heads=nhead,
86
- dropout=dropout,
87
- batch_first=True
88
- )
89
-
90
- # Feed-forward network.
91
- self.linear1 = nn.Linear(d_model, dim_feedforward)
92
- self.dropout = nn.Dropout(dropout)
93
- self.linear2 = nn.Linear(dim_feedforward, d_model)
94
-
95
- # Normalization and dropout.
96
- self.norm1 = nn.LayerNorm(d_model)
97
- self.norm2 = nn.LayerNorm(d_model)
98
- self.dropout1 = nn.Dropout(dropout)
99
- self.dropout2 = nn.Dropout(dropout)
100
-
101
- self.activation = nn.GELU()
102
- # If you prefer ReLU, set: self.activation = nn.ReLU()
103
- self.norm_first = norm_first
104
-
105
- # Residual scaling coefficients.
106
- self.res_scale_attn = residual_scale_attn
107
- self.res_scale_ffn = residual_scale_ffn
108
-
109
- def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal: Optional[bool] = None, **_kwargs):
110
- # Input tensor shape: (batch, seq_len, d_model).
111
- x = src
112
-
113
- if self.norm_first:
114
- # Pre-norm before attention.
115
- x = x + self._sa_block(
116
- self.norm1(x),
117
- src_mask,
118
- src_key_padding_mask,
119
- is_causal=is_causal,
120
- )
121
- x = x + self._ff_block(self.norm2(x))
122
- else:
123
- # Post-norm (usually disabled).
124
- x = self.norm1(
125
- x + self._sa_block(
126
- x,
127
- src_mask,
128
- src_key_padding_mask,
129
- is_causal=is_causal,
130
- )
131
- )
132
- x = self.norm2(x + self._ff_block(x))
133
-
134
- return x
135
-
136
- def _sa_block(self, x, attn_mask, key_padding_mask, *, is_causal: Optional[bool] = None):
137
- # Self-attention with residual scaling.
138
- if is_causal is None:
139
- attn_out, _ = self.self_attn(
140
- x, x, x,
141
- attn_mask=attn_mask,
142
- key_padding_mask=key_padding_mask,
143
- need_weights=False,
144
- )
145
- else:
146
- try:
147
- attn_out, _ = self.self_attn(
148
- x, x, x,
149
- attn_mask=attn_mask,
150
- key_padding_mask=key_padding_mask,
151
- need_weights=False,
152
- is_causal=is_causal,
153
- )
154
- except TypeError:
155
- attn_out, _ = self.self_attn(
156
- x, x, x,
157
- attn_mask=attn_mask,
158
- key_padding_mask=key_padding_mask,
159
- need_weights=False,
160
- )
161
- return self.res_scale_attn * self.dropout1(attn_out)
162
-
163
- def _ff_block(self, x):
164
- # Feed-forward block with residual scaling.
165
- x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
166
- return self.res_scale_ffn * self.dropout2(x2)
167
-
168
- # FT-Transformer core model.
169
-
170
-
171
- class FTTransformerCore(nn.Module):
172
- # Minimal FT-Transformer built from:
173
- # 1) FeatureTokenizer: convert numeric/categorical features to tokens;
174
- # 2) TransformerEncoder: model feature interactions;
175
- # 3) Pooling + MLP + Softplus: positive outputs for Tweedie/Gamma tasks.
176
-
177
- def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
178
- n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
179
- task_type: str = 'regression', num_geo: int = 0,
180
- num_numeric_tokens: int = 1
181
- ):
182
- super().__init__()
183
-
184
- self.num_numeric = int(num_numeric)
185
- self.cat_cardinalities = list(cat_cardinalities or [])
186
-
187
- self.tokenizer = FeatureTokenizer(
188
- num_numeric=num_numeric,
189
- cat_cardinalities=cat_cardinalities,
190
- d_model=d_model,
191
- num_geo=num_geo,
192
- num_numeric_tokens=num_numeric_tokens
193
- )
194
- scale = 1.0 / math.sqrt(n_layers) # Recommended default.
195
- encoder_layer = ScaledTransformerEncoderLayer(
196
- d_model=d_model,
197
- nhead=n_heads,
198
- dim_feedforward=d_model * 4,
199
- dropout=dropout,
200
- residual_scale_attn=scale,
201
- residual_scale_ffn=scale,
202
- norm_first=True,
203
- )
204
- self.encoder = nn.TransformerEncoder(
205
- encoder_layer,
206
- num_layers=n_layers
207
- )
208
- self.n_layers = n_layers
209
-
210
- layers = [
211
- # If you need a deeper head, enable the sample layers below:
212
- # nn.LayerNorm(d_model), # Extra normalization
213
- # nn.Linear(d_model, d_model), # Extra fully connected layer
214
- # nn.GELU(), # Activation
215
- nn.Linear(d_model, 1),
216
- ]
217
-
218
- if task_type == 'classification':
219
- # Classification outputs logits for BCEWithLogitsLoss.
220
- layers.append(nn.Identity())
221
- else:
222
- # Regression keeps positive outputs for Tweedie/Gamma.
223
- layers.append(nn.Softplus())
224
-
225
- self.head = nn.Sequential(*layers)
226
-
227
- # ---- Self-supervised reconstruction head (masked modeling) ----
228
- self.num_recon_head = nn.Linear(
229
- d_model, self.num_numeric) if self.num_numeric > 0 else None
230
- self.cat_recon_heads = nn.ModuleList([
231
- nn.Linear(d_model, int(card)) for card in self.cat_cardinalities
232
- ])
233
-
234
- def forward(
235
- self,
236
- X_num,
237
- X_cat,
238
- X_geo=None,
239
- return_embedding: bool = False,
240
- return_reconstruction: bool = False):
241
-
242
- # Inputs:
243
- # X_num -> float32 tensor with shape (batch, num_numeric_features)
244
- # X_cat -> long tensor with shape (batch, num_categorical_features)
245
- # X_geo -> float32 tensor with shape (batch, geo_token_dim)
246
-
247
- if self.training and not hasattr(self, '_printed_device'):
248
- print(f">>> FTTransformerCore executing on device: {X_num.device}")
249
- self._printed_device = True
250
-
251
- # => tensor shape (batch, token_num, d_model)
252
- tokens = self.tokenizer(X_num, X_cat, X_geo)
253
- # => tensor shape (batch, token_num, d_model)
254
- x = self.encoder(tokens)
255
-
256
- # Mean-pool tokens, then send to the head.
257
- x = x.mean(dim=1) # => tensor shape (batch, d_model)
258
-
259
- if return_reconstruction:
260
- num_pred, cat_logits = self.reconstruct(x)
261
- cat_logits_out = tuple(
262
- cat_logits) if cat_logits is not None else tuple()
263
- if return_embedding:
264
- return x, num_pred, cat_logits_out
265
- return num_pred, cat_logits_out
266
-
267
- if return_embedding:
268
- return x
269
-
270
- # => tensor shape (batch, 1); Softplus keeps it positive.
271
- out = self.head(x)
272
- return out
273
-
274
- def reconstruct(self, embedding: torch.Tensor) -> Tuple[Optional[torch.Tensor], List[torch.Tensor]]:
275
- """Reconstruct numeric/categorical inputs from pooled embedding (batch, d_model)."""
276
- num_pred = self.num_recon_head(
277
- embedding) if self.num_recon_head is not None else None
278
- cat_logits = [head(embedding) for head in self.cat_recon_heads]
279
- return num_pred, cat_logits
280
-
281
- # TabularDataset.
282
-
283
-
284
- class TabularDataset(Dataset):
285
- def __init__(self, X_num, X_cat, X_geo, y, w):
286
-
287
- # Input tensors:
288
- # X_num: torch.float32, shape=(N, num_numeric_features)
289
- # X_cat: torch.long, shape=(N, num_categorical_features)
290
- # X_geo: torch.float32, shape=(N, geo_token_dim), can be empty
291
- # y: torch.float32, shape=(N, 1)
292
- # w: torch.float32, shape=(N, 1)
293
-
294
- self.X_num = X_num
295
- self.X_cat = X_cat
296
- self.X_geo = X_geo
297
- self.y = y
298
- self.w = w
299
-
300
- def __len__(self):
301
- return self.y.shape[0]
302
-
303
- def __getitem__(self, idx):
304
- return (
305
- self.X_num[idx],
306
- self.X_cat[idx],
307
- self.X_geo[idx],
308
- self.y[idx],
309
- self.w[idx],
310
- )
311
-
312
-
313
- class MaskedTabularDataset(Dataset):
314
- def __init__(self,
315
- X_num_masked: torch.Tensor,
316
- X_cat_masked: torch.Tensor,
317
- X_geo: torch.Tensor,
318
- X_num_true: Optional[torch.Tensor],
319
- num_mask: Optional[torch.Tensor],
320
- X_cat_true: Optional[torch.Tensor],
321
- cat_mask: Optional[torch.Tensor]):
322
- self.X_num_masked = X_num_masked
323
- self.X_cat_masked = X_cat_masked
324
- self.X_geo = X_geo
325
- self.X_num_true = X_num_true
326
- self.num_mask = num_mask
327
- self.X_cat_true = X_cat_true
328
- self.cat_mask = cat_mask
329
-
330
- def __len__(self):
331
- return self.X_num_masked.shape[0]
332
-
333
- def __getitem__(self, idx):
334
- return (
335
- self.X_num_masked[idx],
336
- self.X_cat_masked[idx],
337
- self.X_geo[idx],
338
- None if self.X_num_true is None else self.X_num_true[idx],
339
- None if self.num_mask is None else self.num_mask[idx],
340
- None if self.X_cat_true is None else self.X_cat_true[idx],
341
- None if self.cat_mask is None else self.cat_mask[idx],
342
- )
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset
9
+ from ins_pricing.utils import get_logger, log_print
10
+
11
+ _logger = get_logger("ins_pricing.modelling.bayesopt.models.model_ft_components")
12
+
13
+
14
+ def _log(*args, **kwargs) -> None:
15
+ log_print(_logger, *args, **kwargs)
16
+
17
+
18
+ # =============================================================================
19
+ # FT-Transformer model and sklearn-style wrapper.
20
+ # =============================================================================
21
+ # Define FT-Transformer model structure.
22
+
23
+
24
+ class FeatureTokenizer(nn.Module):
25
+ """Map numeric/categorical/geo tokens into transformer input tokens."""
26
+
27
+ def __init__(
28
+ self,
29
+ num_numeric: int,
30
+ cat_cardinalities,
31
+ d_model: int,
32
+ num_geo: int = 0,
33
+ num_numeric_tokens: int = 1,
34
+ ):
35
+ super().__init__()
36
+
37
+ self.num_numeric = num_numeric
38
+ self.num_geo = num_geo
39
+ self.has_geo = num_geo > 0
40
+
41
+ if num_numeric > 0:
42
+ if int(num_numeric_tokens) <= 0:
43
+ raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
44
+ self.num_numeric_tokens = int(num_numeric_tokens)
45
+ self.has_numeric = True
46
+ self.num_linear = nn.Linear(num_numeric, d_model * self.num_numeric_tokens)
47
+ else:
48
+ self.num_numeric_tokens = 0
49
+ self.has_numeric = False
50
+
51
+ self.embeddings = nn.ModuleList([
52
+ nn.Embedding(card, d_model) for card in cat_cardinalities
53
+ ])
54
+
55
+ if self.has_geo:
56
+ # Map geo tokens with a linear layer to avoid one-hot on raw strings; upstream is encoded/normalized.
57
+ self.geo_linear = nn.Linear(num_geo, d_model)
58
+
59
+ def forward(self, X_num, X_cat, X_geo=None):
60
+ tokens = []
61
+
62
+ if self.has_numeric:
63
+ batch_size = X_num.shape[0]
64
+ num_token = self.num_linear(X_num)
65
+ num_token = num_token.view(batch_size, self.num_numeric_tokens, -1)
66
+ tokens.append(num_token)
67
+
68
+ for i, emb in enumerate(self.embeddings):
69
+ tok = emb(X_cat[:, i])
70
+ tokens.append(tok.unsqueeze(1))
71
+
72
+ if self.has_geo:
73
+ if X_geo is None:
74
+ raise RuntimeError("Geo tokens are enabled but X_geo was not provided.")
75
+ geo_token = self.geo_linear(X_geo)
76
+ tokens.append(geo_token.unsqueeze(1))
77
+
78
+ x = torch.cat(tokens, dim=1)
79
+ return x
80
+
81
+ # Encoder layer with residual scaling.
82
+
83
+
84
+ class ScaledTransformerEncoderLayer(nn.Module):
85
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
86
+ dropout: float = 0.1, residual_scale_attn: float = 1.0,
87
+ residual_scale_ffn: float = 1.0, norm_first: bool = True,
88
+ ):
89
+ super().__init__()
90
+ self.self_attn = nn.MultiheadAttention(
91
+ embed_dim=d_model,
92
+ num_heads=nhead,
93
+ dropout=dropout,
94
+ batch_first=True
95
+ )
96
+
97
+ # Feed-forward network.
98
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
99
+ self.dropout = nn.Dropout(dropout)
100
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
101
+
102
+ # Normalization and dropout.
103
+ self.norm1 = nn.LayerNorm(d_model)
104
+ self.norm2 = nn.LayerNorm(d_model)
105
+ self.dropout1 = nn.Dropout(dropout)
106
+ self.dropout2 = nn.Dropout(dropout)
107
+
108
+ self.activation = nn.GELU()
109
+ # If you prefer ReLU, set: self.activation = nn.ReLU()
110
+ self.norm_first = norm_first
111
+
112
+ # Residual scaling coefficients.
113
+ self.res_scale_attn = residual_scale_attn
114
+ self.res_scale_ffn = residual_scale_ffn
115
+
116
+ def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal: Optional[bool] = None, **_kwargs):
117
+ # Input tensor shape: (batch, seq_len, d_model).
118
+ x = src
119
+
120
+ if self.norm_first:
121
+ # Pre-norm before attention.
122
+ x = x + self._sa_block(
123
+ self.norm1(x),
124
+ src_mask,
125
+ src_key_padding_mask,
126
+ is_causal=is_causal,
127
+ )
128
+ x = x + self._ff_block(self.norm2(x))
129
+ else:
130
+ # Post-norm (usually disabled).
131
+ x = self.norm1(
132
+ x + self._sa_block(
133
+ x,
134
+ src_mask,
135
+ src_key_padding_mask,
136
+ is_causal=is_causal,
137
+ )
138
+ )
139
+ x = self.norm2(x + self._ff_block(x))
140
+
141
+ return x
142
+
143
+ def _sa_block(self, x, attn_mask, key_padding_mask, *, is_causal: Optional[bool] = None):
144
+ # Self-attention with residual scaling.
145
+ if is_causal is None:
146
+ attn_out, _ = self.self_attn(
147
+ x, x, x,
148
+ attn_mask=attn_mask,
149
+ key_padding_mask=key_padding_mask,
150
+ need_weights=False,
151
+ )
152
+ else:
153
+ try:
154
+ attn_out, _ = self.self_attn(
155
+ x, x, x,
156
+ attn_mask=attn_mask,
157
+ key_padding_mask=key_padding_mask,
158
+ need_weights=False,
159
+ is_causal=is_causal,
160
+ )
161
+ except TypeError:
162
+ attn_out, _ = self.self_attn(
163
+ x, x, x,
164
+ attn_mask=attn_mask,
165
+ key_padding_mask=key_padding_mask,
166
+ need_weights=False,
167
+ )
168
+ return self.res_scale_attn * self.dropout1(attn_out)
169
+
170
+ def _ff_block(self, x):
171
+ # Feed-forward block with residual scaling.
172
+ x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
173
+ return self.res_scale_ffn * self.dropout2(x2)
174
+
175
+ # FT-Transformer core model.
176
+
177
+
178
+ class FTTransformerCore(nn.Module):
179
+ # Minimal FT-Transformer built from:
180
+ # 1) FeatureTokenizer: convert numeric/categorical features to tokens;
181
+ # 2) TransformerEncoder: model feature interactions;
182
+ # 3) Pooling + MLP + Softplus: positive outputs for Tweedie/Gamma tasks.
183
+
184
+ def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
185
+ n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
186
+ task_type: str = 'regression', num_geo: int = 0,
187
+ num_numeric_tokens: int = 1
188
+ ):
189
+ super().__init__()
190
+
191
+ self.num_numeric = int(num_numeric)
192
+ self.cat_cardinalities = list(cat_cardinalities or [])
193
+
194
+ self.tokenizer = FeatureTokenizer(
195
+ num_numeric=num_numeric,
196
+ cat_cardinalities=cat_cardinalities,
197
+ d_model=d_model,
198
+ num_geo=num_geo,
199
+ num_numeric_tokens=num_numeric_tokens
200
+ )
201
+ scale = 1.0 / math.sqrt(n_layers) # Recommended default.
202
+ encoder_layer = ScaledTransformerEncoderLayer(
203
+ d_model=d_model,
204
+ nhead=n_heads,
205
+ dim_feedforward=d_model * 4,
206
+ dropout=dropout,
207
+ residual_scale_attn=scale,
208
+ residual_scale_ffn=scale,
209
+ norm_first=True,
210
+ )
211
+ self.encoder = nn.TransformerEncoder(
212
+ encoder_layer,
213
+ num_layers=n_layers
214
+ )
215
+ self.n_layers = n_layers
216
+
217
+ layers = [
218
+ # If you need a deeper head, enable the sample layers below:
219
+ # nn.LayerNorm(d_model), # Extra normalization
220
+ # nn.Linear(d_model, d_model), # Extra fully connected layer
221
+ # nn.GELU(), # Activation
222
+ nn.Linear(d_model, 1),
223
+ ]
224
+
225
+ if task_type == 'classification':
226
+ # Classification outputs logits for BCEWithLogitsLoss.
227
+ layers.append(nn.Identity())
228
+ else:
229
+ # Regression keeps positive outputs for Tweedie/Gamma.
230
+ layers.append(nn.Softplus())
231
+
232
+ self.head = nn.Sequential(*layers)
233
+
234
+ # ---- Self-supervised reconstruction head (masked modeling) ----
235
+ self.num_recon_head = nn.Linear(
236
+ d_model, self.num_numeric) if self.num_numeric > 0 else None
237
+ self.cat_recon_heads = nn.ModuleList([
238
+ nn.Linear(d_model, int(card)) for card in self.cat_cardinalities
239
+ ])
240
+
241
+ def forward(
242
+ self,
243
+ X_num,
244
+ X_cat,
245
+ X_geo=None,
246
+ return_embedding: bool = False,
247
+ return_reconstruction: bool = False):
248
+
249
+ # Inputs:
250
+ # X_num -> float32 tensor with shape (batch, num_numeric_features)
251
+ # X_cat -> long tensor with shape (batch, num_categorical_features)
252
+ # X_geo -> float32 tensor with shape (batch, geo_token_dim)
253
+
254
+ if self.training and not hasattr(self, '_printed_device'):
255
+ _log(f">>> FTTransformerCore executing on device: {X_num.device}")
256
+ self._printed_device = True
257
+
258
+ # => tensor shape (batch, token_num, d_model)
259
+ tokens = self.tokenizer(X_num, X_cat, X_geo)
260
+ # => tensor shape (batch, token_num, d_model)
261
+ x = self.encoder(tokens)
262
+
263
+ # Mean-pool tokens, then send to the head.
264
+ x = x.mean(dim=1) # => tensor shape (batch, d_model)
265
+
266
+ if return_reconstruction:
267
+ num_pred, cat_logits = self.reconstruct(x)
268
+ cat_logits_out = tuple(
269
+ cat_logits) if cat_logits is not None else tuple()
270
+ if return_embedding:
271
+ return x, num_pred, cat_logits_out
272
+ return num_pred, cat_logits_out
273
+
274
+ if return_embedding:
275
+ return x
276
+
277
+ # => tensor shape (batch, 1); Softplus keeps it positive.
278
+ out = self.head(x)
279
+ return out
280
+
281
+ def reconstruct(self, embedding: torch.Tensor) -> Tuple[Optional[torch.Tensor], List[torch.Tensor]]:
282
+ """Reconstruct numeric/categorical inputs from pooled embedding (batch, d_model)."""
283
+ num_pred = self.num_recon_head(
284
+ embedding) if self.num_recon_head is not None else None
285
+ cat_logits = [head(embedding) for head in self.cat_recon_heads]
286
+ return num_pred, cat_logits
287
+
288
+ # TabularDataset.
289
+
290
+
291
+ class TabularDataset(Dataset):
292
+ def __init__(self, X_num, X_cat, X_geo, y, w):
293
+
294
+ # Input tensors:
295
+ # X_num: torch.float32, shape=(N, num_numeric_features)
296
+ # X_cat: torch.long, shape=(N, num_categorical_features)
297
+ # X_geo: torch.float32, shape=(N, geo_token_dim), can be empty
298
+ # y: torch.float32, shape=(N, 1)
299
+ # w: torch.float32, shape=(N, 1)
300
+
301
+ self.X_num = X_num
302
+ self.X_cat = X_cat
303
+ self.X_geo = X_geo
304
+ self.y = y
305
+ self.w = w
306
+
307
+ def __len__(self):
308
+ return self.y.shape[0]
309
+
310
+ def __getitem__(self, idx):
311
+ return (
312
+ self.X_num[idx],
313
+ self.X_cat[idx],
314
+ self.X_geo[idx],
315
+ self.y[idx],
316
+ self.w[idx],
317
+ )
318
+
319
+
320
+ class MaskedTabularDataset(Dataset):
321
+ def __init__(self,
322
+ X_num_masked: torch.Tensor,
323
+ X_cat_masked: torch.Tensor,
324
+ X_geo: torch.Tensor,
325
+ X_num_true: Optional[torch.Tensor],
326
+ num_mask: Optional[torch.Tensor],
327
+ X_cat_true: Optional[torch.Tensor],
328
+ cat_mask: Optional[torch.Tensor]):
329
+ self.X_num_masked = X_num_masked
330
+ self.X_cat_masked = X_cat_masked
331
+ self.X_geo = X_geo
332
+ self.X_num_true = X_num_true
333
+ self.num_mask = num_mask
334
+ self.X_cat_true = X_cat_true
335
+ self.cat_mask = cat_mask
336
+
337
+ def __len__(self):
338
+ return self.X_num_masked.shape[0]
339
+
340
+ def __getitem__(self, idx):
341
+ return (
342
+ self.X_num_masked[idx],
343
+ self.X_cat_masked[idx],
344
+ self.X_geo[idx],
345
+ None if self.X_num_true is None else self.X_num_true[idx],
346
+ None if self.num_mask is None else self.num_mask[idx],
347
+ None if self.X_cat_true is None else self.X_cat_true[idx],
348
+ None if self.cat_mask is None else self.cat_mask[idx],
349
+ )