nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -46,7 +46,8 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
|
46
46
|
from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
|
|
47
47
|
from nextrec.basic.activation import activation_layer
|
|
48
48
|
from nextrec.basic.model import BaseModel
|
|
49
|
-
|
|
49
|
+
|
|
50
|
+
from nextrec.utils.model import select_features
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
class POSOGate(nn.Module):
|
|
@@ -79,7 +80,7 @@ class POSOGate(nn.Module):
|
|
|
79
80
|
h = self.act(self.fc1(pc))
|
|
80
81
|
g = torch.sigmoid(self.fc2(h)) # (B, out_dim) in (0,1)
|
|
81
82
|
return self.scale_factor * g
|
|
82
|
-
|
|
83
|
+
|
|
83
84
|
|
|
84
85
|
class POSOFC(nn.Module):
|
|
85
86
|
"""
|
|
@@ -116,10 +117,10 @@ class POSOFC(nn.Module):
|
|
|
116
117
|
pc: (B, pc_dim)
|
|
117
118
|
return: (B, out_dim)
|
|
118
119
|
"""
|
|
119
|
-
h = self.act(self.linear(x))
|
|
120
|
-
g = self.gate(pc)
|
|
121
|
-
return g * h
|
|
122
|
-
|
|
120
|
+
h = self.act(self.linear(x)) # Standard FC with activation
|
|
121
|
+
g = self.gate(pc) # (B, out_dim)
|
|
122
|
+
return g * h # Element-wise gating
|
|
123
|
+
|
|
123
124
|
|
|
124
125
|
class POSOMLP(nn.Module):
|
|
125
126
|
"""
|
|
@@ -173,7 +174,7 @@ class POSOMLP(nn.Module):
|
|
|
173
174
|
if self.dropout is not None:
|
|
174
175
|
h = self.dropout(h)
|
|
175
176
|
return h
|
|
176
|
-
|
|
177
|
+
|
|
177
178
|
|
|
178
179
|
class POSOMMoE(nn.Module):
|
|
179
180
|
"""
|
|
@@ -183,7 +184,7 @@ class POSOMMoE(nn.Module):
|
|
|
183
184
|
- Task gates aggregate the PC-masked expert outputs
|
|
184
185
|
|
|
185
186
|
Concretely:
|
|
186
|
-
h_e = expert_e(x) # (B, D)
|
|
187
|
+
h_e = expert_e(x) # (B, D)
|
|
187
188
|
g_e = POSOGate(pc) in (0, C)^{D} # (B, D)
|
|
188
189
|
h_e_tilde = g_e ⊙ h_e # (B, D)
|
|
189
190
|
z_t = Σ_e gate_t,e(x) * h_e_tilde
|
|
@@ -192,14 +193,14 @@ class POSOMMoE(nn.Module):
|
|
|
192
193
|
def __init__(
|
|
193
194
|
self,
|
|
194
195
|
input_dim: int,
|
|
195
|
-
pc_dim: int,
|
|
196
|
+
pc_dim: int, # for poso feature dimension
|
|
196
197
|
num_experts: int,
|
|
197
198
|
expert_hidden_dims: list[int],
|
|
198
199
|
num_tasks: int,
|
|
199
200
|
activation: str = "relu",
|
|
200
201
|
expert_dropout: float = 0.0,
|
|
201
|
-
gate_hidden_dim: int = 32,
|
|
202
|
-
scale_factor: float = 2.0,
|
|
202
|
+
gate_hidden_dim: int = 32, # for poso gate hidden dimension
|
|
203
|
+
scale_factor: float = 2.0, # for poso gate scale factor
|
|
203
204
|
gate_use_softmax: bool = True,
|
|
204
205
|
) -> None:
|
|
205
206
|
super().__init__()
|
|
@@ -207,15 +208,41 @@ class POSOMMoE(nn.Module):
|
|
|
207
208
|
self.num_tasks = num_tasks
|
|
208
209
|
|
|
209
210
|
# Experts built with framework MLP, same as standard MMoE
|
|
210
|
-
self.experts = nn.ModuleList(
|
|
211
|
-
|
|
211
|
+
self.experts = nn.ModuleList(
|
|
212
|
+
[
|
|
213
|
+
MLP(
|
|
214
|
+
input_dim=input_dim,
|
|
215
|
+
output_layer=False,
|
|
216
|
+
dims=expert_hidden_dims,
|
|
217
|
+
activation=activation,
|
|
218
|
+
dropout=expert_dropout,
|
|
219
|
+
)
|
|
220
|
+
for _ in range(num_experts)
|
|
221
|
+
]
|
|
222
|
+
)
|
|
223
|
+
self.expert_output_dim = (
|
|
224
|
+
expert_hidden_dims[-1] if expert_hidden_dims else input_dim
|
|
225
|
+
)
|
|
212
226
|
|
|
213
227
|
# Task-specific gates: gate_t(x) over experts
|
|
214
|
-
self.gates = nn.ModuleList(
|
|
228
|
+
self.gates = nn.ModuleList(
|
|
229
|
+
[nn.Linear(input_dim, num_experts) for _ in range(num_tasks)]
|
|
230
|
+
)
|
|
215
231
|
self.gate_use_softmax = gate_use_softmax
|
|
216
232
|
|
|
217
233
|
# PC gate per expert: g_e(pc) ∈ R^D
|
|
218
|
-
self.expert_pc_gates = nn.ModuleList(
|
|
234
|
+
self.expert_pc_gates = nn.ModuleList(
|
|
235
|
+
[
|
|
236
|
+
POSOGate(
|
|
237
|
+
pc_dim=pc_dim,
|
|
238
|
+
out_dim=self.expert_output_dim,
|
|
239
|
+
hidden_dim=gate_hidden_dim,
|
|
240
|
+
scale_factor=scale_factor,
|
|
241
|
+
activation=activation,
|
|
242
|
+
)
|
|
243
|
+
for _ in range(num_experts)
|
|
244
|
+
]
|
|
245
|
+
)
|
|
219
246
|
|
|
220
247
|
def forward(self, x: torch.Tensor, pc: torch.Tensor) -> list[torch.Tensor]:
|
|
221
248
|
"""
|
|
@@ -226,9 +253,9 @@ class POSOMMoE(nn.Module):
|
|
|
226
253
|
# 1) Expert outputs with POSO PC gate
|
|
227
254
|
masked_expert_outputs = []
|
|
228
255
|
for e, expert in enumerate(self.experts):
|
|
229
|
-
h_e = expert(x)
|
|
230
|
-
g_e = self.expert_pc_gates[e](pc)
|
|
231
|
-
h_e_tilde = g_e * h_e
|
|
256
|
+
h_e = expert(x) # (B, D)
|
|
257
|
+
g_e = self.expert_pc_gates[e](pc) # (B, D)
|
|
258
|
+
h_e_tilde = g_e * h_e # (B, D)
|
|
232
259
|
masked_expert_outputs.append(h_e_tilde)
|
|
233
260
|
|
|
234
261
|
masked_expert_outputs = torch.stack(masked_expert_outputs, dim=1) # (B, E, D)
|
|
@@ -236,13 +263,13 @@ class POSOMMoE(nn.Module):
|
|
|
236
263
|
# 2) Task gates depend on x as in standard MMoE
|
|
237
264
|
task_outputs: list[torch.Tensor] = []
|
|
238
265
|
for t in range(self.num_tasks):
|
|
239
|
-
logits = self.gates[t](x)
|
|
266
|
+
logits = self.gates[t](x) # (B, E)
|
|
240
267
|
if self.gate_use_softmax:
|
|
241
268
|
gate = F.softmax(logits, dim=1)
|
|
242
269
|
else:
|
|
243
270
|
gate = logits
|
|
244
271
|
|
|
245
|
-
gate = gate.unsqueeze(-1)
|
|
272
|
+
gate = gate.unsqueeze(-1) # (B, E, 1)
|
|
246
273
|
z_t = torch.sum(gate * masked_expert_outputs, dim=1) # (B, D)
|
|
247
274
|
task_outputs.append(z_t)
|
|
248
275
|
|
|
@@ -269,15 +296,18 @@ class POSO(BaseModel):
|
|
|
269
296
|
|
|
270
297
|
def __init__(
|
|
271
298
|
self,
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
299
|
+
dense_features: list[DenseFeature] | None,
|
|
300
|
+
sparse_features: list[SparseFeature] | None,
|
|
301
|
+
sequence_features: list[SequenceFeature] | None,
|
|
302
|
+
main_dense_features: list[str] | None,
|
|
303
|
+
main_sparse_features: list[str] | None,
|
|
304
|
+
main_sequence_features: list[str] | None,
|
|
305
|
+
pc_dense_features: list[str] | None,
|
|
306
|
+
pc_sparse_features: list[str] | None,
|
|
307
|
+
pc_sequence_features: list[str] | None,
|
|
278
308
|
tower_params_list: list[dict],
|
|
279
309
|
target: list[str],
|
|
280
|
-
task: str | list[str]
|
|
310
|
+
task: str | list[str] = "binary",
|
|
281
311
|
architecture: str = "mlp",
|
|
282
312
|
# POSO gating defaults
|
|
283
313
|
gate_hidden_dim: int = 32,
|
|
@@ -303,28 +333,32 @@ class POSO(BaseModel):
|
|
|
303
333
|
dense_l2_reg: float = 1e-4,
|
|
304
334
|
**kwargs,
|
|
305
335
|
):
|
|
306
|
-
# Keep explicit copies of main and PC features
|
|
307
|
-
self.main_dense_features = list(main_dense_features or [])
|
|
308
|
-
self.main_sparse_features = list(main_sparse_features or [])
|
|
309
|
-
self.main_sequence_features = list(main_sequence_features or [])
|
|
310
|
-
self.pc_dense_features = list(pc_dense_features or [])
|
|
311
|
-
self.pc_sparse_features = list(pc_sparse_features or [])
|
|
312
|
-
self.pc_sequence_features = list(pc_sequence_features or [])
|
|
313
336
|
self.num_tasks = len(target)
|
|
314
337
|
|
|
315
|
-
|
|
316
|
-
|
|
338
|
+
# Normalize task to match num_tasks
|
|
339
|
+
resolved_task = task
|
|
340
|
+
if resolved_task is None:
|
|
341
|
+
resolved_task = self.default_task
|
|
342
|
+
elif isinstance(resolved_task, str):
|
|
343
|
+
resolved_task = [resolved_task] * self.num_tasks
|
|
344
|
+
elif len(resolved_task) == 1 and self.num_tasks > 1:
|
|
345
|
+
resolved_task = resolved_task * self.num_tasks
|
|
346
|
+
elif len(resolved_task) != self.num_tasks:
|
|
347
|
+
raise ValueError(
|
|
348
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
|
|
349
|
+
)
|
|
317
350
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
351
|
+
if len(tower_params_list) != self.num_tasks:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
354
|
+
)
|
|
321
355
|
|
|
322
356
|
super().__init__(
|
|
323
357
|
dense_features=dense_features,
|
|
324
358
|
sparse_features=sparse_features,
|
|
325
359
|
sequence_features=sequence_features,
|
|
326
360
|
target=target,
|
|
327
|
-
task=
|
|
361
|
+
task=resolved_task,
|
|
328
362
|
device=device,
|
|
329
363
|
embedding_l1_reg=embedding_l1_reg,
|
|
330
364
|
dense_l1_reg=dense_l1_reg,
|
|
@@ -333,15 +367,58 @@ class POSO(BaseModel):
|
|
|
333
367
|
**kwargs,
|
|
334
368
|
)
|
|
335
369
|
|
|
336
|
-
self.
|
|
370
|
+
self.main_dense_feature_names = list(main_dense_features or [])
|
|
371
|
+
self.main_sparse_feature_names = list(main_sparse_features or [])
|
|
372
|
+
self.main_sequence_feature_names = list(main_sequence_features or [])
|
|
373
|
+
self.pc_dense_feature_names = list(pc_dense_features or [])
|
|
374
|
+
self.pc_sparse_feature_names = list(pc_sparse_features or [])
|
|
375
|
+
self.pc_sequence_feature_names = list(pc_sequence_features or [])
|
|
376
|
+
|
|
377
|
+
if loss is None:
|
|
378
|
+
self.loss = "bce"
|
|
379
|
+
self.loss = loss
|
|
380
|
+
|
|
337
381
|
optimizer_params = optimizer_params or {}
|
|
338
382
|
|
|
339
|
-
self.
|
|
340
|
-
|
|
341
|
-
|
|
383
|
+
self.main_dense_features = select_features(
|
|
384
|
+
self.dense_features, self.main_dense_feature_names, "main_dense_features"
|
|
385
|
+
)
|
|
386
|
+
self.main_sparse_features = select_features(
|
|
387
|
+
self.sparse_features, self.main_sparse_feature_names, "main_sparse_features"
|
|
388
|
+
)
|
|
389
|
+
self.main_sequence_features = select_features(
|
|
390
|
+
self.sequence_features,
|
|
391
|
+
self.main_sequence_feature_names,
|
|
392
|
+
"main_sequence_features",
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
self.pc_dense_features = select_features(
|
|
396
|
+
self.dense_features, self.pc_dense_feature_names, "pc_dense_features"
|
|
397
|
+
)
|
|
398
|
+
self.pc_sparse_features = select_features(
|
|
399
|
+
self.sparse_features, self.pc_sparse_feature_names, "pc_sparse_features"
|
|
400
|
+
)
|
|
401
|
+
self.pc_sequence_features = select_features(
|
|
402
|
+
self.sequence_features,
|
|
403
|
+
self.pc_sequence_feature_names,
|
|
404
|
+
"pc_sequence_features",
|
|
405
|
+
)
|
|
342
406
|
|
|
343
|
-
self.main_features =
|
|
344
|
-
|
|
407
|
+
self.main_features = (
|
|
408
|
+
self.main_dense_features
|
|
409
|
+
+ self.main_sparse_features
|
|
410
|
+
+ self.main_sequence_features
|
|
411
|
+
)
|
|
412
|
+
self.pc_features = (
|
|
413
|
+
self.pc_dense_features + self.pc_sparse_features + self.pc_sequence_features
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
if not self.main_features:
|
|
417
|
+
raise ValueError("POSO requires at least one main feature.")
|
|
418
|
+
if not self.pc_features:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
"POSO requires at least one PC feature for personalization."
|
|
421
|
+
)
|
|
345
422
|
|
|
346
423
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
347
424
|
self.main_input_dim = self.embedding.get_input_dim(self.main_features)
|
|
@@ -349,7 +426,9 @@ class POSO(BaseModel):
|
|
|
349
426
|
|
|
350
427
|
self.architecture = architecture.lower()
|
|
351
428
|
if self.architecture not in {"mlp", "mmoe"}:
|
|
352
|
-
raise ValueError(
|
|
429
|
+
raise ValueError(
|
|
430
|
+
f"Unsupported architecture '{architecture}', choose from ['mlp', 'mmoe']."
|
|
431
|
+
)
|
|
353
432
|
|
|
354
433
|
# Build backbones
|
|
355
434
|
if self.architecture == "mlp":
|
|
@@ -358,13 +437,17 @@ class POSO(BaseModel):
|
|
|
358
437
|
for tower_params in tower_params_list:
|
|
359
438
|
dims = tower_params.get("dims")
|
|
360
439
|
if not dims:
|
|
361
|
-
raise ValueError(
|
|
440
|
+
raise ValueError(
|
|
441
|
+
"tower_params must include a non-empty 'dims' list for POSO-MLP towers."
|
|
442
|
+
)
|
|
362
443
|
dropout = tower_params.get("dropout", 0.0)
|
|
363
444
|
tower = POSOMLP(
|
|
364
445
|
input_dim=self.main_input_dim,
|
|
365
446
|
pc_dim=self.pc_input_dim,
|
|
366
447
|
dims=dims,
|
|
367
|
-
gate_hidden_dim=tower_params.get(
|
|
448
|
+
gate_hidden_dim=tower_params.get(
|
|
449
|
+
"gate_hidden_dim", gate_hidden_dim
|
|
450
|
+
),
|
|
368
451
|
scale_factor=tower_params.get("scale_factor", gate_scale_factor),
|
|
369
452
|
activation=tower_params.get("activation", gate_activation),
|
|
370
453
|
use_bias=tower_params.get("use_bias", gate_use_bias),
|
|
@@ -375,7 +458,9 @@ class POSO(BaseModel):
|
|
|
375
458
|
self.tower_heads.append(nn.Linear(tower_output_dim, 1))
|
|
376
459
|
else:
|
|
377
460
|
if expert_hidden_dims is None or not expert_hidden_dims:
|
|
378
|
-
raise ValueError(
|
|
461
|
+
raise ValueError(
|
|
462
|
+
"expert_hidden_dims must be provided for MMoE architecture."
|
|
463
|
+
)
|
|
379
464
|
self.mmoe = POSOMMoE(
|
|
380
465
|
input_dim=self.main_input_dim,
|
|
381
466
|
pc_dim=self.pc_input_dim,
|
|
@@ -388,12 +473,35 @@ class POSO(BaseModel):
|
|
|
388
473
|
scale_factor=expert_gate_scale_factor,
|
|
389
474
|
gate_use_softmax=gate_use_softmax,
|
|
390
475
|
)
|
|
391
|
-
self.towers = nn.ModuleList(
|
|
476
|
+
self.towers = nn.ModuleList(
|
|
477
|
+
[
|
|
478
|
+
MLP(
|
|
479
|
+
input_dim=self.mmoe.expert_output_dim,
|
|
480
|
+
output_layer=True,
|
|
481
|
+
**tower_params,
|
|
482
|
+
)
|
|
483
|
+
for tower_params in tower_params_list
|
|
484
|
+
]
|
|
485
|
+
)
|
|
392
486
|
self.tower_heads = None
|
|
393
|
-
self.prediction_layer = PredictionLayer(
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
487
|
+
self.prediction_layer = PredictionLayer(
|
|
488
|
+
task_type=self.default_task,
|
|
489
|
+
task_dims=[1] * self.num_tasks,
|
|
490
|
+
)
|
|
491
|
+
include_modules = (
|
|
492
|
+
["towers", "tower_heads"]
|
|
493
|
+
if self.architecture == "mlp"
|
|
494
|
+
else ["mmoe", "towers"]
|
|
495
|
+
)
|
|
496
|
+
self.register_regularization_weights(
|
|
497
|
+
embedding_attr="embedding", include_modules=include_modules
|
|
498
|
+
)
|
|
499
|
+
self.compile(
|
|
500
|
+
optimizer=optimizer,
|
|
501
|
+
optimizer_params=optimizer_params,
|
|
502
|
+
loss=loss,
|
|
503
|
+
loss_params=loss_params,
|
|
504
|
+
)
|
|
397
505
|
|
|
398
506
|
def forward(self, x):
|
|
399
507
|
# Embed main and PC features separately so PC can gate hidden units
|
|
@@ -56,42 +56,58 @@ class ShareBottom(BaseModel):
|
|
|
56
56
|
def default_task(self):
|
|
57
57
|
num_tasks = getattr(self, "num_tasks", None)
|
|
58
58
|
if num_tasks is not None and num_tasks > 0:
|
|
59
|
-
return [
|
|
60
|
-
return [
|
|
61
|
-
|
|
62
|
-
def __init__(
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
59
|
+
return ["binary"] * num_tasks
|
|
60
|
+
return ["binary"]
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
dense_features: list[DenseFeature],
|
|
65
|
+
sparse_features: list[SparseFeature],
|
|
66
|
+
sequence_features: list[SequenceFeature],
|
|
67
|
+
bottom_params: dict,
|
|
68
|
+
tower_params_list: list[dict],
|
|
69
|
+
target: list[str],
|
|
70
|
+
task: str | list[str] | None = None,
|
|
71
|
+
optimizer: str = "adam",
|
|
72
|
+
optimizer_params: dict | None = None,
|
|
73
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
74
|
+
loss_params: dict | list[dict] | None = None,
|
|
75
|
+
device: str = "cpu",
|
|
76
|
+
embedding_l1_reg=1e-6,
|
|
77
|
+
dense_l1_reg=1e-5,
|
|
78
|
+
embedding_l2_reg=1e-5,
|
|
79
|
+
dense_l2_reg=1e-4,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
|
|
83
|
+
optimizer_params = optimizer_params or {}
|
|
84
|
+
|
|
81
85
|
self.num_tasks = len(target)
|
|
82
86
|
|
|
87
|
+
resolved_task = task
|
|
88
|
+
if resolved_task is None:
|
|
89
|
+
resolved_task = self.default_task
|
|
90
|
+
elif isinstance(resolved_task, str):
|
|
91
|
+
resolved_task = [resolved_task] * self.num_tasks
|
|
92
|
+
elif len(resolved_task) == 1 and self.num_tasks > 1:
|
|
93
|
+
resolved_task = resolved_task * self.num_tasks
|
|
94
|
+
elif len(resolved_task) != self.num_tasks:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
|
|
97
|
+
)
|
|
98
|
+
|
|
83
99
|
super(ShareBottom, self).__init__(
|
|
84
100
|
dense_features=dense_features,
|
|
85
101
|
sparse_features=sparse_features,
|
|
86
102
|
sequence_features=sequence_features,
|
|
87
103
|
target=target,
|
|
88
|
-
task=
|
|
104
|
+
task=resolved_task,
|
|
89
105
|
device=device,
|
|
90
106
|
embedding_l1_reg=embedding_l1_reg,
|
|
91
107
|
dense_l1_reg=dense_l1_reg,
|
|
92
108
|
embedding_l2_reg=embedding_l2_reg,
|
|
93
109
|
dense_l2_reg=dense_l2_reg,
|
|
94
|
-
**kwargs
|
|
110
|
+
**kwargs,
|
|
95
111
|
)
|
|
96
112
|
|
|
97
113
|
self.loss = loss
|
|
@@ -100,7 +116,9 @@ class ShareBottom(BaseModel):
|
|
|
100
116
|
# Number of tasks
|
|
101
117
|
self.num_tasks = len(target)
|
|
102
118
|
if len(tower_params_list) != self.num_tasks:
|
|
103
|
-
raise ValueError(
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
121
|
+
)
|
|
104
122
|
# Embedding layer
|
|
105
123
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
106
124
|
# Calculate input dimension
|
|
@@ -108,39 +126,48 @@ class ShareBottom(BaseModel):
|
|
|
108
126
|
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
109
127
|
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
110
128
|
# input_dim = emb_dim_total + dense_input_dim
|
|
111
|
-
|
|
129
|
+
|
|
112
130
|
# Shared bottom network
|
|
113
131
|
self.bottom = MLP(input_dim=input_dim, output_layer=False, **bottom_params)
|
|
114
|
-
|
|
132
|
+
|
|
115
133
|
# Get bottom output dimension
|
|
116
|
-
if
|
|
117
|
-
bottom_output_dim = bottom_params[
|
|
134
|
+
if "dims" in bottom_params and len(bottom_params["dims"]) > 0:
|
|
135
|
+
bottom_output_dim = bottom_params["dims"][-1]
|
|
118
136
|
else:
|
|
119
137
|
bottom_output_dim = input_dim
|
|
120
|
-
|
|
138
|
+
|
|
121
139
|
# Task-specific towers
|
|
122
140
|
self.towers = nn.ModuleList()
|
|
123
141
|
for tower_params in tower_params_list:
|
|
124
142
|
tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
|
|
125
143
|
self.towers.append(tower)
|
|
126
|
-
self.prediction_layer = PredictionLayer(
|
|
144
|
+
self.prediction_layer = PredictionLayer(
|
|
145
|
+
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
146
|
+
)
|
|
127
147
|
# Register regularization weights
|
|
128
|
-
self.register_regularization_weights(
|
|
129
|
-
|
|
148
|
+
self.register_regularization_weights(
|
|
149
|
+
embedding_attr="embedding", include_modules=["bottom", "towers"]
|
|
150
|
+
)
|
|
151
|
+
self.compile(
|
|
152
|
+
optimizer=optimizer,
|
|
153
|
+
optimizer_params=optimizer_params,
|
|
154
|
+
loss=loss,
|
|
155
|
+
loss_params=loss_params,
|
|
156
|
+
)
|
|
130
157
|
|
|
131
158
|
def forward(self, x):
|
|
132
159
|
# Get all embeddings and flatten
|
|
133
160
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
134
|
-
|
|
161
|
+
|
|
135
162
|
# Shared bottom
|
|
136
163
|
bottom_output = self.bottom(input_flat) # [B, bottom_dim]
|
|
137
|
-
|
|
164
|
+
|
|
138
165
|
# Task-specific towers
|
|
139
166
|
task_outputs = []
|
|
140
167
|
for tower in self.towers:
|
|
141
168
|
tower_output = tower(bottom_output) # [B, 1]
|
|
142
169
|
task_outputs.append(tower_output)
|
|
143
|
-
|
|
170
|
+
|
|
144
171
|
# Stack outputs: [B, num_tasks]
|
|
145
172
|
y = torch.cat(task_outputs, dim=1)
|
|
146
173
|
return self.prediction_layer(y)
|