nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -79,7 +79,7 @@ class POSOGate(nn.Module):
|
|
|
79
79
|
h = self.act(self.fc1(pc))
|
|
80
80
|
g = torch.sigmoid(self.fc2(h)) # (B, out_dim) in (0,1)
|
|
81
81
|
return self.scale_factor * g
|
|
82
|
-
|
|
82
|
+
|
|
83
83
|
|
|
84
84
|
class POSOFC(nn.Module):
|
|
85
85
|
"""
|
|
@@ -116,10 +116,10 @@ class POSOFC(nn.Module):
|
|
|
116
116
|
pc: (B, pc_dim)
|
|
117
117
|
return: (B, out_dim)
|
|
118
118
|
"""
|
|
119
|
-
h = self.act(self.linear(x))
|
|
120
|
-
g = self.gate(pc)
|
|
121
|
-
return g * h
|
|
122
|
-
|
|
119
|
+
h = self.act(self.linear(x)) # Standard FC with activation
|
|
120
|
+
g = self.gate(pc) # (B, out_dim)
|
|
121
|
+
return g * h # Element-wise gating
|
|
122
|
+
|
|
123
123
|
|
|
124
124
|
class POSOMLP(nn.Module):
|
|
125
125
|
"""
|
|
@@ -173,7 +173,7 @@ class POSOMLP(nn.Module):
|
|
|
173
173
|
if self.dropout is not None:
|
|
174
174
|
h = self.dropout(h)
|
|
175
175
|
return h
|
|
176
|
-
|
|
176
|
+
|
|
177
177
|
|
|
178
178
|
class POSOMMoE(nn.Module):
|
|
179
179
|
"""
|
|
@@ -183,7 +183,7 @@ class POSOMMoE(nn.Module):
|
|
|
183
183
|
- Task gates aggregate the PC-masked expert outputs
|
|
184
184
|
|
|
185
185
|
Concretely:
|
|
186
|
-
h_e = expert_e(x) # (B, D)
|
|
186
|
+
h_e = expert_e(x) # (B, D)
|
|
187
187
|
g_e = POSOGate(pc) in (0, C)^{D} # (B, D)
|
|
188
188
|
h_e_tilde = g_e ⊙ h_e # (B, D)
|
|
189
189
|
z_t = Σ_e gate_t,e(x) * h_e_tilde
|
|
@@ -192,14 +192,14 @@ class POSOMMoE(nn.Module):
|
|
|
192
192
|
def __init__(
|
|
193
193
|
self,
|
|
194
194
|
input_dim: int,
|
|
195
|
-
pc_dim: int,
|
|
195
|
+
pc_dim: int, # for poso feature dimension
|
|
196
196
|
num_experts: int,
|
|
197
197
|
expert_hidden_dims: list[int],
|
|
198
198
|
num_tasks: int,
|
|
199
199
|
activation: str = "relu",
|
|
200
200
|
expert_dropout: float = 0.0,
|
|
201
|
-
gate_hidden_dim: int = 32,
|
|
202
|
-
scale_factor: float = 2.0,
|
|
201
|
+
gate_hidden_dim: int = 32, # for poso gate hidden dimension
|
|
202
|
+
scale_factor: float = 2.0, # for poso gate scale factor
|
|
203
203
|
gate_use_softmax: bool = True,
|
|
204
204
|
) -> None:
|
|
205
205
|
super().__init__()
|
|
@@ -207,15 +207,41 @@ class POSOMMoE(nn.Module):
|
|
|
207
207
|
self.num_tasks = num_tasks
|
|
208
208
|
|
|
209
209
|
# Experts built with framework MLP, same as standard MMoE
|
|
210
|
-
self.experts = nn.ModuleList(
|
|
211
|
-
|
|
210
|
+
self.experts = nn.ModuleList(
|
|
211
|
+
[
|
|
212
|
+
MLP(
|
|
213
|
+
input_dim=input_dim,
|
|
214
|
+
output_layer=False,
|
|
215
|
+
dims=expert_hidden_dims,
|
|
216
|
+
activation=activation,
|
|
217
|
+
dropout=expert_dropout,
|
|
218
|
+
)
|
|
219
|
+
for _ in range(num_experts)
|
|
220
|
+
]
|
|
221
|
+
)
|
|
222
|
+
self.expert_output_dim = (
|
|
223
|
+
expert_hidden_dims[-1] if expert_hidden_dims else input_dim
|
|
224
|
+
)
|
|
212
225
|
|
|
213
226
|
# Task-specific gates: gate_t(x) over experts
|
|
214
|
-
self.gates = nn.ModuleList(
|
|
227
|
+
self.gates = nn.ModuleList(
|
|
228
|
+
[nn.Linear(input_dim, num_experts) for _ in range(num_tasks)]
|
|
229
|
+
)
|
|
215
230
|
self.gate_use_softmax = gate_use_softmax
|
|
216
231
|
|
|
217
232
|
# PC gate per expert: g_e(pc) ∈ R^D
|
|
218
|
-
self.expert_pc_gates = nn.ModuleList(
|
|
233
|
+
self.expert_pc_gates = nn.ModuleList(
|
|
234
|
+
[
|
|
235
|
+
POSOGate(
|
|
236
|
+
pc_dim=pc_dim,
|
|
237
|
+
out_dim=self.expert_output_dim,
|
|
238
|
+
hidden_dim=gate_hidden_dim,
|
|
239
|
+
scale_factor=scale_factor,
|
|
240
|
+
activation=activation,
|
|
241
|
+
)
|
|
242
|
+
for _ in range(num_experts)
|
|
243
|
+
]
|
|
244
|
+
)
|
|
219
245
|
|
|
220
246
|
def forward(self, x: torch.Tensor, pc: torch.Tensor) -> list[torch.Tensor]:
|
|
221
247
|
"""
|
|
@@ -226,9 +252,9 @@ class POSOMMoE(nn.Module):
|
|
|
226
252
|
# 1) Expert outputs with POSO PC gate
|
|
227
253
|
masked_expert_outputs = []
|
|
228
254
|
for e, expert in enumerate(self.experts):
|
|
229
|
-
h_e = expert(x)
|
|
230
|
-
g_e = self.expert_pc_gates[e](pc)
|
|
231
|
-
h_e_tilde = g_e * h_e
|
|
255
|
+
h_e = expert(x) # (B, D)
|
|
256
|
+
g_e = self.expert_pc_gates[e](pc) # (B, D)
|
|
257
|
+
h_e_tilde = g_e * h_e # (B, D)
|
|
232
258
|
masked_expert_outputs.append(h_e_tilde)
|
|
233
259
|
|
|
234
260
|
masked_expert_outputs = torch.stack(masked_expert_outputs, dim=1) # (B, E, D)
|
|
@@ -236,13 +262,13 @@ class POSOMMoE(nn.Module):
|
|
|
236
262
|
# 2) Task gates depend on x as in standard MMoE
|
|
237
263
|
task_outputs: list[torch.Tensor] = []
|
|
238
264
|
for t in range(self.num_tasks):
|
|
239
|
-
logits = self.gates[t](x)
|
|
265
|
+
logits = self.gates[t](x) # (B, E)
|
|
240
266
|
if self.gate_use_softmax:
|
|
241
267
|
gate = F.softmax(logits, dim=1)
|
|
242
268
|
else:
|
|
243
269
|
gate = logits
|
|
244
270
|
|
|
245
|
-
gate = gate.unsqueeze(-1)
|
|
271
|
+
gate = gate.unsqueeze(-1) # (B, E, 1)
|
|
246
272
|
z_t = torch.sum(gate * masked_expert_outputs, dim=1) # (B, D)
|
|
247
273
|
task_outputs.append(z_t)
|
|
248
274
|
|
|
@@ -261,8 +287,11 @@ class POSO(BaseModel):
|
|
|
261
287
|
return "POSO"
|
|
262
288
|
|
|
263
289
|
@property
|
|
264
|
-
def
|
|
265
|
-
|
|
290
|
+
def default_task(self) -> list[str]:
|
|
291
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
292
|
+
if num_tasks is not None and num_tasks > 0:
|
|
293
|
+
return ["binary"] * num_tasks
|
|
294
|
+
return ["binary"]
|
|
266
295
|
|
|
267
296
|
def __init__(
|
|
268
297
|
self,
|
|
@@ -274,7 +303,7 @@ class POSO(BaseModel):
|
|
|
274
303
|
pc_sequence_features: list[SequenceFeature] | None,
|
|
275
304
|
tower_params_list: list[dict],
|
|
276
305
|
target: list[str],
|
|
277
|
-
task: str | list[str] =
|
|
306
|
+
task: str | list[str] | None = None,
|
|
278
307
|
architecture: str = "mlp",
|
|
279
308
|
# POSO gating defaults
|
|
280
309
|
gate_hidden_dim: int = 32,
|
|
@@ -307,26 +336,38 @@ class POSO(BaseModel):
|
|
|
307
336
|
self.pc_dense_features = list(pc_dense_features or [])
|
|
308
337
|
self.pc_sparse_features = list(pc_sparse_features or [])
|
|
309
338
|
self.pc_sequence_features = list(pc_sequence_features or [])
|
|
339
|
+
self.num_tasks = len(target)
|
|
310
340
|
|
|
311
|
-
if
|
|
312
|
-
|
|
341
|
+
if (
|
|
342
|
+
not self.pc_dense_features
|
|
343
|
+
and not self.pc_sparse_features
|
|
344
|
+
and not self.pc_sequence_features
|
|
345
|
+
):
|
|
346
|
+
raise ValueError(
|
|
347
|
+
"POSO requires at least one PC feature for personalization."
|
|
348
|
+
)
|
|
313
349
|
|
|
314
|
-
dense_features = merge_features(
|
|
315
|
-
|
|
316
|
-
|
|
350
|
+
dense_features = merge_features(
|
|
351
|
+
self.main_dense_features, self.pc_dense_features
|
|
352
|
+
)
|
|
353
|
+
sparse_features = merge_features(
|
|
354
|
+
self.main_sparse_features, self.pc_sparse_features
|
|
355
|
+
)
|
|
356
|
+
sequence_features = merge_features(
|
|
357
|
+
self.main_sequence_features, self.pc_sequence_features
|
|
358
|
+
)
|
|
317
359
|
|
|
318
360
|
super().__init__(
|
|
319
361
|
dense_features=dense_features,
|
|
320
362
|
sparse_features=sparse_features,
|
|
321
363
|
sequence_features=sequence_features,
|
|
322
364
|
target=target,
|
|
323
|
-
task=task,
|
|
365
|
+
task=task or self.default_task,
|
|
324
366
|
device=device,
|
|
325
367
|
embedding_l1_reg=embedding_l1_reg,
|
|
326
368
|
dense_l1_reg=dense_l1_reg,
|
|
327
369
|
embedding_l2_reg=embedding_l2_reg,
|
|
328
370
|
dense_l2_reg=dense_l2_reg,
|
|
329
|
-
early_stop_patience=20,
|
|
330
371
|
**kwargs,
|
|
331
372
|
)
|
|
332
373
|
|
|
@@ -335,10 +376,18 @@ class POSO(BaseModel):
|
|
|
335
376
|
|
|
336
377
|
self.num_tasks = len(target)
|
|
337
378
|
if len(tower_params_list) != self.num_tasks:
|
|
338
|
-
raise ValueError(
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
381
|
+
)
|
|
339
382
|
|
|
340
|
-
self.main_features =
|
|
341
|
-
|
|
383
|
+
self.main_features = (
|
|
384
|
+
self.main_dense_features
|
|
385
|
+
+ self.main_sparse_features
|
|
386
|
+
+ self.main_sequence_features
|
|
387
|
+
)
|
|
388
|
+
self.pc_features = (
|
|
389
|
+
self.pc_dense_features + self.pc_sparse_features + self.pc_sequence_features
|
|
390
|
+
)
|
|
342
391
|
|
|
343
392
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
344
393
|
self.main_input_dim = self.embedding.get_input_dim(self.main_features)
|
|
@@ -346,7 +395,9 @@ class POSO(BaseModel):
|
|
|
346
395
|
|
|
347
396
|
self.architecture = architecture.lower()
|
|
348
397
|
if self.architecture not in {"mlp", "mmoe"}:
|
|
349
|
-
raise ValueError(
|
|
398
|
+
raise ValueError(
|
|
399
|
+
f"Unsupported architecture '{architecture}', choose from ['mlp', 'mmoe']."
|
|
400
|
+
)
|
|
350
401
|
|
|
351
402
|
# Build backbones
|
|
352
403
|
if self.architecture == "mlp":
|
|
@@ -355,13 +406,17 @@ class POSO(BaseModel):
|
|
|
355
406
|
for tower_params in tower_params_list:
|
|
356
407
|
dims = tower_params.get("dims")
|
|
357
408
|
if not dims:
|
|
358
|
-
raise ValueError(
|
|
409
|
+
raise ValueError(
|
|
410
|
+
"tower_params must include a non-empty 'dims' list for POSO-MLP towers."
|
|
411
|
+
)
|
|
359
412
|
dropout = tower_params.get("dropout", 0.0)
|
|
360
413
|
tower = POSOMLP(
|
|
361
414
|
input_dim=self.main_input_dim,
|
|
362
415
|
pc_dim=self.pc_input_dim,
|
|
363
416
|
dims=dims,
|
|
364
|
-
gate_hidden_dim=tower_params.get(
|
|
417
|
+
gate_hidden_dim=tower_params.get(
|
|
418
|
+
"gate_hidden_dim", gate_hidden_dim
|
|
419
|
+
),
|
|
365
420
|
scale_factor=tower_params.get("scale_factor", gate_scale_factor),
|
|
366
421
|
activation=tower_params.get("activation", gate_activation),
|
|
367
422
|
use_bias=tower_params.get("use_bias", gate_use_bias),
|
|
@@ -372,7 +427,9 @@ class POSO(BaseModel):
|
|
|
372
427
|
self.tower_heads.append(nn.Linear(tower_output_dim, 1))
|
|
373
428
|
else:
|
|
374
429
|
if expert_hidden_dims is None or not expert_hidden_dims:
|
|
375
|
-
raise ValueError(
|
|
430
|
+
raise ValueError(
|
|
431
|
+
"expert_hidden_dims must be provided for MMoE architecture."
|
|
432
|
+
)
|
|
376
433
|
self.mmoe = POSOMMoE(
|
|
377
434
|
input_dim=self.main_input_dim,
|
|
378
435
|
pc_dim=self.pc_input_dim,
|
|
@@ -385,12 +442,35 @@ class POSO(BaseModel):
|
|
|
385
442
|
scale_factor=expert_gate_scale_factor,
|
|
386
443
|
gate_use_softmax=gate_use_softmax,
|
|
387
444
|
)
|
|
388
|
-
self.towers = nn.ModuleList(
|
|
445
|
+
self.towers = nn.ModuleList(
|
|
446
|
+
[
|
|
447
|
+
MLP(
|
|
448
|
+
input_dim=self.mmoe.expert_output_dim,
|
|
449
|
+
output_layer=True,
|
|
450
|
+
**tower_params,
|
|
451
|
+
)
|
|
452
|
+
for tower_params in tower_params_list
|
|
453
|
+
]
|
|
454
|
+
)
|
|
389
455
|
self.tower_heads = None
|
|
390
|
-
self.prediction_layer = PredictionLayer(
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
456
|
+
self.prediction_layer = PredictionLayer(
|
|
457
|
+
task_type=self.default_task,
|
|
458
|
+
task_dims=[1] * self.num_tasks,
|
|
459
|
+
)
|
|
460
|
+
include_modules = (
|
|
461
|
+
["towers", "tower_heads"]
|
|
462
|
+
if self.architecture == "mlp"
|
|
463
|
+
else ["mmoe", "towers"]
|
|
464
|
+
)
|
|
465
|
+
self.register_regularization_weights(
|
|
466
|
+
embedding_attr="embedding", include_modules=include_modules
|
|
467
|
+
)
|
|
468
|
+
self.compile(
|
|
469
|
+
optimizer=optimizer,
|
|
470
|
+
optimizer_params=optimizer_params,
|
|
471
|
+
loss=loss,
|
|
472
|
+
loss_params=loss_params,
|
|
473
|
+
)
|
|
394
474
|
|
|
395
475
|
def forward(self, x):
|
|
396
476
|
# Embed main and PC features separately so PC can gate hidden units
|
|
@@ -53,42 +53,47 @@ class ShareBottom(BaseModel):
|
|
|
53
53
|
return "ShareBottom"
|
|
54
54
|
|
|
55
55
|
@property
|
|
56
|
-
def
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
56
|
+
def default_task(self):
|
|
57
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
58
|
+
if num_tasks is not None and num_tasks > 0:
|
|
59
|
+
return ["binary"] * num_tasks
|
|
60
|
+
return ["binary"]
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
dense_features: list[DenseFeature],
|
|
65
|
+
sparse_features: list[SparseFeature],
|
|
66
|
+
sequence_features: list[SequenceFeature],
|
|
67
|
+
bottom_params: dict,
|
|
68
|
+
tower_params_list: list[dict],
|
|
69
|
+
target: list[str],
|
|
70
|
+
task: str | list[str] | None = None,
|
|
71
|
+
optimizer: str = "adam",
|
|
72
|
+
optimizer_params: dict = {},
|
|
73
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
74
|
+
loss_params: dict | list[dict] | None = None,
|
|
75
|
+
device: str = "cpu",
|
|
76
|
+
embedding_l1_reg=1e-6,
|
|
77
|
+
dense_l1_reg=1e-5,
|
|
78
|
+
embedding_l2_reg=1e-5,
|
|
79
|
+
dense_l2_reg=1e-4,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
|
|
83
|
+
self.num_tasks = len(target)
|
|
84
|
+
|
|
79
85
|
super(ShareBottom, self).__init__(
|
|
80
86
|
dense_features=dense_features,
|
|
81
87
|
sparse_features=sparse_features,
|
|
82
88
|
sequence_features=sequence_features,
|
|
83
89
|
target=target,
|
|
84
|
-
task=task,
|
|
90
|
+
task=task or self.default_task,
|
|
85
91
|
device=device,
|
|
86
92
|
embedding_l1_reg=embedding_l1_reg,
|
|
87
93
|
dense_l1_reg=dense_l1_reg,
|
|
88
94
|
embedding_l2_reg=embedding_l2_reg,
|
|
89
95
|
dense_l2_reg=dense_l2_reg,
|
|
90
|
-
|
|
91
|
-
**kwargs
|
|
96
|
+
**kwargs,
|
|
92
97
|
)
|
|
93
98
|
|
|
94
99
|
self.loss = loss
|
|
@@ -97,7 +102,9 @@ class ShareBottom(BaseModel):
|
|
|
97
102
|
# Number of tasks
|
|
98
103
|
self.num_tasks = len(target)
|
|
99
104
|
if len(tower_params_list) != self.num_tasks:
|
|
100
|
-
raise ValueError(
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
107
|
+
)
|
|
101
108
|
# Embedding layer
|
|
102
109
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
103
110
|
# Calculate input dimension
|
|
@@ -105,39 +112,48 @@ class ShareBottom(BaseModel):
|
|
|
105
112
|
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
106
113
|
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
107
114
|
# input_dim = emb_dim_total + dense_input_dim
|
|
108
|
-
|
|
115
|
+
|
|
109
116
|
# Shared bottom network
|
|
110
117
|
self.bottom = MLP(input_dim=input_dim, output_layer=False, **bottom_params)
|
|
111
|
-
|
|
118
|
+
|
|
112
119
|
# Get bottom output dimension
|
|
113
|
-
if
|
|
114
|
-
bottom_output_dim = bottom_params[
|
|
120
|
+
if "dims" in bottom_params and len(bottom_params["dims"]) > 0:
|
|
121
|
+
bottom_output_dim = bottom_params["dims"][-1]
|
|
115
122
|
else:
|
|
116
123
|
bottom_output_dim = input_dim
|
|
117
|
-
|
|
124
|
+
|
|
118
125
|
# Task-specific towers
|
|
119
126
|
self.towers = nn.ModuleList()
|
|
120
127
|
for tower_params in tower_params_list:
|
|
121
128
|
tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
|
|
122
129
|
self.towers.append(tower)
|
|
123
|
-
self.prediction_layer = PredictionLayer(
|
|
130
|
+
self.prediction_layer = PredictionLayer(
|
|
131
|
+
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
132
|
+
)
|
|
124
133
|
# Register regularization weights
|
|
125
|
-
self.register_regularization_weights(
|
|
126
|
-
|
|
134
|
+
self.register_regularization_weights(
|
|
135
|
+
embedding_attr="embedding", include_modules=["bottom", "towers"]
|
|
136
|
+
)
|
|
137
|
+
self.compile(
|
|
138
|
+
optimizer=optimizer,
|
|
139
|
+
optimizer_params=optimizer_params,
|
|
140
|
+
loss=loss,
|
|
141
|
+
loss_params=loss_params,
|
|
142
|
+
)
|
|
127
143
|
|
|
128
144
|
def forward(self, x):
|
|
129
145
|
# Get all embeddings and flatten
|
|
130
146
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
131
|
-
|
|
147
|
+
|
|
132
148
|
# Shared bottom
|
|
133
149
|
bottom_output = self.bottom(input_flat) # [B, bottom_dim]
|
|
134
|
-
|
|
150
|
+
|
|
135
151
|
# Task-specific towers
|
|
136
152
|
task_outputs = []
|
|
137
153
|
for tower in self.towers:
|
|
138
154
|
tower_output = tower(bottom_output) # [B, 1]
|
|
139
155
|
task_outputs.append(tower_output)
|
|
140
|
-
|
|
156
|
+
|
|
141
157
|
# Stack outputs: [B, num_tasks]
|
|
142
158
|
y = torch.cat(task_outputs, dim=1)
|
|
143
159
|
return self.prediction_layer(y)
|