nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/asserts.py +72 -0
  3. nextrec/basic/loggers.py +18 -1
  4. nextrec/basic/model.py +191 -71
  5. nextrec/basic/summary.py +58 -0
  6. nextrec/cli.py +13 -0
  7. nextrec/data/data_processing.py +3 -9
  8. nextrec/data/dataloader.py +25 -2
  9. nextrec/data/preprocessor.py +283 -36
  10. nextrec/models/multi_task/[pre]aitm.py +173 -0
  11. nextrec/models/multi_task/[pre]snr_trans.py +232 -0
  12. nextrec/models/multi_task/[pre]star.py +192 -0
  13. nextrec/models/multi_task/apg.py +330 -0
  14. nextrec/models/multi_task/cross_stitch.py +229 -0
  15. nextrec/models/multi_task/escm.py +290 -0
  16. nextrec/models/multi_task/esmm.py +8 -21
  17. nextrec/models/multi_task/hmoe.py +203 -0
  18. nextrec/models/multi_task/mmoe.py +20 -28
  19. nextrec/models/multi_task/pepnet.py +68 -66
  20. nextrec/models/multi_task/ple.py +30 -44
  21. nextrec/models/multi_task/poso.py +13 -22
  22. nextrec/models/multi_task/share_bottom.py +14 -25
  23. nextrec/models/ranking/afm.py +2 -2
  24. nextrec/models/ranking/autoint.py +2 -4
  25. nextrec/models/ranking/dcn.py +2 -3
  26. nextrec/models/ranking/dcn_v2.py +2 -3
  27. nextrec/models/ranking/deepfm.py +2 -3
  28. nextrec/models/ranking/dien.py +7 -9
  29. nextrec/models/ranking/din.py +8 -10
  30. nextrec/models/ranking/eulernet.py +1 -2
  31. nextrec/models/ranking/ffm.py +1 -2
  32. nextrec/models/ranking/fibinet.py +2 -3
  33. nextrec/models/ranking/fm.py +1 -1
  34. nextrec/models/ranking/lr.py +1 -1
  35. nextrec/models/ranking/masknet.py +1 -2
  36. nextrec/models/ranking/pnn.py +1 -2
  37. nextrec/models/ranking/widedeep.py +2 -3
  38. nextrec/models/ranking/xdeepfm.py +2 -4
  39. nextrec/models/representation/rqvae.py +4 -4
  40. nextrec/models/retrieval/dssm.py +18 -26
  41. nextrec/models/retrieval/dssm_v2.py +15 -22
  42. nextrec/models/retrieval/mind.py +9 -15
  43. nextrec/models/retrieval/sdm.py +36 -33
  44. nextrec/models/retrieval/youtube_dnn.py +16 -24
  45. nextrec/models/sequential/hstu.py +2 -2
  46. nextrec/utils/__init__.py +5 -1
  47. nextrec/utils/config.py +2 -0
  48. nextrec/utils/model.py +16 -77
  49. nextrec/utils/torch_utils.py +11 -0
  50. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
  51. nextrec-0.4.27.dist-info/RECORD +90 -0
  52. nextrec/models/multi_task/aitm.py +0 -0
  53. nextrec/models/multi_task/snr_trans.py +0 -0
  54. nextrec-0.4.24.dist-info/RECORD +0 -86
  55. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
  56. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
  57. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,330 @@
