nextrec 0.4.25__py3-none-any.whl → 0.4.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/asserts.py +72 -0
  3. nextrec/basic/loggers.py +18 -1
  4. nextrec/basic/model.py +54 -51
  5. nextrec/data/batch_utils.py +23 -3
  6. nextrec/data/dataloader.py +3 -8
  7. nextrec/models/multi_task/[pre]aitm.py +173 -0
  8. nextrec/models/multi_task/[pre]snr_trans.py +232 -0
  9. nextrec/models/multi_task/[pre]star.py +192 -0
  10. nextrec/models/multi_task/apg.py +330 -0
  11. nextrec/models/multi_task/cross_stitch.py +229 -0
  12. nextrec/models/multi_task/escm.py +290 -0
  13. nextrec/models/multi_task/esmm.py +8 -21
  14. nextrec/models/multi_task/hmoe.py +203 -0
  15. nextrec/models/multi_task/mmoe.py +20 -28
  16. nextrec/models/multi_task/pepnet.py +81 -76
  17. nextrec/models/multi_task/ple.py +30 -44
  18. nextrec/models/multi_task/poso.py +13 -22
  19. nextrec/models/multi_task/share_bottom.py +14 -25
  20. nextrec/models/ranking/afm.py +2 -2
  21. nextrec/models/ranking/autoint.py +2 -4
  22. nextrec/models/ranking/dcn.py +2 -3
  23. nextrec/models/ranking/dcn_v2.py +2 -3
  24. nextrec/models/ranking/deepfm.py +2 -3
  25. nextrec/models/ranking/dien.py +7 -9
  26. nextrec/models/ranking/din.py +8 -10
  27. nextrec/models/ranking/eulernet.py +1 -2
  28. nextrec/models/ranking/ffm.py +1 -2
  29. nextrec/models/ranking/fibinet.py +2 -3
  30. nextrec/models/ranking/fm.py +1 -1
  31. nextrec/models/ranking/lr.py +1 -1
  32. nextrec/models/ranking/masknet.py +1 -2
  33. nextrec/models/ranking/pnn.py +1 -2
  34. nextrec/models/ranking/widedeep.py +2 -3
  35. nextrec/models/ranking/xdeepfm.py +2 -4
  36. nextrec/models/representation/rqvae.py +4 -4
  37. nextrec/models/retrieval/dssm.py +18 -26
  38. nextrec/models/retrieval/dssm_v2.py +15 -22
  39. nextrec/models/retrieval/mind.py +9 -15
  40. nextrec/models/retrieval/sdm.py +36 -33
  41. nextrec/models/retrieval/youtube_dnn.py +16 -24
  42. nextrec/models/sequential/hstu.py +2 -2
  43. nextrec/utils/__init__.py +5 -1
  44. nextrec/utils/model.py +9 -14
  45. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/METADATA +72 -62
  46. nextrec-0.4.28.dist-info/RECORD +90 -0
  47. nextrec/models/multi_task/aitm.py +0 -0
  48. nextrec/models/multi_task/snr_trans.py +0 -0
  49. nextrec-0.4.25.dist-info/RECORD +0 -86
  50. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/WHEEL +0 -0
  51. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/entry_points.txt +0 -0
  52. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,290 @@
