nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/asserts.py +72 -0
- nextrec/basic/loggers.py +18 -1
- nextrec/basic/model.py +191 -71
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/models/multi_task/[pre]aitm.py +173 -0
- nextrec/models/multi_task/[pre]snr_trans.py +232 -0
- nextrec/models/multi_task/[pre]star.py +192 -0
- nextrec/models/multi_task/apg.py +330 -0
- nextrec/models/multi_task/cross_stitch.py +229 -0
- nextrec/models/multi_task/escm.py +290 -0
- nextrec/models/multi_task/esmm.py +8 -21
- nextrec/models/multi_task/hmoe.py +203 -0
- nextrec/models/multi_task/mmoe.py +20 -28
- nextrec/models/multi_task/pepnet.py +68 -66
- nextrec/models/multi_task/ple.py +30 -44
- nextrec/models/multi_task/poso.py +13 -22
- nextrec/models/multi_task/share_bottom.py +14 -25
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -4
- nextrec/models/ranking/dcn.py +2 -3
- nextrec/models/ranking/dcn_v2.py +2 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +7 -9
- nextrec/models/ranking/din.py +8 -10
- nextrec/models/ranking/eulernet.py +1 -2
- nextrec/models/ranking/ffm.py +1 -2
- nextrec/models/ranking/fibinet.py +2 -3
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/lr.py +1 -1
- nextrec/models/ranking/masknet.py +1 -2
- nextrec/models/ranking/pnn.py +1 -2
- nextrec/models/ranking/widedeep.py +2 -3
- nextrec/models/ranking/xdeepfm.py +2 -4
- nextrec/models/representation/rqvae.py +4 -4
- nextrec/models/retrieval/dssm.py +18 -26
- nextrec/models/retrieval/dssm_v2.py +15 -22
- nextrec/models/retrieval/mind.py +9 -15
- nextrec/models/retrieval/sdm.py +36 -33
- nextrec/models/retrieval/youtube_dnn.py +16 -24
- nextrec/models/sequential/hstu.py +2 -2
- nextrec/utils/__init__.py +5 -1
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +16 -77
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
- nextrec-0.4.27.dist-info/RECORD +90 -0
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec-0.4.24.dist-info/RECORD +0 -86
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 01/01/2026
|
|
3
|
+
Checkpoint: edit on 01/01/2026
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
- [1] Wang H, Chang T-W, Liu T, Huang J, Chen Z, Yu C, Li R, Chu W. ESCM²: Entire Space Counterfactual Multi-Task Model for Post-Click Conversion Rate Estimation. Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’22), 2022:363–372.
|
|
7
|
+
URL: https://arxiv.org/abs/2204.05125
|
|
8
|
+
- [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
|
|
9
|
+
|
|
10
|
+
Entire Space Counterfactual Model (ESCM) extends ESMM with counterfactual
|
|
11
|
+
training objectives (e.g., IPS/DR) to debias CVR estimation. The architecture
|
|
12
|
+
keeps separate CTR/CVR towers and derives CTCVR as the product of probabilities.
|
|
13
|
+
Optional exposure propensity (IMP) prediction is included for DR-style variants.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
21
|
+
from nextrec.basic.heads import TaskHead
|
|
22
|
+
from nextrec.basic.layers import EmbeddingLayer, MLP
|
|
23
|
+
from nextrec.basic.model import BaseModel
|
|
24
|
+
from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
25
|
+
from nextrec.utils.model import compute_ranking_loss
|
|
26
|
+
from nextrec.utils.types import TaskTypeName
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ESCM(BaseModel):
|
|
30
|
+
"""
|
|
31
|
+
Entire Space Counterfactual Model.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def model_name(self) -> str:
|
|
36
|
+
return "ESCM"
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
40
|
+
nums_task = getattr(self, "nums_task", None)
|
|
41
|
+
if nums_task is not None and nums_task > 0:
|
|
42
|
+
return ["binary"] * nums_task
|
|
43
|
+
return ["binary"]
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
dense_features: list[DenseFeature] | None = None,
|
|
48
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
49
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
50
|
+
ctr_mlp_params: dict | None = None,
|
|
51
|
+
cvr_mlp_params: dict | None = None,
|
|
52
|
+
imp_mlp_params: dict | None = None,
|
|
53
|
+
use_dr: bool = False,
|
|
54
|
+
target: list[str] | str | None = None,
|
|
55
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
56
|
+
**kwargs,
|
|
57
|
+
) -> None:
|
|
58
|
+
dense_features = dense_features or []
|
|
59
|
+
sparse_features = sparse_features or []
|
|
60
|
+
sequence_features = sequence_features or []
|
|
61
|
+
ctr_mlp_params = ctr_mlp_params or {}
|
|
62
|
+
cvr_mlp_params = cvr_mlp_params or {}
|
|
63
|
+
imp_mlp_params = imp_mlp_params or {}
|
|
64
|
+
|
|
65
|
+
if target is None:
|
|
66
|
+
target = ["ctr", "cvr", "ctcvr"]
|
|
67
|
+
if use_dr:
|
|
68
|
+
target.append("imp")
|
|
69
|
+
elif isinstance(target, str):
|
|
70
|
+
target = [target]
|
|
71
|
+
|
|
72
|
+
self.nums_task = len(target) if target else 1
|
|
73
|
+
|
|
74
|
+
super().__init__(
|
|
75
|
+
dense_features=dense_features,
|
|
76
|
+
sparse_features=sparse_features,
|
|
77
|
+
sequence_features=sequence_features,
|
|
78
|
+
target=target,
|
|
79
|
+
task=task,
|
|
80
|
+
**kwargs,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if not target:
|
|
84
|
+
raise ValueError("ESCM requires at least one target.")
|
|
85
|
+
|
|
86
|
+
valid_targets = {"ctr", "cvr", "ctcvr", "imp"}
|
|
87
|
+
default_roles = ["ctr", "cvr", "ctcvr", "imp"]
|
|
88
|
+
if all(name in valid_targets for name in target):
|
|
89
|
+
target_roles = list(target)
|
|
90
|
+
else:
|
|
91
|
+
if len(target) > len(default_roles):
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"ESCM supports up to {len(default_roles)} targets, got {len(target)}."
|
|
94
|
+
)
|
|
95
|
+
target_roles = default_roles[: len(target)]
|
|
96
|
+
|
|
97
|
+
self.target_roles = target_roles
|
|
98
|
+
self.use_dr = use_dr or ("imp" in self.target_roles)
|
|
99
|
+
base_targets = ["ctr", "cvr"]
|
|
100
|
+
if self.use_dr:
|
|
101
|
+
base_targets.append("imp")
|
|
102
|
+
|
|
103
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
104
|
+
input_dim = self.embedding.input_dim
|
|
105
|
+
|
|
106
|
+
self.ctr_tower = MLP(input_dim=input_dim, output_dim=1, **ctr_mlp_params)
|
|
107
|
+
self.cvr_tower = MLP(input_dim=input_dim, output_dim=1, **cvr_mlp_params)
|
|
108
|
+
if self.use_dr:
|
|
109
|
+
self.imp_tower = MLP(input_dim=input_dim, output_dim=1, **imp_mlp_params)
|
|
110
|
+
|
|
111
|
+
self.base_task_types = ["binary"] * len(base_targets)
|
|
112
|
+
self.prediction_layer = TaskHead(
|
|
113
|
+
task_type=self.base_task_types, task_dims=[1] * len(base_targets)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.grad_norm_shared_modules = ["embedding"]
|
|
117
|
+
reg_modules = ["ctr_tower", "cvr_tower"]
|
|
118
|
+
if self.use_dr:
|
|
119
|
+
reg_modules.append("imp_tower")
|
|
120
|
+
self.register_regularization_weights(
|
|
121
|
+
embedding_attr="embedding", include_modules=reg_modules
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
125
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
126
|
+
|
|
127
|
+
ctr_logit = self.ctr_tower(input_flat)
|
|
128
|
+
cvr_logit = self.cvr_tower(input_flat)
|
|
129
|
+
base_logits = [ctr_logit, cvr_logit]
|
|
130
|
+
if self.use_dr:
|
|
131
|
+
imp_logit = self.imp_tower(input_flat)
|
|
132
|
+
base_logits.append(imp_logit)
|
|
133
|
+
|
|
134
|
+
base_logits_cat = torch.cat(base_logits, dim=1)
|
|
135
|
+
base_preds = self.prediction_layer(base_logits_cat)
|
|
136
|
+
base_preds = base_preds.split(1, dim=1)
|
|
137
|
+
|
|
138
|
+
pred_map = {"ctr": base_preds[0], "cvr": base_preds[1]}
|
|
139
|
+
if self.use_dr:
|
|
140
|
+
pred_map["imp"] = base_preds[2]
|
|
141
|
+
|
|
142
|
+
ctcvr_pred = pred_map["ctr"] * pred_map["cvr"]
|
|
143
|
+
|
|
144
|
+
outputs = []
|
|
145
|
+
for name in self.target_roles:
|
|
146
|
+
if name == "ctcvr":
|
|
147
|
+
outputs.append(ctcvr_pred)
|
|
148
|
+
else:
|
|
149
|
+
outputs.append(pred_map[name])
|
|
150
|
+
return torch.cat(outputs, dim=1)
|
|
151
|
+
|
|
152
|
+
def _loss_no_reduce(
|
|
153
|
+
self,
|
|
154
|
+
loss_fn: torch.nn.Module,
|
|
155
|
+
y_pred: torch.Tensor,
|
|
156
|
+
y_true: torch.Tensor,
|
|
157
|
+
) -> torch.Tensor:
|
|
158
|
+
if hasattr(loss_fn, "reduction"):
|
|
159
|
+
reduction = loss_fn.reduction
|
|
160
|
+
if reduction != "none":
|
|
161
|
+
loss_fn.reduction = "none"
|
|
162
|
+
loss = loss_fn(y_pred, y_true)
|
|
163
|
+
loss_fn.reduction = reduction
|
|
164
|
+
else:
|
|
165
|
+
loss = loss_fn(y_pred, y_true)
|
|
166
|
+
else:
|
|
167
|
+
loss = loss_fn(y_pred, y_true)
|
|
168
|
+
|
|
169
|
+
if loss.dim() == 0:
|
|
170
|
+
return loss
|
|
171
|
+
if loss.dim() > 1:
|
|
172
|
+
loss = loss.view(loss.size(0), -1).mean(dim=1)
|
|
173
|
+
return loss.view(-1)
|
|
174
|
+
|
|
175
|
+
def _compute_cvr_loss(
|
|
176
|
+
self,
|
|
177
|
+
loss_fn: torch.nn.Module,
|
|
178
|
+
y_pred: torch.Tensor,
|
|
179
|
+
y_true: torch.Tensor,
|
|
180
|
+
click_label: torch.Tensor | None,
|
|
181
|
+
prop_pred: torch.Tensor | None,
|
|
182
|
+
valid_mask: torch.Tensor | None,
|
|
183
|
+
eps: float = 1e-7,
|
|
184
|
+
) -> torch.Tensor:
|
|
185
|
+
if click_label is None:
|
|
186
|
+
return loss_fn(y_pred.view(-1), y_true.view(-1))
|
|
187
|
+
|
|
188
|
+
click = click_label
|
|
189
|
+
if valid_mask is not None:
|
|
190
|
+
click = click[valid_mask]
|
|
191
|
+
click = click.detach()
|
|
192
|
+
|
|
193
|
+
if prop_pred is not None:
|
|
194
|
+
prop = prop_pred
|
|
195
|
+
if valid_mask is not None:
|
|
196
|
+
prop = prop[valid_mask]
|
|
197
|
+
prop = prop.detach()
|
|
198
|
+
prop = torch.clamp(prop, min=eps, max=1.0 - eps)
|
|
199
|
+
weight = (click / prop).view(-1)
|
|
200
|
+
else:
|
|
201
|
+
weight = click.view(-1)
|
|
202
|
+
|
|
203
|
+
per_sample = self._loss_no_reduce(loss_fn, y_pred, y_true).view(-1)
|
|
204
|
+
if self.use_dr and prop_pred is not None:
|
|
205
|
+
impute_target = y_pred.detach()
|
|
206
|
+
impute_loss = self._loss_no_reduce(loss_fn, y_pred, impute_target).view(-1)
|
|
207
|
+
return (impute_loss + weight * (per_sample - impute_loss)).mean()
|
|
208
|
+
return (per_sample * weight).mean()
|
|
209
|
+
|
|
210
|
+
def compute_loss(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
|
|
211
|
+
if y_true is None:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
"[ESCM-compute_loss Error] Ground truth labels (y_true) are required."
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if y_pred.dim() == 1:
|
|
217
|
+
y_pred = y_pred.view(-1, 1)
|
|
218
|
+
if y_true.dim() == 1:
|
|
219
|
+
y_true = y_true.view(-1, 1)
|
|
220
|
+
|
|
221
|
+
role_to_index = {role: idx for idx, role in enumerate(self.target_roles)}
|
|
222
|
+
ctr_index = role_to_index.get("ctr")
|
|
223
|
+
imp_index = role_to_index.get("imp")
|
|
224
|
+
|
|
225
|
+
ctr_pred = (
|
|
226
|
+
y_pred[:, ctr_index : ctr_index + 1] if ctr_index is not None else None
|
|
227
|
+
)
|
|
228
|
+
ctr_true = (
|
|
229
|
+
y_true[:, ctr_index : ctr_index + 1] if ctr_index is not None else None
|
|
230
|
+
)
|
|
231
|
+
imp_pred = (
|
|
232
|
+
y_pred[:, imp_index : imp_index + 1] if imp_index is not None else None
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
task_losses = []
|
|
236
|
+
for i, role in enumerate(self.target_roles):
|
|
237
|
+
y_pred_i = y_pred[:, i : i + 1]
|
|
238
|
+
y_true_i = y_true[:, i : i + 1]
|
|
239
|
+
valid_mask = None
|
|
240
|
+
if self.ignore_label is not None:
|
|
241
|
+
valid_mask = y_true_i != self.ignore_label
|
|
242
|
+
if valid_mask.dim() > 1:
|
|
243
|
+
valid_mask = valid_mask.all(dim=1)
|
|
244
|
+
if not torch.any(valid_mask):
|
|
245
|
+
task_losses.append(y_pred_i.sum() * 0.0)
|
|
246
|
+
continue
|
|
247
|
+
y_pred_i = y_pred_i[valid_mask]
|
|
248
|
+
y_true_i = y_true_i[valid_mask]
|
|
249
|
+
|
|
250
|
+
if role == "cvr":
|
|
251
|
+
prop_pred = imp_pred if self.use_dr else ctr_pred
|
|
252
|
+
if prop_pred is None:
|
|
253
|
+
prop_pred = ctr_pred
|
|
254
|
+
task_loss = self._compute_cvr_loss(
|
|
255
|
+
loss_fn=self.loss_fn[i],
|
|
256
|
+
y_pred=y_pred_i,
|
|
257
|
+
y_true=y_true_i,
|
|
258
|
+
click_label=ctr_true,
|
|
259
|
+
prop_pred=prop_pred,
|
|
260
|
+
valid_mask=valid_mask,
|
|
261
|
+
)
|
|
262
|
+
else:
|
|
263
|
+
mode = self.training_modes[i]
|
|
264
|
+
if mode in {"pairwise", "listwise"}:
|
|
265
|
+
task_loss = compute_ranking_loss(
|
|
266
|
+
training_mode=mode,
|
|
267
|
+
loss_fn=self.loss_fn[i],
|
|
268
|
+
y_pred=y_pred_i,
|
|
269
|
+
y_true=y_true_i,
|
|
270
|
+
)
|
|
271
|
+
elif y_pred_i.shape[1] == 1:
|
|
272
|
+
task_loss = self.loss_fn[i](y_pred_i.view(-1), y_true_i.view(-1))
|
|
273
|
+
else:
|
|
274
|
+
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
275
|
+
task_losses.append(task_loss)
|
|
276
|
+
|
|
277
|
+
if self.grad_norm is not None:
|
|
278
|
+
if self.grad_norm_shared_params is None:
|
|
279
|
+
self.grad_norm_shared_params = get_grad_norm_shared_params(
|
|
280
|
+
self, getattr(self, "grad_norm_shared_modules", None)
|
|
281
|
+
)
|
|
282
|
+
return self.grad_norm.compute_weighted_loss(
|
|
283
|
+
task_losses, self.grad_norm_shared_params
|
|
284
|
+
)
|
|
285
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
286
|
+
task_losses = [
|
|
287
|
+
task_loss * self.loss_weights[i]
|
|
288
|
+
for i, task_loss in enumerate(task_losses)
|
|
289
|
+
]
|
|
290
|
+
return torch.stack(task_losses).sum()
|
|
@@ -3,9 +3,8 @@ Date: create on 09/11/2025
|
|
|
3
3
|
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
|
-
[1] Ma X, Zhao L, Huang G,
|
|
7
|
-
|
|
8
|
-
(https://dl.acm.org/doi/10.1145/3209978.3210007)
|
|
6
|
+
- [1] Ma X, Zhao L, Huang G, Wang Z, Hu Z, Zhu X, Gai K. Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate. In: Proceedings of the 41st International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’18), 2018, pp. 1137–1140.
|
|
7
|
+
URL: https://dl.acm.org/doi/10.1145/3209978.3210007
|
|
9
8
|
|
|
10
9
|
Entire Space Multi-task Model (ESMM) targets CVR estimation by jointly optimizing
|
|
11
10
|
CTR and CTCVR on the full impression space, mitigating sample selection bias and
|
|
@@ -75,9 +74,9 @@ class ESMM(BaseModel):
|
|
|
75
74
|
dense_features: list[DenseFeature],
|
|
76
75
|
sparse_features: list[SparseFeature],
|
|
77
76
|
sequence_features: list[SequenceFeature],
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
task:
|
|
77
|
+
ctr_mlp_params: dict,
|
|
78
|
+
cvr_mlp_params: dict,
|
|
79
|
+
task: list[TaskTypeName] | None = None,
|
|
81
80
|
target: list[str] | None = None, # Note: ctcvr = ctr * cvr
|
|
82
81
|
**kwargs,
|
|
83
82
|
):
|
|
@@ -90,25 +89,13 @@ class ESMM(BaseModel):
|
|
|
90
89
|
)
|
|
91
90
|
|
|
92
91
|
self.nums_task = len(target)
|
|
93
|
-
resolved_task = task
|
|
94
|
-
if resolved_task is None:
|
|
95
|
-
resolved_task = self.default_task
|
|
96
|
-
elif isinstance(resolved_task, str):
|
|
97
|
-
resolved_task = [resolved_task] * self.nums_task
|
|
98
|
-
elif len(resolved_task) == 1 and self.nums_task > 1:
|
|
99
|
-
resolved_task = resolved_task * self.nums_task
|
|
100
|
-
elif len(resolved_task) != self.nums_task:
|
|
101
|
-
raise ValueError(
|
|
102
|
-
f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
|
|
103
|
-
)
|
|
104
|
-
# resolved_task is now guaranteed to be a list[str]
|
|
105
92
|
|
|
106
93
|
super(ESMM, self).__init__(
|
|
107
94
|
dense_features=dense_features,
|
|
108
95
|
sparse_features=sparse_features,
|
|
109
96
|
sequence_features=sequence_features,
|
|
110
97
|
target=target,
|
|
111
|
-
task=
|
|
98
|
+
task=task, # Both CTR and CTCVR are binary classification
|
|
112
99
|
**kwargs,
|
|
113
100
|
)
|
|
114
101
|
|
|
@@ -116,10 +103,10 @@ class ESMM(BaseModel):
|
|
|
116
103
|
input_dim = self.embedding.input_dim
|
|
117
104
|
|
|
118
105
|
# CTR tower
|
|
119
|
-
self.ctr_tower = MLP(input_dim=input_dim, output_dim=1, **
|
|
106
|
+
self.ctr_tower = MLP(input_dim=input_dim, output_dim=1, **ctr_mlp_params)
|
|
120
107
|
|
|
121
108
|
# CVR tower
|
|
122
|
-
self.cvr_tower = MLP(input_dim=input_dim, output_dim=1, **
|
|
109
|
+
self.cvr_tower = MLP(input_dim=input_dim, output_dim=1, **cvr_mlp_params)
|
|
123
110
|
self.grad_norm_shared_modules = ["embedding"]
|
|
124
111
|
self.prediction_layer = TaskHead(task_type=self.task, task_dims=[1, 1])
|
|
125
112
|
# Register regularization weights
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 01/01/2026
|
|
3
|
+
Checkpoint: edit on 01/01/2026
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
[1] Zhao Z, Liu Y, Jin R, Zhu X, He X. HMOE: Improving Multi-Scenario Learning to Rank in E-commerce by Exploiting Task Relationships in the Label Space. Proceedings of the 29th ACM International Conference on Information & Knowledge Management (CIKM ’20), 2020, pp. 2069–2078.
|
|
6
|
+
URL: https://dl.acm.org/doi/10.1145/3340531.3412713
|
|
7
|
+
[2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation:
|
|
8
|
+
https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
|
|
9
|
+
|
|
10
|
+
Hierarchical Mixture-of-Experts (HMOE) extends MMOE with task-to-task
|
|
11
|
+
feature aggregation. Each task builds a tower representation from expert
|
|
12
|
+
mixtures, then a task-weight network mixes all tower features with
|
|
13
|
+
stop-gradient on non-target tasks to reduce negative transfer.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
|
|
21
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
22
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
23
|
+
from nextrec.basic.heads import TaskHead
|
|
24
|
+
from nextrec.basic.model import BaseModel
|
|
25
|
+
from nextrec.utils.model import get_mlp_output_dim
|
|
26
|
+
from nextrec.utils.types import TaskTypeName
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class HMOE(BaseModel):
|
|
30
|
+
"""
|
|
31
|
+
Hierarchical Mixture-of-Experts.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def model_name(self) -> str:
|
|
36
|
+
return "HMOE"
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
40
|
+
nums_task = getattr(self, "nums_task", None)
|
|
41
|
+
if nums_task is not None and nums_task > 0:
|
|
42
|
+
return ["binary"] * nums_task
|
|
43
|
+
return ["binary"]
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
dense_features: list[DenseFeature] | None = None,
|
|
48
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
49
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
50
|
+
expert_mlp_params: dict | None = None,
|
|
51
|
+
num_experts: int = 4,
|
|
52
|
+
gate_mlp_params: dict | None = None,
|
|
53
|
+
tower_mlp_params_list: list[dict] | None = None,
|
|
54
|
+
task_weight_mlp_params: list[dict] | None = None,
|
|
55
|
+
target: list[str] | str | None = None,
|
|
56
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
57
|
+
**kwargs,
|
|
58
|
+
) -> None:
|
|
59
|
+
dense_features = dense_features or []
|
|
60
|
+
sparse_features = sparse_features or []
|
|
61
|
+
sequence_features = sequence_features or []
|
|
62
|
+
expert_mlp_params = expert_mlp_params or {}
|
|
63
|
+
gate_mlp_params = gate_mlp_params or {}
|
|
64
|
+
tower_mlp_params_list = tower_mlp_params_list or []
|
|
65
|
+
|
|
66
|
+
if target is None:
|
|
67
|
+
target = []
|
|
68
|
+
elif isinstance(target, str):
|
|
69
|
+
target = [target]
|
|
70
|
+
|
|
71
|
+
self.nums_task = len(target) if target else 1
|
|
72
|
+
|
|
73
|
+
super().__init__(
|
|
74
|
+
dense_features=dense_features,
|
|
75
|
+
sparse_features=sparse_features,
|
|
76
|
+
sequence_features=sequence_features,
|
|
77
|
+
target=target,
|
|
78
|
+
task=task,
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.nums_task = len(target) if target else 1
|
|
83
|
+
self.num_experts = num_experts
|
|
84
|
+
|
|
85
|
+
if len(tower_mlp_params_list) != self.nums_task:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"Number of tower mlp params "
|
|
88
|
+
f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
92
|
+
input_dim = self.embedding.input_dim
|
|
93
|
+
|
|
94
|
+
self.experts = nn.ModuleList(
|
|
95
|
+
[
|
|
96
|
+
MLP(input_dim=input_dim, output_dim=None, **expert_mlp_params)
|
|
97
|
+
for _ in range(num_experts)
|
|
98
|
+
]
|
|
99
|
+
)
|
|
100
|
+
expert_output_dim = get_mlp_output_dim(expert_mlp_params, input_dim)
|
|
101
|
+
|
|
102
|
+
self.gates = nn.ModuleList(
|
|
103
|
+
[
|
|
104
|
+
MLP(input_dim=input_dim, output_dim=num_experts, **gate_mlp_params)
|
|
105
|
+
for _ in range(self.nums_task)
|
|
106
|
+
]
|
|
107
|
+
)
|
|
108
|
+
self.grad_norm_shared_modules = [
|
|
109
|
+
"embedding",
|
|
110
|
+
"experts",
|
|
111
|
+
"gates",
|
|
112
|
+
"task_weights",
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
tower_params = [params.copy() for params in tower_mlp_params_list]
|
|
116
|
+
tower_output_dims = [
|
|
117
|
+
get_mlp_output_dim(params, expert_output_dim) for params in tower_params
|
|
118
|
+
]
|
|
119
|
+
if len(set(tower_output_dims)) != 1:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"All tower output dims must match, got {tower_output_dims}."
|
|
122
|
+
)
|
|
123
|
+
tower_output_dim = tower_output_dims[0]
|
|
124
|
+
|
|
125
|
+
self.towers = nn.ModuleList(
|
|
126
|
+
[
|
|
127
|
+
MLP(input_dim=expert_output_dim, output_dim=None, **params)
|
|
128
|
+
for params in tower_params
|
|
129
|
+
]
|
|
130
|
+
)
|
|
131
|
+
self.tower_logits = nn.ModuleList(
|
|
132
|
+
[nn.Linear(tower_output_dim, 1, bias=False) for _ in range(self.nums_task)]
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if task_weight_mlp_params is None:
|
|
136
|
+
raise ValueError("task_weight_mlp_params must be a list of dicts.")
|
|
137
|
+
if len(task_weight_mlp_params) != self.nums_task:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"Length of task_weight_mlp_params "
|
|
140
|
+
f"({len(task_weight_mlp_params)}) must match number of tasks ({self.nums_task})."
|
|
141
|
+
)
|
|
142
|
+
task_weight_mlp_params_list = [
|
|
143
|
+
params.copy() for params in task_weight_mlp_params
|
|
144
|
+
]
|
|
145
|
+
self.task_weights = nn.ModuleList(
|
|
146
|
+
[
|
|
147
|
+
MLP(input_dim=input_dim, output_dim=self.nums_task, **params)
|
|
148
|
+
for params in task_weight_mlp_params_list
|
|
149
|
+
]
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.prediction_layer = TaskHead(
|
|
153
|
+
task_type=self.task, task_dims=[1] * self.nums_task
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self.register_regularization_weights(
|
|
157
|
+
embedding_attr="embedding",
|
|
158
|
+
include_modules=[
|
|
159
|
+
"experts",
|
|
160
|
+
"gates",
|
|
161
|
+
"task_weights",
|
|
162
|
+
"towers",
|
|
163
|
+
"tower_logits",
|
|
164
|
+
],
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
168
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
169
|
+
|
|
170
|
+
expert_outputs = [expert(input_flat) for expert in self.experts]
|
|
171
|
+
expert_outputs = torch.stack(expert_outputs, dim=0) # [E, B, D]
|
|
172
|
+
expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, E, D]
|
|
173
|
+
|
|
174
|
+
tower_features = []
|
|
175
|
+
for task_idx in range(self.nums_task):
|
|
176
|
+
gate_logits = self.gates[task_idx](input_flat)
|
|
177
|
+
gate_weights = torch.softmax(gate_logits, dim=1).unsqueeze(2)
|
|
178
|
+
gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1)
|
|
179
|
+
tower_features.append(self.towers[task_idx](gated_output))
|
|
180
|
+
|
|
181
|
+
task_weight_probs = [
|
|
182
|
+
torch.softmax(task_weight(input_flat), dim=1)
|
|
183
|
+
for task_weight in self.task_weights
|
|
184
|
+
]
|
|
185
|
+
|
|
186
|
+
task_logits = []
|
|
187
|
+
for task_idx in range(self.nums_task):
|
|
188
|
+
task_feat = (
|
|
189
|
+
task_weight_probs[task_idx][:, task_idx].view(-1, 1)
|
|
190
|
+
* tower_features[task_idx]
|
|
191
|
+
)
|
|
192
|
+
for other_idx in range(self.nums_task):
|
|
193
|
+
if other_idx == task_idx:
|
|
194
|
+
continue
|
|
195
|
+
task_feat = (
|
|
196
|
+
task_feat
|
|
197
|
+
+ task_weight_probs[task_idx][:, other_idx].view(-1, 1)
|
|
198
|
+
* tower_features[other_idx].detach()
|
|
199
|
+
)
|
|
200
|
+
task_logits.append(self.tower_logits[task_idx](task_feat))
|
|
201
|
+
|
|
202
|
+
logits = torch.cat(task_logits, dim=1)
|
|
203
|
+
return self.prediction_layer(logits)
|
|
@@ -3,9 +3,8 @@ Date: create on 09/11/2025
|
|
|
3
3
|
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
|
-
[1] Ma J, Zhao Z, Yi X,
|
|
7
|
-
|
|
8
|
-
(https://dl.acm.org/doi/10.1145/3219819.3220007)
|
|
6
|
+
- [1] Ma J, Zhao Z, Yi X, Chen J, Hong L, Chi E H. Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts. In: Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD ’18), 2018, pp. 1930–1939.
|
|
7
|
+
URL: https://dl.acm.org/doi/10.1145/3219819.3220007
|
|
9
8
|
|
|
10
9
|
Multi-gate Mixture-of-Experts (MMoE) extends shared-bottom multi-task learning by
|
|
11
10
|
introducing multiple experts and task-specific softmax gates. Each task learns its
|
|
@@ -49,6 +48,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
|
49
48
|
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
50
49
|
from nextrec.basic.heads import TaskHead
|
|
51
50
|
from nextrec.basic.model import BaseModel
|
|
51
|
+
from nextrec.utils.types import TaskTypeName
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
class MMOE(BaseModel):
|
|
@@ -77,19 +77,19 @@ class MMOE(BaseModel):
|
|
|
77
77
|
dense_features: list[DenseFeature] | None = None,
|
|
78
78
|
sparse_features: list[SparseFeature] | None = None,
|
|
79
79
|
sequence_features: list[SequenceFeature] | None = None,
|
|
80
|
-
|
|
80
|
+
expert_mlp_params: dict | None = None,
|
|
81
81
|
num_experts: int = 3,
|
|
82
|
-
|
|
82
|
+
tower_mlp_params_list: list[dict] | None = None,
|
|
83
83
|
target: list[str] | str | None = None,
|
|
84
|
-
task:
|
|
84
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
85
85
|
**kwargs,
|
|
86
86
|
):
|
|
87
87
|
|
|
88
88
|
dense_features = dense_features or []
|
|
89
89
|
sparse_features = sparse_features or []
|
|
90
90
|
sequence_features = sequence_features or []
|
|
91
|
-
|
|
92
|
-
|
|
91
|
+
expert_mlp_params = expert_mlp_params or {}
|
|
92
|
+
tower_mlp_params_list = tower_mlp_params_list or []
|
|
93
93
|
|
|
94
94
|
if target is None:
|
|
95
95
|
target = []
|
|
@@ -98,24 +98,12 @@ class MMOE(BaseModel):
|
|
|
98
98
|
|
|
99
99
|
self.nums_task = len(target) if target else 1
|
|
100
100
|
|
|
101
|
-
resolved_task = task
|
|
102
|
-
if resolved_task is None:
|
|
103
|
-
resolved_task = self.default_task
|
|
104
|
-
elif isinstance(resolved_task, str):
|
|
105
|
-
resolved_task = [resolved_task] * self.nums_task
|
|
106
|
-
elif len(resolved_task) == 1 and self.nums_task > 1:
|
|
107
|
-
resolved_task = resolved_task * self.nums_task
|
|
108
|
-
elif len(resolved_task) != self.nums_task:
|
|
109
|
-
raise ValueError(
|
|
110
|
-
f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
|
|
111
|
-
)
|
|
112
|
-
|
|
113
101
|
super(MMOE, self).__init__(
|
|
114
102
|
dense_features=dense_features,
|
|
115
103
|
sparse_features=sparse_features,
|
|
116
104
|
sequence_features=sequence_features,
|
|
117
105
|
target=target,
|
|
118
|
-
task=
|
|
106
|
+
task=task,
|
|
119
107
|
**kwargs,
|
|
120
108
|
)
|
|
121
109
|
|
|
@@ -123,9 +111,10 @@ class MMOE(BaseModel):
|
|
|
123
111
|
self.nums_task = len(target)
|
|
124
112
|
self.num_experts = num_experts
|
|
125
113
|
|
|
126
|
-
if len(
|
|
114
|
+
if len(tower_mlp_params_list) != self.nums_task:
|
|
127
115
|
raise ValueError(
|
|
128
|
-
|
|
116
|
+
"Number of tower mlp params "
|
|
117
|
+
f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})"
|
|
129
118
|
)
|
|
130
119
|
|
|
131
120
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
@@ -134,12 +123,15 @@ class MMOE(BaseModel):
|
|
|
134
123
|
# Expert networks (shared by all tasks)
|
|
135
124
|
self.experts = nn.ModuleList()
|
|
136
125
|
for _ in range(num_experts):
|
|
137
|
-
expert = MLP(input_dim=input_dim, output_dim=None, **
|
|
126
|
+
expert = MLP(input_dim=input_dim, output_dim=None, **expert_mlp_params)
|
|
138
127
|
self.experts.append(expert)
|
|
139
128
|
|
|
140
129
|
# Get expert output dimension
|
|
141
|
-
if
|
|
142
|
-
|
|
130
|
+
if (
|
|
131
|
+
"hidden_dims" in expert_mlp_params
|
|
132
|
+
and len(expert_mlp_params["hidden_dims"]) > 0
|
|
133
|
+
):
|
|
134
|
+
expert_output_dim = expert_mlp_params["hidden_dims"][-1]
|
|
143
135
|
else:
|
|
144
136
|
expert_output_dim = input_dim
|
|
145
137
|
|
|
@@ -152,8 +144,8 @@ class MMOE(BaseModel):
|
|
|
152
144
|
|
|
153
145
|
# Task-specific towers
|
|
154
146
|
self.towers = nn.ModuleList()
|
|
155
|
-
for
|
|
156
|
-
tower = MLP(input_dim=expert_output_dim, output_dim=1, **
|
|
147
|
+
for tower_mlp_params in tower_mlp_params_list:
|
|
148
|
+
tower = MLP(input_dim=expert_output_dim, output_dim=1, **tower_mlp_params)
|
|
157
149
|
self.towers.append(tower)
|
|
158
150
|
self.prediction_layer = TaskHead(
|
|
159
151
|
task_type=self.task, task_dims=[1] * self.nums_task
|