nextrec 0.4.22__py3-none-any.whl → 0.4.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/layers.py +96 -46
  3. nextrec/basic/metrics.py +128 -114
  4. nextrec/basic/model.py +94 -91
  5. nextrec/basic/summary.py +36 -2
  6. nextrec/data/dataloader.py +2 -0
  7. nextrec/data/preprocessor.py +137 -5
  8. nextrec/loss/listwise.py +19 -6
  9. nextrec/loss/pairwise.py +6 -4
  10. nextrec/loss/pointwise.py +8 -6
  11. nextrec/models/multi_task/aitm.py +0 -0
  12. nextrec/models/multi_task/apg.py +0 -0
  13. nextrec/models/multi_task/cross_stitch.py +0 -0
  14. nextrec/models/multi_task/esmm.py +5 -28
  15. nextrec/models/multi_task/mmoe.py +6 -28
  16. nextrec/models/multi_task/pepnet.py +335 -0
  17. nextrec/models/multi_task/ple.py +21 -40
  18. nextrec/models/multi_task/poso.py +17 -39
  19. nextrec/models/multi_task/share_bottom.py +5 -28
  20. nextrec/models/multi_task/snr_trans.py +0 -0
  21. nextrec/models/ranking/afm.py +3 -27
  22. nextrec/models/ranking/autoint.py +5 -38
  23. nextrec/models/ranking/dcn.py +1 -26
  24. nextrec/models/ranking/dcn_v2.py +6 -34
  25. nextrec/models/ranking/deepfm.py +2 -29
  26. nextrec/models/ranking/dien.py +2 -28
  27. nextrec/models/ranking/din.py +2 -27
  28. nextrec/models/ranking/eulernet.py +3 -30
  29. nextrec/models/ranking/ffm.py +0 -26
  30. nextrec/models/ranking/fibinet.py +8 -32
  31. nextrec/models/ranking/fm.py +0 -29
  32. nextrec/models/ranking/lr.py +0 -30
  33. nextrec/models/ranking/masknet.py +4 -30
  34. nextrec/models/ranking/pnn.py +4 -28
  35. nextrec/models/ranking/widedeep.py +0 -32
  36. nextrec/models/ranking/xdeepfm.py +0 -30
  37. nextrec/models/retrieval/dssm.py +4 -28
  38. nextrec/models/retrieval/dssm_v2.py +4 -28
  39. nextrec/models/retrieval/mind.py +2 -22
  40. nextrec/models/retrieval/sdm.py +4 -24
  41. nextrec/models/retrieval/youtube_dnn.py +4 -25
  42. nextrec/models/sequential/hstu.py +0 -18
  43. nextrec/utils/model.py +91 -4
  44. nextrec/utils/types.py +35 -0
  45. {nextrec-0.4.22.dist-info → nextrec-0.4.24.dist-info}/METADATA +8 -6
  46. nextrec-0.4.24.dist-info/RECORD +86 -0
  47. nextrec-0.4.22.dist-info/RECORD +0 -81
  48. {nextrec-0.4.22.dist-info → nextrec-0.4.24.dist-info}/WHEEL +0 -0
  49. {nextrec-0.4.22.dist-info → nextrec-0.4.24.dist-info}/entry_points.txt +0 -0
  50. {nextrec-0.4.22.dist-info → nextrec-0.4.24.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/listwise.py CHANGED
@@ -2,10 +2,11 @@
2
2
  Listwise loss functions for ranking and contrastive training.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
+ from typing import Literal
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  import torch.nn.functional as F
@@ -16,7 +17,7 @@ class SampledSoftmaxLoss(nn.Module):
16
17
  Softmax over one positive and multiple sampled negatives.
17
18
  """
18
19
 
19
- def __init__(self, reduction: str = "mean"):
20
+ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean"):
20
21
  super().__init__()
21
22
  self.reduction = reduction
22
23
 
@@ -37,7 +38,11 @@ class InfoNCELoss(nn.Module):
37
38
  InfoNCE loss for contrastive learning with one positive and many negatives.
38
39
  """
39
40
 
40
- def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
41
+ def __init__(
42
+ self,
43
+ temperature: float = 0.07,
44
+ reduction: Literal["mean", "sum", "none"] = "mean",
45
+ ):
41
46
  super().__init__()
42
47
  self.temperature = temperature
43
48
  self.reduction = reduction
@@ -61,7 +66,11 @@ class ListNetLoss(nn.Module):
61
66
  Reference: Cao et al. (ICML 2007)
62
67
  """
63
68
 
64
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
69
+ def __init__(
70
+ self,
71
+ temperature: float = 1.0,
72
+ reduction: Literal["mean", "sum", "none"] = "mean",
73
+ ):
65
74
  super().__init__()
66
75
  self.temperature = temperature
67
76
  self.reduction = reduction
@@ -84,7 +93,7 @@ class ListMLELoss(nn.Module):
84
93
  Reference: Xia et al. (ICML 2008)
85
94
  """
86
95
 
87
- def __init__(self, reduction: str = "mean"):
96
+ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean"):
88
97
  super().__init__()
89
98
  self.reduction = reduction
90
99
 
@@ -117,7 +126,11 @@ class ApproxNDCGLoss(nn.Module):
117
126
  Reference: Qin et al. (2010)
118
127
  """
119
128
 
120
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
129
+ def __init__(
130
+ self,
131
+ temperature: float = 1.0,
132
+ reduction: Literal["mean", "sum", "none"] = "mean",
133
+ ):
121
134
  super().__init__()
122
135
  self.temperature = temperature
123
136
  self.reduction = reduction
nextrec/loss/pairwise.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Pairwise loss functions for learning-to-rank and matching tasks.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -18,7 +18,7 @@ class BPRLoss(nn.Module):
18
18
  Bayesian Personalized Ranking loss with support for multiple negatives.
19
19
  """
20
20
 
21
- def __init__(self, reduction: str = "mean"):
21
+ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean"):
22
22
  super().__init__()
23
23
  self.reduction = reduction
24
24
 
@@ -42,7 +42,9 @@ class HingeLoss(nn.Module):
42
42
  Hinge loss for pairwise ranking.
43
43
  """
44
44
 
45
- def __init__(self, margin: float = 1.0, reduction: str = "mean"):
45
+ def __init__(
46
+ self, margin: float = 1.0, reduction: Literal["mean", "sum", "none"] = "mean"
47
+ ):
46
48
  super().__init__()
47
49
  self.margin = margin
48
50
  self.reduction = reduction
@@ -69,7 +71,7 @@ class TripletLoss(nn.Module):
69
71
  def __init__(
70
72
  self,
71
73
  margin: float = 1.0,
72
- reduction: str = "mean",
74
+ reduction: Literal["mean", "sum", "none"] = "mean",
73
75
  distance: Literal["euclidean", "cosine"] = "euclidean",
74
76
  ):
75
77
  super().__init__()
nextrec/loss/pointwise.py CHANGED
@@ -2,11 +2,11 @@
2
2
  Pointwise loss functions, including imbalance-aware variants.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
- from typing import Optional, Sequence
9
+ from typing import Optional, Sequence, Literal
10
10
 
11
11
  import torch
12
12
  import torch.nn as nn
@@ -18,7 +18,9 @@ class CosineContrastiveLoss(nn.Module):
18
18
  Contrastive loss using cosine similarity for positive/negative pairs.
19
19
  """
20
20
 
21
- def __init__(self, margin: float = 0.5, reduction: str = "mean"):
21
+ def __init__(
22
+ self, margin: float = 0.5, reduction: Literal["mean", "sum", "none"] = "mean"
23
+ ):
22
24
  super().__init__()
23
25
  self.margin = margin
24
26
  self.reduction = reduction
@@ -50,7 +52,7 @@ class WeightedBCELoss(nn.Module):
50
52
  def __init__(
51
53
  self,
52
54
  pos_weight: float | torch.Tensor | None = None,
53
- reduction: str = "mean",
55
+ reduction: Literal["mean", "sum", "none"] = "mean",
54
56
  logits: bool = False,
55
57
  auto_balance: bool = False,
56
58
  ):
@@ -110,7 +112,7 @@ class FocalLoss(nn.Module):
110
112
  self,
111
113
  gamma: float = 2.0,
112
114
  alpha: Optional[float | Sequence[float] | torch.Tensor] = None,
113
- reduction: str = "mean",
115
+ reduction: Literal["mean", "sum", "none"] = "mean",
114
116
  logits: bool = False,
115
117
  ):