1
+ """
2
+ Date: create on 01/01/2026
3
+ Checkpoint: edit on 01/01/2026
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ Reference:
6
+ - [1] Yan B, Wang P, Zhang K, Li F, Deng H, Xu J, Zheng B. APG: Adaptive Parameter Generation Network for Click-Through Rate Prediction. Advances in Neural Information Processing Systems 35 (NeurIPS 2022), 2022.
7
+ URL: https://arxiv.org/abs/2203.16218
8
+ - [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from nextrec.basic.activation import activation_layer
18
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
19
+ from nextrec.basic.layers import EmbeddingLayer, MLP
20
+ from nextrec.basic.heads import TaskHead
21
+ from nextrec.basic.model import BaseModel
22
+ from nextrec.utils.model import select_features
23
+ from nextrec.utils.types import ActivationName, TaskTypeName
24
+
25
+
26
+ class APGLayer(nn.Module):
27
+ def __init__(
28
+ self,
29
+ input_dim: int,
30
+ output_dim: int,
31
+ scene_emb_dim: int,
32
+ activation: ActivationName = "relu",
33
+ generate_activation: ActivationName | None = None,
34
+ inner_activation: ActivationName | None = None,
35
+ use_uv_shared: bool = True,
36
+ use_mf_p: bool = False,
37
+ mf_k: int = 16,
38
+ mf_p: int = 4,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.use_uv_shared = use_uv_shared
42
+ self.use_mf_p = use_mf_p
43
+ self.input_dim = input_dim
44
+ self.output_dim = output_dim
45
+
46
+ self.activation = (
47
+ activation_layer(activation) if activation is not None else nn.Identity()
48
+ )
49
+ self.inner_activation = (
50
+ activation_layer(inner_activation) if inner_activation is not None else None
51
+ )
52
+
53
+ min_dim = min(int(input_dim), int(output_dim))
54
+ self.p_dim = math.ceil(float(min_dim) / float(mf_p))
55
+ self.k_dim = math.ceil(float(min_dim) / float(mf_k))
56
+
57
+ if use_uv_shared:
58
+ if use_mf_p:
59
+ self.shared_weight_np = nn.Parameter(
60
+ torch.empty(self.input_dim, self.p_dim)
61
+ )
62
+ self.shared_bias_np = nn.Parameter(torch.zeros(self.p_dim))
63
+ self.shared_weight_pk = nn.Parameter(
64
+ torch.empty(self.p_dim, self.k_dim)
65
+ )
66
+ self.shared_bias_pk = nn.Parameter(torch.zeros(self.k_dim))
67
+
68
+ self.shared_weight_kp = nn.Parameter(
69
+ torch.empty(self.k_dim, self.p_dim)
70
+ )
71
+ self.shared_bias_kp = nn.Parameter(torch.zeros(self.p_dim))
72
+ self.shared_weight_pm = nn.Parameter(
73
+ torch.empty(self.p_dim, self.output_dim)
74
+ )
75
+ self.shared_bias_pm = nn.Parameter(torch.zeros(self.output_dim))
76
+ else:
77
+ self.shared_weight_nk = nn.Parameter(
78
+ torch.empty(self.input_dim, self.k_dim)
79
+ )
80
+ self.shared_bias_nk = nn.Parameter(torch.zeros(self.k_dim))
81
+ self.shared_weight_km = nn.Parameter(
82
+ torch.empty(self.k_dim, self.output_dim)
83
+ )
84
+ self.shared_bias_km = nn.Parameter(torch.zeros(self.output_dim))
85
+ self.specific_weight_kk = MLP(
86
+ input_dim=scene_emb_dim,
87
+ hidden_dims=None,
88
+ output_dim=self.k_dim * self.k_dim,
89
+ activation="relu",
90
+ output_activation=generate_activation or "none",
91
+ )
92
+ self.specific_bias_kk = MLP(
93
+ input_dim=scene_emb_dim,
94
+ hidden_dims=None,
95
+ output_dim=self.k_dim,
96
+ activation="relu",
97
+ output_activation=generate_activation or "none",
98
+ )
99
+ if not use_uv_shared:
100
+ self.specific_weight_nk = MLP(
101
+ input_dim=scene_emb_dim,
102
+ hidden_dims=None,
103
+ output_dim=self.input_dim * self.k_dim,
104
+ activation="relu",
105
+ output_activation=generate_activation or "none",
106
+ )
107
+ self.specific_bias_nk = MLP(
108
+ input_dim=scene_emb_dim,
109
+ hidden_dims=None,
110
+ output_dim=self.k_dim,
111
+ activation="relu",
112
+ output_activation=generate_activation or "none",
113
+ )
114
+ self.specific_weight_km = MLP(
115
+ input_dim=scene_emb_dim,
116
+ hidden_dims=None,
117
+ output_dim=self.k_dim * self.output_dim,
118
+ activation="relu",
119
+ output_activation=generate_activation or "none",
120
+ )
121
+ self.specific_bias_km = MLP(
122
+ input_dim=scene_emb_dim,
123
+ hidden_dims=None,
124
+ output_dim=self.output_dim,
125
+ activation="relu",
126
+ output_activation=generate_activation or "none",
127
+ )
128
+ self.reset_parameters()
129
+
130
+ def reset_parameters(self) -> None:
131
+ if self.use_uv_shared:
132
+ if self.use_mf_p:
133
+ nn.init.xavier_uniform_(self.shared_weight_np)
134
+ nn.init.zeros_(self.shared_bias_np)
135
+ nn.init.xavier_uniform_(self.shared_weight_pk)
136
+ nn.init.zeros_(self.shared_bias_pk)
137
+ nn.init.xavier_uniform_(self.shared_weight_kp)
138
+ nn.init.zeros_(self.shared_bias_kp)
139
+ nn.init.xavier_uniform_(self.shared_weight_pm)
140
+ nn.init.zeros_(self.shared_bias_pm)
141
+ else:
142
+ nn.init.xavier_uniform_(self.shared_weight_nk)
143
+ nn.init.zeros_(self.shared_bias_nk)
144
+ nn.init.xavier_uniform_(self.shared_weight_km)
145
+ nn.init.zeros_(self.shared_bias_km)
146
+
147
+ def forward(self, inputs: torch.Tensor, scene_emb: torch.Tensor) -> torch.Tensor:
148
+ specific_weight_kk = self.specific_weight_kk(scene_emb)
149
+ specific_weight_kk = specific_weight_kk.view(-1, self.k_dim, self.k_dim)
150
+ specific_bias_kk = self.specific_bias_kk(scene_emb)
151
+
152
+ if self.use_uv_shared:
153
+ if self.use_mf_p:
154
+ output_np = inputs @ self.shared_weight_np + self.shared_bias_np
155
+ if self.inner_activation is not None:
156
+ output_np = self.inner_activation(output_np)
157
+ output_pk = output_np @ self.shared_weight_pk + self.shared_bias_pk
158
+ if self.inner_activation is not None:
159
+ output_pk = self.inner_activation(output_pk)
160
+ output_kk = (
161
+ torch.bmm(output_pk.unsqueeze(1), specific_weight_kk).squeeze(1)
162
+ + specific_bias_kk
163
+ )
164
+ if self.inner_activation is not None:
165
+ output_kk = self.inner_activation(output_kk)
166
+ output_kp = output_kk @ self.shared_weight_kp + self.shared_bias_kp
167
+ if self.inner_activation is not None:
168
+ output_kp = self.inner_activation(output_kp)
169
+ output = output_kp @ self.shared_weight_pm + self.shared_bias_pm
170
+ else:
171
+ output_nk = inputs @ self.shared_weight_nk + self.shared_bias_nk
172
+ if self.inner_activation is not None:
173
+ output_nk = self.inner_activation(output_nk)
174
+ output_kk = (
175
+ torch.bmm(output_nk.unsqueeze(1), specific_weight_kk).squeeze(1)
176
+ + specific_bias_kk
177
+ )
178
+ if self.inner_activation is not None:
179
+ output_kk = self.inner_activation(output_kk)
180
+ output = output_kk @ self.shared_weight_km + self.shared_bias_km
181
+ else:
182
+ specific_weight_nk = self.specific_weight_nk(scene_emb)
183
+ specific_weight_nk = specific_weight_nk.view(-1, self.input_dim, self.k_dim)
184
+ specific_bias_nk = self.specific_bias_nk(scene_emb)
185
+ specific_weight_km = self.specific_weight_km(scene_emb)
186
+ specific_weight_km = specific_weight_km.view(
187
+ -1, self.k_dim, self.output_dim
188
+ )
189
+ specific_bias_km = self.specific_bias_km(scene_emb)
190
+
191
+ output_nk = (
192
+ torch.bmm(inputs.unsqueeze(1), specific_weight_nk).squeeze(1)
193
+ + specific_bias_nk
194
+ )
195
+ if self.inner_activation is not None:
196
+ output_nk = self.inner_activation(output_nk)
197
+ output_kk = (
198
+ torch.bmm(output_nk.unsqueeze(1), specific_weight_kk).squeeze(1)
199
+ + specific_bias_kk
200
+ )
201
+ if self.inner_activation is not None:
202
+ output_kk = self.inner_activation(output_kk)
203
+ output = (
204
+ torch.bmm(output_kk.unsqueeze(1), specific_weight_km).squeeze(1)
205
+ + specific_bias_km
206
+ )
207
+
208
+ return self.activation(output)
209
+
210
+
211
+ class APG(BaseModel):
212
+ """
213
+ Adaptive Parameter Generation (APG) model.
214
+
215
+ APG stacks APG layers whose middle transformation matrix is generated from
216
+ a scene embedding, enabling scenario-conditioned multi-task learning.
217
+ """
218
+
219
+ @property
220
+ def model_name(self) -> str:
221
+ return "APG"
222
+
223
+ @property
224
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
225
+ nums_task = self.nums_task if hasattr(self, "nums_task") else None
226
+ if nums_task is not None and nums_task > 0:
227
+ return ["binary"] * nums_task
228
+ return ["binary"]
229
+
230
+ def __init__(
231
+ self,
232
+ dense_features: list[DenseFeature] | None = None,
233
+ sparse_features: list[SparseFeature] | None = None,
234
+ sequence_features: list[SequenceFeature] | None = None,
235
+ target: list[str] | str | None = None,
236
+ task: TaskTypeName | list[TaskTypeName] | None = None,
237
+ mlp_params: dict | None = None,
238
+ inner_activation: ActivationName | None = None,
239
+ generate_activation: ActivationName | None = None,
240
+ scene_features: list[str] | str | None = None,
241
+ detach_scene: bool = True,
242
+ use_uv_shared: bool = True,
243
+ use_mf_p: bool = False,
244
+ mf_k: int = 16,
245
+ mf_p: int = 4,
246
+ **kwargs,
247
+ ) -> None:
248
+ dense_features = dense_features or []
249
+ sparse_features = sparse_features or []
250
+ sequence_features = sequence_features or []
251
+ mlp_params = mlp_params or {}
252
+ mlp_params.setdefault("hidden_dims", [256, 128])
253
+ mlp_params.setdefault("activation", "relu")
254
+
255
+ if target is None:
256
+ target = []
257
+ elif isinstance(target, str):
258
+ target = [target]
259
+
260
+ self.nums_task = len(target) if target else 1
261
+
262
+ super().__init__(
263
+ dense_features=dense_features,
264
+ sparse_features=sparse_features,
265
+ sequence_features=sequence_features,
266
+ target=target,
267
+ task=task,
268
+ **kwargs,
269
+ )
270
+
271
+ if not scene_features:
272
+ raise ValueError("APG requires scene_features to generate parameters.")
273
+ if isinstance(scene_features, str):
274
+ scene_features = [scene_features]
275
+ self.scene_features = select_features(
276
+ self.all_features, scene_features, "scene_features"
277
+ )
278
+ self.detach_scene = detach_scene
279
+
280
+ if len(mlp_params["hidden_dims"]) == 0:
281
+ raise ValueError("mlp_params['hidden_dims'] cannot be empty for APG.")
282
+
283
+ self.embedding = EmbeddingLayer(features=self.all_features)
284
+ input_dim = self.embedding.input_dim
285
+ scene_emb_dim = self.embedding.compute_output_dim(self.scene_features)
286
+
287
+ layer_units = [input_dim] + list(mlp_params["hidden_dims"])
288
+ self.apg_layers = nn.ModuleList(
289
+ [
290
+ APGLayer(
291
+ input_dim=layer_units[idx],
292
+ output_dim=layer_units[idx + 1],
293
+ scene_emb_dim=scene_emb_dim,
294
+ activation=mlp_params["activation"],
295
+ generate_activation=generate_activation,
296
+ inner_activation=inner_activation,
297
+ use_uv_shared=use_uv_shared,
298
+ use_mf_p=use_mf_p,
299
+ mf_k=mf_k,
300
+ mf_p=mf_p,
301
+ )
302
+ for idx in range(len(mlp_params["hidden_dims"]))
303
+ ]
304
+ )
305
+
306
+ self.towers = nn.ModuleList(
307
+ [nn.Linear(mlp_params["hidden_dims"][-1], 1) for _ in range(self.nums_task)]
308
+ )
309
+ self.prediction_layer = TaskHead(
310
+ task_type=self.task, task_dims=[1] * self.nums_task
311
+ )
312
+
313
+ self.grad_norm_shared_modules = ["embedding", "apg_layers"]
314
+ self.register_regularization_weights(
315
+ embedding_attr="embedding", include_modules=["apg_layers", "towers"]
316
+ )
317
+
318
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
319
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
320
+ scene_emb = self.embedding(x=x, features=self.scene_features, squeeze_dim=True)
321
+ if self.detach_scene:
322
+ scene_emb = scene_emb.detach()
323
+
324
+ apg_output = input_flat
325
+ for layer in self.apg_layers:
326
+ apg_output = layer(apg_output, scene_emb)
327
+
328
+ task_outputs = [tower(apg_output) for tower in self.towers]
329
+ logits = torch.cat(task_outputs, dim=1)
330
+ return self.prediction_layer(logits)
@@ -0,0 +1,229 @@
1
+ """
2
+ Date: create on 01/01/2026
3
+ Checkpoint: edit on 01/01/2026
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ Reference:
6
+ - [1] Misra I, Shrivastava A, Gupta A, Hebert M. Cross-Stitch Networks for Multi-Task Learning. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR 2016), 2016, pp. 3994–4003.
7
+ URL: https://www.cv-foundation.org/openaccess/content_cvpr_2016/html/Misra_Cross-Stitch_Networks_for_CVPR_2016_paper.html
8
+ - [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
+
10
+ Cross-Stitch networks mix task-specific representations with a learnable
11
+ linear combination at each layer, enabling soft sharing while preserving
12
+ task-specific subspaces.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
21
+ from nextrec.basic.layers import EmbeddingLayer, MLP
22
+ from nextrec.basic.heads import TaskHead
23
+ from nextrec.basic.model import BaseModel
24
+ from nextrec.utils.types import TaskTypeName
25
+
26
+
27
+ class CrossStitchLayer(nn.Module):
28
+ """
29
+ Cross-stitch layer to linearly mix task-specific representations.
30
+ """
31
+
32
+ def __init__(self, input_dims: list[int]) -> None:
33
+ super().__init__()
34
+ if len(input_dims) < 2:
35
+ raise ValueError("CrossStitchLayer requires at least 2 inputs.")
36
+ self.input_dims = list(input_dims)
37
+ if len(set(self.input_dims)) != 1:
38
+ raise ValueError(
39
+ "CrossStitchLayer expects all input dims to be equal to align channels."
40
+ )
41
+ self.num_tasks = len(self.input_dims)
42
+ self.unit_dim = self.input_dims[0]
43
+ identity = torch.eye(self.num_tasks).unsqueeze(-1)
44
+ weight = identity.repeat(1, 1, self.unit_dim)
45
+ self.cross_stitch_weight = nn.Parameter(weight)
46
+
47
+ def forward(self, inputs: list[torch.Tensor]) -> list[torch.Tensor]:
48
+ if len(inputs) != len(self.input_dims):
49
+ raise ValueError(
50
+ f"CrossStitchLayer expects {len(self.input_dims)} inputs, got {len(inputs)}"
51
+ )
52
+ stacked = torch.stack(inputs, dim=1)
53
+ mixed = torch.einsum("b s d, t s d -> b t d", stacked, self.cross_stitch_weight)
54
+ return [mixed[:, task_idx, :] for task_idx in range(self.num_tasks)]
55
+
56
+
57
+ class CrossStitch(BaseModel):
58
+ """
59
+ Cross-Stitch Networks for multi-task learning.
60
+ """
61
+
62
+ @property
63
+ def model_name(self) -> str:
64
+ return "CrossStitch"
65
+
66
+ @property
67
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
68
+ nums_task = self.nums_task if hasattr(self, "nums_task") else None
69
+ if nums_task is not None and nums_task > 0:
70
+ return ["binary"] * nums_task
71
+ return ["binary"]
72
+
73
+ def __init__(
74
+ self,
75
+ dense_features: list[DenseFeature] | None = None,
76
+ sparse_features: list[SparseFeature] | None = None,
77
+ sequence_features: list[SequenceFeature] | None = None,
78
+ target: list[str] | str | None = None,
79
+ task: TaskTypeName | list[TaskTypeName] | None = None,
80
+ shared_mlp_params: dict | None = None,
81
+ task_mlp_params: dict | None = None,
82
+ tower_mlp_params: dict | None = None,
83
+ tower_mlp_params_list: list[dict] | None = None,
84
+ **kwargs,
85
+ ) -> None:
86
+ dense_features = dense_features or []
87
+ sparse_features = sparse_features or []
88
+ sequence_features = sequence_features or []
89
+ shared_mlp_params = shared_mlp_params or {}
90
+ task_mlp_params = task_mlp_params or {}
91
+ tower_mlp_params = tower_mlp_params or {}
92
+ tower_mlp_params_list = tower_mlp_params_list or []
93
+
94
+ shared_mlp_params.setdefault("hidden_dims", [])
95
+ task_mlp_params.setdefault("hidden_dims", [256, 128])
96
+ tower_mlp_params.setdefault("hidden_dims", [64])
97
+
98
+ default_activation = task_mlp_params.get("activation", "relu")
99
+ default_dropout = task_mlp_params.get("dropout", 0.0)
100
+ default_norm_type = task_mlp_params.get("norm_type", "none")
101
+
102
+ shared_mlp_params.setdefault("activation", default_activation)
103
+ shared_mlp_params.setdefault("dropout", default_dropout)
104
+ shared_mlp_params.setdefault("norm_type", default_norm_type)
105
+ task_mlp_params.setdefault("activation", default_activation)
106
+ task_mlp_params.setdefault("dropout", default_dropout)
107
+ task_mlp_params.setdefault("norm_type", default_norm_type)
108
+ tower_mlp_params.setdefault("activation", default_activation)
109
+ tower_mlp_params.setdefault("dropout", default_dropout)
110
+ tower_mlp_params.setdefault("norm_type", default_norm_type)
111
+
112
+ if target is None:
113
+ target = []
114
+ elif isinstance(target, str):
115
+ target = [target]
116
+
117
+ self.nums_task = len(target) if target else 1
118
+
119
+ super().__init__(
120
+ dense_features=dense_features,
121
+ sparse_features=sparse_features,
122
+ sequence_features=sequence_features,
123
+ target=target,
124
+ task=task,
125
+ **kwargs,
126
+ )
127
+
128
+ self.nums_task = len(target) if target else 1
129
+ if self.nums_task <= 1:
130
+ raise ValueError("CrossStitch requires at least 2 tasks.")
131
+ if not task_mlp_params["hidden_dims"]:
132
+ raise ValueError("task_mlp_params['hidden_dims'] must not be empty.")
133
+ shared_hidden_dims = shared_mlp_params["hidden_dims"]
134
+
135
+ if tower_mlp_params_list:
136
+ if len(tower_mlp_params_list) != self.nums_task:
137
+ raise ValueError(
138
+ "Number of tower mlp params "
139
+ f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
140
+ )
141
+ tower_params = [params.copy() for params in tower_mlp_params_list]
142
+ else:
143
+ tower_params = [tower_mlp_params.copy() for _ in range(self.nums_task)]
144
+
145
+ self.embedding = EmbeddingLayer(features=self.all_features)
146
+ input_dim = self.embedding.input_dim
147
+
148
+ if shared_hidden_dims:
149
+ self.shared_layer = MLP(
150
+ input_dim=input_dim,
151
+ hidden_dims=shared_hidden_dims,
152
+ output_dim=None,
153
+ dropout=shared_mlp_params["dropout"],
154
+ activation=shared_mlp_params["activation"],
155
+ norm_type=shared_mlp_params["norm_type"],
156
+ )
157
+ prev_dim = shared_hidden_dims[-1]
158
+ else:
159
+ self.shared_layer = nn.Identity()
160
+ prev_dim = input_dim
161
+ self.grad_norm_shared_modules = [
162
+ "embedding",
163
+ "shared_layer",
164
+ "task_layers",
165
+ "cross_stitch_layers",
166
+ ]
167
+
168
+ self.task_layers = nn.ModuleList()
169
+ self.cross_stitch_layers = nn.ModuleList()
170
+ for hidden_dim in task_mlp_params["hidden_dims"]:
171
+ layer_tasks = nn.ModuleList(
172
+ [
173
+ MLP(
174
+ input_dim=prev_dim,
175
+ hidden_dims=[hidden_dim],
176
+ output_dim=None,
177
+ dropout=task_mlp_params["dropout"],
178
+ activation=task_mlp_params["activation"],
179
+ norm_type=task_mlp_params["norm_type"],
180
+ )
181
+ for _ in range(self.nums_task)
182
+ ]
183
+ )
184
+ self.task_layers.append(layer_tasks)
185
+ self.cross_stitch_layers.append(
186
+ CrossStitchLayer(input_dims=[hidden_dim] * self.nums_task)
187
+ )
188
+ prev_dim = hidden_dim
189
+
190
+ self.towers = nn.ModuleList()
191
+ for params in tower_params:
192
+ if tower_mlp_params_list:
193
+ tower = MLP(input_dim=prev_dim, output_dim=1, **params)
194
+ else:
195
+ tower = MLP(
196
+ input_dim=prev_dim,
197
+ hidden_dims=params.get("hidden_dims"),
198
+ output_dim=1,
199
+ dropout=params.get("dropout", tower_mlp_params["dropout"]),
200
+ activation=params.get("activation", tower_mlp_params["activation"]),
201
+ norm_type=params.get("norm_type", tower_mlp_params["norm_type"]),
202
+ )
203
+ self.towers.append(tower)
204
+
205
+ self.prediction_layer = TaskHead(
206
+ task_type=self.task, task_dims=[1] * self.nums_task
207
+ )
208
+ self.register_regularization_weights(
209
+ embedding_attr="embedding",
210
+ include_modules=["shared_layer", "task_layers", "towers"],
211
+ )
212
+
213
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
214
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
215
+ task_reps = [self.shared_layer(input_flat) for _ in range(self.nums_task)]
216
+
217
+ for layer_idx in range(len(self.task_layers)):
218
+ for task_idx in range(self.nums_task):
219
+ task_reps[task_idx] = self.task_layers[layer_idx][task_idx](
220
+ task_reps[task_idx]
221
+ )
222
+ task_reps = self.cross_stitch_layers[layer_idx](task_reps)
223
+
224
+ task_outputs = []
225
+ for task_idx, tower in enumerate(self.towers):
226
+ task_outputs.append(tower(task_reps[task_idx]))
227
+
228
+ y = torch.cat(task_outputs, dim=1)
229
+ return self.prediction_layer(y)