1
+ """
2
+ Date: create on 01/01/2026
3
+ Checkpoint: edit on 01/01/2026
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ Reference:
6
+ - [1] Wang H, Chang T-W, Liu T, Huang J, Chen Z, Yu C, Li R, Chu W. ESCM²: Entire Space Counterfactual Multi-Task Model for Post-Click Conversion Rate Estimation. Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’22), 2022:363–372.
7
+ URL: https://arxiv.org/abs/2204.05125
8
+ - [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
+
10
+ Entire Space Counterfactual Model (ESCM) extends ESMM with counterfactual
11
+ training objectives (e.g., IPS/DR) to debias CVR estimation. The architecture
12
+ keeps separate CTR/CVR towers and derives CTCVR as the product of probabilities.
13
+ Optional exposure propensity (IMP) prediction is included for DR-style variants.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import torch
19
+
20
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
21
+ from nextrec.basic.heads import TaskHead
22
+ from nextrec.basic.layers import EmbeddingLayer, MLP
23
+ from nextrec.basic.model import BaseModel
24
+ from nextrec.loss.grad_norm import get_grad_norm_shared_params
25
+ from nextrec.utils.model import compute_ranking_loss
26
+ from nextrec.utils.types import TaskTypeName
27
+
28
+
29
+ class ESCM(BaseModel):
30
+ """
31
+ Entire Space Counterfactual Model.
32
+ """
33
+
34
+ @property
35
+ def model_name(self) -> str:
36
+ return "ESCM"
37
+
38
+ @property
39
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
40
+ nums_task = getattr(self, "nums_task", None)
41
+ if nums_task is not None and nums_task > 0:
42
+ return ["binary"] * nums_task
43
+ return ["binary"]
44
+
45
+ def __init__(
46
+ self,
47
+ dense_features: list[DenseFeature] | None = None,
48
+ sparse_features: list[SparseFeature] | None = None,
49
+ sequence_features: list[SequenceFeature] | None = None,
50
+ ctr_mlp_params: dict | None = None,
51
+ cvr_mlp_params: dict | None = None,
52
+ imp_mlp_params: dict | None = None,
53
+ use_dr: bool = False,
54
+ target: list[str] | str | None = None,
55
+ task: TaskTypeName | list[TaskTypeName] | None = None,
56
+ **kwargs,
57
+ ) -> None:
58
+ dense_features = dense_features or []
59
+ sparse_features = sparse_features or []
60
+ sequence_features = sequence_features or []
61
+ ctr_mlp_params = ctr_mlp_params or {}
62
+ cvr_mlp_params = cvr_mlp_params or {}
63
+ imp_mlp_params = imp_mlp_params or {}
64
+
65
+ if target is None:
66
+ target = ["ctr", "cvr", "ctcvr"]
67
+ if use_dr:
68
+ target.append("imp")
69
+ elif isinstance(target, str):
70
+ target = [target]
71
+
72
+ self.nums_task = len(target) if target else 1
73
+
74
+ super().__init__(
75
+ dense_features=dense_features,
76
+ sparse_features=sparse_features,
77
+ sequence_features=sequence_features,
78
+ target=target,
79
+ task=task,
80
+ **kwargs,
81
+ )
82
+
83
+ if not target:
84
+ raise ValueError("ESCM requires at least one target.")
85
+
86
+ valid_targets = {"ctr", "cvr", "ctcvr", "imp"}
87
+ default_roles = ["ctr", "cvr", "ctcvr", "imp"]
88
+ if all(name in valid_targets for name in target):
89
+ target_roles = list(target)
90
+ else:
91
+ if len(target) > len(default_roles):
92
+ raise ValueError(
93
+ f"ESCM supports up to {len(default_roles)} targets, got {len(target)}."
94
+ )
95
+ target_roles = default_roles[: len(target)]
96
+
97
+ self.target_roles = target_roles
98
+ self.use_dr = use_dr or ("imp" in self.target_roles)
99
+ base_targets = ["ctr", "cvr"]
100
+ if self.use_dr:
101
+ base_targets.append("imp")
102
+
103
+ self.embedding = EmbeddingLayer(features=self.all_features)
104
+ input_dim = self.embedding.input_dim
105
+
106
+ self.ctr_tower = MLP(input_dim=input_dim, output_dim=1, **ctr_mlp_params)
107
+ self.cvr_tower = MLP(input_dim=input_dim, output_dim=1, **cvr_mlp_params)
108
+ if self.use_dr:
109
+ self.imp_tower = MLP(input_dim=input_dim, output_dim=1, **imp_mlp_params)
110
+
111
+ self.base_task_types = ["binary"] * len(base_targets)
112
+ self.prediction_layer = TaskHead(
113
+ task_type=self.base_task_types, task_dims=[1] * len(base_targets)
114
+ )
115
+
116
+ self.grad_norm_shared_modules = ["embedding"]
117
+ reg_modules = ["ctr_tower", "cvr_tower"]
118
+ if self.use_dr:
119
+ reg_modules.append("imp_tower")
120
+ self.register_regularization_weights(
121
+ embedding_attr="embedding", include_modules=reg_modules
122
+ )
123
+
124
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
125
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
126
+
127
+ ctr_logit = self.ctr_tower(input_flat)
128
+ cvr_logit = self.cvr_tower(input_flat)
129
+ base_logits = [ctr_logit, cvr_logit]
130
+ if self.use_dr:
131
+ imp_logit = self.imp_tower(input_flat)
132
+ base_logits.append(imp_logit)
133
+
134
+ base_logits_cat = torch.cat(base_logits, dim=1)
135
+ base_preds = self.prediction_layer(base_logits_cat)
136
+ base_preds = base_preds.split(1, dim=1)
137
+
138
+ pred_map = {"ctr": base_preds[0], "cvr": base_preds[1]}
139
+ if self.use_dr:
140
+ pred_map["imp"] = base_preds[2]
141
+
142
+ ctcvr_pred = pred_map["ctr"] * pred_map["cvr"]
143
+
144
+ outputs = []
145
+ for name in self.target_roles:
146
+ if name == "ctcvr":
147
+ outputs.append(ctcvr_pred)
148
+ else:
149
+ outputs.append(pred_map[name])
150
+ return torch.cat(outputs, dim=1)
151
+
152
+ def _loss_no_reduce(
153
+ self,
154
+ loss_fn: torch.nn.Module,
155
+ y_pred: torch.Tensor,
156
+ y_true: torch.Tensor,
157
+ ) -> torch.Tensor:
158
+ if hasattr(loss_fn, "reduction"):
159
+ reduction = loss_fn.reduction
160
+ if reduction != "none":
161
+ loss_fn.reduction = "none"
162
+ loss = loss_fn(y_pred, y_true)
163
+ loss_fn.reduction = reduction
164
+ else:
165
+ loss = loss_fn(y_pred, y_true)
166
+ else:
167
+ loss = loss_fn(y_pred, y_true)
168
+
169
+ if loss.dim() == 0:
170
+ return loss
171
+ if loss.dim() > 1:
172
+ loss = loss.view(loss.size(0), -1).mean(dim=1)
173
+ return loss.view(-1)
174
+
175
+ def _compute_cvr_loss(
176
+ self,
177
+ loss_fn: torch.nn.Module,
178
+ y_pred: torch.Tensor,
179
+ y_true: torch.Tensor,
180
+ click_label: torch.Tensor | None,
181
+ prop_pred: torch.Tensor | None,
182
+ valid_mask: torch.Tensor | None,
183
+ eps: float = 1e-7,
184
+ ) -> torch.Tensor:
185
+ if click_label is None:
186
+ return loss_fn(y_pred.view(-1), y_true.view(-1))
187
+
188
+ click = click_label
189
+ if valid_mask is not None:
190
+ click = click[valid_mask]
191
+ click = click.detach()
192
+
193
+ if prop_pred is not None:
194
+ prop = prop_pred
195
+ if valid_mask is not None:
196
+ prop = prop[valid_mask]
197
+ prop = prop.detach()
198
+ prop = torch.clamp(prop, min=eps, max=1.0 - eps)
199
+ weight = (click / prop).view(-1)
200
+ else:
201
+ weight = click.view(-1)
202
+
203
+ per_sample = self._loss_no_reduce(loss_fn, y_pred, y_true).view(-1)
204
+ if self.use_dr and prop_pred is not None:
205
+ impute_target = y_pred.detach()
206
+ impute_loss = self._loss_no_reduce(loss_fn, y_pred, impute_target).view(-1)
207
+ return (impute_loss + weight * (per_sample - impute_loss)).mean()
208
+ return (per_sample * weight).mean()
209
+
210
+ def compute_loss(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
211
+ if y_true is None:
212
+ raise ValueError(
213
+ "[ESCM-compute_loss Error] Ground truth labels (y_true) are required."
214
+ )
215
+
216
+ if y_pred.dim() == 1:
217
+ y_pred = y_pred.view(-1, 1)
218
+ if y_true.dim() == 1:
219
+ y_true = y_true.view(-1, 1)
220
+
221
+ role_to_index = {role: idx for idx, role in enumerate(self.target_roles)}
222
+ ctr_index = role_to_index.get("ctr")
223
+ imp_index = role_to_index.get("imp")
224
+
225
+ ctr_pred = (
226
+ y_pred[:, ctr_index : ctr_index + 1] if ctr_index is not None else None
227
+ )
228
+ ctr_true = (
229
+ y_true[:, ctr_index : ctr_index + 1] if ctr_index is not None else None
230
+ )
231
+ imp_pred = (
232
+ y_pred[:, imp_index : imp_index + 1] if imp_index is not None else None
233
+ )
234
+
235
+ task_losses = []
236
+ for i, role in enumerate(self.target_roles):
237
+ y_pred_i = y_pred[:, i : i + 1]
238
+ y_true_i = y_true[:, i : i + 1]
239
+ valid_mask = None
240
+ if self.ignore_label is not None:
241
+ valid_mask = y_true_i != self.ignore_label
242
+ if valid_mask.dim() > 1:
243
+ valid_mask = valid_mask.all(dim=1)
244
+ if not torch.any(valid_mask):
245
+ task_losses.append(y_pred_i.sum() * 0.0)
246
+ continue
247
+ y_pred_i = y_pred_i[valid_mask]
248
+ y_true_i = y_true_i[valid_mask]
249
+
250
+ if role == "cvr":
251
+ prop_pred = imp_pred if self.use_dr else ctr_pred
252
+ if prop_pred is None:
253
+ prop_pred = ctr_pred
254
+ task_loss = self._compute_cvr_loss(
255
+ loss_fn=self.loss_fn[i],
256
+ y_pred=y_pred_i,
257
+ y_true=y_true_i,
258
+ click_label=ctr_true,
259
+ prop_pred=prop_pred,
260
+ valid_mask=valid_mask,
261
+ )
262
+ else:
263
+ mode = self.training_modes[i]
264
+ if mode in {"pairwise", "listwise"}:
265
+ task_loss = compute_ranking_loss(
266
+ training_mode=mode,
267
+ loss_fn=self.loss_fn[i],
268
+ y_pred=y_pred_i,
269
+ y_true=y_true_i,
270
+ )
271
+ elif y_pred_i.shape[1] == 1:
272
+ task_loss = self.loss_fn[i](y_pred_i.view(-1), y_true_i.view(-1))
273
+ else:
274
+ task_loss = self.loss_fn[i](y_pred_i, y_true_i)
275
+ task_losses.append(task_loss)
276
+
277
+ if self.grad_norm is not None:
278
+ if self.grad_norm_shared_params is None:
279
+ self.grad_norm_shared_params = get_grad_norm_shared_params(
280
+ self, getattr(self, "grad_norm_shared_modules", None)
281
+ )
282
+ return self.grad_norm.compute_weighted_loss(
283
+ task_losses, self.grad_norm_shared_params
284
+ )
285
+ if isinstance(self.loss_weights, (list, tuple)):
286
+ task_losses = [
287
+ task_loss * self.loss_weights[i]
288
+ for i, task_loss in enumerate(task_losses)
289
+ ]
290
+ return torch.stack(task_losses).sum()
@@ -3,9 +3,8 @@ Date: create on 09/11/2025
3
3
  Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
- [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach
7
- for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
8
- (https://dl.acm.org/doi/10.1145/3209978.3210007)
6
+ - [1] Ma X, Zhao L, Huang G, Wang Z, Hu Z, Zhu X, Gai K. Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate. In: Proceedings of the 41st International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’18), 2018, pp. 1137–1140.
7
+ URL: https://dl.acm.org/doi/10.1145/3209978.3210007
9
8
 
10
9
  Entire Space Multi-task Model (ESMM) targets CVR estimation by jointly optimizing
11
10
  CTR and CTCVR on the full impression space, mitigating sample selection bias and
@@ -75,9 +74,9 @@ class ESMM(BaseModel):
75
74
  dense_features: list[DenseFeature],
76
75
  sparse_features: list[SparseFeature],
77
76
  sequence_features: list[SequenceFeature],
78
- ctr_params: dict,
79
- cvr_params: dict,
80
- task: TaskTypeName | list[TaskTypeName] | None = None,
77
+ ctr_mlp_params: dict,
78
+ cvr_mlp_params: dict,
79
+ task: list[TaskTypeName] | None = None,
81
80
  target: list[str] | None = None, # Note: ctcvr = ctr * cvr
82
81
  **kwargs,
83
82
  ):
@@ -90,25 +89,13 @@ class ESMM(BaseModel):
90
89
  )
91
90
 
92
91
  self.nums_task = len(target)
93
- resolved_task = task
94
- if resolved_task is None:
95
- resolved_task = self.default_task
96
- elif isinstance(resolved_task, str):
97
- resolved_task = [resolved_task] * self.nums_task
98
- elif len(resolved_task) == 1 and self.nums_task > 1:
99
- resolved_task = resolved_task * self.nums_task
100
- elif len(resolved_task) != self.nums_task:
101
- raise ValueError(
102
- f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
103
- )
104
- # resolved_task is now guaranteed to be a list[str]
105
92
 
106
93
  super(ESMM, self).__init__(
107
94
  dense_features=dense_features,
108
95
  sparse_features=sparse_features,
109
96
  sequence_features=sequence_features,
110
97
  target=target,
111
- task=resolved_task, # Both CTR and CTCVR are binary classification
98
+ task=task, # Both CTR and CTCVR are binary classification
112
99
  **kwargs,
113
100
  )
114
101
 
@@ -116,10 +103,10 @@ class ESMM(BaseModel):
116
103
  input_dim = self.embedding.input_dim
117
104
 
118
105
  # CTR tower
119
- self.ctr_tower = MLP(input_dim=input_dim, output_dim=1, **ctr_params)
106
+ self.ctr_tower = MLP(input_dim=input_dim, output_dim=1, **ctr_mlp_params)
120
107
 
121
108
  # CVR tower
122
- self.cvr_tower = MLP(input_dim=input_dim, output_dim=1, **cvr_params)
109
+ self.cvr_tower = MLP(input_dim=input_dim, output_dim=1, **cvr_mlp_params)
123
110
  self.grad_norm_shared_modules = ["embedding"]
124
111
  self.prediction_layer = TaskHead(task_type=self.task, task_dims=[1, 1])
125
112
  # Register regularization weights
@@ -0,0 +1,203 @@
1
+ """
2
+ Date: create on 01/01/2026
3
+ Checkpoint: edit on 01/01/2026
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ [1] Zhao Z, Liu Y, Jin R, Zhu X, He X. HMOE: Improving Multi-Scenario Learning to Rank in E-commerce by Exploiting Task Relationships in the Label Space. Proceedings of the 29th ACM International Conference on Information & Knowledge Management (CIKM ’20), 2020, pp. 2069–2078.
6
+ URL: https://dl.acm.org/doi/10.1145/3340531.3412713
7
+ [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation:
8
+ https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
+
10
+ Hierarchical Mixture-of-Experts (HMOE) extends MMOE with task-to-task
11
+ feature aggregation. Each task builds a tower representation from expert
12
+ mixtures, then a task-weight network mixes all tower features with
13
+ stop-gradient on non-target tasks to reduce negative transfer.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
+ from nextrec.basic.layers import MLP, EmbeddingLayer
23
+ from nextrec.basic.heads import TaskHead
24
+ from nextrec.basic.model import BaseModel
25
+ from nextrec.utils.model import get_mlp_output_dim
26
+ from nextrec.utils.types import TaskTypeName
27
+
28
+
29
+ class HMOE(BaseModel):
30
+ """
31
+ Hierarchical Mixture-of-Experts.
32
+ """
33
+
34
+ @property
35
+ def model_name(self) -> str:
36
+ return "HMOE"
37
+
38
+ @property
39
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
40
+ nums_task = getattr(self, "nums_task", None)
41
+ if nums_task is not None and nums_task > 0:
42
+ return ["binary"] * nums_task
43
+ return ["binary"]
44
+
45
+ def __init__(
46
+ self,
47
+ dense_features: list[DenseFeature] | None = None,
48
+ sparse_features: list[SparseFeature] | None = None,
49
+ sequence_features: list[SequenceFeature] | None = None,
50
+ expert_mlp_params: dict | None = None,
51
+ num_experts: int = 4,
52
+ gate_mlp_params: dict | None = None,
53
+ tower_mlp_params_list: list[dict] | None = None,
54
+ task_weight_mlp_params: list[dict] | None = None,
55
+ target: list[str] | str | None = None,
56
+ task: TaskTypeName | list[TaskTypeName] | None = None,
57
+ **kwargs,
58
+ ) -> None:
59
+ dense_features = dense_features or []
60
+ sparse_features = sparse_features or []
61
+ sequence_features = sequence_features or []
62
+ expert_mlp_params = expert_mlp_params or {}
63
+ gate_mlp_params = gate_mlp_params or {}
64
+ tower_mlp_params_list = tower_mlp_params_list or []
65
+
66
+ if target is None:
67
+ target = []
68
+ elif isinstance(target, str):
69
+ target = [target]
70
+
71
+ self.nums_task = len(target) if target else 1
72
+
73
+ super().__init__(
74
+ dense_features=dense_features,
75
+ sparse_features=sparse_features,
76
+ sequence_features=sequence_features,
77
+ target=target,
78
+ task=task,
79
+ **kwargs,
80
+ )
81
+
82
+ self.nums_task = len(target) if target else 1
83
+ self.num_experts = num_experts
84
+
85
+ if len(tower_mlp_params_list) != self.nums_task:
86
+ raise ValueError(
87
+ "Number of tower mlp params "
88
+ f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
89
+ )
90
+
91
+ self.embedding = EmbeddingLayer(features=self.all_features)
92
+ input_dim = self.embedding.input_dim
93
+
94
+ self.experts = nn.ModuleList(
95
+ [
96
+ MLP(input_dim=input_dim, output_dim=None, **expert_mlp_params)
97
+ for _ in range(num_experts)
98
+ ]
99
+ )
100
+ expert_output_dim = get_mlp_output_dim(expert_mlp_params, input_dim)
101
+
102
+ self.gates = nn.ModuleList(
103
+ [
104
+ MLP(input_dim=input_dim, output_dim=num_experts, **gate_mlp_params)
105
+ for _ in range(self.nums_task)
106
+ ]
107
+ )
108
+ self.grad_norm_shared_modules = [
109
+ "embedding",
110
+ "experts",
111
+ "gates",
112
+ "task_weights",
113
+ ]
114
+
115
+ tower_params = [params.copy() for params in tower_mlp_params_list]
116
+ tower_output_dims = [
117
+ get_mlp_output_dim(params, expert_output_dim) for params in tower_params
118
+ ]
119
+ if len(set(tower_output_dims)) != 1:
120
+ raise ValueError(
121
+ f"All tower output dims must match, got {tower_output_dims}."
122
+ )
123
+ tower_output_dim = tower_output_dims[0]
124
+
125
+ self.towers = nn.ModuleList(
126
+ [
127
+ MLP(input_dim=expert_output_dim, output_dim=None, **params)
128
+ for params in tower_params
129
+ ]
130
+ )
131
+ self.tower_logits = nn.ModuleList(
132
+ [nn.Linear(tower_output_dim, 1, bias=False) for _ in range(self.nums_task)]
133
+ )
134
+
135
+ if task_weight_mlp_params is None:
136
+ raise ValueError("task_weight_mlp_params must be a list of dicts.")
137
+ if len(task_weight_mlp_params) != self.nums_task:
138
+ raise ValueError(
139
+ "Length of task_weight_mlp_params "
140
+ f"({len(task_weight_mlp_params)}) must match number of tasks ({self.nums_task})."
141
+ )
142
+ task_weight_mlp_params_list = [
143
+ params.copy() for params in task_weight_mlp_params
144
+ ]
145
+ self.task_weights = nn.ModuleList(
146
+ [
147
+ MLP(input_dim=input_dim, output_dim=self.nums_task, **params)
148
+ for params in task_weight_mlp_params_list
149
+ ]
150
+ )
151
+
152
+ self.prediction_layer = TaskHead(
153
+ task_type=self.task, task_dims=[1] * self.nums_task
154
+ )
155
+
156
+ self.register_regularization_weights(
157
+ embedding_attr="embedding",
158
+ include_modules=[
159
+ "experts",
160
+ "gates",
161
+ "task_weights",
162
+ "towers",
163
+ "tower_logits",
164
+ ],
165
+ )
166
+
167
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
168
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
169
+
170
+ expert_outputs = [expert(input_flat) for expert in self.experts]
171
+ expert_outputs = torch.stack(expert_outputs, dim=0) # [E, B, D]
172
+ expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, E, D]
173
+
174
+ tower_features = []
175
+ for task_idx in range(self.nums_task):
176
+ gate_logits = self.gates[task_idx](input_flat)
177
+ gate_weights = torch.softmax(gate_logits, dim=1).unsqueeze(2)
178
+ gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1)
179
+ tower_features.append(self.towers[task_idx](gated_output))
180
+
181
+ task_weight_probs = [
182
+ torch.softmax(task_weight(input_flat), dim=1)
183
+ for task_weight in self.task_weights
184
+ ]
185
+
186
+ task_logits = []
187
+ for task_idx in range(self.nums_task):
188
+ task_feat = (
189
+ task_weight_probs[task_idx][:, task_idx].view(-1, 1)
190
+ * tower_features[task_idx]
191
+ )
192
+ for other_idx in range(self.nums_task):
193
+ if other_idx == task_idx:
194
+ continue
195
+ task_feat = (
196
+ task_feat
197
+ + task_weight_probs[task_idx][:, other_idx].view(-1, 1)
198
+ * tower_features[other_idx].detach()
199
+ )
200
+ task_logits.append(self.tower_logits[task_idx](task_feat))
201
+
202
+ logits = torch.cat(task_logits, dim=1)
203
+ return self.prediction_layer(logits)
@@ -3,9 +3,8 @@ Date: create on 09/11/2025
3
3
  Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
