ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +58 -46
- ins_pricing/cli/BayesOpt_incremental.py +77 -110
- ins_pricing/cli/Explain_Run.py +42 -23
- ins_pricing/cli/Explain_entry.py +551 -577
- ins_pricing/cli/Pricing_Run.py +42 -23
- ins_pricing/cli/bayesopt_entry_runner.py +51 -16
- ins_pricing/cli/utils/bootstrap.py +23 -0
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +379 -360
- ins_pricing/cli/utils/import_resolver.py +375 -358
- ins_pricing/cli/utils/notebook_utils.py +256 -242
- ins_pricing/cli/watchdog_run.py +216 -198
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/app.py +132 -61
- ins_pricing/frontend/config_builder.py +33 -0
- ins_pricing/frontend/example_config.json +11 -0
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/frontend/runner.py +340 -388
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/README.md +1 -1
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/pricing/factors.py +67 -56
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +112 -78
- ins_pricing/utils/device.py +258 -237
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/utils/logging.py +34 -1
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- ins_pricing/utils/profiling.py +8 -4
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
- ins_pricing-0.5.1.dist-info/RECORD +132 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -1,342 +1,349 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import math
|
|
4
|
-
from typing import List, Optional, Tuple
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import torch.nn as nn
|
|
8
|
-
from torch.utils.data import Dataset
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
if self.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
self.
|
|
99
|
-
self.
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
#
|
|
103
|
-
self.
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
self.
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
#
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
#
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
return
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from torch.utils.data import Dataset
|
|
9
|
+
from ins_pricing.utils import get_logger, log_print
|
|
10
|
+
|
|
11
|
+
_logger = get_logger("ins_pricing.modelling.bayesopt.models.model_ft_components")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _log(*args, **kwargs) -> None:
|
|
15
|
+
log_print(_logger, *args, **kwargs)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# FT-Transformer model and sklearn-style wrapper.
|
|
20
|
+
# =============================================================================
|
|
21
|
+
# Define FT-Transformer model structure.
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FeatureTokenizer(nn.Module):
|
|
25
|
+
"""Map numeric/categorical/geo tokens into transformer input tokens."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
num_numeric: int,
|
|
30
|
+
cat_cardinalities,
|
|
31
|
+
d_model: int,
|
|
32
|
+
num_geo: int = 0,
|
|
33
|
+
num_numeric_tokens: int = 1,
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
|
|
37
|
+
self.num_numeric = num_numeric
|
|
38
|
+
self.num_geo = num_geo
|
|
39
|
+
self.has_geo = num_geo > 0
|
|
40
|
+
|
|
41
|
+
if num_numeric > 0:
|
|
42
|
+
if int(num_numeric_tokens) <= 0:
|
|
43
|
+
raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
|
|
44
|
+
self.num_numeric_tokens = int(num_numeric_tokens)
|
|
45
|
+
self.has_numeric = True
|
|
46
|
+
self.num_linear = nn.Linear(num_numeric, d_model * self.num_numeric_tokens)
|
|
47
|
+
else:
|
|
48
|
+
self.num_numeric_tokens = 0
|
|
49
|
+
self.has_numeric = False
|
|
50
|
+
|
|
51
|
+
self.embeddings = nn.ModuleList([
|
|
52
|
+
nn.Embedding(card, d_model) for card in cat_cardinalities
|
|
53
|
+
])
|
|
54
|
+
|
|
55
|
+
if self.has_geo:
|
|
56
|
+
# Map geo tokens with a linear layer to avoid one-hot on raw strings; upstream is encoded/normalized.
|
|
57
|
+
self.geo_linear = nn.Linear(num_geo, d_model)
|
|
58
|
+
|
|
59
|
+
def forward(self, X_num, X_cat, X_geo=None):
|
|
60
|
+
tokens = []
|
|
61
|
+
|
|
62
|
+
if self.has_numeric:
|
|
63
|
+
batch_size = X_num.shape[0]
|
|
64
|
+
num_token = self.num_linear(X_num)
|
|
65
|
+
num_token = num_token.view(batch_size, self.num_numeric_tokens, -1)
|
|
66
|
+
tokens.append(num_token)
|
|
67
|
+
|
|
68
|
+
for i, emb in enumerate(self.embeddings):
|
|
69
|
+
tok = emb(X_cat[:, i])
|
|
70
|
+
tokens.append(tok.unsqueeze(1))
|
|
71
|
+
|
|
72
|
+
if self.has_geo:
|
|
73
|
+
if X_geo is None:
|
|
74
|
+
raise RuntimeError("Geo tokens are enabled but X_geo was not provided.")
|
|
75
|
+
geo_token = self.geo_linear(X_geo)
|
|
76
|
+
tokens.append(geo_token.unsqueeze(1))
|
|
77
|
+
|
|
78
|
+
x = torch.cat(tokens, dim=1)
|
|
79
|
+
return x
|
|
80
|
+
|
|
81
|
+
# Encoder layer with residual scaling.
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ScaledTransformerEncoderLayer(nn.Module):
|
|
85
|
+
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
|
|
86
|
+
dropout: float = 0.1, residual_scale_attn: float = 1.0,
|
|
87
|
+
residual_scale_ffn: float = 1.0, norm_first: bool = True,
|
|
88
|
+
):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.self_attn = nn.MultiheadAttention(
|
|
91
|
+
embed_dim=d_model,
|
|
92
|
+
num_heads=nhead,
|
|
93
|
+
dropout=dropout,
|
|
94
|
+
batch_first=True
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Feed-forward network.
|
|
98
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
99
|
+
self.dropout = nn.Dropout(dropout)
|
|
100
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
101
|
+
|
|
102
|
+
# Normalization and dropout.
|
|
103
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
104
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
105
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
106
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
107
|
+
|
|
108
|
+
self.activation = nn.GELU()
|
|
109
|
+
# If you prefer ReLU, set: self.activation = nn.ReLU()
|
|
110
|
+
self.norm_first = norm_first
|
|
111
|
+
|
|
112
|
+
# Residual scaling coefficients.
|
|
113
|
+
self.res_scale_attn = residual_scale_attn
|
|
114
|
+
self.res_scale_ffn = residual_scale_ffn
|
|
115
|
+
|
|
116
|
+
def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal: Optional[bool] = None, **_kwargs):
|
|
117
|
+
# Input tensor shape: (batch, seq_len, d_model).
|
|
118
|
+
x = src
|
|
119
|
+
|
|
120
|
+
if self.norm_first:
|
|
121
|
+
# Pre-norm before attention.
|
|
122
|
+
x = x + self._sa_block(
|
|
123
|
+
self.norm1(x),
|
|
124
|
+
src_mask,
|
|
125
|
+
src_key_padding_mask,
|
|
126
|
+
is_causal=is_causal,
|
|
127
|
+
)
|
|
128
|
+
x = x + self._ff_block(self.norm2(x))
|
|
129
|
+
else:
|
|
130
|
+
# Post-norm (usually disabled).
|
|
131
|
+
x = self.norm1(
|
|
132
|
+
x + self._sa_block(
|
|
133
|
+
x,
|
|
134
|
+
src_mask,
|
|
135
|
+
src_key_padding_mask,
|
|
136
|
+
is_causal=is_causal,
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
x = self.norm2(x + self._ff_block(x))
|
|
140
|
+
|
|
141
|
+
return x
|
|
142
|
+
|
|
143
|
+
def _sa_block(self, x, attn_mask, key_padding_mask, *, is_causal: Optional[bool] = None):
|
|
144
|
+
# Self-attention with residual scaling.
|
|
145
|
+
if is_causal is None:
|
|
146
|
+
attn_out, _ = self.self_attn(
|
|
147
|
+
x, x, x,
|
|
148
|
+
attn_mask=attn_mask,
|
|
149
|
+
key_padding_mask=key_padding_mask,
|
|
150
|
+
need_weights=False,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
try:
|
|
154
|
+
attn_out, _ = self.self_attn(
|
|
155
|
+
x, x, x,
|
|
156
|
+
attn_mask=attn_mask,
|
|
157
|
+
key_padding_mask=key_padding_mask,
|
|
158
|
+
need_weights=False,
|
|
159
|
+
is_causal=is_causal,
|
|
160
|
+
)
|
|
161
|
+
except TypeError:
|
|
162
|
+
attn_out, _ = self.self_attn(
|
|
163
|
+
x, x, x,
|
|
164
|
+
attn_mask=attn_mask,
|
|
165
|
+
key_padding_mask=key_padding_mask,
|
|
166
|
+
need_weights=False,
|
|
167
|
+
)
|
|
168
|
+
return self.res_scale_attn * self.dropout1(attn_out)
|
|
169
|
+
|
|
170
|
+
def _ff_block(self, x):
|
|
171
|
+
# Feed-forward block with residual scaling.
|
|
172
|
+
x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
173
|
+
return self.res_scale_ffn * self.dropout2(x2)
|
|
174
|
+
|
|
175
|
+
# FT-Transformer core model.
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class FTTransformerCore(nn.Module):
|
|
179
|
+
# Minimal FT-Transformer built from:
|
|
180
|
+
# 1) FeatureTokenizer: convert numeric/categorical features to tokens;
|
|
181
|
+
# 2) TransformerEncoder: model feature interactions;
|
|
182
|
+
# 3) Pooling + MLP + Softplus: positive outputs for Tweedie/Gamma tasks.
|
|
183
|
+
|
|
184
|
+
def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
|
|
185
|
+
n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
|
|
186
|
+
task_type: str = 'regression', num_geo: int = 0,
|
|
187
|
+
num_numeric_tokens: int = 1
|
|
188
|
+
):
|
|
189
|
+
super().__init__()
|
|
190
|
+
|
|
191
|
+
self.num_numeric = int(num_numeric)
|
|
192
|
+
self.cat_cardinalities = list(cat_cardinalities or [])
|
|
193
|
+
|
|
194
|
+
self.tokenizer = FeatureTokenizer(
|
|
195
|
+
num_numeric=num_numeric,
|
|
196
|
+
cat_cardinalities=cat_cardinalities,
|
|
197
|
+
d_model=d_model,
|
|
198
|
+
num_geo=num_geo,
|
|
199
|
+
num_numeric_tokens=num_numeric_tokens
|
|
200
|
+
)
|
|
201
|
+
scale = 1.0 / math.sqrt(n_layers) # Recommended default.
|
|
202
|
+
encoder_layer = ScaledTransformerEncoderLayer(
|
|
203
|
+
d_model=d_model,
|
|
204
|
+
nhead=n_heads,
|
|
205
|
+
dim_feedforward=d_model * 4,
|
|
206
|
+
dropout=dropout,
|
|
207
|
+
residual_scale_attn=scale,
|
|
208
|
+
residual_scale_ffn=scale,
|
|
209
|
+
norm_first=True,
|
|
210
|
+
)
|
|
211
|
+
self.encoder = nn.TransformerEncoder(
|
|
212
|
+
encoder_layer,
|
|
213
|
+
num_layers=n_layers
|
|
214
|
+
)
|
|
215
|
+
self.n_layers = n_layers
|
|
216
|
+
|
|
217
|
+
layers = [
|
|
218
|
+
# If you need a deeper head, enable the sample layers below:
|
|
219
|
+
# nn.LayerNorm(d_model), # Extra normalization
|
|
220
|
+
# nn.Linear(d_model, d_model), # Extra fully connected layer
|
|
221
|
+
# nn.GELU(), # Activation
|
|
222
|
+
nn.Linear(d_model, 1),
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
if task_type == 'classification':
|
|
226
|
+
# Classification outputs logits for BCEWithLogitsLoss.
|
|
227
|
+
layers.append(nn.Identity())
|
|
228
|
+
else:
|
|
229
|
+
# Regression keeps positive outputs for Tweedie/Gamma.
|
|
230
|
+
layers.append(nn.Softplus())
|
|
231
|
+
|
|
232
|
+
self.head = nn.Sequential(*layers)
|
|
233
|
+
|
|
234
|
+
# ---- Self-supervised reconstruction head (masked modeling) ----
|
|
235
|
+
self.num_recon_head = nn.Linear(
|
|
236
|
+
d_model, self.num_numeric) if self.num_numeric > 0 else None
|
|
237
|
+
self.cat_recon_heads = nn.ModuleList([
|
|
238
|
+
nn.Linear(d_model, int(card)) for card in self.cat_cardinalities
|
|
239
|
+
])
|
|
240
|
+
|
|
241
|
+
def forward(
|
|
242
|
+
self,
|
|
243
|
+
X_num,
|
|
244
|
+
X_cat,
|
|
245
|
+
X_geo=None,
|
|
246
|
+
return_embedding: bool = False,
|
|
247
|
+
return_reconstruction: bool = False):
|
|
248
|
+
|
|
249
|
+
# Inputs:
|
|
250
|
+
# X_num -> float32 tensor with shape (batch, num_numeric_features)
|
|
251
|
+
# X_cat -> long tensor with shape (batch, num_categorical_features)
|
|
252
|
+
# X_geo -> float32 tensor with shape (batch, geo_token_dim)
|
|
253
|
+
|
|
254
|
+
if self.training and not hasattr(self, '_printed_device'):
|
|
255
|
+
_log(f">>> FTTransformerCore executing on device: {X_num.device}")
|
|
256
|
+
self._printed_device = True
|
|
257
|
+
|
|
258
|
+
# => tensor shape (batch, token_num, d_model)
|
|
259
|
+
tokens = self.tokenizer(X_num, X_cat, X_geo)
|
|
260
|
+
# => tensor shape (batch, token_num, d_model)
|
|
261
|
+
x = self.encoder(tokens)
|
|
262
|
+
|
|
263
|
+
# Mean-pool tokens, then send to the head.
|
|
264
|
+
x = x.mean(dim=1) # => tensor shape (batch, d_model)
|
|
265
|
+
|
|
266
|
+
if return_reconstruction:
|
|
267
|
+
num_pred, cat_logits = self.reconstruct(x)
|
|
268
|
+
cat_logits_out = tuple(
|
|
269
|
+
cat_logits) if cat_logits is not None else tuple()
|
|
270
|
+
if return_embedding:
|
|
271
|
+
return x, num_pred, cat_logits_out
|
|
272
|
+
return num_pred, cat_logits_out
|
|
273
|
+
|
|
274
|
+
if return_embedding:
|
|
275
|
+
return x
|
|
276
|
+
|
|
277
|
+
# => tensor shape (batch, 1); Softplus keeps it positive.
|
|
278
|
+
out = self.head(x)
|
|
279
|
+
return out
|
|
280
|
+
|
|
281
|
+
def reconstruct(self, embedding: torch.Tensor) -> Tuple[Optional[torch.Tensor], List[torch.Tensor]]:
|
|
282
|
+
"""Reconstruct numeric/categorical inputs from pooled embedding (batch, d_model)."""
|
|
283
|
+
num_pred = self.num_recon_head(
|
|
284
|
+
embedding) if self.num_recon_head is not None else None
|
|
285
|
+
cat_logits = [head(embedding) for head in self.cat_recon_heads]
|
|
286
|
+
return num_pred, cat_logits
|
|
287
|
+
|
|
288
|
+
# TabularDataset.
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class TabularDataset(Dataset):
|
|
292
|
+
def __init__(self, X_num, X_cat, X_geo, y, w):
|
|
293
|
+
|
|
294
|
+
# Input tensors:
|
|
295
|
+
# X_num: torch.float32, shape=(N, num_numeric_features)
|
|
296
|
+
# X_cat: torch.long, shape=(N, num_categorical_features)
|
|
297
|
+
# X_geo: torch.float32, shape=(N, geo_token_dim), can be empty
|
|
298
|
+
# y: torch.float32, shape=(N, 1)
|
|
299
|
+
# w: torch.float32, shape=(N, 1)
|
|
300
|
+
|
|
301
|
+
self.X_num = X_num
|
|
302
|
+
self.X_cat = X_cat
|
|
303
|
+
self.X_geo = X_geo
|
|
304
|
+
self.y = y
|
|
305
|
+
self.w = w
|
|
306
|
+
|
|
307
|
+
def __len__(self):
|
|
308
|
+
return self.y.shape[0]
|
|
309
|
+
|
|
310
|
+
def __getitem__(self, idx):
|
|
311
|
+
return (
|
|
312
|
+
self.X_num[idx],
|
|
313
|
+
self.X_cat[idx],
|
|
314
|
+
self.X_geo[idx],
|
|
315
|
+
self.y[idx],
|
|
316
|
+
self.w[idx],
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class MaskedTabularDataset(Dataset):
|
|
321
|
+
def __init__(self,
|
|
322
|
+
X_num_masked: torch.Tensor,
|
|
323
|
+
X_cat_masked: torch.Tensor,
|
|
324
|
+
X_geo: torch.Tensor,
|
|
325
|
+
X_num_true: Optional[torch.Tensor],
|
|
326
|
+
num_mask: Optional[torch.Tensor],
|
|
327
|
+
X_cat_true: Optional[torch.Tensor],
|
|
328
|
+
cat_mask: Optional[torch.Tensor]):
|
|
329
|
+
self.X_num_masked = X_num_masked
|
|
330
|
+
self.X_cat_masked = X_cat_masked
|
|
331
|
+
self.X_geo = X_geo
|
|
332
|
+
self.X_num_true = X_num_true
|
|
333
|
+
self.num_mask = num_mask
|
|
334
|
+
self.X_cat_true = X_cat_true
|
|
335
|
+
self.cat_mask = cat_mask
|
|
336
|
+
|
|
337
|
+
def __len__(self):
|
|
338
|
+
return self.X_num_masked.shape[0]
|
|
339
|
+
|
|
340
|
+
def __getitem__(self, idx):
|
|
341
|
+
return (
|
|
342
|
+
self.X_num_masked[idx],
|
|
343
|
+
self.X_cat_masked[idx],
|
|
344
|
+
self.X_geo[idx],
|
|
345
|
+
None if self.X_num_true is None else self.X_num_true[idx],
|
|
346
|
+
None if self.num_mask is None else self.num_mask[idx],
|
|
347
|
+
None if self.X_cat_true is None else self.X_cat_true[idx],
|
|
348
|
+
None if self.cat_mask is None else self.cat_mask[idx],
|
|
349
|
+
)
|