nextrec 0.4.25__py3-none-any.whl → 0.4.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/asserts.py +72 -0
- nextrec/basic/loggers.py +18 -1
- nextrec/basic/model.py +54 -51
- nextrec/data/batch_utils.py +23 -3
- nextrec/data/dataloader.py +3 -8
- nextrec/models/multi_task/[pre]aitm.py +173 -0
- nextrec/models/multi_task/[pre]snr_trans.py +232 -0
- nextrec/models/multi_task/[pre]star.py +192 -0
- nextrec/models/multi_task/apg.py +330 -0
- nextrec/models/multi_task/cross_stitch.py +229 -0
- nextrec/models/multi_task/escm.py +290 -0
- nextrec/models/multi_task/esmm.py +8 -21
- nextrec/models/multi_task/hmoe.py +203 -0
- nextrec/models/multi_task/mmoe.py +20 -28
- nextrec/models/multi_task/pepnet.py +81 -76
- nextrec/models/multi_task/ple.py +30 -44
- nextrec/models/multi_task/poso.py +13 -22
- nextrec/models/multi_task/share_bottom.py +14 -25
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -4
- nextrec/models/ranking/dcn.py +2 -3
- nextrec/models/ranking/dcn_v2.py +2 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +7 -9
- nextrec/models/ranking/din.py +8 -10
- nextrec/models/ranking/eulernet.py +1 -2
- nextrec/models/ranking/ffm.py +1 -2
- nextrec/models/ranking/fibinet.py +2 -3
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/lr.py +1 -1
- nextrec/models/ranking/masknet.py +1 -2
- nextrec/models/ranking/pnn.py +1 -2
- nextrec/models/ranking/widedeep.py +2 -3
- nextrec/models/ranking/xdeepfm.py +2 -4
- nextrec/models/representation/rqvae.py +4 -4
- nextrec/models/retrieval/dssm.py +18 -26
- nextrec/models/retrieval/dssm_v2.py +15 -22
- nextrec/models/retrieval/mind.py +9 -15
- nextrec/models/retrieval/sdm.py +36 -33
- nextrec/models/retrieval/youtube_dnn.py +16 -24
- nextrec/models/sequential/hstu.py +2 -2
- nextrec/utils/__init__.py +5 -1
- nextrec/utils/model.py +9 -14
- {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/METADATA +72 -62
- nextrec-0.4.28.dist-info/RECORD +90 -0
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec-0.4.25.dist-info/RECORD +0 -86
- {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/WHEEL +0 -0
- {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
|
|
3
|
+
Checkpoint: edit on 01/01/2026
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
- [1] Ma J, Zhao Z, Chen J, Li A, Hong L, Chi EH. SNR: Sub-Network Routing for Flexible Parameter Sharing in Multi-Task Learning in E-Commerce by Exploiting Task Relationships in the Label Space. Proceedings of the 33rd AAAI Conference on Artificial Intelligence (AAAI 2019), 2019, pp. 216-223.
|
|
7
|
+
URL: https://ojs.aaai.org/index.php/AAAI/article/view/3788
|
|
8
|
+
- [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
|
|
9
|
+
|
|
10
|
+
SNR-Trans stacks multiple expert layers and applies sparse routing with
|
|
11
|
+
learnable per-output transform matrices. Intermediate gates route expert
|
|
12
|
+
outputs to the next expert stage, while the final gate routes to each task
|
|
13
|
+
tower for prediction.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
|
|
21
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
22
|
+
from nextrec.basic.layers import EmbeddingLayer, MLP
|
|
23
|
+
from nextrec.basic.heads import TaskHead
|
|
24
|
+
from nextrec.basic.model import BaseModel
|
|
25
|
+
from nextrec.utils.types import TaskTypeName
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SNRTransGate(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
Sparse routing gate with per-output transform matrices.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, num_inputs: int, num_outputs: int, units: int) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
if num_inputs < 1 or num_outputs < 1:
|
|
36
|
+
raise ValueError("num_inputs and num_outputs must be >= 1")
|
|
37
|
+
|
|
38
|
+
self.num_inputs = num_inputs
|
|
39
|
+
self.num_outputs = num_outputs
|
|
40
|
+
self.units = units
|
|
41
|
+
|
|
42
|
+
self.alpha = nn.Parameter(torch.rand(1), requires_grad=True)
|
|
43
|
+
self.beta = 0.9
|
|
44
|
+
self.gamma = -0.1
|
|
45
|
+
self.epsilon = 1.1
|
|
46
|
+
self.eps = 1e-8
|
|
47
|
+
|
|
48
|
+
u_init = torch.empty(num_outputs, num_inputs)
|
|
49
|
+
u_init = nn.init.uniform_(u_init, self.eps, 1 - self.eps)
|
|
50
|
+
self.u = nn.Parameter(u_init, requires_grad=True)
|
|
51
|
+
|
|
52
|
+
trans = torch.empty(num_outputs, num_inputs, units, units)
|
|
53
|
+
nn.init.xavier_normal_(trans)
|
|
54
|
+
self.trans_matrix = nn.Parameter(trans, requires_grad=True)
|
|
55
|
+
|
|
56
|
+
def forward(self, inputs: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
57
|
+
if len(inputs) != self.num_inputs:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"SNRTransGate expects {self.num_inputs} inputs, got {len(inputs)}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
s = torch.sigmoid(
|
|
63
|
+
torch.log(self.u)
|
|
64
|
+
- torch.log(1 - self.u)
|
|
65
|
+
+ torch.log(self.alpha) / self.beta
|
|
66
|
+
)
|
|
67
|
+
s_ = s * (self.epsilon - self.gamma) + self.gamma
|
|
68
|
+
z = torch.clamp(s_, min=0.0, max=1.0)
|
|
69
|
+
|
|
70
|
+
x_stack = torch.stack(inputs, dim=1) # [B, num_inputs, units]
|
|
71
|
+
transformed = torch.einsum(
|
|
72
|
+
"bnu,onuv->bonv", x_stack, self.trans_matrix
|
|
73
|
+
) # [B, num_outputs, num_inputs, units]
|
|
74
|
+
weighted = transformed * z.unsqueeze(0).unsqueeze(-1)
|
|
75
|
+
outputs = weighted.sum(dim=2) # [B, num_outputs, units]
|
|
76
|
+
return [outputs[:, i, :] for i in range(self.num_outputs)]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class SNRTrans(BaseModel):
|
|
80
|
+
"""
|
|
81
|
+
SNR-Trans with sparse expert routing.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def model_name(self) -> str:
|
|
86
|
+
return "SNRTrans"
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
90
|
+
nums_task = self.nums_task if hasattr(self, "nums_task") else None
|
|
91
|
+
if nums_task is not None and nums_task > 0:
|
|
92
|
+
return ["binary"] * nums_task
|
|
93
|
+
return ["binary"]
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
dense_features: list[DenseFeature] | None = None,
|
|
98
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
99
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
100
|
+
expert_mlp_params: dict | None = None,
|
|
101
|
+
num_experts: int = 4,
|
|
102
|
+
tower_mlp_params_list: list[dict] | None = None,
|
|
103
|
+
target: list[str] | str | None = None,
|
|
104
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
105
|
+
**kwargs,
|
|
106
|
+
) -> None:
|
|
107
|
+
dense_features = dense_features or []
|
|
108
|
+
sparse_features = sparse_features or []
|
|
109
|
+
sequence_features = sequence_features or []
|
|
110
|
+
expert_mlp_params = expert_mlp_params or {}
|
|
111
|
+
tower_mlp_params_list = tower_mlp_params_list or []
|
|
112
|
+
|
|
113
|
+
expert_mlp_params.setdefault("hidden_dims", [256, 128])
|
|
114
|
+
expert_mlp_params.setdefault("activation", "relu")
|
|
115
|
+
expert_mlp_params.setdefault("dropout", 0.0)
|
|
116
|
+
expert_mlp_params.setdefault("norm_type", "none")
|
|
117
|
+
expert_hidden_dims = expert_mlp_params["hidden_dims"]
|
|
118
|
+
|
|
119
|
+
if target is None:
|
|
120
|
+
target = []
|
|
121
|
+
elif isinstance(target, str):
|
|
122
|
+
target = [target]
|
|
123
|
+
|
|
124
|
+
self.nums_task = len(target) if target else 1
|
|
125
|
+
|
|
126
|
+
super().__init__(
|
|
127
|
+
dense_features=dense_features,
|
|
128
|
+
sparse_features=sparse_features,
|
|
129
|
+
sequence_features=sequence_features,
|
|
130
|
+
target=target,
|
|
131
|
+
task=task,
|
|
132
|
+
**kwargs,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.nums_task = len(target) if target else 1
|
|
136
|
+
self.num_experts = num_experts
|
|
137
|
+
|
|
138
|
+
if self.nums_task <= 1:
|
|
139
|
+
raise ValueError("SNRTrans requires at least 2 tasks.")
|
|
140
|
+
if self.num_experts <= 1:
|
|
141
|
+
raise ValueError("num_experts must be greater than 1.")
|
|
142
|
+
if not expert_hidden_dims:
|
|
143
|
+
raise ValueError("expert_mlp_params['hidden_dims'] must not be empty.")
|
|
144
|
+
|
|
145
|
+
if tower_mlp_params_list:
|
|
146
|
+
if len(tower_mlp_params_list) != self.nums_task:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"Number of tower mlp params "
|
|
149
|
+
f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
|
|
150
|
+
)
|
|
151
|
+
tower_params = [params.copy() for params in tower_mlp_params_list]
|
|
152
|
+
else:
|
|
153
|
+
tower_params = [{} for _ in range(self.nums_task)]
|
|
154
|
+
|
|
155
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
156
|
+
input_dim = self.embedding.input_dim
|
|
157
|
+
|
|
158
|
+
self.expert_layers = nn.ModuleList()
|
|
159
|
+
self.gates = nn.ModuleList()
|
|
160
|
+
prev_dim = input_dim
|
|
161
|
+
for idx, hidden_dim in enumerate(expert_hidden_dims):
|
|
162
|
+
layer_experts = nn.ModuleList(
|
|
163
|
+
[
|
|
164
|
+
MLP(
|
|
165
|
+
input_dim=prev_dim,
|
|
166
|
+
hidden_dims=[hidden_dim],
|
|
167
|
+
output_dim=None,
|
|
168
|
+
dropout=expert_mlp_params["dropout"],
|
|
169
|
+
activation=expert_mlp_params["activation"],
|
|
170
|
+
norm_type=expert_mlp_params["norm_type"],
|
|
171
|
+
)
|
|
172
|
+
for _ in range(self.num_experts)
|
|
173
|
+
]
|
|
174
|
+
)
|
|
175
|
+
self.expert_layers.append(layer_experts)
|
|
176
|
+
output_dim = (
|
|
177
|
+
self.nums_task
|
|
178
|
+
if idx == len(expert_hidden_dims) - 1
|
|
179
|
+
else self.num_experts
|
|
180
|
+
)
|
|
181
|
+
self.gates.append(
|
|
182
|
+
SNRTransGate(
|
|
183
|
+
num_inputs=self.num_experts,
|
|
184
|
+
num_outputs=output_dim,
|
|
185
|
+
units=hidden_dim,
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
prev_dim = hidden_dim
|
|
189
|
+
|
|
190
|
+
self.towers = nn.ModuleList(
|
|
191
|
+
[
|
|
192
|
+
MLP(input_dim=expert_hidden_dims[-1], output_dim=1, **params)
|
|
193
|
+
for params in tower_params
|
|
194
|
+
]
|
|
195
|
+
)
|
|
196
|
+
self.prediction_layer = TaskHead(
|
|
197
|
+
task_type=self.task, task_dims=[1] * self.nums_task
|
|
198
|
+
)
|
|
199
|
+
self.grad_norm_shared_modules = ["embedding", "expert_layers", "gates"]
|
|
200
|
+
self.register_regularization_weights(
|
|
201
|
+
embedding_attr="embedding",
|
|
202
|
+
include_modules=["expert_layers", "gates", "towers"],
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
206
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
207
|
+
|
|
208
|
+
gate_outputs: list[torch.Tensor] | None = None
|
|
209
|
+
for layer_idx, (layer_experts, gate) in enumerate(
|
|
210
|
+
zip(self.expert_layers, self.gates)
|
|
211
|
+
):
|
|
212
|
+
expert_outputs = []
|
|
213
|
+
if layer_idx == 0:
|
|
214
|
+
expert_inputs = [input_flat] * self.num_experts
|
|
215
|
+
else:
|
|
216
|
+
if gate_outputs is None:
|
|
217
|
+
raise RuntimeError("SNRTrans gate outputs are not initialized.")
|
|
218
|
+
expert_inputs = gate_outputs
|
|
219
|
+
for expert, expert_input in zip(layer_experts, expert_inputs):
|
|
220
|
+
expert_outputs.append(expert(expert_input))
|
|
221
|
+
gate_outputs = gate(expert_outputs)
|
|
222
|
+
|
|
223
|
+
if gate_outputs is None or len(gate_outputs) != self.nums_task:
|
|
224
|
+
raise RuntimeError("SNRTrans gate outputs do not match task count.")
|
|
225
|
+
|
|
226
|
+
task_outputs = []
|
|
227
|
+
for task_idx in range(self.nums_task):
|
|
228
|
+
tower_output = self.towers[task_idx](gate_outputs[task_idx])
|
|
229
|
+
task_outputs.append(tower_output)
|
|
230
|
+
|
|
231
|
+
y = torch.cat(task_outputs, dim=1)
|
|
232
|
+
return self.prediction_layer(y)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
|
|
3
|
+
Checkpoint: edit on 01/01/2026
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
- [1] Sheng XR, Zhao L, Zhou G, Ding X, Dai B, Luo Q, Yang S, Lv J, Zhang C, Deng H, Zhu X. One Model to Serve All: Star Topology Adaptive Recommender for Multi-Domain CTR Prediction. arXiv preprint arXiv:2101.11427, 2021.
|
|
7
|
+
URL: https://arxiv.org/abs/2101.11427
|
|
8
|
+
- [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
|
|
9
|
+
|
|
10
|
+
STAR uses shared-specific linear layers to adapt representations per task while
|
|
11
|
+
optionally reusing shared parameters. It can also apply domain-specific batch
|
|
12
|
+
normalization on the first hidden layer when a domain mask is provided.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from nextrec.basic.activation import activation_layer
|
|
21
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
22
|
+
from nextrec.basic.heads import TaskHead
|
|
23
|
+
from nextrec.basic.layers import DomainBatchNorm, EmbeddingLayer
|
|
24
|
+
from nextrec.basic.model import BaseModel
|
|
25
|
+
from nextrec.utils.types import TaskTypeName
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SharedSpecificLinear(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
Shared-specific linear layer: task-specific projection plus optional shared one.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
input_dim: int,
|
|
36
|
+
output_dim: int,
|
|
37
|
+
nums_task: int,
|
|
38
|
+
use_shared: bool = True,
|
|
39
|
+
) -> None:
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.use_shared = use_shared
|
|
42
|
+
self.shared = nn.Linear(input_dim, output_dim) if use_shared else None
|
|
43
|
+
self.specific = nn.ModuleList(
|
|
44
|
+
[nn.Linear(input_dim, output_dim) for _ in range(nums_task)]
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def forward(self, x: torch.Tensor, task_idx: int) -> torch.Tensor:
|
|
48
|
+
output = self.specific[task_idx](x)
|
|
49
|
+
if self.use_shared and self.shared is not None:
|
|
50
|
+
output = output + self.shared(x)
|
|
51
|
+
return output
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class STAR(BaseModel):
|
|
55
|
+
"""
|
|
56
|
+
STAR: shared-specific multi-task tower with optional domain-specific batch norm.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def model_name(self) -> str:
|
|
61
|
+
return "STAR"
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
65
|
+
nums_task = self.nums_task if hasattr(self, "nums_task") else None
|
|
66
|
+
if nums_task is not None and nums_task > 0:
|
|
67
|
+
return ["binary"] * nums_task
|
|
68
|
+
return ["binary"]
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
dense_features: list[DenseFeature] | None = None,
|
|
73
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
74
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
75
|
+
target: list[str] | str | None = None,
|
|
76
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
77
|
+
mlp_params: dict | None = None,
|
|
78
|
+
use_shared: bool = True,
|
|
79
|
+
**kwargs,
|
|
80
|
+
) -> None:
|
|
81
|
+
dense_features = dense_features or []
|
|
82
|
+
sparse_features = sparse_features or []
|
|
83
|
+
sequence_features = sequence_features or []
|
|
84
|
+
mlp_params = mlp_params or {}
|
|
85
|
+
mlp_params.setdefault("hidden_dims", [256, 128])
|
|
86
|
+
mlp_params.setdefault("activation", "relu")
|
|
87
|
+
mlp_params.setdefault("dropout", 0.0)
|
|
88
|
+
mlp_params.setdefault("norm_type", "none")
|
|
89
|
+
|
|
90
|
+
if target is None:
|
|
91
|
+
target = []
|
|
92
|
+
elif isinstance(target, str):
|
|
93
|
+
target = [target]
|
|
94
|
+
|
|
95
|
+
self.nums_task = len(target) if target else 1
|
|
96
|
+
|
|
97
|
+
super().__init__(
|
|
98
|
+
dense_features=dense_features,
|
|
99
|
+
sparse_features=sparse_features,
|
|
100
|
+
sequence_features=sequence_features,
|
|
101
|
+
target=target,
|
|
102
|
+
task=task,
|
|
103
|
+
**kwargs,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if not mlp_params["hidden_dims"]:
|
|
107
|
+
raise ValueError("mlp_params['hidden_dims'] must not be empty.")
|
|
108
|
+
|
|
109
|
+
norm_type = mlp_params["norm_type"]
|
|
110
|
+
self.dnn_use_bn = norm_type == "batch_norm"
|
|
111
|
+
self.dnn_dropout = mlp_params["dropout"]
|
|
112
|
+
|
|
113
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
114
|
+
input_dim = self.embedding.input_dim
|
|
115
|
+
|
|
116
|
+
layer_units = [input_dim] + list(mlp_params["hidden_dims"])
|
|
117
|
+
self.star_layers = nn.ModuleList(
|
|
118
|
+
[
|
|
119
|
+
SharedSpecificLinear(
|
|
120
|
+
input_dim=layer_units[idx],
|
|
121
|
+
output_dim=layer_units[idx + 1],
|
|
122
|
+
nums_task=self.nums_task,
|
|
123
|
+
use_shared=use_shared,
|
|
124
|
+
)
|
|
125
|
+
for idx in range(len(mlp_params["hidden_dims"]))
|
|
126
|
+
]
|
|
127
|
+
)
|
|
128
|
+
self.activation_layers = nn.ModuleList(
|
|
129
|
+
[
|
|
130
|
+
activation_layer(mlp_params["activation"])
|
|
131
|
+
for _ in range(len(mlp_params["hidden_dims"]))
|
|
132
|
+
]
|
|
133
|
+
)
|
|
134
|
+
if mlp_params["dropout"] > 0:
|
|
135
|
+
self.dropout_layers = nn.ModuleList(
|
|
136
|
+
[
|
|
137
|
+
nn.Dropout(mlp_params["dropout"])
|
|
138
|
+
for _ in range(len(mlp_params["hidden_dims"]))
|
|
139
|
+
]
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
self.dropout_layers = nn.ModuleList(
|
|
143
|
+
[nn.Identity() for _ in range(len(mlp_params["hidden_dims"]))]
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self.domain_bn = (
|
|
147
|
+
DomainBatchNorm(
|
|
148
|
+
num_features=mlp_params["hidden_dims"][0], num_domains=self.nums_task
|
|
149
|
+
)
|
|
150
|
+
if self.dnn_use_bn
|
|
151
|
+
else None
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
self.final_layer = SharedSpecificLinear(
|
|
155
|
+
input_dim=mlp_params["hidden_dims"][-1],
|
|
156
|
+
output_dim=1,
|
|
157
|
+
nums_task=self.nums_task,
|
|
158
|
+
use_shared=use_shared,
|
|
159
|
+
)
|
|
160
|
+
self.prediction_layer = TaskHead(
|
|
161
|
+
task_type=self.task, task_dims=[1] * self.nums_task
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
self.grad_norm_shared_modules = ["embedding", "star_layers", "final_layer"]
|
|
165
|
+
self.register_regularization_weights(
|
|
166
|
+
embedding_attr="embedding",
|
|
167
|
+
include_modules=["star_layers", "final_layer"],
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def forward(
|
|
171
|
+
self, x: dict[str, torch.Tensor], domain_mask: torch.Tensor | None = None
|
|
172
|
+
) -> torch.Tensor:
|
|
173
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
174
|
+
|
|
175
|
+
task_outputs = []
|
|
176
|
+
for task_idx in range(self.nums_task):
|
|
177
|
+
output = input_flat
|
|
178
|
+
for layer_idx, layer in enumerate(self.star_layers):
|
|
179
|
+
output = layer(output, task_idx)
|
|
180
|
+
output = self.activation_layers[layer_idx](output)
|
|
181
|
+
output = self.dropout_layers[layer_idx](output)
|
|
182
|
+
if (
|
|
183
|
+
layer_idx == 0
|
|
184
|
+
and self.dnn_use_bn
|
|
185
|
+
and self.domain_bn is not None
|
|
186
|
+
and domain_mask is not None
|
|
187
|
+
):
|
|
188
|
+
output = self.domain_bn(output, domain_mask)
|
|
189
|
+
task_outputs.append(self.final_layer(output, task_idx))
|
|
190
|
+
|
|
191
|
+
logits = torch.cat(task_outputs, dim=1)
|
|
192
|
+
return self.prediction_layer(logits)
|