- [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with
7
- multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
8
- (https://dl.acm.org/doi/10.1145/3219819.3220007)
6
+ - [1] Ma J, Zhao Z, Yi X, Chen J, Hong L, Chi E H. Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts. In: Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD ’18), 2018, pp. 1930–1939.
7
+ URL: https://dl.acm.org/doi/10.1145/3219819.3220007
9
8
 
10
9
  Multi-gate Mixture-of-Experts (MMoE) extends shared-bottom multi-task learning by
11
10
  introducing multiple experts and task-specific softmax gates. Each task learns its
@@ -49,6 +48,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
49
48
  from nextrec.basic.layers import MLP, EmbeddingLayer
50
49
  from nextrec.basic.heads import TaskHead
51
50
  from nextrec.basic.model import BaseModel
51
+ from nextrec.utils.types import TaskTypeName
52
52
 
53
53
 
54
54
  class MMOE(BaseModel):
@@ -77,19 +77,19 @@ class MMOE(BaseModel):
77
77
  dense_features: list[DenseFeature] | None = None,
78
78
  sparse_features: list[SparseFeature] | None = None,
79
79
  sequence_features: list[SequenceFeature] | None = None,
80
- expert_params: dict | None = None,
80
+ expert_mlp_params: dict | None = None,
81
81
  num_experts: int = 3,
82
- tower_params_list: list[dict] | None = None,
82
+ tower_mlp_params_list: list[dict] | None = None,
83
83
  target: list[str] | str | None = None,
84
- task: str | list[str] = "binary",
84
+ task: TaskTypeName | list[TaskTypeName] | None = None,
85
85
  **kwargs,
86
86
  ):
