nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +250 -112
  7. nextrec/basic/loggers.py +63 -44
  8. nextrec/basic/metrics.py +270 -120
  9. nextrec/basic/model.py +1084 -402
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +492 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +273 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +69 -46
  29. nextrec/models/multi_task/mmoe.py +91 -53
  30. nextrec/models/multi_task/ple.py +117 -58
  31. nextrec/models/multi_task/poso.py +163 -55
  32. nextrec/models/multi_task/share_bottom.py +63 -36
  33. nextrec/models/ranking/afm.py +80 -45
  34. nextrec/models/ranking/autoint.py +74 -57
  35. nextrec/models/ranking/dcn.py +110 -48
  36. nextrec/models/ranking/dcn_v2.py +265 -45
  37. nextrec/models/ranking/deepfm.py +39 -24
  38. nextrec/models/ranking/dien.py +335 -146
  39. nextrec/models/ranking/din.py +158 -92
  40. nextrec/models/ranking/fibinet.py +134 -52
  41. nextrec/models/ranking/fm.py +68 -26
  42. nextrec/models/ranking/masknet.py +95 -33
  43. nextrec/models/ranking/pnn.py +128 -58
  44. nextrec/models/ranking/widedeep.py +40 -28
  45. nextrec/models/ranking/xdeepfm.py +67 -40
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +496 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +33 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/model.py +22 -0
  55. nextrec/utils/optimizer.py +25 -9
  56. nextrec/utils/synthetic_data.py +283 -165
  57. nextrec/utils/tensor.py +24 -13
  58. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
  59. nextrec-0.4.3.dist-info/RECORD +69 -0
  60. nextrec-0.4.3.dist-info/entry_points.txt +2 -0
  61. nextrec-0.4.1.dist-info/RECORD +0 -66
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
  63. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -46,7 +46,8 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
46
  from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
47
47
  from nextrec.basic.activation import activation_layer
48
48
  from nextrec.basic.model import BaseModel
49
- from nextrec.utils.model import merge_features
49
+
50
+ from nextrec.utils.model import select_features
50
51
 
51
52
 
52
53
  class POSOGate(nn.Module):
@@ -79,7 +80,7 @@ class POSOGate(nn.Module):
79
80
  h = self.act(self.fc1(pc))
80
81
  g = torch.sigmoid(self.fc2(h)) # (B, out_dim) in (0,1)
81
82
  return self.scale_factor * g
82
-
83
+
83
84
 
84
85
  class POSOFC(nn.Module):
85
86
  """
@@ -116,10 +117,10 @@ class POSOFC(nn.Module):
116
117
  pc: (B, pc_dim)
117
118
  return: (B, out_dim)
118
119
  """
119
- h = self.act(self.linear(x)) # Standard FC with activation
120
- g = self.gate(pc) # (B, out_dim)
121
- return g * h # Element-wise gating
122
-
120
+ h = self.act(self.linear(x)) # Standard FC with activation
121
+ g = self.gate(pc) # (B, out_dim)
122
+ return g * h # Element-wise gating
123
+
123
124
 
124
125
  class POSOMLP(nn.Module):
125
126
  """
@@ -173,7 +174,7 @@ class POSOMLP(nn.Module):
173
174
  if self.dropout is not None:
174
175
  h = self.dropout(h)
175
176
  return h
176
-
177
+
177
178
 
178
179
  class POSOMMoE(nn.Module):
179
180
  """
@@ -183,7 +184,7 @@ class POSOMMoE(nn.Module):
183
184
  - Task gates aggregate the PC-masked expert outputs
184
185
 
185
186
  Concretely:
186
- h_e = expert_e(x) # (B, D)
187
+ h_e = expert_e(x) # (B, D)
187
188
  g_e = POSOGate(pc) in (0, C)^{D} # (B, D)
188
189
  h_e_tilde = g_e ⊙ h_e # (B, D)
189
190
  z_t = Σ_e gate_t,e(x) * h_e_tilde