116
118
  super().__init__()
@@ -187,7 +189,7 @@ class ClassBalancedFocalLoss(nn.Module):
187
189
  class_counts: Sequence[int] | torch.Tensor,
188
190
  beta: float = 0.9999,
189
191
  gamma: float = 2.0,
190
- reduction: str = "mean",
192
+ reduction: Literal["mean", "sum", "none"] = "mean",
191
193
  ):
192
194
  super().__init__()
193
195
  self.gamma = gamma
File without changes
File without changes
File without changes
@@ -42,12 +42,12 @@ CVR 预测 P(conversion|click),二者相乘得到 CTCVR 并在曝光标签上
42
42
  """
43
43
 
44
44
  import torch
45
- import torch.nn as nn
46
45
 
47
46
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
47
  from nextrec.basic.layers import MLP, EmbeddingLayer
49
48
  from nextrec.basic.heads import TaskHead
50
49
  from nextrec.basic.model import BaseModel
50
+ from nextrec.utils.types import TaskTypeName
51
51
 
52
52
 
53
53
  class ESMM(BaseModel):
@@ -77,23 +77,12 @@ class ESMM(BaseModel):
77
77
  sequence_features: list[SequenceFeature],
78
78
  ctr_params: dict,
79
79
  cvr_params: dict,
80
+ task: TaskTypeName | list[TaskTypeName] | None = None,
80
81
  target: list[str] | None = None, # Note: ctcvr = ctr * cvr
81
- task: list[str] | None = None,
82
- optimizer: str = "adam",
83
- optimizer_params: dict | None = None,
84
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
85
- loss_params: dict | list[dict] | None = None,
86
- embedding_l1_reg=0.0,
87
- dense_l1_reg=0.0,
88
- embedding_l2_reg=0.0,
89
- dense_l2_reg=0.0,
90
82
  **kwargs,
91
83
  ):
92
84
 
93
85
  target = target or ["ctr", "ctcvr"]
94
- optimizer_params = optimizer_params or {}
95
- if loss is None:
96
- loss = "bce"
97
86
 
98
87
  if len(target) != 2:
99
88
  raise ValueError(
@@ -120,35 +109,23 @@ class ESMM(BaseModel):
120
109
  sequence_features=sequence_features,
121
110
  target=target,
122
111
  task=resolved_task, # Both CTR and CTCVR are binary classification
123
- embedding_l1_reg=embedding_l1_reg,
124
- dense_l1_reg=dense_l1_reg,
125
- embedding_l2_reg=embedding_l2_reg,
126
- dense_l2_reg=dense_l2_reg,
127
112
  **kwargs,
128
113
  )
129
114
 
130
- self.loss = loss
131
-
132
115
  self.embedding = EmbeddingLayer(features=self.all_features)
133
116
  input_dim = self.embedding.input_dim
134
117
 
135
118
  # CTR tower
136
- self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
119
+ self.ctr_tower = MLP(input_dim=input_dim, output_dim=1, **ctr_params)
137
120
 
138
121
  # CVR tower
139
- self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
122
+ self.cvr_tower = MLP(input_dim=input_dim, output_dim=1, **cvr_params)
140
123
  self.grad_norm_shared_modules = ["embedding"]
141
- self.prediction_layer = TaskHead(task_type=self.default_task, task_dims=[1, 1])
124
+ self.prediction_layer = TaskHead(task_type=self.task, task_dims=[1, 1])
142
125
  # Register regularization weights
143
126
  self.register_regularization_weights(
144
127
  embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
145
128
  )
146
- self.compile(
147
- optimizer=optimizer,
148
- optimizer_params=optimizer_params,
149
- loss=loss,
150
- loss_params=loss_params,
151
- )
152
129
 
153
130
  def forward(self, x):
154
131
  # Get all embeddings and flatten
@@ -82,14 +82,6 @@ class MMOE(BaseModel):
82
82
  tower_params_list: list[dict] | None = None,
83
83
  target: list[str] | str | None = None,
84
84
  task: str | list[str] = "binary",
85
- optimizer: str = "adam",
86
- optimizer_params: dict | None = None,
87
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
88
- loss_params: dict | list[dict] | None = None,
89
- embedding_l1_reg=0.0,
90
- dense_l1_reg=0.0,
91
- embedding_l2_reg=0.0,
92
- dense_l2_reg=0.0,
93
85
  **kwargs,
94
86
  ):
95
87
 
@@ -98,9 +90,7 @@ class MMOE(BaseModel):
98
90
  sequence_features = sequence_features or []
99
91
  expert_params = expert_params or {}
100
92
  tower_params_list = tower_params_list or []
101
- optimizer_params = optimizer_params or {}
102
- if loss is None:
103
- loss = "bce"
93
+
104
94
  if target is None:
105
95
  target = []
106
96
  elif isinstance(target, str):
@@ -126,15 +116,9 @@ class MMOE(BaseModel):
126
116
  sequence_features=sequence_features,
127
117
  target=target,
128
118
  task=resolved_task,
129
- embedding_l1_reg=embedding_l1_reg,
130
- dense_l1_reg=dense_l1_reg,
131
- embedding_l2_reg=embedding_l2_reg,
132
- dense_l2_reg=dense_l2_reg,
133
119
  **kwargs,
134
120
  )
135
121
 
136
- self.loss = loss
137
-
138
122
  # Number of tasks and experts
139
123
  self.nums_task = len(target)
140
124
  self.num_experts = num_experts
@@ -150,12 +134,12 @@ class MMOE(BaseModel):
150
134
  # Expert networks (shared by all tasks)
151
135
  self.experts = nn.ModuleList()
152
136
  for _ in range(num_experts):
153
- expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
137
+ expert = MLP(input_dim=input_dim, output_dim=None, **expert_params)
154
138
  self.experts.append(expert)
155
139
 
156
140
  # Get expert output dimension
157
- if "dims" in expert_params and len(expert_params["dims"]) > 0:
158
- expert_output_dim = expert_params["dims"][-1]
141
+ if "hidden_dims" in expert_params and len(expert_params["hidden_dims"]) > 0:
142
+ expert_output_dim = expert_params["hidden_dims"][-1]
159
143
  else:
160
144
  expert_output_dim = input_dim
161
145
 
@@ -169,21 +153,15 @@ class MMOE(BaseModel):
169
153
  # Task-specific towers
170
154
  self.towers = nn.ModuleList()
171
155
  for tower_params in tower_params_list:
172
- tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
156
+ tower = MLP(input_dim=expert_output_dim, output_dim=1, **tower_params)
173
157
  self.towers.append(tower)
174
158
  self.prediction_layer = TaskHead(
175
- task_type=self.default_task, task_dims=[1] * self.nums_task
159
+ task_type=self.task, task_dims=[1] * self.nums_task
176
160
  )
177
161
  # Register regularization weights
178
162
  self.register_regularization_weights(
179
163
  embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
180
164
  )
181
- self.compile(
182
- optimizer=optimizer,
183
- optimizer_params=optimizer_params,
184
- loss=self.loss,
185
- loss_params=loss_params,
186
- )
187
165
 
188
166
  def forward(self, x):
189
167
  # Get all embeddings and flatten
@@ -0,0 +1,335 @@
1
+ """
2
+ Date: create on 09/11/2025
3
+ Checkpoint: edit on 30/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ Reference:
6
+ [1] Yang et al. "PEPNet: Parameter and Embedding Personalized Network for Multi-Task Learning", 2021.
7
+ [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation:
8
+ https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/blob/main/model/pepnet.py
9
+
10
+ PEPNet (Parameter and Embedding Personalized Network) is a multi-task learning
11
+ model that personalizes both input features and layer transformations with
12
+ context (scene/domain, user, item). It applies a shared feature gate to the
13
+ backbone embedding and then uses per-task gated MLP blocks (PPNet blocks) whose
14
+ gates are conditioned on task-specific context. This enables task-aware routing
15
+ at both feature and layer levels, improving adaptation across scenarios/tasks.
16
+
17
+ Workflow:
18
+ (1) Embed all features and build the backbone input
19
+ (2) Build task context embedding from domain/user/item features
20
+ (3) Feature gate masks backbone input using domain context
21
+ (4) Each task tower applies layer-wise gates conditioned on context + backbone embedding output
22
+ (5) Task heads produce per-task predictions
23
+
24
+ Key Advantages:
25
+ - Two-level personalization: feature gate + layer gates
26
+ - Context-driven routing for multi-scenario/multi-task recommendation
27
+ - Task towers share embeddings while adapting via gates
28
+ - Gate input uses stop-grad on backbone embedding output for stable training
29
+ - Compatible with heterogeneous features via unified embeddings
30
+
31
+ PEPNet(Parameter and Embedding Personalized Network)通过场景/用户/物品等上下文
32
+ 对输入特征与网络层进行双层门控个性化。先用共享特征门控调整主干输入,再在每个
33
+ 任务塔中使用条件门控的 MLP 层(PPNet block),实现任务与场景感知的逐层路由。
34
+
35
+ 流程:
36
+ (1) 对全部特征做 embedding,得到主干输入
37
+ (2) 由场景/用户/物品特征构建任务上下文向量
38
+ (3) 共享特征门控按场景调制主干输入
39
+ (4) 任务塔逐层门控,结合上下文与主干 embedding 输出进行路由
40
+ (5) 任务头输出各任务预测结果
41
+
42
+ 主要优点:
43
+ - 特征级与层级双重个性化
44
+ - 上下文驱动的多场景/多任务适配
45
+ - 共享 embedding 的同时通过门控实现任务定制
46
+ - 对主干 embedding 输出 stop-grad,提高训练稳定性
47
+ - 统一 embedding 支持多类特征
48
+ """
49
+
50
+ from __future__ import annotations
51
+
52
+ import torch
53
+ import torch.nn as nn
54
+
55
+ from nextrec.basic.activation import activation_layer
56
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
57
+ from nextrec.basic.layers import EmbeddingLayer, GateMLP
58
+ from nextrec.basic.heads import TaskHead
59
+ from nextrec.basic.model import BaseModel
60
+ from nextrec.utils.model import select_features
61
+ from nextrec.utils.types import ActivationName, TaskTypeName
62
+
63
+
64
+ class PPNetBlock(nn.Module):
65
+ """
66
+ PEPNet block with per-layer gates conditioned on task context.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ input_dim: int,
72
+ output_dim: int,
73
+ gate_input_dim: int,
74
+ gate_hidden_dim: int | None,
75
+ hidden_units: list[int] | None = None,
76
+ hidden_activations: ActivationName | list[ActivationName] = "relu",
77
+ dropout_rates: float | list[float] = 0.0,
78
+ batch_norm: bool = False,
79
+ use_bias: bool = True,
80
+ gate_activation: ActivationName = "relu",
81
+ gate_dropout: float = 0.0,
82
+ gate_use_bn: bool = False,
83
+ ) -> None:
84
+ super().__init__()
85
+ hidden_units = hidden_units or []
86
+
87
+ if isinstance(dropout_rates, list):
88
+ if len(dropout_rates) != len(hidden_units):
89
+ raise ValueError("dropout_rates length must match hidden_units length.")
90
+ dropout_list = dropout_rates
91
+ else:
92
+ dropout_list = [dropout_rates] * len(hidden_units)
93
+
94
+ if isinstance(hidden_activations, list):
95
+ if len(hidden_activations) != len(hidden_units):
96
+ raise ValueError(
97
+ "hidden_activations length must match hidden_units length."
98
+ )
99
+ activation_list = hidden_activations
100
+ else:
101
+ activation_list = [hidden_activations] * len(hidden_units)
102
+
103
+ self.gate_layers = nn.ModuleList()
104
+ self.mlp_layers = nn.ModuleList()
105
+
106
+ layer_units = [input_dim] + hidden_units
107
+ for idx in range(len(layer_units) - 1):
108
+ dense_layers: list[nn.Module] = [
109
+ nn.Linear(layer_units[idx], layer_units[idx + 1], bias=use_bias)
110
+ ]
111
+ if batch_norm:
112
+ dense_layers.append(nn.BatchNorm1d(layer_units[idx + 1]))
113
+ dense_layers.append(activation_layer(activation_list[idx]))
114
+ if dropout_list[idx] > 0:
115
+ dense_layers.append(nn.Dropout(p=dropout_list[idx]))
116
+
117
+ self.gate_layers.append(
118
+ GateMLP(
119
+ input_dim=gate_input_dim,
120
+ hidden_dim=gate_hidden_dim,
121
+ output_dim=layer_units[idx],
122
+ activation=gate_activation,
123
+ dropout=gate_dropout,
124
+ use_bn=gate_use_bn,
125
+ scale_factor=2.0,
126
+ )
127
+ )
128
+ self.mlp_layers.append(nn.Sequential(*dense_layers))
129
+
130
+ self.gate_layers.append(
131
+ GateMLP(
132
+ input_dim=gate_input_dim,
133
+ hidden_dim=gate_hidden_dim,
134
+ output_dim=layer_units[-1],
135
+ activation=gate_activation,
136
+ dropout=gate_dropout,
137
+ use_bn=gate_use_bn,
138
+ scale_factor=1.0,
139
+ )
140
+ )
141
+ self.mlp_layers.append(nn.Linear(layer_units[-1], output_dim, bias=use_bias))
142
+
143
+ def forward(self, o_ep: torch.Tensor, o_prior: torch.Tensor) -> torch.Tensor:
144
+ """
145
+ o_ep: EPNet output embedding (will be stop-grad in gate input)
146
+ o_prior: prior/task context embedding
147
+ """
148
+ gate_input = torch.cat([o_prior, o_ep.detach()], dim=-1)
149
+
150
+ hidden = o_ep
151
+ for gate, mlp in zip(self.gate_layers, self.mlp_layers):
152
+ gw = gate(gate_input)
153
+ hidden = mlp(hidden * gw)
154
+ return hidden
155
+
156
+
157
+ class PEPNet(BaseModel):
158
+ """
159
+ PEPNet: feature-gated multi-task tower with task-conditioned gates.
160
+ """
161
+
162
+ @property
163
+ def model_name(self) -> str:
164
+ return "PepNet"
165
+
166
+ @property
167
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
168
+ nums_task = self.nums_task if hasattr(self, "nums_task") else None
169
+ if nums_task is not None and nums_task > 0:
170
+ return ["binary"] * nums_task
171
+ return ["binary"]
172
+
173
+ def __init__(
174
+ self,
175
+ dense_features: list[DenseFeature] | None = None,
176
+ sparse_features: list[SparseFeature] | None = None,
177
+ sequence_features: list[SequenceFeature] | None = None,
178
+ target: list[str] | str | None = None,
179
+ task: TaskTypeName | list[TaskTypeName] | None = None,
180
+ dnn_hidden_units: list[int] | None = None,
181
+ dnn_activation: ActivationName = "relu",
182
+ dnn_dropout: float | list[float] = 0.0,
183
+ dnn_use_bn: bool = False,
184
+ feature_gate_hidden_dim: int = 128,
185
+ gate_hidden_dim: int | None = None,
186
+ gate_activation: ActivationName = "relu",
187
+ gate_dropout: float = 0.0,
188
+ gate_use_bn: bool = False,
189
+ domain_features: list[str] | str | None = None,
190
+ user_features: list[str] | str | None = None,
191
+ item_features: list[str] | str | None = None,
192
+ use_bias: bool = True,
193
+ **kwargs,
194
+ ) -> None:
195
+ dense_features = dense_features or []
196
+ sparse_features = sparse_features or []
197
+ sequence_features = sequence_features or []
198
+ dnn_hidden_units = dnn_hidden_units or [256, 128]
199
+
200
+ if target is None:
201
+ target = []
202
+ elif isinstance(target, str):
203
+ target = [target]
204
+
205
+ self.nums_task = len(target) if target else 1
206
+ resolved_task = task
207
+ if resolved_task is None:
208
+ resolved_task = self.default_task
209
+ elif isinstance(resolved_task, str):
210
+ resolved_task = [resolved_task] * self.nums_task
211
+ elif len(resolved_task) == 1 and self.nums_task > 1:
212
+ resolved_task = resolved_task * self.nums_task
213
+ elif len(resolved_task) != self.nums_task:
214
+ raise ValueError(
215
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
216
+ )
217
+
218
+ super().__init__(
219
+ dense_features=dense_features,
220
+ sparse_features=sparse_features,
221
+ sequence_features=sequence_features,
222
+ target=target,
223
+ task=resolved_task,
224
+ **kwargs,
225
+ )
226
+
227
+ if isinstance(domain_features, str):
228
+ domain_features = [domain_features]
229
+ if isinstance(user_features, str):
230
+ user_features = [user_features]
231
+ if isinstance(item_features, str):
232
+ item_features = [item_features]
233
+
234
+ self.scene_feature_names = list(domain_features or [])
235
+ self.user_feature_names = list(user_features or [])
236
+ self.item_feature_names = list(item_features or [])
237
+
238
+ if not self.scene_feature_names:
239
+ raise ValueError("PepNet requires at least one scene feature name.")
240
+
241
+ self.domain_features = select_features(
242
+ self.all_features, self.scene_feature_names, "domain_features"
243
+ )
244
+ self.user_features = select_features(
245
+ self.all_features, self.user_feature_names, "user_features"
246
+ )
247
+ self.item_features = select_features(
248
+ self.all_features, self.item_feature_names, "item_features"
249
+ )
250
+
251
+ if not self.all_features:
252
+ raise ValueError("PepNet requires at least one input feature.")
253
+
254
+ self.embedding = EmbeddingLayer(features=self.all_features)
255
+ input_dim = self.embedding.get_input_dim(self.all_features)
256
+ domain_dim = self.embedding.get_input_dim(self.domain_features)
257
+ user_dim = (
258
+ self.embedding.get_input_dim(self.user_features)
259
+ if self.user_features
260
+ else 0
261
+ )
262
+ item_dim = (
263
+ self.embedding.get_input_dim(self.item_features)
264
+ if self.item_features
265
+ else 0
266
+ )
267
+ task_dim = domain_dim + user_dim + item_dim
268
+
269
+ self.feature_gate = GateMLP(
270
+ input_dim=input_dim + domain_dim,
271
+ hidden_dim=feature_gate_hidden_dim,
272
+ output_dim=input_dim,
273
+ activation=gate_activation,
274
+ dropout=gate_dropout,
275
+ use_bn=gate_use_bn,
276
+ )
277
+
278
+ self.ppn_blocks = nn.ModuleList(
279
+ [
280
+ PPNetBlock(
281
+ input_dim=input_dim,
282
+ output_dim=1,
283
+ gate_input_dim=input_dim + task_dim,
284
+ gate_hidden_dim=gate_hidden_dim,
285
+ hidden_units=dnn_hidden_units,
286
+ hidden_activations=dnn_activation,
287
+ dropout_rates=dnn_dropout,
288
+ batch_norm=dnn_use_bn,
289
+ use_bias=use_bias,
290
+ gate_activation=gate_activation,
291
+ gate_dropout=gate_dropout,
292
+ gate_use_bn=gate_use_bn,
293
+ )
294
+ for _ in range(self.nums_task)
295
+ ]
296
+ )
297
+
298
+ self.prediction_layer = TaskHead(
299
+ task_type=self.task, task_dims=[1] * self.nums_task
300
+ )
301
+ self.grad_norm_shared_modules = ["embedding", "feature_gate"]
302
+ self.register_regularization_weights(
303
+ embedding_attr="embedding", include_modules=["feature_gate", "ppn_blocks"]
304
+ )
305
+
306
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
307
+ dnn_input = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
308
+ domain_emb = self.embedding(
309
+ x=x, features=self.domain_features, squeeze_dim=True
310
+ ).detach()
311
+
312
+ task_parts = [domain_emb]
313
+ if self.user_features:
314
+ task_parts.append(
315
+ self.embedding(
316
+ x=x, features=self.user_features, squeeze_dim=True
317
+ ).detach()
318
+ )
319
+ if self.item_features:
320
+ task_parts.append(
321
+ self.embedding(
322
+ x=x, features=self.item_features, squeeze_dim=True
323
+ ).detach()
324
+ )
325
+ task_sf_emb = torch.cat(task_parts, dim=-1)
326
+
327
+ gate_input = torch.cat([dnn_input.detach(), domain_emb], dim=-1)
328
+ dnn_input = self.feature_gate(gate_input) * dnn_input
329
+
330
+ task_logits = []
331
+ for block in self.ppn_blocks:
332
+ task_logits.append(block(o_ep=dnn_input, o_prior=task_sf_emb))
333
+
334
+ y = torch.cat(task_logits, dim=1)
335
+ return self.prediction_layer(y)