87
87
 
88
88
  dense_features = dense_features or []
89
89
  sparse_features = sparse_features or []
90
90
  sequence_features = sequence_features or []
91
- expert_params = expert_params or {}
92
- tower_params_list = tower_params_list or []
91
+ expert_mlp_params = expert_mlp_params or {}
92
+ tower_mlp_params_list = tower_mlp_params_list or []
93
93
 
94
94
  if target is None:
95
95
  target = []
@@ -98,24 +98,12 @@ class MMOE(BaseModel):
98
98
 
99
99
  self.nums_task = len(target) if target else 1
100
100
 
101
- resolved_task = task
102
- if resolved_task is None:
103
- resolved_task = self.default_task
104
- elif isinstance(resolved_task, str):
105
- resolved_task = [resolved_task] * self.nums_task
106
- elif len(resolved_task) == 1 and self.nums_task > 1:
107
- resolved_task = resolved_task * self.nums_task
108
- elif len(resolved_task) != self.nums_task:
109
- raise ValueError(
110
- f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
111
- )
112
-
113
101
  super(MMOE, self).__init__(
114
102
  dense_features=dense_features,
115
103
  sparse_features=sparse_features,
116
104
  sequence_features=sequence_features,
117
105
  target=target,
118
- task=resolved_task,
106
+ task=task,
119
107
  **kwargs,
120
108
  )
