nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/asserts.py +72 -0
- nextrec/basic/loggers.py +18 -1
- nextrec/basic/model.py +191 -71
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/models/multi_task/[pre]aitm.py +173 -0
- nextrec/models/multi_task/[pre]snr_trans.py +232 -0
- nextrec/models/multi_task/[pre]star.py +192 -0
- nextrec/models/multi_task/apg.py +330 -0
- nextrec/models/multi_task/cross_stitch.py +229 -0
- nextrec/models/multi_task/escm.py +290 -0
- nextrec/models/multi_task/esmm.py +8 -21
- nextrec/models/multi_task/hmoe.py +203 -0
- nextrec/models/multi_task/mmoe.py +20 -28
- nextrec/models/multi_task/pepnet.py +68 -66
- nextrec/models/multi_task/ple.py +30 -44
- nextrec/models/multi_task/poso.py +13 -22
- nextrec/models/multi_task/share_bottom.py +14 -25
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -4
- nextrec/models/ranking/dcn.py +2 -3
- nextrec/models/ranking/dcn_v2.py +2 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +7 -9
- nextrec/models/ranking/din.py +8 -10
- nextrec/models/ranking/eulernet.py +1 -2
- nextrec/models/ranking/ffm.py +1 -2
- nextrec/models/ranking/fibinet.py +2 -3
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/lr.py +1 -1
- nextrec/models/ranking/masknet.py +1 -2
- nextrec/models/ranking/pnn.py +1 -2
- nextrec/models/ranking/widedeep.py +2 -3
- nextrec/models/ranking/xdeepfm.py +2 -4
- nextrec/models/representation/rqvae.py +4 -4
- nextrec/models/retrieval/dssm.py +18 -26
- nextrec/models/retrieval/dssm_v2.py +15 -22
- nextrec/models/retrieval/mind.py +9 -15
- nextrec/models/retrieval/sdm.py +36 -33
- nextrec/models/retrieval/youtube_dnn.py +16 -24
- nextrec/models/sequential/hstu.py +2 -2
- nextrec/utils/__init__.py +5 -1
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +16 -77
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
- nextrec-0.4.27.dist-info/RECORD +90 -0
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec-0.4.24.dist-info/RECORD +0 -86
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/models/multi_task/apg.py
CHANGED
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 01/01/2026
|
|
3
|
+
Checkpoint: edit on 01/01/2026
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
- [1] Yan B, Wang P, Zhang K, Li F, Deng H, Xu J, Zheng B. APG: Adaptive Parameter Generation Network for Click-Through Rate Prediction. Advances in Neural Information Processing Systems 35 (NeurIPS 2022), 2022.
|
|
7
|
+
URL: https://arxiv.org/abs/2203.16218
|
|
8
|
+
- [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import math
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
|
|
17
|
+
from nextrec.basic.activation import activation_layer
|
|
18
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
19
|
+
from nextrec.basic.layers import EmbeddingLayer, MLP
|
|
20
|
+
from nextrec.basic.heads import TaskHead
|
|
21
|
+
from nextrec.basic.model import BaseModel
|
|
22
|
+
from nextrec.utils.model import select_features
|
|
23
|
+
from nextrec.utils.types import ActivationName, TaskTypeName
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class APGLayer(nn.Module):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
input_dim: int,
|
|
30
|
+
output_dim: int,
|
|
31
|
+
scene_emb_dim: int,
|
|
32
|
+
activation: ActivationName = "relu",
|
|
33
|
+
generate_activation: ActivationName | None = None,
|
|
34
|
+
inner_activation: ActivationName | None = None,
|
|
35
|
+
use_uv_shared: bool = True,
|
|
36
|
+
use_mf_p: bool = False,
|
|
37
|
+
mf_k: int = 16,
|
|
38
|
+
mf_p: int = 4,
|
|
39
|
+
) -> None:
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.use_uv_shared = use_uv_shared
|
|
42
|
+
self.use_mf_p = use_mf_p
|
|
43
|
+
self.input_dim = input_dim
|
|
44
|
+
self.output_dim = output_dim
|
|
45
|
+
|
|
46
|
+
self.activation = (
|
|
47
|
+
activation_layer(activation) if activation is not None else nn.Identity()
|
|
48
|
+
)
|
|
49
|
+
self.inner_activation = (
|
|
50
|
+
activation_layer(inner_activation) if inner_activation is not None else None
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
min_dim = min(int(input_dim), int(output_dim))
|
|
54
|
+
self.p_dim = math.ceil(float(min_dim) / float(mf_p))
|
|
55
|
+
self.k_dim = math.ceil(float(min_dim) / float(mf_k))
|
|
56
|
+
|
|
57
|
+
if use_uv_shared:
|
|
58
|
+
if use_mf_p:
|
|
59
|
+
self.shared_weight_np = nn.Parameter(
|
|
60
|
+
torch.empty(self.input_dim, self.p_dim)
|
|
61
|
+
)
|
|
62
|
+
self.shared_bias_np = nn.Parameter(torch.zeros(self.p_dim))
|
|
63
|
+
self.shared_weight_pk = nn.Parameter(
|
|
64
|
+
torch.empty(self.p_dim, self.k_dim)
|
|
65
|
+
)
|
|
66
|
+
self.shared_bias_pk = nn.Parameter(torch.zeros(self.k_dim))
|
|
67
|
+
|
|
68
|
+
self.shared_weight_kp = nn.Parameter(
|
|
69
|
+
torch.empty(self.k_dim, self.p_dim)
|
|
70
|
+
)
|
|
71
|
+
self.shared_bias_kp = nn.Parameter(torch.zeros(self.p_dim))
|
|
72
|
+
self.shared_weight_pm = nn.Parameter(
|
|
73
|
+
torch.empty(self.p_dim, self.output_dim)
|
|
74
|
+
)
|
|
75
|
+
self.shared_bias_pm = nn.Parameter(torch.zeros(self.output_dim))
|
|
76
|
+
else:
|
|
77
|
+
self.shared_weight_nk = nn.Parameter(
|
|
78
|
+
torch.empty(self.input_dim, self.k_dim)
|
|
79
|
+
)
|
|
80
|
+
self.shared_bias_nk = nn.Parameter(torch.zeros(self.k_dim))
|
|
81
|
+
self.shared_weight_km = nn.Parameter(
|
|
82
|
+
torch.empty(self.k_dim, self.output_dim)
|
|
83
|
+
)
|
|
84
|
+
self.shared_bias_km = nn.Parameter(torch.zeros(self.output_dim))
|
|
85
|
+
self.specific_weight_kk = MLP(
|
|
86
|
+
input_dim=scene_emb_dim,
|
|
87
|
+
hidden_dims=None,
|
|
88
|
+
output_dim=self.k_dim * self.k_dim,
|
|
89
|
+
activation="relu",
|
|
90
|
+
output_activation=generate_activation or "none",
|
|
91
|
+
)
|
|
92
|
+
self.specific_bias_kk = MLP(
|
|
93
|
+
input_dim=scene_emb_dim,
|
|
94
|
+
hidden_dims=None,
|
|
95
|
+
output_dim=self.k_dim,
|
|
96
|
+
activation="relu",
|
|
97
|
+
output_activation=generate_activation or "none",
|
|
98
|
+
)
|
|
99
|
+
if not use_uv_shared:
|
|
100
|
+
self.specific_weight_nk = MLP(
|
|
101
|
+
input_dim=scene_emb_dim,
|
|
102
|
+
hidden_dims=None,
|
|
103
|
+
output_dim=self.input_dim * self.k_dim,
|
|
104
|
+
activation="relu",
|
|
105
|
+
output_activation=generate_activation or "none",
|
|
106
|
+
)
|
|
107
|
+
self.specific_bias_nk = MLP(
|
|
108
|
+
input_dim=scene_emb_dim,
|
|
109
|
+
hidden_dims=None,
|
|
110
|
+
output_dim=self.k_dim,
|
|
111
|
+
activation="relu",
|
|
112
|
+
output_activation=generate_activation or "none",
|
|
113
|
+
)
|
|
114
|
+
self.specific_weight_km = MLP(
|
|
115
|
+
input_dim=scene_emb_dim,
|
|
116
|
+
hidden_dims=None,
|
|
117
|
+
output_dim=self.k_dim * self.output_dim,
|
|
118
|
+
activation="relu",
|
|
119
|
+
output_activation=generate_activation or "none",
|
|
120
|
+
)
|
|
121
|
+
self.specific_bias_km = MLP(
|
|
122
|
+
input_dim=scene_emb_dim,
|
|
123
|
+
hidden_dims=None,
|
|
124
|
+
output_dim=self.output_dim,
|
|
125
|
+
activation="relu",
|
|
126
|
+
output_activation=generate_activation or "none",
|
|
127
|
+
)
|
|
128
|
+
self.reset_parameters()
|
|
129
|
+
|
|
130
|
+
def reset_parameters(self) -> None:
|
|
131
|
+
if self.use_uv_shared:
|
|
132
|
+
if self.use_mf_p:
|
|
133
|
+
nn.init.xavier_uniform_(self.shared_weight_np)
|
|
134
|
+
nn.init.zeros_(self.shared_bias_np)
|
|
135
|
+
nn.init.xavier_uniform_(self.shared_weight_pk)
|
|
136
|
+
nn.init.zeros_(self.shared_bias_pk)
|
|
137
|
+
nn.init.xavier_uniform_(self.shared_weight_kp)
|
|
138
|
+
nn.init.zeros_(self.shared_bias_kp)
|
|
139
|
+
nn.init.xavier_uniform_(self.shared_weight_pm)
|
|
140
|
+
nn.init.zeros_(self.shared_bias_pm)
|
|
141
|
+
else:
|
|
142
|
+
nn.init.xavier_uniform_(self.shared_weight_nk)
|
|
143
|
+
nn.init.zeros_(self.shared_bias_nk)
|
|
144
|
+
nn.init.xavier_uniform_(self.shared_weight_km)
|
|
145
|
+
nn.init.zeros_(self.shared_bias_km)
|
|
146
|
+
|
|
147
|
+
def forward(self, inputs: torch.Tensor, scene_emb: torch.Tensor) -> torch.Tensor:
|
|
148
|
+
specific_weight_kk = self.specific_weight_kk(scene_emb)
|
|
149
|
+
specific_weight_kk = specific_weight_kk.view(-1, self.k_dim, self.k_dim)
|
|
150
|
+
specific_bias_kk = self.specific_bias_kk(scene_emb)
|
|
151
|
+
|
|
152
|
+
if self.use_uv_shared:
|
|
153
|
+
if self.use_mf_p:
|
|
154
|
+
output_np = inputs @ self.shared_weight_np + self.shared_bias_np
|
|
155
|
+
if self.inner_activation is not None:
|
|
156
|
+
output_np = self.inner_activation(output_np)
|
|
157
|
+
output_pk = output_np @ self.shared_weight_pk + self.shared_bias_pk
|
|
158
|
+
if self.inner_activation is not None:
|
|
159
|
+
output_pk = self.inner_activation(output_pk)
|
|
160
|
+
output_kk = (
|
|
161
|
+
torch.bmm(output_pk.unsqueeze(1), specific_weight_kk).squeeze(1)
|
|
162
|
+
+ specific_bias_kk
|
|
163
|
+
)
|
|
164
|
+
if self.inner_activation is not None:
|
|
165
|
+
output_kk = self.inner_activation(output_kk)
|
|
166
|
+
output_kp = output_kk @ self.shared_weight_kp + self.shared_bias_kp
|
|
167
|
+
if self.inner_activation is not None:
|
|
168
|
+
output_kp = self.inner_activation(output_kp)
|
|
169
|
+
output = output_kp @ self.shared_weight_pm + self.shared_bias_pm
|
|
170
|
+
else:
|
|
171
|
+
output_nk = inputs @ self.shared_weight_nk + self.shared_bias_nk
|
|
172
|
+
if self.inner_activation is not None:
|
|
173
|
+
output_nk = self.inner_activation(output_nk)
|
|
174
|
+
output_kk = (
|
|
175
|
+
torch.bmm(output_nk.unsqueeze(1), specific_weight_kk).squeeze(1)
|
|
176
|
+
+ specific_bias_kk
|
|
177
|
+
)
|
|
178
|
+
if self.inner_activation is not None:
|
|
179
|
+
output_kk = self.inner_activation(output_kk)
|
|
180
|
+
output = output_kk @ self.shared_weight_km + self.shared_bias_km
|
|
181
|
+
else:
|
|
182
|
+
specific_weight_nk = self.specific_weight_nk(scene_emb)
|
|
183
|
+
specific_weight_nk = specific_weight_nk.view(-1, self.input_dim, self.k_dim)
|
|
184
|
+
specific_bias_nk = self.specific_bias_nk(scene_emb)
|
|
185
|
+
specific_weight_km = self.specific_weight_km(scene_emb)
|
|
186
|
+
specific_weight_km = specific_weight_km.view(
|
|
187
|
+
-1, self.k_dim, self.output_dim
|
|
188
|
+
)
|
|
189
|
+
specific_bias_km = self.specific_bias_km(scene_emb)
|
|
190
|
+
|
|
191
|
+
output_nk = (
|
|
192
|
+
torch.bmm(inputs.unsqueeze(1), specific_weight_nk).squeeze(1)
|
|
193
|
+
+ specific_bias_nk
|
|
194
|
+
)
|
|
195
|
+
if self.inner_activation is not None:
|
|
196
|
+
output_nk = self.inner_activation(output_nk)
|
|
197
|
+
output_kk = (
|
|
198
|
+
torch.bmm(output_nk.unsqueeze(1), specific_weight_kk).squeeze(1)
|
|
199
|
+
+ specific_bias_kk
|
|
200
|
+
)
|
|
201
|
+
if self.inner_activation is not None:
|
|
202
|
+
output_kk = self.inner_activation(output_kk)
|
|
203
|
+
output = (
|
|
204
|
+
torch.bmm(output_kk.unsqueeze(1), specific_weight_km).squeeze(1)
|
|
205
|
+
+ specific_bias_km
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return self.activation(output)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class APG(BaseModel):
|
|
212
|
+
"""
|
|
213
|
+
Adaptive Parameter Generation (APG) model.
|
|
214
|
+
|
|
215
|
+
APG stacks APG layers whose middle transformation matrix is generated from
|
|
216
|
+
a scene embedding, enabling scenario-conditioned multi-task learning.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def model_name(self) -> str:
|
|
221
|
+
return "APG"
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
225
|
+
nums_task = self.nums_task if hasattr(self, "nums_task") else None
|
|
226
|
+
if nums_task is not None and nums_task > 0:
|
|
227
|
+
return ["binary"] * nums_task
|
|
228
|
+
return ["binary"]
|
|
229
|
+
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
dense_features: list[DenseFeature] | None = None,
|
|
233
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
234
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
235
|
+
target: list[str] | str | None = None,
|
|
236
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
237
|
+
mlp_params: dict | None = None,
|
|
238
|
+
inner_activation: ActivationName | None = None,
|
|
239
|
+
generate_activation: ActivationName | None = None,
|
|
240
|
+
scene_features: list[str] | str | None = None,
|
|
241
|
+
detach_scene: bool = True,
|
|
242
|
+
use_uv_shared: bool = True,
|
|
243
|
+
use_mf_p: bool = False,
|
|
244
|
+
mf_k: int = 16,
|
|
245
|
+
mf_p: int = 4,
|
|
246
|
+
**kwargs,
|
|
247
|
+
) -> None:
|
|
248
|
+
dense_features = dense_features or []
|
|
249
|
+
sparse_features = sparse_features or []
|
|
250
|
+
sequence_features = sequence_features or []
|
|
251
|
+
mlp_params = mlp_params or {}
|
|
252
|
+
mlp_params.setdefault("hidden_dims", [256, 128])
|
|
253
|
+
mlp_params.setdefault("activation", "relu")
|
|
254
|
+
|
|
255
|
+
if target is None:
|
|
256
|
+
target = []
|
|
257
|
+
elif isinstance(target, str):
|
|
258
|
+
target = [target]
|
|
259
|
+
|
|
260
|
+
self.nums_task = len(target) if target else 1
|
|
261
|
+
|
|
262
|
+
super().__init__(
|
|
263
|
+
dense_features=dense_features,
|
|
264
|
+
sparse_features=sparse_features,
|
|
265
|
+
sequence_features=sequence_features,
|
|
266
|
+
target=target,
|
|
267
|
+
task=task,
|
|
268
|
+
**kwargs,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
if not scene_features:
|
|
272
|
+
raise ValueError("APG requires scene_features to generate parameters.")
|
|
273
|
+
if isinstance(scene_features, str):
|
|
274
|
+
scene_features = [scene_features]
|
|
275
|
+
self.scene_features = select_features(
|
|
276
|
+
self.all_features, scene_features, "scene_features"
|
|
277
|
+
)
|
|
278
|
+
self.detach_scene = detach_scene
|
|
279
|
+
|
|
280
|
+
if len(mlp_params["hidden_dims"]) == 0:
|
|
281
|
+
raise ValueError("mlp_params['hidden_dims'] cannot be empty for APG.")
|
|
282
|
+
|
|
283
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
284
|
+
input_dim = self.embedding.input_dim
|
|
285
|
+
scene_emb_dim = self.embedding.compute_output_dim(self.scene_features)
|
|
286
|
+
|
|
287
|
+
layer_units = [input_dim] + list(mlp_params["hidden_dims"])
|
|
288
|
+
self.apg_layers = nn.ModuleList(
|
|
289
|
+
[
|
|
290
|
+
APGLayer(
|
|
291
|
+
input_dim=layer_units[idx],
|
|
292
|
+
output_dim=layer_units[idx + 1],
|
|
293
|
+
scene_emb_dim=scene_emb_dim,
|
|
294
|
+
activation=mlp_params["activation"],
|
|
295
|
+
generate_activation=generate_activation,
|
|
296
|
+
inner_activation=inner_activation,
|
|
297
|
+
use_uv_shared=use_uv_shared,
|
|
298
|
+
use_mf_p=use_mf_p,
|
|
299
|
+
mf_k=mf_k,
|
|
300
|
+
mf_p=mf_p,
|
|
301
|
+
)
|
|
302
|
+
for idx in range(len(mlp_params["hidden_dims"]))
|
|
303
|
+
]
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
self.towers = nn.ModuleList(
|
|
307
|
+
[nn.Linear(mlp_params["hidden_dims"][-1], 1) for _ in range(self.nums_task)]
|
|
308
|
+
)
|
|
309
|
+
self.prediction_layer = TaskHead(
|
|
310
|
+
task_type=self.task, task_dims=[1] * self.nums_task
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
self.grad_norm_shared_modules = ["embedding", "apg_layers"]
|
|
314
|
+
self.register_regularization_weights(
|
|
315
|
+
embedding_attr="embedding", include_modules=["apg_layers", "towers"]
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
319
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
320
|
+
scene_emb = self.embedding(x=x, features=self.scene_features, squeeze_dim=True)
|
|
321
|
+
if self.detach_scene:
|
|
322
|
+
scene_emb = scene_emb.detach()
|
|
323
|
+
|
|
324
|
+
apg_output = input_flat
|
|
325
|
+
for layer in self.apg_layers:
|
|
326
|
+
apg_output = layer(apg_output, scene_emb)
|
|
327
|
+
|
|
328
|
+
task_outputs = [tower(apg_output) for tower in self.towers]
|
|
329
|
+
logits = torch.cat(task_outputs, dim=1)
|
|
330
|
+
return self.prediction_layer(logits)
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 01/01/2026
|
|
3
|
+
Checkpoint: edit on 01/01/2026
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
- [1] Misra I, Shrivastava A, Gupta A, Hebert M. Cross-Stitch Networks for Multi-Task Learning. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR 2016), 2016, pp. 3994–4003.
|
|
7
|
+
URL: https://www.cv-foundation.org/openaccess/content_cvpr_2016/html/Misra_Cross-Stitch_Networks_for_CVPR_2016_paper.html
|
|
8
|
+
- [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
|
|
9
|
+
|
|
10
|
+
Cross-Stitch networks mix task-specific representations with a learnable
|
|
11
|
+
linear combination at each layer, enabling soft sharing while preserving
|
|
12
|
+
task-specific subspaces.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
21
|
+
from nextrec.basic.layers import EmbeddingLayer, MLP
|
|
22
|
+
from nextrec.basic.heads import TaskHead
|
|
23
|
+
from nextrec.basic.model import BaseModel
|
|
24
|
+
from nextrec.utils.types import TaskTypeName
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CrossStitchLayer(nn.Module):
|
|
28
|
+
"""
|
|
29
|
+
Cross-stitch layer to linearly mix task-specific representations.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, input_dims: list[int]) -> None:
|
|
33
|
+
super().__init__()
|
|
34
|
+
if len(input_dims) < 2:
|
|
35
|
+
raise ValueError("CrossStitchLayer requires at least 2 inputs.")
|
|
36
|
+
self.input_dims = list(input_dims)
|
|
37
|
+
if len(set(self.input_dims)) != 1:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
"CrossStitchLayer expects all input dims to be equal to align channels."
|
|
40
|
+
)
|
|
41
|
+
self.num_tasks = len(self.input_dims)
|
|
42
|
+
self.unit_dim = self.input_dims[0]
|
|
43
|
+
identity = torch.eye(self.num_tasks).unsqueeze(-1)
|
|
44
|
+
weight = identity.repeat(1, 1, self.unit_dim)
|
|
45
|
+
self.cross_stitch_weight = nn.Parameter(weight)
|
|
46
|
+
|
|
47
|
+
def forward(self, inputs: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
48
|
+
if len(inputs) != len(self.input_dims):
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"CrossStitchLayer expects {len(self.input_dims)} inputs, got {len(inputs)}"
|
|
51
|
+
)
|
|
52
|
+
stacked = torch.stack(inputs, dim=1)
|
|
53
|
+
mixed = torch.einsum("b s d, t s d -> b t d", stacked, self.cross_stitch_weight)
|
|
54
|
+
return [mixed[:, task_idx, :] for task_idx in range(self.num_tasks)]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CrossStitch(BaseModel):
|
|
58
|
+
"""
|
|
59
|
+
Cross-Stitch Networks for multi-task learning.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def model_name(self) -> str:
|
|
64
|
+
return "CrossStitch"
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
68
|
+
nums_task = self.nums_task if hasattr(self, "nums_task") else None
|
|
69
|
+
if nums_task is not None and nums_task > 0:
|
|
70
|
+
return ["binary"] * nums_task
|
|
71
|
+
return ["binary"]
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
dense_features: list[DenseFeature] | None = None,
|
|
76
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
77
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
78
|
+
target: list[str] | str | None = None,
|
|
79
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
80
|
+
shared_mlp_params: dict | None = None,
|
|
81
|
+
task_mlp_params: dict | None = None,
|
|
82
|
+
tower_mlp_params: dict | None = None,
|
|
83
|
+
tower_mlp_params_list: list[dict] | None = None,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> None:
|
|
86
|
+
dense_features = dense_features or []
|
|
87
|
+
sparse_features = sparse_features or []
|
|
88
|
+
sequence_features = sequence_features or []
|
|
89
|
+
shared_mlp_params = shared_mlp_params or {}
|
|
90
|
+
task_mlp_params = task_mlp_params or {}
|
|
91
|
+
tower_mlp_params = tower_mlp_params or {}
|
|
92
|
+
tower_mlp_params_list = tower_mlp_params_list or []
|
|
93
|
+
|
|
94
|
+
shared_mlp_params.setdefault("hidden_dims", [])
|
|
95
|
+
task_mlp_params.setdefault("hidden_dims", [256, 128])
|
|
96
|
+
tower_mlp_params.setdefault("hidden_dims", [64])
|
|
97
|
+
|
|
98
|
+
default_activation = task_mlp_params.get("activation", "relu")
|
|
99
|
+
default_dropout = task_mlp_params.get("dropout", 0.0)
|
|
100
|
+
default_norm_type = task_mlp_params.get("norm_type", "none")
|
|
101
|
+
|
|
102
|
+
shared_mlp_params.setdefault("activation", default_activation)
|
|
103
|
+
shared_mlp_params.setdefault("dropout", default_dropout)
|
|
104
|
+
shared_mlp_params.setdefault("norm_type", default_norm_type)
|
|
105
|
+
task_mlp_params.setdefault("activation", default_activation)
|
|
106
|
+
task_mlp_params.setdefault("dropout", default_dropout)
|
|
107
|
+
task_mlp_params.setdefault("norm_type", default_norm_type)
|
|
108
|
+
tower_mlp_params.setdefault("activation", default_activation)
|
|
109
|
+
tower_mlp_params.setdefault("dropout", default_dropout)
|
|
110
|
+
tower_mlp_params.setdefault("norm_type", default_norm_type)
|
|
111
|
+
|
|
112
|
+
if target is None:
|
|
113
|
+
target = []
|
|
114
|
+
elif isinstance(target, str):
|
|
115
|
+
target = [target]
|
|
116
|
+
|
|
117
|
+
self.nums_task = len(target) if target else 1
|
|
118
|
+
|
|
119
|
+
super().__init__(
|
|
120
|
+
dense_features=dense_features,
|
|
121
|
+
sparse_features=sparse_features,
|
|
122
|
+
sequence_features=sequence_features,
|
|
123
|
+
target=target,
|
|
124
|
+
task=task,
|
|
125
|
+
**kwargs,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.nums_task = len(target) if target else 1
|
|
129
|
+
if self.nums_task <= 1:
|
|
130
|
+
raise ValueError("CrossStitch requires at least 2 tasks.")
|
|
131
|
+
if not task_mlp_params["hidden_dims"]:
|
|
132
|
+
raise ValueError("task_mlp_params['hidden_dims'] must not be empty.")
|
|
133
|
+
shared_hidden_dims = shared_mlp_params["hidden_dims"]
|
|
134
|
+
|
|
135
|
+
if tower_mlp_params_list:
|
|
136
|
+
if len(tower_mlp_params_list) != self.nums_task:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"Number of tower mlp params "
|
|
139
|
+
f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
|
|
140
|
+
)
|
|
141
|
+
tower_params = [params.copy() for params in tower_mlp_params_list]
|
|
142
|
+
else:
|
|
143
|
+
tower_params = [tower_mlp_params.copy() for _ in range(self.nums_task)]
|
|
144
|
+
|
|
145
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
146
|
+
input_dim = self.embedding.input_dim
|
|
147
|
+
|
|
148
|
+
if shared_hidden_dims:
|
|
149
|
+
self.shared_layer = MLP(
|
|
150
|
+
input_dim=input_dim,
|
|
151
|
+
hidden_dims=shared_hidden_dims,
|
|
152
|
+
output_dim=None,
|
|
153
|
+
dropout=shared_mlp_params["dropout"],
|
|
154
|
+
activation=shared_mlp_params["activation"],
|
|
155
|
+
norm_type=shared_mlp_params["norm_type"],
|
|
156
|
+
)
|
|
157
|
+
prev_dim = shared_hidden_dims[-1]
|
|
158
|
+
else:
|
|
159
|
+
self.shared_layer = nn.Identity()
|
|
160
|
+
prev_dim = input_dim
|
|
161
|
+
self.grad_norm_shared_modules = [
|
|
162
|
+
"embedding",
|
|
163
|
+
"shared_layer",
|
|
164
|
+
"task_layers",
|
|
165
|
+
"cross_stitch_layers",
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
self.task_layers = nn.ModuleList()
|
|
169
|
+
self.cross_stitch_layers = nn.ModuleList()
|
|
170
|
+
for hidden_dim in task_mlp_params["hidden_dims"]:
|
|
171
|
+
layer_tasks = nn.ModuleList(
|
|
172
|
+
[
|
|
173
|
+
MLP(
|
|
174
|
+
input_dim=prev_dim,
|
|
175
|
+
hidden_dims=[hidden_dim],
|
|
176
|
+
output_dim=None,
|
|
177
|
+
dropout=task_mlp_params["dropout"],
|
|
178
|
+
activation=task_mlp_params["activation"],
|
|
179
|
+
norm_type=task_mlp_params["norm_type"],
|
|
180
|
+
)
|
|
181
|
+
for _ in range(self.nums_task)
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
self.task_layers.append(layer_tasks)
|
|
185
|
+
self.cross_stitch_layers.append(
|
|
186
|
+
CrossStitchLayer(input_dims=[hidden_dim] * self.nums_task)
|
|
187
|
+
)
|
|
188
|
+
prev_dim = hidden_dim
|
|
189
|
+
|
|
190
|
+
self.towers = nn.ModuleList()
|
|
191
|
+
for params in tower_params:
|
|
192
|
+
if tower_mlp_params_list:
|
|
193
|
+
tower = MLP(input_dim=prev_dim, output_dim=1, **params)
|
|
194
|
+
else:
|
|
195
|
+
tower = MLP(
|
|
196
|
+
input_dim=prev_dim,
|
|
197
|
+
hidden_dims=params.get("hidden_dims"),
|
|
198
|
+
output_dim=1,
|
|
199
|
+
dropout=params.get("dropout", tower_mlp_params["dropout"]),
|
|
200
|
+
activation=params.get("activation", tower_mlp_params["activation"]),
|
|
201
|
+
norm_type=params.get("norm_type", tower_mlp_params["norm_type"]),
|
|
202
|
+
)
|
|
203
|
+
self.towers.append(tower)
|
|
204
|
+
|
|
205
|
+
self.prediction_layer = TaskHead(
|
|
206
|
+
task_type=self.task, task_dims=[1] * self.nums_task
|
|
207
|
+
)
|
|
208
|
+
self.register_regularization_weights(
|
|
209
|
+
embedding_attr="embedding",
|
|
210
|
+
include_modules=["shared_layer", "task_layers", "towers"],
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
214
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
215
|
+
task_reps = [self.shared_layer(input_flat) for _ in range(self.nums_task)]
|
|
216
|
+
|
|
217
|
+
for layer_idx in range(len(self.task_layers)):
|
|
218
|
+
for task_idx in range(self.nums_task):
|
|
219
|
+
task_reps[task_idx] = self.task_layers[layer_idx][task_idx](
|
|
220
|
+
task_reps[task_idx]
|
|
221
|
+
)
|
|
222
|
+
task_reps = self.cross_stitch_layers[layer_idx](task_reps)
|
|
223
|
+
|
|
224
|
+
task_outputs = []
|
|
225
|
+
for task_idx, tower in enumerate(self.towers):
|
|
226
|
+
task_outputs.append(tower(task_reps[task_idx]))
|
|
227
|
+
|
|
228
|
+
y = torch.cat(task_outputs, dim=1)
|
|
229
|
+
return self.prediction_layer(y)
|