@@ -192,14 +193,14 @@ class POSOMMoE(nn.Module):
192
193
  def __init__(
193
194
  self,
194
195
  input_dim: int,
195
- pc_dim: int, # for poso feature dimension
196
+ pc_dim: int, # for poso feature dimension
196
197
  num_experts: int,
197
198
  expert_hidden_dims: list[int],
198
199
  num_tasks: int,
199
200
  activation: str = "relu",
200
201
  expert_dropout: float = 0.0,
201
- gate_hidden_dim: int = 32, # for poso gate hidden dimension
202
- scale_factor: float = 2.0, # for poso gate scale factor
202
+ gate_hidden_dim: int = 32, # for poso gate hidden dimension
203
+ scale_factor: float = 2.0, # for poso gate scale factor
203
204
  gate_use_softmax: bool = True,
204
205
  ) -> None:
205
206
  super().__init__()
@@ -207,15 +208,41 @@ class POSOMMoE(nn.Module):
207
208
  self.num_tasks = num_tasks
208
209
 
209
210
  # Experts built with framework MLP, same as standard MMoE
210
- self.experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, dims=expert_hidden_dims, activation=activation, dropout=expert_dropout,) for _ in range(num_experts)])
211
- self.expert_output_dim = expert_hidden_dims[-1] if expert_hidden_dims else input_dim
211
+ self.experts = nn.ModuleList(
212
+ [
213
+ MLP(
214
+ input_dim=input_dim,
215
+ output_layer=False,
216
+ dims=expert_hidden_dims,
217
+ activation=activation,
218
+ dropout=expert_dropout,
219
+ )
220
+ for _ in range(num_experts)
221
+ ]
222
+ )
223
+ self.expert_output_dim = (
224
+ expert_hidden_dims[-1] if expert_hidden_dims else input_dim
225
+ )
212
226
 
213
227
  # Task-specific gates: gate_t(x) over experts
214
- self.gates = nn.ModuleList([nn.Linear(input_dim, num_experts) for _ in range(num_tasks)])
228
+ self.gates = nn.ModuleList(
229
+ [nn.Linear(input_dim, num_experts) for _ in range(num_tasks)]
230
+ )
215
231
  self.gate_use_softmax = gate_use_softmax
216
232
 
217
233
  # PC gate per expert: g_e(pc) ∈ R^D
218
- self.expert_pc_gates = nn.ModuleList([POSOGate(pc_dim=pc_dim, out_dim=self.expert_output_dim, hidden_dim=gate_hidden_dim, scale_factor=scale_factor, activation=activation,) for _ in range(num_experts)])
234
+ self.expert_pc_gates = nn.ModuleList(
235
+ [
236
+ POSOGate(
237
+ pc_dim=pc_dim,
238
+ out_dim=self.expert_output_dim,
239
+ hidden_dim=gate_hidden_dim,
240
+ scale_factor=scale_factor,
241
+ activation=activation,
242
+ )
243
+ for _ in range(num_experts)
244
+ ]
245
+ )
219
246
 
220
247
  def forward(self, x: torch.Tensor, pc: torch.Tensor) -> list[torch.Tensor]:
