nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -52,89 +52,103 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
52
52
|
class ESMM(BaseModel):
|
|
53
53
|
"""
|
|
54
54
|
Entire Space Multi-Task Model
|
|
55
|
-
|
|
55
|
+
|
|
56
56
|
ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
|
|
57
57
|
- CTR task: P(click | impression)
|
|
58
58
|
- CVR task: P(conversion | click)
|
|
59
59
|
- CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
This design addresses the sample selection bias and data sparsity issues in CVR modeling.
|
|
62
62
|
"""
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
@property
|
|
65
65
|
def model_name(self):
|
|
66
66
|
return "ESMM"
|
|
67
67
|
|
|
68
68
|
@property
|
|
69
|
-
def
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
69
|
+
def default_task(self):
|
|
70
|
+
return ["binary", "binary"]
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
dense_features: list[DenseFeature],
|
|
75
|
+
sparse_features: list[SparseFeature],
|
|
76
|
+
sequence_features: list[SequenceFeature],
|
|
77
|
+
ctr_params: dict,
|
|
78
|
+
cvr_params: dict,
|
|
79
|
+
target: list[str] = ["ctr", "ctcvr"], # Note: ctcvr = ctr * cvr
|
|
80
|
+
task: list[str] | None = None,
|
|
81
|
+
optimizer: str = "adam",
|
|
82
|
+
optimizer_params: dict = {},
|
|
83
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
84
|
+
loss_params: dict | list[dict] | None = None,
|
|
85
|
+
device: str = "cpu",
|
|
86
|
+
embedding_l1_reg=1e-6,
|
|
87
|
+
dense_l1_reg=1e-5,
|
|
88
|
+
embedding_l2_reg=1e-5,
|
|
89
|
+
dense_l2_reg=1e-4,
|
|
90
|
+
**kwargs,
|
|
91
|
+
):
|
|
92
|
+
|
|
92
93
|
# ESMM requires exactly 2 targets: ctr and ctcvr
|
|
93
94
|
if len(target) != 2:
|
|
94
|
-
raise ValueError(
|
|
95
|
-
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
|
|
97
|
+
)
|
|
98
|
+
|
|
96
99
|
super(ESMM, self).__init__(
|
|
97
100
|
dense_features=dense_features,
|
|
98
101
|
sparse_features=sparse_features,
|
|
99
102
|
sequence_features=sequence_features,
|
|
100
103
|
target=target,
|
|
101
|
-
task=task
|
|
104
|
+
task=task
|
|
105
|
+
or self.default_task, # Both CTR and CTCVR are binary classification
|
|
102
106
|
device=device,
|
|
103
107
|
embedding_l1_reg=embedding_l1_reg,
|
|
104
108
|
dense_l1_reg=dense_l1_reg,
|
|
105
109
|
embedding_l2_reg=embedding_l2_reg,
|
|
106
110
|
dense_l2_reg=dense_l2_reg,
|
|
107
|
-
|
|
108
|
-
**kwargs
|
|
111
|
+
**kwargs,
|
|
109
112
|
)
|
|
110
113
|
|
|
111
114
|
self.loss = loss
|
|
112
115
|
if self.loss is None:
|
|
113
116
|
self.loss = "bce"
|
|
114
|
-
|
|
117
|
+
|
|
115
118
|
# All features
|
|
116
119
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
117
120
|
# Shared embedding layer
|
|
118
121
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
119
|
-
input_dim =
|
|
122
|
+
input_dim = (
|
|
123
|
+
self.embedding.input_dim
|
|
124
|
+
) # Calculate input dimension, better way than below
|
|
120
125
|
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
121
126
|
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
122
127
|
# input_dim = emb_dim_total + dense_input_dim
|
|
123
128
|
|
|
124
129
|
# CTR tower
|
|
125
130
|
self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
|
|
126
|
-
|
|
131
|
+
|
|
127
132
|
# CVR tower
|
|
128
133
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
129
|
-
self.prediction_layer = PredictionLayer(
|
|
134
|
+
self.prediction_layer = PredictionLayer(
|
|
135
|
+
task_type=self.default_task, task_dims=[1, 1]
|
|
136
|
+
)
|
|
130
137
|
# Register regularization weights
|
|
131
|
-
self.register_regularization_weights(
|
|
132
|
-
|
|
138
|
+
self.register_regularization_weights(
|
|
139
|
+
embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
|
|
140
|
+
)
|
|
141
|
+
self.compile(
|
|
142
|
+
optimizer=optimizer,
|
|
143
|
+
optimizer_params=optimizer_params,
|
|
144
|
+
loss=loss,
|
|
145
|
+
loss_params=loss_params,
|
|
146
|
+
)
|
|
133
147
|
|
|
134
148
|
def forward(self, x):
|
|
135
149
|
# Get all embeddings and flatten
|
|
136
150
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
137
|
-
|
|
151
|
+
|
|
138
152
|
# CTR prediction: P(click | impression)
|
|
139
153
|
ctr_logit = self.ctr_tower(input_flat) # [B, 1]
|
|
140
154
|
cvr_logit = self.cvr_tower(input_flat) # [B, 1]
|
|
@@ -142,7 +156,7 @@ class ESMM(BaseModel):
|
|
|
142
156
|
preds = self.prediction_layer(logits)
|
|
143
157
|
ctr, cvr = preds.chunk(2, dim=1)
|
|
144
158
|
ctcvr = ctr * cvr # [B, 1]
|
|
145
|
-
|
|
159
|
+
|
|
146
160
|
# Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
|
|
147
161
|
y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
|
|
148
162
|
return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
|
|
@@ -53,66 +53,74 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
53
53
|
class MMOE(BaseModel):
|
|
54
54
|
"""
|
|
55
55
|
Multi-gate Mixture-of-Experts
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
MMOE improves upon shared-bottom architecture by using multiple expert networks
|
|
58
58
|
and task-specific gating networks. Each task has its own gate that learns to
|
|
59
59
|
weight the contributions of different experts, allowing for both task-specific
|
|
60
60
|
and shared representations.
|
|
61
61
|
"""
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
@property
|
|
64
64
|
def model_name(self):
|
|
65
65
|
return "MMOE"
|
|
66
66
|
|
|
67
67
|
@property
|
|
68
|
-
def
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
68
|
+
def default_task(self):
|
|
69
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
70
|
+
if num_tasks is not None and num_tasks > 0:
|
|
71
|
+
return ["binary"] * num_tasks
|
|
72
|
+
return ["binary"]
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
dense_features: list[DenseFeature] = [],
|
|
77
|
+
sparse_features: list[SparseFeature] = [],
|
|
78
|
+
sequence_features: list[SequenceFeature] = [],
|
|
79
|
+
expert_params: dict = {},
|
|
80
|
+
num_experts: int = 3,
|
|
81
|
+
tower_params_list: list[dict] = [],
|
|
82
|
+
target: list[str] = [],
|
|
83
|
+
task: str | list[str] | None = None,
|
|
84
|
+
optimizer: str = "adam",
|
|
85
|
+
optimizer_params: dict = {},
|
|
86
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
87
|
+
loss_params: dict | list[dict] | None = None,
|
|
88
|
+
device: str = "cpu",
|
|
89
|
+
embedding_l1_reg=1e-6,
|
|
90
|
+
dense_l1_reg=1e-5,
|
|
91
|
+
embedding_l2_reg=1e-5,
|
|
92
|
+
dense_l2_reg=1e-4,
|
|
93
|
+
**kwargs,
|
|
94
|
+
):
|
|
95
|
+
|
|
96
|
+
self.num_tasks = len(target)
|
|
97
|
+
|
|
91
98
|
super(MMOE, self).__init__(
|
|
92
99
|
dense_features=dense_features,
|
|
93
100
|
sparse_features=sparse_features,
|
|
94
101
|
sequence_features=sequence_features,
|
|
95
102
|
target=target,
|
|
96
|
-
task=task,
|
|
103
|
+
task=task or self.default_task,
|
|
97
104
|
device=device,
|
|
98
105
|
embedding_l1_reg=embedding_l1_reg,
|
|
99
106
|
dense_l1_reg=dense_l1_reg,
|
|
100
107
|
embedding_l2_reg=embedding_l2_reg,
|
|
101
108
|
dense_l2_reg=dense_l2_reg,
|
|
102
|
-
|
|
103
|
-
**kwargs
|
|
109
|
+
**kwargs,
|
|
104
110
|
)
|
|
105
111
|
|
|
106
112
|
self.loss = loss
|
|
107
113
|
if self.loss is None:
|
|
108
114
|
self.loss = "bce"
|
|
109
|
-
|
|
115
|
+
|
|
110
116
|
# Number of tasks and experts
|
|
111
117
|
self.num_tasks = len(target)
|
|
112
118
|
self.num_experts = num_experts
|
|
113
|
-
|
|
119
|
+
|
|
114
120
|
if len(tower_params_list) != self.num_tasks:
|
|
115
|
-
raise ValueError(
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
123
|
+
)
|
|
116
124
|
|
|
117
125
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
118
126
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
@@ -126,54 +134,69 @@ class MMOE(BaseModel):
|
|
|
126
134
|
for _ in range(num_experts):
|
|
127
135
|
expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
|
|
128
136
|
self.experts.append(expert)
|
|
129
|
-
|
|
137
|
+
|
|
130
138
|
# Get expert output dimension
|
|
131
|
-
if
|
|
132
|
-
expert_output_dim = expert_params[
|
|
139
|
+
if "dims" in expert_params and len(expert_params["dims"]) > 0:
|
|
140
|
+
expert_output_dim = expert_params["dims"][-1]
|
|
133
141
|
else:
|
|
134
142
|
expert_output_dim = input_dim
|
|
135
|
-
|
|
143
|
+
|
|
136
144
|
# Task-specific gates
|
|
137
145
|
self.gates = nn.ModuleList()
|
|
138
146
|
for _ in range(self.num_tasks):
|
|
139
147
|
gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
|
|
140
148
|
self.gates.append(gate)
|
|
141
|
-
|
|
149
|
+
|
|
142
150
|
# Task-specific towers
|
|
143
151
|
self.towers = nn.ModuleList()
|
|
144
152
|
for tower_params in tower_params_list:
|
|
145
153
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
146
154
|
self.towers.append(tower)
|
|
147
|
-
self.prediction_layer = PredictionLayer(
|
|
155
|
+
self.prediction_layer = PredictionLayer(
|
|
156
|
+
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
157
|
+
)
|
|
148
158
|
# Register regularization weights
|
|
149
|
-
self.register_regularization_weights(
|
|
150
|
-
|
|
159
|
+
self.register_regularization_weights(
|
|
160
|
+
embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
|
|
161
|
+
)
|
|
162
|
+
self.compile(
|
|
163
|
+
optimizer=optimizer,
|
|
164
|
+
optimizer_params=optimizer_params,
|
|
165
|
+
loss=loss,
|
|
166
|
+
loss_params=loss_params,
|
|
167
|
+
)
|
|
151
168
|
|
|
152
169
|
def forward(self, x):
|
|
153
170
|
# Get all embeddings and flatten
|
|
154
171
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
155
|
-
|
|
172
|
+
|
|
156
173
|
# Expert outputs: [num_experts, B, expert_dim]
|
|
157
174
|
expert_outputs = [expert(input_flat) for expert in self.experts]
|
|
158
|
-
expert_outputs = torch.stack(
|
|
159
|
-
|
|
175
|
+
expert_outputs = torch.stack(
|
|
176
|
+
expert_outputs, dim=0
|
|
177
|
+
) # [num_experts, B, expert_dim]
|
|
178
|
+
|
|
160
179
|
# Task-specific processing
|
|
161
180
|
task_outputs = []
|
|
162
181
|
for task_idx in range(self.num_tasks):
|
|
163
182
|
# Gate weights for this task: [B, num_experts]
|
|
164
183
|
gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
|
|
165
|
-
|
|
184
|
+
|
|
166
185
|
# Weighted sum of expert outputs
|
|
167
186
|
# gate_weights: [B, num_experts, 1]
|
|
168
187
|
# expert_outputs: [num_experts, B, expert_dim]
|
|
169
188
|
gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
|
|
170
|
-
expert_outputs_t = expert_outputs.permute(
|
|
171
|
-
|
|
172
|
-
|
|
189
|
+
expert_outputs_t = expert_outputs.permute(
|
|
190
|
+
1, 0, 2
|
|
191
|
+
) # [B, num_experts, expert_dim]
|
|
192
|
+
gated_output = torch.sum(
|
|
193
|
+
gate_weights * expert_outputs_t, dim=1
|
|
194
|
+
) # [B, expert_dim]
|
|
195
|
+
|
|
173
196
|
# Tower output
|
|
174
197
|
tower_output = self.towers[task_idx](gated_output) # [B, 1]
|
|
175
198
|
task_outputs.append(tower_output)
|
|
176
|
-
|
|
199
|
+
|
|
177
200
|
# Stack outputs: [B, num_tasks]
|
|
178
201
|
y = torch.cat(task_outputs, dim=1)
|
|
179
202
|
return self.prediction_layer(y)
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -52,6 +52,7 @@ from nextrec.basic.model import BaseModel
|
|
|
52
52
|
from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
|
|
53
53
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
54
54
|
|
|
55
|
+
|
|
55
56
|
class CGCLayer(nn.Module):
|
|
56
57
|
"""
|
|
57
58
|
CGC (Customized Gate Control) block used by PLE.
|
|
@@ -71,26 +72,61 @@ class CGCLayer(nn.Module):
|
|
|
71
72
|
if num_tasks < 1:
|
|
72
73
|
raise ValueError("num_tasks must be >= 1")
|
|
73
74
|
|
|
74
|
-
specific_params_list = self._normalize_specific_params(
|
|
75
|
+
specific_params_list = self._normalize_specific_params(
|
|
76
|
+
specific_expert_params, num_tasks
|
|
77
|
+
)
|
|
75
78
|
|
|
76
79
|
self.output_dim = self._get_output_dim(shared_expert_params, input_dim)
|
|
77
|
-
specific_dims = [
|
|
80
|
+
specific_dims = [
|
|
81
|
+
self._get_output_dim(params, input_dim) for params in specific_params_list
|
|
82
|
+
]
|
|
78
83
|
dims_set = set(specific_dims + [self.output_dim])
|
|
79
84
|
if len(dims_set) != 1:
|
|
80
|
-
raise ValueError(
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Shared/specific expert output dims must match, got {dims_set}"
|
|
87
|
+
)
|
|
81
88
|
|
|
82
89
|
# experts
|
|
83
|
-
self.shared_experts = nn.ModuleList(
|
|
90
|
+
self.shared_experts = nn.ModuleList(
|
|
91
|
+
[
|
|
92
|
+
MLP(
|
|
93
|
+
input_dim=input_dim,
|
|
94
|
+
output_layer=False,
|
|
95
|
+
**shared_expert_params,
|
|
96
|
+
)
|
|
97
|
+
for _ in range(num_shared_experts)
|
|
98
|
+
]
|
|
99
|
+
)
|
|
84
100
|
self.specific_experts = nn.ModuleList()
|
|
85
101
|
for params in specific_params_list:
|
|
86
|
-
task_experts = nn.ModuleList(
|
|
102
|
+
task_experts = nn.ModuleList(
|
|
103
|
+
[
|
|
104
|
+
MLP(
|
|
105
|
+
input_dim=input_dim,
|
|
106
|
+
output_layer=False,
|
|
107
|
+
**params,
|
|
108
|
+
)
|
|
109
|
+
for _ in range(num_specific_experts)
|
|
110
|
+
]
|
|
111
|
+
)
|
|
87
112
|
self.specific_experts.append(task_experts)
|
|
88
113
|
|
|
89
114
|
# gates
|
|
90
115
|
task_gate_expert_num = num_shared_experts + num_specific_experts
|
|
91
|
-
self.task_gates = nn.ModuleList(
|
|
116
|
+
self.task_gates = nn.ModuleList(
|
|
117
|
+
[
|
|
118
|
+
nn.Sequential(
|
|
119
|
+
nn.Linear(input_dim, task_gate_expert_num),
|
|
120
|
+
nn.Softmax(dim=1),
|
|
121
|
+
)
|
|
122
|
+
for _ in range(num_tasks)
|
|
123
|
+
]
|
|
124
|
+
)
|
|
92
125
|
shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
|
|
93
|
-
self.shared_gate = nn.Sequential(
|
|
126
|
+
self.shared_gate = nn.Sequential(
|
|
127
|
+
nn.Linear(input_dim, shared_gate_expert_num),
|
|
128
|
+
nn.Softmax(dim=1),
|
|
129
|
+
)
|
|
94
130
|
|
|
95
131
|
self.num_tasks = num_tasks
|
|
96
132
|
|
|
@@ -98,7 +134,9 @@ class CGCLayer(nn.Module):
|
|
|
98
134
|
self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
|
|
99
135
|
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
100
136
|
if len(task_inputs) != self.num_tasks:
|
|
101
|
-
raise ValueError(
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}"
|
|
139
|
+
)
|
|
102
140
|
|
|
103
141
|
shared_outputs = [expert(shared_input) for expert in self.shared_experts]
|
|
104
142
|
shared_stack = torch.stack(shared_outputs, dim=0) # [num_shared, B, D]
|
|
@@ -108,7 +146,7 @@ class CGCLayer(nn.Module):
|
|
|
108
146
|
|
|
109
147
|
for task_idx in range(self.num_tasks):
|
|
110
148
|
task_input = task_inputs[task_idx]
|
|
111
|
-
task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]]
|
|
149
|
+
task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
|
|
112
150
|
all_specific_for_shared.extend(task_specific_outputs)
|
|
113
151
|
specific_stack = torch.stack(task_specific_outputs, dim=0)
|
|
114
152
|
|
|
@@ -139,7 +177,9 @@ class CGCLayer(nn.Module):
|
|
|
139
177
|
) -> list[dict]:
|
|
140
178
|
if isinstance(params, list):
|
|
141
179
|
if len(params) != num_tasks:
|
|
142
|
-
raise ValueError(
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks})."
|
|
182
|
+
)
|
|
143
183
|
return [p.copy() for p in params]
|
|
144
184
|
return [params.copy() for _ in range(num_tasks)]
|
|
145
185
|
|
|
@@ -147,57 +187,63 @@ class CGCLayer(nn.Module):
|
|
|
147
187
|
class PLE(BaseModel):
|
|
148
188
|
"""
|
|
149
189
|
Progressive Layered Extraction
|
|
150
|
-
|
|
190
|
+
|
|
151
191
|
PLE is an advanced multi-task learning model that extends MMOE by introducing
|
|
152
192
|
both task-specific experts and shared experts at each level. It uses a progressive
|
|
153
193
|
routing mechanism where experts from level k feed into gates at level k+1.
|
|
154
194
|
This design better captures task-specific and shared information progressively.
|
|
155
195
|
"""
|
|
156
|
-
|
|
196
|
+
|
|
157
197
|
@property
|
|
158
198
|
def model_name(self):
|
|
159
199
|
return "PLE"
|
|
160
200
|
|
|
161
201
|
@property
|
|
162
|
-
def
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
202
|
+
def default_task(self):
|
|
203
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
204
|
+
if num_tasks is not None and num_tasks > 0:
|
|
205
|
+
return ["binary"] * num_tasks
|
|
206
|
+
return ["binary"]
|
|
207
|
+
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
dense_features: list[DenseFeature],
|
|
211
|
+
sparse_features: list[SparseFeature],
|
|
212
|
+
sequence_features: list[SequenceFeature],
|
|
213
|
+
shared_expert_params: dict,
|
|
214
|
+
specific_expert_params: dict | list[dict],
|
|
215
|
+
num_shared_experts: int,
|
|
216
|
+
num_specific_experts: int,
|
|
217
|
+
num_levels: int,
|
|
218
|
+
tower_params_list: list[dict],
|
|
219
|
+
target: list[str],
|
|
220
|
+
task: str | list[str] | None = None,
|
|
221
|
+
optimizer: str = "adam",
|
|
222
|
+
optimizer_params: dict | None = None,
|
|
223
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
224
|
+
loss_params: dict | list[dict] | None = None,
|
|
225
|
+
device: str = "cpu",
|
|
226
|
+
embedding_l1_reg=1e-6,
|
|
227
|
+
dense_l1_reg=1e-5,
|
|
228
|
+
embedding_l2_reg=1e-5,
|
|
229
|
+
dense_l2_reg=1e-4,
|
|
230
|
+
**kwargs,
|
|
231
|
+
):
|
|
232
|
+
|
|
233
|
+
self.num_tasks = len(target)
|
|
234
|
+
|
|
188
235
|
super(PLE, self).__init__(
|
|
189
236
|
dense_features=dense_features,
|
|
190
237
|
sparse_features=sparse_features,
|
|
191
238
|
sequence_features=sequence_features,
|
|
192
239
|
target=target,
|
|
193
|
-
task=task,
|
|
240
|
+
task=task or self.default_task,
|
|
194
241
|
device=device,
|
|
195
242
|
embedding_l1_reg=embedding_l1_reg,
|
|
196
243
|
dense_l1_reg=dense_l1_reg,
|
|
197
244
|
embedding_l2_reg=embedding_l2_reg,
|
|
198
245
|
dense_l2_reg=dense_l2_reg,
|
|
199
|
-
|
|
200
|
-
**kwargs
|
|
246
|
+
**kwargs,
|
|
201
247
|
)
|
|
202
248
|
|
|
203
249
|
self.loss = loss
|
|
@@ -211,7 +257,9 @@ class PLE(BaseModel):
|
|
|
211
257
|
if optimizer_params is None:
|
|
212
258
|
optimizer_params = {}
|
|
213
259
|
if len(tower_params_list) != self.num_tasks:
|
|
214
|
-
raise ValueError(
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
262
|
+
)
|
|
215
263
|
# Embedding layer
|
|
216
264
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
217
265
|
|
|
@@ -220,13 +268,13 @@ class PLE(BaseModel):
|
|
|
220
268
|
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
221
269
|
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
222
270
|
# input_dim = emb_dim_total + dense_input_dim
|
|
223
|
-
|
|
271
|
+
|
|
224
272
|
# Get expert output dimension
|
|
225
|
-
if
|
|
226
|
-
expert_output_dim = shared_expert_params[
|
|
273
|
+
if "dims" in shared_expert_params and len(shared_expert_params["dims"]) > 0:
|
|
274
|
+
expert_output_dim = shared_expert_params["dims"][-1]
|
|
227
275
|
else:
|
|
228
276
|
expert_output_dim = input_dim
|
|
229
|
-
|
|
277
|
+
|
|
230
278
|
# Build CGC layers
|
|
231
279
|
self.cgc_layers = nn.ModuleList()
|
|
232
280
|
for level in range(num_levels):
|
|
@@ -241,16 +289,25 @@ class PLE(BaseModel):
|
|
|
241
289
|
)
|
|
242
290
|
self.cgc_layers.append(cgc_layer)
|
|
243
291
|
expert_output_dim = cgc_layer.output_dim
|
|
244
|
-
|
|
292
|
+
|
|
245
293
|
# Task-specific towers
|
|
246
294
|
self.towers = nn.ModuleList()
|
|
247
295
|
for tower_params in tower_params_list:
|
|
248
296
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
249
297
|
self.towers.append(tower)
|
|
250
|
-
self.prediction_layer = PredictionLayer(
|
|
298
|
+
self.prediction_layer = PredictionLayer(
|
|
299
|
+
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
300
|
+
)
|
|
251
301
|
# Register regularization weights
|
|
252
|
-
self.register_regularization_weights(
|
|
253
|
-
|
|
302
|
+
self.register_regularization_weights(
|
|
303
|
+
embedding_attr="embedding", include_modules=["cgc_layers", "towers"]
|
|
304
|
+
)
|
|
305
|
+
self.compile(
|
|
306
|
+
optimizer=optimizer,
|
|
307
|
+
optimizer_params=optimizer_params,
|
|
308
|
+
loss=self.loss,
|
|
309
|
+
loss_params=loss_params,
|
|
310
|
+
)
|
|
254
311
|
|
|
255
312
|
def forward(self, x):
|
|
256
313
|
# Get all embeddings and flatten
|