121
109
 
@@ -123,9 +111,10 @@ class MMOE(BaseModel):
123
111
  self.nums_task = len(target)
124
112
  self.num_experts = num_experts
125
113
 
126
- if len(tower_params_list) != self.nums_task:
114
+ if len(tower_mlp_params_list) != self.nums_task:
127
115
  raise ValueError(
128
- f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
116
+ "Number of tower mlp params "
117
+ f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})"
129
118
  )
130
119
 
131
120
  self.embedding = EmbeddingLayer(features=self.all_features)
@@ -134,12 +123,15 @@ class MMOE(BaseModel):
134
123
  # Expert networks (shared by all tasks)
135
124
  self.experts = nn.ModuleList()
136
125
  for _ in range(num_experts):
137
- expert = MLP(input_dim=input_dim, output_dim=None, **expert_params)
126
+ expert = MLP(input_dim=input_dim, output_dim=None, **expert_mlp_params)
138
127
  self.experts.append(expert)
139
128
 
140
129
  # Get expert output dimension
141
- if "hidden_dims" in expert_params and len(expert_params["hidden_dims"]) > 0:
142
- expert_output_dim = expert_params["hidden_dims"][-1]
130
+ if (
131
+ "hidden_dims" in expert_mlp_params
132
+ and len(expert_mlp_params["hidden_dims"]) > 0
133
+ ):
134
+ expert_output_dim = expert_mlp_params["hidden_dims"][-1]
143
135
  else:
144
136
  expert_output_dim = input_dim
145
137
 
@@ -152,8 +144,8 @@ class MMOE(BaseModel):
152
144
 
153
145
  # Task-specific towers
154
146
  self.towers = nn.ModuleList()
155
- for tower_params in tower_params_list:
156
- tower = MLP(input_dim=expert_output_dim, output_dim=1, **tower_params)
147
+ for tower_mlp_params in tower_mlp_params_list:
148
+ tower = MLP(input_dim=expert_output_dim, output_dim=1, **tower_mlp_params)
157
149
  self.towers.append(tower)
158
150
  self.prediction_layer = TaskHead(
159
151
  task_type=self.task, task_dims=[1] * self.nums_task