221
248
  """
@@ -226,9 +253,9 @@ class POSOMMoE(nn.Module):
226
253
  # 1) Expert outputs with POSO PC gate
227
254
  masked_expert_outputs = []
228
255
  for e, expert in enumerate(self.experts):
229
- h_e = expert(x) # (B, D)
230
- g_e = self.expert_pc_gates[e](pc) # (B, D)
231
- h_e_tilde = g_e * h_e # (B, D)
256
+ h_e = expert(x) # (B, D)
257
+ g_e = self.expert_pc_gates[e](pc) # (B, D)
258
+ h_e_tilde = g_e * h_e # (B, D)
232
259
  masked_expert_outputs.append(h_e_tilde)
233
260
 
234
261
  masked_expert_outputs = torch.stack(masked_expert_outputs, dim=1) # (B, E, D)
@@ -236,13 +263,13 @@ class POSOMMoE(nn.Module):
236
263
  # 2) Task gates depend on x as in standard MMoE
237
264
  task_outputs: list[torch.Tensor] = []
238
265
  for t in range(self.num_tasks):
239
- logits = self.gates[t](x) # (B, E)
266
+ logits = self.gates[t](x) # (B, E)
240
267
  if self.gate_use_softmax:
241
268
  gate = F.softmax(logits, dim=1)
242
269
  else:
243
270
  gate = logits
244
271
 
245
- gate = gate.unsqueeze(-1) # (B, E, 1)
272
+ gate = gate.unsqueeze(-1) # (B, E, 1)
246
273
  z_t = torch.sum(gate * masked_expert_outputs, dim=1) # (B, D)
247
274
  task_outputs.append(z_t)
248
275
 
@@ -269,15 +296,18 @@ class POSO(BaseModel):
269
296
 
270
297
  def __init__(
271
298
  self,
272
- main_dense_features: list[DenseFeature] | None,
273
- main_sparse_features: list[SparseFeature] | None,
274
- main_sequence_features: list[SequenceFeature] | None,
275
- pc_dense_features: list[DenseFeature] | None,
276
- pc_sparse_features: list[SparseFeature] | None,
277
- pc_sequence_features: list[SequenceFeature] | None,
299
+ dense_features: list[DenseFeature] | None,
300
+ sparse_features: list[SparseFeature] | None,
301
+ sequence_features: list[SequenceFeature] | None,
302
+ main_dense_features: list[str] | None,
303
+ main_sparse_features: list[str] | None,
304
+ main_sequence_features: list[str] | None,
305
+ pc_dense_features: list[str] | None,
306
+ pc_sparse_features: list[str] | None,
307
+ pc_sequence_features: list[str] | None,
278
308
  tower_params_list: list[dict],
279
309
  target: list[str],
280
- task: str | list[str] | None = None,
310
+ task: str | list[str] = "binary",
281
311
  architecture: str = "mlp",
282
312
  # POSO gating defaults
283
313
  gate_hidden_dim: int = 32,
@@ -303,28 +333,32 @@ class POSO(BaseModel):
303
333
  dense_l2_reg: float = 1e-4,
304
334
  **kwargs,
305
335
  ):
306
- # Keep explicit copies of main and PC features
307
- self.main_dense_features = list(main_dense_features or [])
308
- self.main_sparse_features = list(main_sparse_features or [])
309
- self.main_sequence_features = list(main_sequence_features or [])
310
- self.pc_dense_features = list(pc_dense_features or [])
311
- self.pc_sparse_features = list(pc_sparse_features or [])
312
- self.pc_sequence_features = list(pc_sequence_features or [])
313
336
  self.num_tasks = len(target)
314
337
 
315
- if not self.pc_dense_features and not self.pc_sparse_features and not self.pc_sequence_features:
316
- raise ValueError("POSO requires at least one PC feature for personalization.")
338
+ # Normalize task to match num_tasks
339
+ resolved_task = task
340
+ if resolved_task is None:
341
+ resolved_task = self.default_task
342
+ elif isinstance(resolved_task, str):
343
+ resolved_task = [resolved_task] * self.num_tasks
344
+ elif len(resolved_task) == 1 and self.num_tasks > 1:
345
+ resolved_task = resolved_task * self.num_tasks
346
+ elif len(resolved_task) != self.num_tasks:
347
+ raise ValueError(
348
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
349
+ )
317
350
 
318
- dense_features = merge_features(self.main_dense_features, self.pc_dense_features)
319
- sparse_features = merge_features(self.main_sparse_features, self.pc_sparse_features)
320
- sequence_features = merge_features(self.main_sequence_features, self.pc_sequence_features)
351
+ if len(tower_params_list) != self.num_tasks:
352
+ raise ValueError(
353
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
354
+ )
321
355
 
322
356
  super().__init__(
323
357
  dense_features=dense_features,
324
358
  sparse_features=sparse_features,
325
359
  sequence_features=sequence_features,
326
360
  target=target,
327
- task=task or self.default_task,
361
+ task=resolved_task,
328
362
  device=device,
329
363
  embedding_l1_reg=embedding_l1_reg,
330
364
  dense_l1_reg=dense_l1_reg,
@@ -333,15 +367,58 @@ class POSO(BaseModel):
333
367
  **kwargs,
334
368
  )
335
369
 
336
- self.loss = loss if loss is not None else "bce"
370
+ self.main_dense_feature_names = list(main_dense_features or [])
371
+ self.main_sparse_feature_names = list(main_sparse_features or [])
372
+ self.main_sequence_feature_names = list(main_sequence_features or [])
373
+ self.pc_dense_feature_names = list(pc_dense_features or [])
374
+ self.pc_sparse_feature_names = list(pc_sparse_features or [])
375
+ self.pc_sequence_feature_names = list(pc_sequence_features or [])
376
+
377
+ if loss is None:
378
+ self.loss = "bce"
379
+ self.loss = loss
380
+
337
381
  optimizer_params = optimizer_params or {}
338
382
 
339
- self.num_tasks = len(target)
340
- if len(tower_params_list) != self.num_tasks:
341
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
383
+ self.main_dense_features = select_features(
384
+ self.dense_features, self.main_dense_feature_names, "main_dense_features"
385
+ )
386
+ self.main_sparse_features = select_features(
387
+ self.sparse_features, self.main_sparse_feature_names, "main_sparse_features"
388
+ )
389
+ self.main_sequence_features = select_features(
390
+ self.sequence_features,
391
+ self.main_sequence_feature_names,
392
+ "main_sequence_features",
393
+ )
394
+
395
+ self.pc_dense_features = select_features(
396
+ self.dense_features, self.pc_dense_feature_names, "pc_dense_features"
397
+ )
398
+ self.pc_sparse_features = select_features(
399
+ self.sparse_features, self.pc_sparse_feature_names, "pc_sparse_features"
400
+ )
401
+ self.pc_sequence_features = select_features(
402
+ self.sequence_features,
403
+ self.pc_sequence_feature_names,
404
+ "pc_sequence_features",
405
+ )
342
406
 
343
- self.main_features = self.main_dense_features + self.main_sparse_features + self.main_sequence_features
344
- self.pc_features = self.pc_dense_features + self.pc_sparse_features + self.pc_sequence_features
407
+ self.main_features = (
408
+ self.main_dense_features
409
+ + self.main_sparse_features
410
+ + self.main_sequence_features
411
+ )
412
+ self.pc_features = (
413
+ self.pc_dense_features + self.pc_sparse_features + self.pc_sequence_features
414
+ )
415
+
416
+ if not self.main_features:
417
+ raise ValueError("POSO requires at least one main feature.")
418
+ if not self.pc_features:
419
+ raise ValueError(
420
+ "POSO requires at least one PC feature for personalization."
421
+ )
345
422
 
346
423
  self.embedding = EmbeddingLayer(features=self.all_features)
347
424
  self.main_input_dim = self.embedding.get_input_dim(self.main_features)
@@ -349,7 +426,9 @@ class POSO(BaseModel):
349
426
 
350
427
  self.architecture = architecture.lower()
351
428
  if self.architecture not in {"mlp", "mmoe"}:
352
- raise ValueError(f"Unsupported architecture '{architecture}', choose from ['mlp', 'mmoe'].")
429
+ raise ValueError(
430
+ f"Unsupported architecture '{architecture}', choose from ['mlp', 'mmoe']."
431
+ )
353
432
 
354
433
  # Build backbones
355
434
  if self.architecture == "mlp":
@@ -358,13 +437,17 @@ class POSO(BaseModel):
358
437
  for tower_params in tower_params_list:
359
438
  dims = tower_params.get("dims")
360
439
  if not dims:
361
- raise ValueError("tower_params must include a non-empty 'dims' list for POSO-MLP towers.")
440
+ raise ValueError(
441
+ "tower_params must include a non-empty 'dims' list for POSO-MLP towers."
442
+ )
362
443
  dropout = tower_params.get("dropout", 0.0)
363
444
  tower = POSOMLP(
364
445
  input_dim=self.main_input_dim,
365
446
  pc_dim=self.pc_input_dim,
366
447
  dims=dims,
367
- gate_hidden_dim=tower_params.get("gate_hidden_dim", gate_hidden_dim),
448
+ gate_hidden_dim=tower_params.get(
449
+ "gate_hidden_dim", gate_hidden_dim
450
+ ),
368
451
  scale_factor=tower_params.get("scale_factor", gate_scale_factor),
369
452
  activation=tower_params.get("activation", gate_activation),
370
453
  use_bias=tower_params.get("use_bias", gate_use_bias),
@@ -375,7 +458,9 @@ class POSO(BaseModel):
375
458
  self.tower_heads.append(nn.Linear(tower_output_dim, 1))
376
459
  else:
377
460
  if expert_hidden_dims is None or not expert_hidden_dims:
378
- raise ValueError("expert_hidden_dims must be provided for MMoE architecture.")
461
+ raise ValueError(
462
+ "expert_hidden_dims must be provided for MMoE architecture."
463
+ )
379
464
  self.mmoe = POSOMMoE(
380
465
  input_dim=self.main_input_dim,
381
466
  pc_dim=self.pc_input_dim,
@@ -388,12 +473,35 @@ class POSO(BaseModel):
388
473
  scale_factor=expert_gate_scale_factor,
389
474
  gate_use_softmax=gate_use_softmax,
390
475
  )
391
- self.towers = nn.ModuleList([MLP(input_dim=self.mmoe.expert_output_dim, output_layer=True, **tower_params,) for tower_params in tower_params_list])
476
+ self.towers = nn.ModuleList(
477
+ [
478
+ MLP(
479
+ input_dim=self.mmoe.expert_output_dim,
480
+ output_layer=True,
481
+ **tower_params,
482
+ )
483
+ for tower_params in tower_params_list
484
+ ]
485
+ )
392
486
  self.tower_heads = None
393
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks,)
394
- include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
395
- self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
396
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
487
+ self.prediction_layer = PredictionLayer(
488
+ task_type=self.default_task,
489
+ task_dims=[1] * self.num_tasks,
490
+ )
491
+ include_modules = (
492
+ ["towers", "tower_heads"]
493
+ if self.architecture == "mlp"
494
+ else ["mmoe", "towers"]
495
+ )
496
+ self.register_regularization_weights(
497
+ embedding_attr="embedding", include_modules=include_modules
498
+ )
499
+ self.compile(
500
+ optimizer=optimizer,
501
+ optimizer_params=optimizer_params,
502
+ loss=loss,
503
+ loss_params=loss_params,
504
+ )
397
505
 
398
506
  def forward(self, x):
399
507
  # Embed main and PC features separately so PC can gate hidden units
@@ -56,42 +56,58 @@ class ShareBottom(BaseModel):
56
56
  def default_task(self):
57
57
  num_tasks = getattr(self, "num_tasks", None)
58
58
  if num_tasks is not None and num_tasks > 0:
59
- return ['binary'] * num_tasks
60
- return ['binary']
61
-
62
- def __init__(self,
63
- dense_features: list[DenseFeature],
64
- sparse_features: list[SparseFeature],
65
- sequence_features: list[SequenceFeature],
66
- bottom_params: dict,
67
- tower_params_list: list[dict],
68
- target: list[str],
69
- task: str | list[str] | None = None,
70
- optimizer: str = "adam",
71
- optimizer_params: dict = {},
72
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
73
- loss_params: dict | list[dict] | None = None,
74
- device: str = 'cpu',
75
- embedding_l1_reg=1e-6,
76
- dense_l1_reg=1e-5,
77
- embedding_l2_reg=1e-5,
78
- dense_l2_reg=1e-4,
79
- **kwargs):
80
-
59
+ return ["binary"] * num_tasks
60
+ return ["binary"]
61
+
62
+ def __init__(
63
+ self,
64
+ dense_features: list[DenseFeature],
65
+ sparse_features: list[SparseFeature],
66
+ sequence_features: list[SequenceFeature],
67
+ bottom_params: dict,
68
+ tower_params_list: list[dict],
69
+ target: list[str],
70
+ task: str | list[str] | None = None,
71
+ optimizer: str = "adam",
72
+ optimizer_params: dict | None = None,
73
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
74
+ loss_params: dict | list[dict] | None = None,
75
+ device: str = "cpu",
76
+ embedding_l1_reg=1e-6,
77
+ dense_l1_reg=1e-5,
78
+ embedding_l2_reg=1e-5,
79
+ dense_l2_reg=1e-4,
80
+ **kwargs,
81
+ ):
82
+
83
+ optimizer_params = optimizer_params or {}
84
+
81
85
  self.num_tasks = len(target)
82
86
 
87
+ resolved_task = task
88
+ if resolved_task is None:
89
+ resolved_task = self.default_task
90
+ elif isinstance(resolved_task, str):
91
+ resolved_task = [resolved_task] * self.num_tasks
92
+ elif len(resolved_task) == 1 and self.num_tasks > 1:
93
+ resolved_task = resolved_task * self.num_tasks
94
+ elif len(resolved_task) != self.num_tasks:
95
+ raise ValueError(
96
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
97
+ )
98
+
83
99
  super(ShareBottom, self).__init__(
84
100
  dense_features=dense_features,
85
101
  sparse_features=sparse_features,
86
102
  sequence_features=sequence_features,
87
103
  target=target,
88
- task=task or self.default_task,
104
+ task=resolved_task,
89
105
  device=device,
90
106
  embedding_l1_reg=embedding_l1_reg,
91
107
  dense_l1_reg=dense_l1_reg,
92
108
  embedding_l2_reg=embedding_l2_reg,
93
109
  dense_l2_reg=dense_l2_reg,
94
- **kwargs
110
+ **kwargs,
95
111
  )
96
112
 
97
113
  self.loss = loss
@@ -100,7 +116,9 @@ class ShareBottom(BaseModel):
100
116
  # Number of tasks
101
117
  self.num_tasks = len(target)
102
118
  if len(tower_params_list) != self.num_tasks:
103
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
119
+ raise ValueError(
120
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
121
+ )
104
122
  # Embedding layer
105
123
  self.embedding = EmbeddingLayer(features=self.all_features)
106
124
  # Calculate input dimension
@@ -108,39 +126,48 @@ class ShareBottom(BaseModel):
108
126
  # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
109
127
  # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
110
128
  # input_dim = emb_dim_total + dense_input_dim
111
-
129
+
112
130
  # Shared bottom network
113
131
  self.bottom = MLP(input_dim=input_dim, output_layer=False, **bottom_params)
114
-
132
+
115
133
  # Get bottom output dimension
116
- if 'dims' in bottom_params and len(bottom_params['dims']) > 0:
117
- bottom_output_dim = bottom_params['dims'][-1]
134
+ if "dims" in bottom_params and len(bottom_params["dims"]) > 0:
135
+ bottom_output_dim = bottom_params["dims"][-1]
118
136
  else:
119
137
  bottom_output_dim = input_dim
120
-
138
+
121
139
  # Task-specific towers
122
140
  self.towers = nn.ModuleList()
123
141
  for tower_params in tower_params_list:
124
142
  tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
125
143
  self.towers.append(tower)
126
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
144
+ self.prediction_layer = PredictionLayer(
145
+ task_type=self.default_task, task_dims=[1] * self.num_tasks
146
+ )
127
147
  # Register regularization weights
128
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
129
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
148
+ self.register_regularization_weights(
149
+ embedding_attr="embedding", include_modules=["bottom", "towers"]
150
+ )
151
+ self.compile(
152
+ optimizer=optimizer,
153
+ optimizer_params=optimizer_params,
154
+ loss=loss,
155
+ loss_params=loss_params,
156
+ )
130
157
 
131
158
  def forward(self, x):
132
159
  # Get all embeddings and flatten
133
160
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
134
-
161
+
135
162
  # Shared bottom
136
163
  bottom_output = self.bottom(input_flat) # [B, bottom_dim]
137
-
164
+
138
165
  # Task-specific towers
139
166
  task_outputs = []
140
167
  for tower in self.towers:
141
168
  tower_output = tower(bottom_output) # [B, 1]
142
169
  task_outputs.append(tower_output)
143
-
170
+
144
171
  # Stack outputs: [B, num_tasks]
145
172
  y = torch.cat(task_outputs, dim=1)
146
173
  return self.prediction_layer(y)