nextrec 0.4.16__py3-none-any.whl → 0.4.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/heads.py +99 -0
- nextrec/basic/loggers.py +5 -5
- nextrec/basic/model.py +217 -88
- nextrec/cli.py +1 -1
- nextrec/data/dataloader.py +93 -95
- nextrec/data/preprocessor.py +108 -46
- nextrec/loss/grad_norm.py +13 -13
- nextrec/models/multi_task/esmm.py +10 -11
- nextrec/models/multi_task/mmoe.py +20 -19
- nextrec/models/multi_task/ple.py +35 -34
- nextrec/models/multi_task/poso.py +23 -21
- nextrec/models/multi_task/share_bottom.py +18 -17
- nextrec/models/ranking/afm.py +4 -3
- nextrec/models/ranking/autoint.py +4 -3
- nextrec/models/ranking/dcn.py +4 -3
- nextrec/models/ranking/dcn_v2.py +4 -3
- nextrec/models/ranking/deepfm.py +4 -3
- nextrec/models/ranking/dien.py +2 -2
- nextrec/models/ranking/din.py +2 -2
- nextrec/models/ranking/eulernet.py +4 -3
- nextrec/models/ranking/ffm.py +4 -3
- nextrec/models/ranking/fibinet.py +2 -2
- nextrec/models/ranking/fm.py +4 -3
- nextrec/models/ranking/lr.py +4 -3
- nextrec/models/ranking/masknet.py +4 -5
- nextrec/models/ranking/pnn.py +5 -4
- nextrec/models/ranking/widedeep.py +8 -8
- nextrec/models/ranking/xdeepfm.py +5 -4
- nextrec/utils/console.py +20 -6
- nextrec/utils/data.py +154 -32
- nextrec/utils/model.py +86 -1
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/METADATA +5 -6
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/RECORD +37 -36
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/WHEEL +0 -0
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/grad_norm.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
GradNorm loss weighting for multi-task learning.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 24/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
|
|
8
8
|
Reference:
|
|
@@ -45,7 +45,7 @@ class GradNormLossWeighting:
|
|
|
45
45
|
Adaptive multi-task loss weighting with GradNorm.
|
|
46
46
|
|
|
47
47
|
Args:
|
|
48
|
-
|
|
48
|
+
nums_task: Number of tasks.
|
|
49
49
|
alpha: GradNorm balancing strength.
|
|
50
50
|
lr: Learning rate for the weight optimizer.
|
|
51
51
|
init_weights: Optional initial weights per task.
|
|
@@ -58,7 +58,7 @@ class GradNormLossWeighting:
|
|
|
58
58
|
|
|
59
59
|
def __init__(
|
|
60
60
|
self,
|
|
61
|
-
|
|
61
|
+
nums_task: int,
|
|
62
62
|
alpha: float = 1.5,
|
|
63
63
|
lr: float = 0.025,
|
|
64
64
|
init_weights: Iterable[float] | None = None,
|
|
@@ -68,9 +68,9 @@ class GradNormLossWeighting:
|
|
|
68
68
|
init_ema_decay: float = 0.9,
|
|
69
69
|
eps: float = 1e-8,
|
|
70
70
|
) -> None:
|
|
71
|
-
if
|
|
72
|
-
raise ValueError("GradNorm requires
|
|
73
|
-
self.
|
|
71
|
+
if nums_task <= 1:
|
|
72
|
+
raise ValueError("GradNorm requires nums_task > 1.")
|
|
73
|
+
self.nums_task = nums_task
|
|
74
74
|
self.alpha = alpha
|
|
75
75
|
self.eps = eps
|
|
76
76
|
if ema_decay is not None:
|
|
@@ -87,12 +87,12 @@ class GradNormLossWeighting:
|
|
|
87
87
|
self.init_ema_count = 0
|
|
88
88
|
|
|
89
89
|
if init_weights is None:
|
|
90
|
-
weights = torch.ones(self.
|
|
90
|
+
weights = torch.ones(self.nums_task, dtype=torch.float32)
|
|
91
91
|
else:
|
|
92
92
|
weights = torch.tensor(list(init_weights), dtype=torch.float32)
|
|
93
|
-
if weights.numel() != self.
|
|
93
|
+
if weights.numel() != self.nums_task:
|
|
94
94
|
raise ValueError(
|
|
95
|
-
"init_weights length must match
|
|
95
|
+
"init_weights length must match nums_task for GradNorm."
|
|
96
96
|
)
|
|
97
97
|
if device is not None:
|
|
98
98
|
weights = weights.to(device)
|
|
@@ -123,9 +123,9 @@ class GradNormLossWeighting:
|
|
|
123
123
|
"""
|
|
124
124
|
Return weighted total loss and update task weights with GradNorm.
|
|
125
125
|
"""
|
|
126
|
-
if len(task_losses) != self.
|
|
126
|
+
if len(task_losses) != self.nums_task:
|
|
127
127
|
raise ValueError(
|
|
128
|
-
f"Expected {self.
|
|
128
|
+
f"Expected {self.nums_task} task losses, got {len(task_losses)}."
|
|
129
129
|
)
|
|
130
130
|
shared_params = [p for p in shared_params if p.requires_grad]
|
|
131
131
|
if not shared_params:
|
|
@@ -152,7 +152,7 @@ class GradNormLossWeighting:
|
|
|
152
152
|
|
|
153
153
|
weights_detached = self.weights.detach()
|
|
154
154
|
weighted_losses = [
|
|
155
|
-
weights_detached[i] * task_losses[i] for i in range(self.
|
|
155
|
+
weights_detached[i] * task_losses[i] for i in range(self.nums_task)
|
|
156
156
|
]
|
|
157
157
|
total_loss = torch.stack(weighted_losses).sum()
|
|
158
158
|
|
|
@@ -226,7 +226,7 @@ class GradNormLossWeighting:
|
|
|
226
226
|
|
|
227
227
|
with torch.no_grad():
|
|
228
228
|
w = self.weights.clamp(min=self.eps)
|
|
229
|
-
w = w * self.
|
|
229
|
+
w = w * self.nums_task / (w.sum() + self.eps)
|
|
230
230
|
self.weights.copy_(w)
|
|
231
231
|
|
|
232
232
|
self.pending_grad = None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach
|
|
@@ -45,7 +45,8 @@ import torch
|
|
|
45
45
|
import torch.nn as nn
|
|
46
46
|
|
|
47
47
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
48
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
48
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
49
|
+
from nextrec.basic.heads import TaskHead
|
|
49
50
|
from nextrec.basic.model import BaseModel
|
|
50
51
|
|
|
51
52
|
|
|
@@ -100,17 +101,17 @@ class ESMM(BaseModel):
|
|
|
100
101
|
f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
|
|
101
102
|
)
|
|
102
103
|
|
|
103
|
-
self.
|
|
104
|
+
self.nums_task = len(target)
|
|
104
105
|
resolved_task = task
|
|
105
106
|
if resolved_task is None:
|
|
106
107
|
resolved_task = self.default_task
|
|
107
108
|
elif isinstance(resolved_task, str):
|
|
108
|
-
resolved_task = [resolved_task] * self.
|
|
109
|
-
elif len(resolved_task) == 1 and self.
|
|
110
|
-
resolved_task = resolved_task * self.
|
|
111
|
-
elif len(resolved_task) != self.
|
|
109
|
+
resolved_task = [resolved_task] * self.nums_task
|
|
110
|
+
elif len(resolved_task) == 1 and self.nums_task > 1:
|
|
111
|
+
resolved_task = resolved_task * self.nums_task
|
|
112
|
+
elif len(resolved_task) != self.nums_task:
|
|
112
113
|
raise ValueError(
|
|
113
|
-
f"Length of task ({len(resolved_task)}) must match number of targets ({self.
|
|
114
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
|
|
114
115
|
)
|
|
115
116
|
# resolved_task is now guaranteed to be a list[str]
|
|
116
117
|
|
|
@@ -139,9 +140,7 @@ class ESMM(BaseModel):
|
|
|
139
140
|
# CVR tower
|
|
140
141
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
141
142
|
self.grad_norm_shared_modules = ["embedding"]
|
|
142
|
-
self.prediction_layer =
|
|
143
|
-
task_type=self.default_task, task_dims=[1, 1]
|
|
144
|
-
)
|
|
143
|
+
self.prediction_layer = TaskHead(task_type=self.default_task, task_dims=[1, 1])
|
|
145
144
|
# Register regularization weights
|
|
146
145
|
self.register_regularization_weights(
|
|
147
146
|
embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with
|
|
@@ -46,7 +46,8 @@ import torch
|
|
|
46
46
|
import torch.nn as nn
|
|
47
47
|
|
|
48
48
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
49
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
49
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
50
|
+
from nextrec.basic.heads import TaskHead
|
|
50
51
|
from nextrec.basic.model import BaseModel
|
|
51
52
|
|
|
52
53
|
|
|
@@ -66,9 +67,9 @@ class MMOE(BaseModel):
|
|
|
66
67
|
|
|
67
68
|
@property
|
|
68
69
|
def default_task(self):
|
|
69
|
-
|
|
70
|
-
if
|
|
71
|
-
return ["binary"] *
|
|
70
|
+
nums_task = getattr(self, "nums_task", None)
|
|
71
|
+
if nums_task is not None and nums_task > 0:
|
|
72
|
+
return ["binary"] * nums_task
|
|
72
73
|
return ["binary"]
|
|
73
74
|
|
|
74
75
|
def __init__(
|
|
@@ -106,18 +107,18 @@ class MMOE(BaseModel):
|
|
|
106
107
|
elif isinstance(target, str):
|
|
107
108
|
target = [target]
|
|
108
109
|
|
|
109
|
-
self.
|
|
110
|
+
self.nums_task = len(target) if target else 1
|
|
110
111
|
|
|
111
112
|
resolved_task = task
|
|
112
113
|
if resolved_task is None:
|
|
113
114
|
resolved_task = self.default_task
|
|
114
115
|
elif isinstance(resolved_task, str):
|
|
115
|
-
resolved_task = [resolved_task] * self.
|
|
116
|
-
elif len(resolved_task) == 1 and self.
|
|
117
|
-
resolved_task = resolved_task * self.
|
|
118
|
-
elif len(resolved_task) != self.
|
|
116
|
+
resolved_task = [resolved_task] * self.nums_task
|
|
117
|
+
elif len(resolved_task) == 1 and self.nums_task > 1:
|
|
118
|
+
resolved_task = resolved_task * self.nums_task
|
|
119
|
+
elif len(resolved_task) != self.nums_task:
|
|
119
120
|
raise ValueError(
|
|
120
|
-
f"Length of task ({len(resolved_task)}) must match number of targets ({self.
|
|
121
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
|
|
121
122
|
)
|
|
122
123
|
|
|
123
124
|
super(MMOE, self).__init__(
|
|
@@ -137,12 +138,12 @@ class MMOE(BaseModel):
|
|
|
137
138
|
self.loss = loss
|
|
138
139
|
|
|
139
140
|
# Number of tasks and experts
|
|
140
|
-
self.
|
|
141
|
+
self.nums_task = len(target)
|
|
141
142
|
self.num_experts = num_experts
|
|
142
143
|
|
|
143
|
-
if len(tower_params_list) != self.
|
|
144
|
+
if len(tower_params_list) != self.nums_task:
|
|
144
145
|
raise ValueError(
|
|
145
|
-
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.
|
|
146
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
|
|
146
147
|
)
|
|
147
148
|
|
|
148
149
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
@@ -162,7 +163,7 @@ class MMOE(BaseModel):
|
|
|
162
163
|
|
|
163
164
|
# Task-specific gates
|
|
164
165
|
self.gates = nn.ModuleList()
|
|
165
|
-
for _ in range(self.
|
|
166
|
+
for _ in range(self.nums_task):
|
|
166
167
|
gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
|
|
167
168
|
self.gates.append(gate)
|
|
168
169
|
self.grad_norm_shared_modules = ["embedding", "experts", "gates"]
|
|
@@ -172,8 +173,8 @@ class MMOE(BaseModel):
|
|
|
172
173
|
for tower_params in tower_params_list:
|
|
173
174
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
174
175
|
self.towers.append(tower)
|
|
175
|
-
self.prediction_layer =
|
|
176
|
-
task_type=self.default_task, task_dims=[1] * self.
|
|
176
|
+
self.prediction_layer = TaskHead(
|
|
177
|
+
task_type=self.default_task, task_dims=[1] * self.nums_task
|
|
177
178
|
)
|
|
178
179
|
# Register regularization weights
|
|
179
180
|
self.register_regularization_weights(
|
|
@@ -198,7 +199,7 @@ class MMOE(BaseModel):
|
|
|
198
199
|
|
|
199
200
|
# Task-specific processing
|
|
200
201
|
task_outputs = []
|
|
201
|
-
for task_idx in range(self.
|
|
202
|
+
for task_idx in range(self.nums_task):
|
|
202
203
|
# Gate weights for this task: [B, num_experts]
|
|
203
204
|
gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
|
|
204
205
|
|
|
@@ -217,6 +218,6 @@ class MMOE(BaseModel):
|
|
|
217
218
|
tower_output = self.towers[task_idx](gated_output) # [B, 1]
|
|
218
219
|
task_outputs.append(tower_output)
|
|
219
220
|
|
|
220
|
-
# Stack outputs: [B,
|
|
221
|
+
# Stack outputs: [B, nums_task]
|
|
221
222
|
y = torch.cat(task_outputs, dim=1)
|
|
222
223
|
return self.prediction_layer(y)
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (PLE): A novel
|
|
@@ -49,7 +49,8 @@ import torch
|
|
|
49
49
|
import torch.nn as nn
|
|
50
50
|
|
|
51
51
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
52
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
52
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
53
|
+
from nextrec.basic.heads import TaskHead
|
|
53
54
|
from nextrec.basic.model import BaseModel
|
|
54
55
|
from nextrec.utils.model import get_mlp_output_dim
|
|
55
56
|
|
|
@@ -63,18 +64,18 @@ class CGCLayer(nn.Module):
|
|
|
63
64
|
def __init__(
|
|
64
65
|
self,
|
|
65
66
|
input_dim: int,
|
|
66
|
-
|
|
67
|
+
nums_task: int,
|
|
67
68
|
num_shared_experts: int,
|
|
68
69
|
num_specific_experts: int,
|
|
69
70
|
shared_expert_params: dict,
|
|
70
71
|
specific_expert_params: dict | list[dict],
|
|
71
72
|
):
|
|
72
73
|
super().__init__()
|
|
73
|
-
if
|
|
74
|
-
raise ValueError("
|
|
74
|
+
if nums_task < 1:
|
|
75
|
+
raise ValueError("nums_task must be >= 1")
|
|
75
76
|
|
|
76
77
|
specific_params_list = self.normalize_specific_params(
|
|
77
|
-
specific_expert_params,
|
|
78
|
+
specific_expert_params, nums_task
|
|
78
79
|
)
|
|
79
80
|
|
|
80
81
|
self.output_dim = get_mlp_output_dim(shared_expert_params, input_dim)
|
|
@@ -120,23 +121,23 @@ class CGCLayer(nn.Module):
|
|
|
120
121
|
nn.Linear(input_dim, task_gate_expert_num),
|
|
121
122
|
nn.Softmax(dim=1),
|
|
122
123
|
)
|
|
123
|
-
for _ in range(
|
|
124
|
+
for _ in range(nums_task)
|
|
124
125
|
]
|
|
125
126
|
)
|
|
126
|
-
shared_gate_expert_num = num_shared_experts + num_specific_experts *
|
|
127
|
+
shared_gate_expert_num = num_shared_experts + num_specific_experts * nums_task
|
|
127
128
|
self.shared_gate = nn.Sequential(
|
|
128
129
|
nn.Linear(input_dim, shared_gate_expert_num),
|
|
129
130
|
nn.Softmax(dim=1),
|
|
130
131
|
)
|
|
131
132
|
|
|
132
|
-
self.
|
|
133
|
+
self.nums_task = nums_task
|
|
133
134
|
|
|
134
135
|
def forward(
|
|
135
136
|
self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
|
|
136
137
|
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
137
|
-
if len(task_inputs) != self.
|
|
138
|
+
if len(task_inputs) != self.nums_task:
|
|
138
139
|
raise ValueError(
|
|
139
|
-
f"Expected {self.
|
|
140
|
+
f"Expected {self.nums_task} task inputs, got {len(task_inputs)}"
|
|
140
141
|
)
|
|
141
142
|
|
|
142
143
|
shared_outputs = [expert(shared_input) for expert in self.shared_experts]
|
|
@@ -145,7 +146,7 @@ class CGCLayer(nn.Module):
|
|
|
145
146
|
new_task_fea: list[torch.Tensor] = []
|
|
146
147
|
all_specific_for_shared: list[torch.Tensor] = []
|
|
147
148
|
|
|
148
|
-
for task_idx in range(self.
|
|
149
|
+
for task_idx in range(self.nums_task):
|
|
149
150
|
task_input = task_inputs[task_idx]
|
|
150
151
|
task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
|
|
151
152
|
all_specific_for_shared.extend(task_specific_outputs)
|
|
@@ -167,15 +168,15 @@ class CGCLayer(nn.Module):
|
|
|
167
168
|
|
|
168
169
|
@staticmethod
|
|
169
170
|
def normalize_specific_params(
|
|
170
|
-
params: dict | list[dict],
|
|
171
|
+
params: dict | list[dict], nums_task: int
|
|
171
172
|
) -> list[dict]:
|
|
172
173
|
if isinstance(params, list):
|
|
173
|
-
if len(params) !=
|
|
174
|
+
if len(params) != nums_task:
|
|
174
175
|
raise ValueError(
|
|
175
|
-
f"Length of specific_expert_params ({len(params)}) must match
|
|
176
|
+
f"Length of specific_expert_params ({len(params)}) must match nums_task ({nums_task})."
|
|
176
177
|
)
|
|
177
178
|
return [p.copy() for p in params]
|
|
178
|
-
return [params.copy() for _ in range(
|
|
179
|
+
return [params.copy() for _ in range(nums_task)]
|
|
179
180
|
|
|
180
181
|
|
|
181
182
|
class PLE(BaseModel):
|
|
@@ -194,9 +195,9 @@ class PLE(BaseModel):
|
|
|
194
195
|
|
|
195
196
|
@property
|
|
196
197
|
def default_task(self):
|
|
197
|
-
|
|
198
|
-
if
|
|
199
|
-
return ["binary"] *
|
|
198
|
+
nums_task = getattr(self, "nums_task", None)
|
|
199
|
+
if nums_task is not None and nums_task > 0:
|
|
200
|
+
return ["binary"] * nums_task
|
|
200
201
|
return ["binary"]
|
|
201
202
|
|
|
202
203
|
def __init__(
|
|
@@ -224,18 +225,18 @@ class PLE(BaseModel):
|
|
|
224
225
|
**kwargs,
|
|
225
226
|
):
|
|
226
227
|
|
|
227
|
-
self.
|
|
228
|
+
self.nums_task = len(target)
|
|
228
229
|
|
|
229
230
|
resolved_task = task
|
|
230
231
|
if resolved_task is None:
|
|
231
232
|
resolved_task = self.default_task
|
|
232
233
|
elif isinstance(resolved_task, str):
|
|
233
|
-
resolved_task = [resolved_task] * self.
|
|
234
|
-
elif len(resolved_task) == 1 and self.
|
|
235
|
-
resolved_task = resolved_task * self.
|
|
236
|
-
elif len(resolved_task) != self.
|
|
234
|
+
resolved_task = [resolved_task] * self.nums_task
|
|
235
|
+
elif len(resolved_task) == 1 and self.nums_task > 1:
|
|
236
|
+
resolved_task = resolved_task * self.nums_task
|
|
237
|
+
elif len(resolved_task) != self.nums_task:
|
|
237
238
|
raise ValueError(
|
|
238
|
-
f"Length of task ({len(resolved_task)}) must match number of targets ({self.
|
|
239
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
|
|
239
240
|
)
|
|
240
241
|
|
|
241
242
|
super(PLE, self).__init__(
|
|
@@ -256,15 +257,15 @@ class PLE(BaseModel):
|
|
|
256
257
|
if self.loss is None:
|
|
257
258
|
self.loss = "bce"
|
|
258
259
|
# Number of tasks, experts, and levels
|
|
259
|
-
self.
|
|
260
|
+
self.nums_task = len(target)
|
|
260
261
|
self.num_shared_experts = num_shared_experts
|
|
261
262
|
self.num_specific_experts = num_specific_experts
|
|
262
263
|
self.num_levels = num_levels
|
|
263
264
|
if optimizer_params is None:
|
|
264
265
|
optimizer_params = {}
|
|
265
|
-
if len(tower_params_list) != self.
|
|
266
|
+
if len(tower_params_list) != self.nums_task:
|
|
266
267
|
raise ValueError(
|
|
267
|
-
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.
|
|
268
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
|
|
268
269
|
)
|
|
269
270
|
# Embedding layer
|
|
270
271
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
@@ -287,7 +288,7 @@ class PLE(BaseModel):
|
|
|
287
288
|
level_input_dim = input_dim if level == 0 else expert_output_dim
|
|
288
289
|
cgc_layer = CGCLayer(
|
|
289
290
|
input_dim=level_input_dim,
|
|
290
|
-
|
|
291
|
+
nums_task=self.nums_task,
|
|
291
292
|
num_shared_experts=num_shared_experts,
|
|
292
293
|
num_specific_experts=num_specific_experts,
|
|
293
294
|
shared_expert_params=shared_expert_params,
|
|
@@ -302,8 +303,8 @@ class PLE(BaseModel):
|
|
|
302
303
|
for tower_params in tower_params_list:
|
|
303
304
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
304
305
|
self.towers.append(tower)
|
|
305
|
-
self.prediction_layer =
|
|
306
|
-
task_type=self.default_task, task_dims=[1] * self.
|
|
306
|
+
self.prediction_layer = TaskHead(
|
|
307
|
+
task_type=self.default_task, task_dims=[1] * self.nums_task
|
|
307
308
|
)
|
|
308
309
|
# Register regularization weights
|
|
309
310
|
self.register_regularization_weights(
|
|
@@ -321,7 +322,7 @@ class PLE(BaseModel):
|
|
|
321
322
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
322
323
|
|
|
323
324
|
# Initial features for each task and shared
|
|
324
|
-
task_fea = [input_flat for _ in range(self.
|
|
325
|
+
task_fea = [input_flat for _ in range(self.nums_task)]
|
|
325
326
|
shared_fea = input_flat
|
|
326
327
|
|
|
327
328
|
# Progressive Layered Extraction: CGC
|
|
@@ -330,10 +331,10 @@ class PLE(BaseModel):
|
|
|
330
331
|
|
|
331
332
|
# task tower
|
|
332
333
|
task_outputs = []
|
|
333
|
-
for task_idx in range(self.
|
|
334
|
+
for task_idx in range(self.nums_task):
|
|
334
335
|
tower_output = self.towers[task_idx](task_fea[task_idx]) # [B, 1]
|
|
335
336
|
task_outputs.append(tower_output)
|
|
336
337
|
|
|
337
|
-
# [B,
|
|
338
|
+
# [B, nums_task]
|
|
338
339
|
y = torch.cat(task_outputs, dim=1)
|
|
339
340
|
return self.prediction_layer(y)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 28/11/2025
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
3
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
4
5
|
Reference:
|
|
5
6
|
[1] Wang et al. "POSO: Personalized Cold Start Modules for Large-scale Recommender Systems", 2021.
|
|
@@ -44,7 +45,8 @@ import torch.nn.functional as F
|
|
|
44
45
|
|
|
45
46
|
from nextrec.basic.activation import activation_layer
|
|
46
47
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
47
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
48
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
49
|
+
from nextrec.basic.heads import TaskHead
|
|
48
50
|
from nextrec.basic.model import BaseModel
|
|
49
51
|
from nextrec.utils.model import select_features
|
|
50
52
|
|
|
@@ -195,7 +197,7 @@ class POSOMMoE(nn.Module):
|
|
|
195
197
|
pc_dim: int, # for poso feature dimension
|
|
196
198
|
num_experts: int,
|
|
197
199
|
expert_hidden_dims: list[int],
|
|
198
|
-
|
|
200
|
+
nums_task: int,
|
|
199
201
|
activation: str = "relu",
|
|
200
202
|
expert_dropout: float = 0.0,
|
|
201
203
|
gate_hidden_dim: int = 32, # for poso gate hidden dimension
|
|
@@ -204,7 +206,7 @@ class POSOMMoE(nn.Module):
|
|
|
204
206
|
) -> None:
|
|
205
207
|
super().__init__()
|
|
206
208
|
self.num_experts = num_experts
|
|
207
|
-
self.
|
|
209
|
+
self.nums_task = nums_task
|
|
208
210
|
|
|
209
211
|
# Experts built with framework MLP, same as standard MMoE
|
|
210
212
|
self.experts = nn.ModuleList(
|
|
@@ -225,7 +227,7 @@ class POSOMMoE(nn.Module):
|
|
|
225
227
|
|
|
226
228
|
# Task-specific gates: gate_t(x) over experts
|
|
227
229
|
self.gates = nn.ModuleList(
|
|
228
|
-
[nn.Linear(input_dim, num_experts) for _ in range(
|
|
230
|
+
[nn.Linear(input_dim, num_experts) for _ in range(nums_task)]
|
|
229
231
|
)
|
|
230
232
|
self.gate_use_softmax = gate_use_softmax
|
|
231
233
|
|
|
@@ -247,7 +249,7 @@ class POSOMMoE(nn.Module):
|
|
|
247
249
|
"""
|
|
248
250
|
x: (B, input_dim)
|
|
249
251
|
pc: (B, pc_dim)
|
|
250
|
-
return: list of task outputs z_t with length
|
|
252
|
+
return: list of task outputs z_t with length nums_task, each (B, D)
|
|
251
253
|
"""
|
|
252
254
|
# 1) Expert outputs with POSO PC gate
|
|
253
255
|
masked_expert_outputs = []
|
|
@@ -261,7 +263,7 @@ class POSOMMoE(nn.Module):
|
|
|
261
263
|
|
|
262
264
|
# 2) Task gates depend on x as in standard MMoE
|
|
263
265
|
task_outputs: list[torch.Tensor] = []
|
|
264
|
-
for t in range(self.
|
|
266
|
+
for t in range(self.nums_task):
|
|
265
267
|
logits = self.gates[t](x) # (B, E)
|
|
266
268
|
if self.gate_use_softmax:
|
|
267
269
|
gate = F.softmax(logits, dim=1)
|
|
@@ -288,9 +290,9 @@ class POSO(BaseModel):
|
|
|
288
290
|
|
|
289
291
|
@property
|
|
290
292
|
def default_task(self) -> list[str]:
|
|
291
|
-
|
|
292
|
-
if
|
|
293
|
-
return ["binary"] *
|
|
293
|
+
nums_task = getattr(self, "nums_task", None)
|
|
294
|
+
if nums_task is not None and nums_task > 0:
|
|
295
|
+
return ["binary"] * nums_task
|
|
294
296
|
return ["binary"]
|
|
295
297
|
|
|
296
298
|
def __init__(
|
|
@@ -332,24 +334,24 @@ class POSO(BaseModel):
|
|
|
332
334
|
dense_l2_reg: float = 1e-4,
|
|
333
335
|
**kwargs,
|
|
334
336
|
):
|
|
335
|
-
self.
|
|
337
|
+
self.nums_task = len(target)
|
|
336
338
|
|
|
337
|
-
# Normalize task to match
|
|
339
|
+
# Normalize task to match nums_task
|
|
338
340
|
resolved_task = task
|
|
339
341
|
if resolved_task is None:
|
|
340
342
|
resolved_task = self.default_task
|
|
341
343
|
elif isinstance(resolved_task, str):
|
|
342
|
-
resolved_task = [resolved_task] * self.
|
|
343
|
-
elif len(resolved_task) == 1 and self.
|
|
344
|
-
resolved_task = resolved_task * self.
|
|
345
|
-
elif len(resolved_task) != self.
|
|
344
|
+
resolved_task = [resolved_task] * self.nums_task
|
|
345
|
+
elif len(resolved_task) == 1 and self.nums_task > 1:
|
|
346
|
+
resolved_task = resolved_task * self.nums_task
|
|
347
|
+
elif len(resolved_task) != self.nums_task:
|
|
346
348
|
raise ValueError(
|
|
347
|
-
f"Length of task ({len(resolved_task)}) must match number of targets ({self.
|
|
349
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
|
|
348
350
|
)
|
|
349
351
|
|
|
350
|
-
if len(tower_params_list) != self.
|
|
352
|
+
if len(tower_params_list) != self.nums_task:
|
|
351
353
|
raise ValueError(
|
|
352
|
-
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.
|
|
354
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
|
|
353
355
|
)
|
|
354
356
|
|
|
355
357
|
super().__init__(
|
|
@@ -465,7 +467,7 @@ class POSO(BaseModel):
|
|
|
465
467
|
pc_dim=self.pc_input_dim,
|
|
466
468
|
num_experts=num_experts,
|
|
467
469
|
expert_hidden_dims=expert_hidden_dims,
|
|
468
|
-
|
|
470
|
+
nums_task=self.nums_task,
|
|
469
471
|
activation=expert_activation,
|
|
470
472
|
expert_dropout=expert_dropout,
|
|
471
473
|
gate_hidden_dim=expert_gate_hidden_dim,
|
|
@@ -487,9 +489,9 @@ class POSO(BaseModel):
|
|
|
487
489
|
self.grad_norm_shared_modules = ["embedding"]
|
|
488
490
|
else:
|
|
489
491
|
self.grad_norm_shared_modules = ["embedding", "mmoe"]
|
|
490
|
-
self.prediction_layer =
|
|
492
|
+
self.prediction_layer = TaskHead(
|
|
491
493
|
task_type=self.default_task,
|
|
492
|
-
task_dims=[1] * self.
|
|
494
|
+
task_dims=[1] * self.nums_task,
|
|
493
495
|
)
|
|
494
496
|
include_modules = (
|
|
495
497
|
["towers", "tower_heads"]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Caruana R. Multitask learning[J]. Machine Learning, 1997, 28: 41-75.
|
|
@@ -43,7 +43,8 @@ import torch
|
|
|
43
43
|
import torch.nn as nn
|
|
44
44
|
|
|
45
45
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
46
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
46
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
47
|
+
from nextrec.basic.heads import TaskHead
|
|
47
48
|
from nextrec.basic.model import BaseModel
|
|
48
49
|
|
|
49
50
|
|
|
@@ -54,9 +55,9 @@ class ShareBottom(BaseModel):
|
|
|
54
55
|
|
|
55
56
|
@property
|
|
56
57
|
def default_task(self):
|
|
57
|
-
|
|
58
|
-
if
|
|
59
|
-
return ["binary"] *
|
|
58
|
+
nums_task = getattr(self, "nums_task", None)
|
|
59
|
+
if nums_task is not None and nums_task > 0:
|
|
60
|
+
return ["binary"] * nums_task
|
|
60
61
|
return ["binary"]
|
|
61
62
|
|
|
62
63
|
def __init__(
|
|
@@ -82,18 +83,18 @@ class ShareBottom(BaseModel):
|
|
|
82
83
|
|
|
83
84
|
optimizer_params = optimizer_params or {}
|
|
84
85
|
|
|
85
|
-
self.
|
|
86
|
+
self.nums_task = len(target)
|
|
86
87
|
|
|
87
88
|
resolved_task = task
|
|
88
89
|
if resolved_task is None:
|
|
89
90
|
resolved_task = self.default_task
|
|
90
91
|
elif isinstance(resolved_task, str):
|
|
91
|
-
resolved_task = [resolved_task] * self.
|
|
92
|
-
elif len(resolved_task) == 1 and self.
|
|
93
|
-
resolved_task = resolved_task * self.
|
|
94
|
-
elif len(resolved_task) != self.
|
|
92
|
+
resolved_task = [resolved_task] * self.nums_task
|
|
93
|
+
elif len(resolved_task) == 1 and self.nums_task > 1:
|
|
94
|
+
resolved_task = resolved_task * self.nums_task
|
|
95
|
+
elif len(resolved_task) != self.nums_task:
|
|
95
96
|
raise ValueError(
|
|
96
|
-
f"Length of task ({len(resolved_task)}) must match number of targets ({self.
|
|
97
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
|
|
97
98
|
)
|
|
98
99
|
|
|
99
100
|
super(ShareBottom, self).__init__(
|
|
@@ -114,10 +115,10 @@ class ShareBottom(BaseModel):
|
|
|
114
115
|
if self.loss is None:
|
|
115
116
|
self.loss = "bce"
|
|
116
117
|
# Number of tasks
|
|
117
|
-
self.
|
|
118
|
-
if len(tower_params_list) != self.
|
|
118
|
+
self.nums_task = len(target)
|
|
119
|
+
if len(tower_params_list) != self.nums_task:
|
|
119
120
|
raise ValueError(
|
|
120
|
-
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.
|
|
121
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
|
|
121
122
|
)
|
|
122
123
|
# Embedding layer
|
|
123
124
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
@@ -142,8 +143,8 @@ class ShareBottom(BaseModel):
|
|
|
142
143
|
for tower_params in tower_params_list:
|
|
143
144
|
tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
|
|
144
145
|
self.towers.append(tower)
|
|
145
|
-
self.prediction_layer =
|
|
146
|
-
task_type=self.default_task, task_dims=[1] * self.
|
|
146
|
+
self.prediction_layer = TaskHead(
|
|
147
|
+
task_type=self.default_task, task_dims=[1] * self.nums_task
|
|
147
148
|
)
|
|
148
149
|
# Register regularization weights
|
|
149
150
|
self.register_regularization_weights(
|
|
@@ -169,6 +170,6 @@ class ShareBottom(BaseModel):
|
|
|
169
170
|
tower_output = tower(bottom_output) # [B, 1]
|
|
170
171
|
task_outputs.append(tower_output)
|
|
171
172
|
|
|
172
|
-
# Stack outputs: [B,
|
|
173
|
+
# Stack outputs: [B, nums_task]
|
|
173
174
|
y = torch.cat(task_outputs, dim=1)
|
|
174
175
|
return self.prediction_layer(y)
|