nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -52,87 +52,110 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
52
52
|
class ESMM(BaseModel):
|
|
53
53
|
"""
|
|
54
54
|
Entire Space Multi-Task Model
|
|
55
|
-
|
|
55
|
+
|
|
56
56
|
ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
|
|
57
57
|
- CTR task: P(click | impression)
|
|
58
58
|
- CVR task: P(conversion | click)
|
|
59
59
|
- CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
This design addresses the sample selection bias and data sparsity issues in CVR modeling.
|
|
62
62
|
"""
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
@property
|
|
65
65
|
def model_name(self):
|
|
66
66
|
return "ESMM"
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
@property
|
|
69
69
|
def default_task(self):
|
|
70
|
-
return [
|
|
71
|
-
|
|
72
|
-
def __init__(
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
70
|
+
return ["binary", "binary"]
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
dense_features: list[DenseFeature],
|
|
75
|
+
sparse_features: list[SparseFeature],
|
|
76
|
+
sequence_features: list[SequenceFeature],
|
|
77
|
+
ctr_params: dict,
|
|
78
|
+
cvr_params: dict,
|
|
79
|
+
target: list[str] | None = None, # Note: ctcvr = ctr * cvr
|
|
80
|
+
task: list[str] | None = None,
|
|
81
|
+
optimizer: str = "adam",
|
|
82
|
+
optimizer_params: dict | None = None,
|
|
83
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
84
|
+
loss_params: dict | list[dict] | None = None,
|
|
85
|
+
device: str = "cpu",
|
|
86
|
+
embedding_l1_reg=1e-6,
|
|
87
|
+
dense_l1_reg=1e-5,
|
|
88
|
+
embedding_l2_reg=1e-5,
|
|
89
|
+
dense_l2_reg=1e-4,
|
|
90
|
+
**kwargs,
|
|
91
|
+
):
|
|
92
|
+
|
|
93
|
+
target = target or ["ctr", "ctcvr"]
|
|
94
|
+
optimizer_params = optimizer_params or {}
|
|
95
|
+
if loss is None:
|
|
96
|
+
loss = "bce"
|
|
97
|
+
|
|
92
98
|
if len(target) != 2:
|
|
93
|
-
raise ValueError(
|
|
94
|
-
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
self.num_tasks = len(target)
|
|
104
|
+
resolved_task = task
|
|
105
|
+
if resolved_task is None:
|
|
106
|
+
resolved_task = self.default_task
|
|
107
|
+
elif isinstance(resolved_task, str):
|
|
108
|
+
resolved_task = [resolved_task] * self.num_tasks
|
|
109
|
+
elif len(resolved_task) == 1 and self.num_tasks > 1:
|
|
110
|
+
resolved_task = resolved_task * self.num_tasks
|
|
111
|
+
elif len(resolved_task) != self.num_tasks:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
|
|
114
|
+
)
|
|
115
|
+
# resolved_task is now guaranteed to be a list[str]
|
|
116
|
+
|
|
95
117
|
super(ESMM, self).__init__(
|
|
96
118
|
dense_features=dense_features,
|
|
97
119
|
sparse_features=sparse_features,
|
|
98
120
|
sequence_features=sequence_features,
|
|
99
121
|
target=target,
|
|
100
|
-
task=
|
|
122
|
+
task=resolved_task, # Both CTR and CTCVR are binary classification
|
|
101
123
|
device=device,
|
|
102
124
|
embedding_l1_reg=embedding_l1_reg,
|
|
103
125
|
dense_l1_reg=dense_l1_reg,
|
|
104
126
|
embedding_l2_reg=embedding_l2_reg,
|
|
105
127
|
dense_l2_reg=dense_l2_reg,
|
|
106
|
-
**kwargs
|
|
128
|
+
**kwargs,
|
|
107
129
|
)
|
|
108
130
|
|
|
109
131
|
self.loss = loss
|
|
110
|
-
|
|
111
|
-
self.loss = "bce"
|
|
112
|
-
|
|
113
|
-
# All features
|
|
114
|
-
self.all_features = dense_features + sparse_features + sequence_features
|
|
115
|
-
# Shared embedding layer
|
|
132
|
+
|
|
116
133
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
117
|
-
input_dim = self.embedding.input_dim
|
|
118
|
-
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
119
|
-
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
120
|
-
# input_dim = emb_dim_total + dense_input_dim
|
|
134
|
+
input_dim = self.embedding.input_dim
|
|
121
135
|
|
|
122
136
|
# CTR tower
|
|
123
137
|
self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
|
|
124
|
-
|
|
138
|
+
|
|
125
139
|
# CVR tower
|
|
126
140
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
127
|
-
self.prediction_layer = PredictionLayer(
|
|
141
|
+
self.prediction_layer = PredictionLayer(
|
|
142
|
+
task_type=self.default_task, task_dims=[1, 1]
|
|
143
|
+
)
|
|
128
144
|
# Register regularization weights
|
|
129
|
-
self.register_regularization_weights(
|
|
130
|
-
|
|
145
|
+
self.register_regularization_weights(
|
|
146
|
+
embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
|
|
147
|
+
)
|
|
148
|
+
self.compile(
|
|
149
|
+
optimizer=optimizer,
|
|
150
|
+
optimizer_params=optimizer_params,
|
|
151
|
+
loss=loss,
|
|
152
|
+
loss_params=loss_params,
|
|
153
|
+
)
|
|
131
154
|
|
|
132
155
|
def forward(self, x):
|
|
133
156
|
# Get all embeddings and flatten
|
|
134
157
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
135
|
-
|
|
158
|
+
|
|
136
159
|
# CTR prediction: P(click | impression)
|
|
137
160
|
ctr_logit = self.ctr_tower(input_flat) # [B, 1]
|
|
138
161
|
cvr_logit = self.cvr_tower(input_flat) # [B, 1]
|
|
@@ -140,7 +163,7 @@ class ESMM(BaseModel):
|
|
|
140
163
|
preds = self.prediction_layer(logits)
|
|
141
164
|
ctr, cvr = preds.chunk(2, dim=1)
|
|
142
165
|
ctcvr = ctr * cvr # [B, 1]
|
|
143
|
-
|
|
166
|
+
|
|
144
167
|
# Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
|
|
145
168
|
y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
|
|
146
169
|
return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
|
|
@@ -53,13 +53,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
53
53
|
class MMOE(BaseModel):
|
|
54
54
|
"""
|
|
55
55
|
Multi-gate Mixture-of-Experts
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
MMOE improves upon shared-bottom architecture by using multiple expert networks
|
|
58
58
|
and task-specific gating networks. Each task has its own gate that learns to
|
|
59
59
|
weight the contributions of different experts, allowing for both task-specific
|
|
60
60
|
and shared representations.
|
|
61
61
|
"""
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
@property
|
|
64
64
|
def model_name(self):
|
|
65
65
|
return "MMOE"
|
|
@@ -68,116 +68,154 @@ class MMOE(BaseModel):
|
|
|
68
68
|
def default_task(self):
|
|
69
69
|
num_tasks = getattr(self, "num_tasks", None)
|
|
70
70
|
if num_tasks is not None and num_tasks > 0:
|
|
71
|
-
return [
|
|
72
|
-
return [
|
|
73
|
-
|
|
74
|
-
def __init__(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
71
|
+
return ["binary"] * num_tasks
|
|
72
|
+
return ["binary"]
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
dense_features: list[DenseFeature] | None = None,
|
|
77
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
78
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
79
|
+
expert_params: dict | None = None,
|
|
80
|
+
num_experts: int = 3,
|
|
81
|
+
tower_params_list: list[dict] | None = None,
|
|
82
|
+
target: list[str] | str | None = None,
|
|
83
|
+
task: str | list[str] = "binary",
|
|
84
|
+
optimizer: str = "adam",
|
|
85
|
+
optimizer_params: dict | None = None,
|
|
86
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
87
|
+
loss_params: dict | list[dict] | None = None,
|
|
88
|
+
device: str = "cpu",
|
|
89
|
+
embedding_l1_reg=1e-6,
|
|
90
|
+
dense_l1_reg=1e-5,
|
|
91
|
+
embedding_l2_reg=1e-5,
|
|
92
|
+
dense_l2_reg=1e-4,
|
|
93
|
+
**kwargs,
|
|
94
|
+
):
|
|
95
|
+
|
|
96
|
+
dense_features = dense_features or []
|
|
97
|
+
sparse_features = sparse_features or []
|
|
98
|
+
sequence_features = sequence_features or []
|
|
99
|
+
expert_params = expert_params or {}
|
|
100
|
+
tower_params_list = tower_params_list or []
|
|
101
|
+
optimizer_params = optimizer_params or {}
|
|
102
|
+
if loss is None:
|
|
103
|
+
loss = "bce"
|
|
104
|
+
if target is None:
|
|
105
|
+
target = []
|
|
106
|
+
elif isinstance(target, str):
|
|
107
|
+
target = [target]
|
|
108
|
+
|
|
109
|
+
self.num_tasks = len(target) if target else 1
|
|
110
|
+
|
|
111
|
+
resolved_task = task
|
|
112
|
+
if resolved_task is None:
|
|
113
|
+
resolved_task = self.default_task
|
|
114
|
+
elif isinstance(resolved_task, str):
|
|
115
|
+
resolved_task = [resolved_task] * self.num_tasks
|
|
116
|
+
elif len(resolved_task) == 1 and self.num_tasks > 1:
|
|
117
|
+
resolved_task = resolved_task * self.num_tasks
|
|
118
|
+
elif len(resolved_task) != self.num_tasks:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
|
|
121
|
+
)
|
|
95
122
|
|
|
96
123
|
super(MMOE, self).__init__(
|
|
97
124
|
dense_features=dense_features,
|
|
98
125
|
sparse_features=sparse_features,
|
|
99
126
|
sequence_features=sequence_features,
|
|
100
127
|
target=target,
|
|
101
|
-
task=
|
|
128
|
+
task=resolved_task,
|
|
102
129
|
device=device,
|
|
103
130
|
embedding_l1_reg=embedding_l1_reg,
|
|
104
131
|
dense_l1_reg=dense_l1_reg,
|
|
105
132
|
embedding_l2_reg=embedding_l2_reg,
|
|
106
133
|
dense_l2_reg=dense_l2_reg,
|
|
107
|
-
**kwargs
|
|
134
|
+
**kwargs,
|
|
108
135
|
)
|
|
109
136
|
|
|
110
137
|
self.loss = loss
|
|
111
|
-
|
|
112
|
-
self.loss = "bce"
|
|
113
|
-
|
|
138
|
+
|
|
114
139
|
# Number of tasks and experts
|
|
115
140
|
self.num_tasks = len(target)
|
|
116
141
|
self.num_experts = num_experts
|
|
117
|
-
|
|
142
|
+
|
|
118
143
|
if len(tower_params_list) != self.num_tasks:
|
|
119
|
-
raise ValueError(
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
146
|
+
)
|
|
120
147
|
|
|
121
|
-
self.all_features = dense_features + sparse_features + sequence_features
|
|
122
148
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
123
149
|
input_dim = self.embedding.input_dim
|
|
124
|
-
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
125
|
-
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
126
|
-
# input_dim = emb_dim_total + dense_input_dim
|
|
127
150
|
|
|
128
151
|
# Expert networks (shared by all tasks)
|
|
129
152
|
self.experts = nn.ModuleList()
|
|
130
153
|
for _ in range(num_experts):
|
|
131
154
|
expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
|
|
132
155
|
self.experts.append(expert)
|
|
133
|
-
|
|
156
|
+
|
|
134
157
|
# Get expert output dimension
|
|
135
|
-
if
|
|
136
|
-
expert_output_dim = expert_params[
|
|
158
|
+
if "dims" in expert_params and len(expert_params["dims"]) > 0:
|
|
159
|
+
expert_output_dim = expert_params["dims"][-1]
|
|
137
160
|
else:
|
|
138
161
|
expert_output_dim = input_dim
|
|
139
|
-
|
|
162
|
+
|
|
140
163
|
# Task-specific gates
|
|
141
164
|
self.gates = nn.ModuleList()
|
|
142
165
|
for _ in range(self.num_tasks):
|
|
143
166
|
gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
|
|
144
167
|
self.gates.append(gate)
|
|
145
|
-
|
|
168
|
+
|
|
146
169
|
# Task-specific towers
|
|
147
170
|
self.towers = nn.ModuleList()
|
|
148
171
|
for tower_params in tower_params_list:
|
|
149
172
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
150
173
|
self.towers.append(tower)
|
|
151
|
-
self.prediction_layer = PredictionLayer(
|
|
174
|
+
self.prediction_layer = PredictionLayer(
|
|
175
|
+
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
176
|
+
)
|
|
152
177
|
# Register regularization weights
|
|
153
|
-
self.register_regularization_weights(
|
|
154
|
-
|
|
178
|
+
self.register_regularization_weights(
|
|
179
|
+
embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
|
|
180
|
+
)
|
|
181
|
+
self.compile(
|
|
182
|
+
optimizer=optimizer,
|
|
183
|
+
optimizer_params=optimizer_params,
|
|
184
|
+
loss=self.loss,
|
|
185
|
+
loss_params=loss_params,
|
|
186
|
+
)
|
|
155
187
|
|
|
156
188
|
def forward(self, x):
|
|
157
189
|
# Get all embeddings and flatten
|
|
158
190
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
159
|
-
|
|
191
|
+
|
|
160
192
|
# Expert outputs: [num_experts, B, expert_dim]
|
|
161
193
|
expert_outputs = [expert(input_flat) for expert in self.experts]
|
|
162
|
-
expert_outputs = torch.stack(
|
|
163
|
-
|
|
194
|
+
expert_outputs = torch.stack(
|
|
195
|
+
expert_outputs, dim=0
|
|
196
|
+
) # [num_experts, B, expert_dim]
|
|
197
|
+
|
|
164
198
|
# Task-specific processing
|
|
165
199
|
task_outputs = []
|
|
166
200
|
for task_idx in range(self.num_tasks):
|
|
167
201
|
# Gate weights for this task: [B, num_experts]
|
|
168
202
|
gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
|
|
169
|
-
|
|
203
|
+
|
|
170
204
|
# Weighted sum of expert outputs
|
|
171
205
|
# gate_weights: [B, num_experts, 1]
|
|
172
206
|
# expert_outputs: [num_experts, B, expert_dim]
|
|
173
207
|
gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
|
|
174
|
-
expert_outputs_t = expert_outputs.permute(
|
|
175
|
-
|
|
176
|
-
|
|
208
|
+
expert_outputs_t = expert_outputs.permute(
|
|
209
|
+
1, 0, 2
|
|
210
|
+
) # [B, num_experts, expert_dim]
|
|
211
|
+
gated_output = torch.sum(
|
|
212
|
+
gate_weights * expert_outputs_t, dim=1
|
|
213
|
+
) # [B, expert_dim]
|
|
214
|
+
|
|
177
215
|
# Tower output
|
|
178
216
|
tower_output = self.towers[task_idx](gated_output) # [B, 1]
|
|
179
217
|
task_outputs.append(tower_output)
|
|
180
|
-
|
|
218
|
+
|
|
181
219
|
# Stack outputs: [B, num_tasks]
|
|
182
220
|
y = torch.cat(task_outputs, dim=1)
|
|
183
221
|
return self.prediction_layer(y)
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -51,6 +51,8 @@ import torch.nn as nn
|
|
|
51
51
|
from nextrec.basic.model import BaseModel
|
|
52
52
|
from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
|
|
53
53
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
54
|
+
from nextrec.utils.model import get_mlp_output_dim
|
|
55
|
+
|
|
54
56
|
|
|
55
57
|
class CGCLayer(nn.Module):
|
|
56
58
|
"""
|
|
@@ -71,26 +73,61 @@ class CGCLayer(nn.Module):
|
|
|
71
73
|
if num_tasks < 1:
|
|
72
74
|
raise ValueError("num_tasks must be >= 1")
|
|
73
75
|
|
|
74
|
-
specific_params_list = self.
|
|
76
|
+
specific_params_list = self.normalize_specific_params(
|
|
77
|
+
specific_expert_params, num_tasks
|
|
78
|
+
)
|
|
75
79
|
|
|
76
|
-
self.output_dim =
|
|
77
|
-
specific_dims = [
|
|
80
|
+
self.output_dim = get_mlp_output_dim(shared_expert_params, input_dim)
|
|
81
|
+
specific_dims = [
|
|
82
|
+
get_mlp_output_dim(params, input_dim) for params in specific_params_list
|
|
83
|
+
]
|
|
78
84
|
dims_set = set(specific_dims + [self.output_dim])
|
|
79
85
|
if len(dims_set) != 1:
|
|
80
|
-
raise ValueError(
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"Shared/specific expert output dims must match, got {dims_set}"
|
|
88
|
+
)
|
|
81
89
|
|
|
82
90
|
# experts
|
|
83
|
-
self.shared_experts = nn.ModuleList(
|
|
91
|
+
self.shared_experts = nn.ModuleList(
|
|
92
|
+
[
|
|
93
|
+
MLP(
|
|
94
|
+
input_dim=input_dim,
|
|
95
|
+
output_layer=False,
|
|
96
|
+
**shared_expert_params,
|
|
97
|
+
)
|
|
98
|
+
for _ in range(num_shared_experts)
|
|
99
|
+
]
|
|
100
|
+
)
|
|
84
101
|
self.specific_experts = nn.ModuleList()
|
|
85
102
|
for params in specific_params_list:
|
|
86
|
-
task_experts = nn.ModuleList(
|
|
103
|
+
task_experts = nn.ModuleList(
|
|
104
|
+
[
|
|
105
|
+
MLP(
|
|
106
|
+
input_dim=input_dim,
|
|
107
|
+
output_layer=False,
|
|
108
|
+
**params,
|
|
109
|
+
)
|
|
110
|
+
for _ in range(num_specific_experts)
|
|
111
|
+
]
|
|
112
|
+
)
|
|
87
113
|
self.specific_experts.append(task_experts)
|
|
88
114
|
|
|
89
115
|
# gates
|
|
90
116
|
task_gate_expert_num = num_shared_experts + num_specific_experts
|
|
91
|
-
self.task_gates = nn.ModuleList(
|
|
117
|
+
self.task_gates = nn.ModuleList(
|
|
118
|
+
[
|
|
119
|
+
nn.Sequential(
|
|
120
|
+
nn.Linear(input_dim, task_gate_expert_num),
|
|
121
|
+
nn.Softmax(dim=1),
|
|
122
|
+
)
|
|
123
|
+
for _ in range(num_tasks)
|
|
124
|
+
]
|
|
125
|
+
)
|
|
92
126
|
shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
|
|
93
|
-
self.shared_gate = nn.Sequential(
|
|
127
|
+
self.shared_gate = nn.Sequential(
|
|
128
|
+
nn.Linear(input_dim, shared_gate_expert_num),
|
|
129
|
+
nn.Softmax(dim=1),
|
|
130
|
+
)
|
|
94
131
|
|
|
95
132
|
self.num_tasks = num_tasks
|
|
96
133
|
|
|
@@ -98,7 +135,9 @@ class CGCLayer(nn.Module):
|
|
|
98
135
|
self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
|
|
99
136
|
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
100
137
|
if len(task_inputs) != self.num_tasks:
|
|
101
|
-
raise ValueError(
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}"
|
|
140
|
+
)
|
|
102
141
|
|
|
103
142
|
shared_outputs = [expert(shared_input) for expert in self.shared_experts]
|
|
104
143
|
shared_stack = torch.stack(shared_outputs, dim=0) # [num_shared, B, D]
|
|
@@ -108,7 +147,7 @@ class CGCLayer(nn.Module):
|
|
|
108
147
|
|
|
109
148
|
for task_idx in range(self.num_tasks):
|
|
110
149
|
task_input = task_inputs[task_idx]
|
|
111
|
-
task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]]
|
|
150
|
+
task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
|
|
112
151
|
all_specific_for_shared.extend(task_specific_outputs)
|
|
113
152
|
specific_stack = torch.stack(task_specific_outputs, dim=0)
|
|
114
153
|
|
|
@@ -127,19 +166,14 @@ class CGCLayer(nn.Module):
|
|
|
127
166
|
return new_task_fea, new_shared
|
|
128
167
|
|
|
129
168
|
@staticmethod
|
|
130
|
-
def
|
|
131
|
-
dims = params.get("dims")
|
|
132
|
-
if dims:
|
|
133
|
-
return dims[-1]
|
|
134
|
-
return fallback
|
|
135
|
-
|
|
136
|
-
@staticmethod
|
|
137
|
-
def _normalize_specific_params(
|
|
169
|
+
def normalize_specific_params(
|
|
138
170
|
params: dict | list[dict], num_tasks: int
|
|
139
171
|
) -> list[dict]:
|
|
140
172
|
if isinstance(params, list):
|
|
141
173
|
if len(params) != num_tasks:
|
|
142
|
-
raise ValueError(
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks})."
|
|
176
|
+
)
|
|
143
177
|
return [p.copy() for p in params]
|
|
144
178
|
return [params.copy() for _ in range(num_tasks)]
|
|
145
179
|
|
|
@@ -147,13 +181,13 @@ class CGCLayer(nn.Module):
|
|
|
147
181
|
class PLE(BaseModel):
|
|
148
182
|
"""
|
|
149
183
|
Progressive Layered Extraction
|
|
150
|
-
|
|
184
|
+
|
|
151
185
|
PLE is an advanced multi-task learning model that extends MMOE by introducing
|
|
152
186
|
both task-specific experts and shared experts at each level. It uses a progressive
|
|
153
187
|
routing mechanism where experts from level k feed into gates at level k+1.
|
|
154
188
|
This design better captures task-specific and shared information progressively.
|
|
155
189
|
"""
|
|
156
|
-
|
|
190
|
+
|
|
157
191
|
@property
|
|
158
192
|
def model_name(self):
|
|
159
193
|
return "PLE"
|
|
@@ -162,46 +196,60 @@ class PLE(BaseModel):
|
|
|
162
196
|
def default_task(self):
|
|
163
197
|
num_tasks = getattr(self, "num_tasks", None)
|
|
164
198
|
if num_tasks is not None and num_tasks > 0:
|
|
165
|
-
return [
|
|
166
|
-
return [
|
|
167
|
-
|
|
168
|
-
def __init__(
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
199
|
+
return ["binary"] * num_tasks
|
|
200
|
+
return ["binary"]
|
|
201
|
+
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
dense_features: list[DenseFeature],
|
|
205
|
+
sparse_features: list[SparseFeature],
|
|
206
|
+
sequence_features: list[SequenceFeature],
|
|
207
|
+
shared_expert_params: dict,
|
|
208
|
+
specific_expert_params: dict | list[dict],
|
|
209
|
+
num_shared_experts: int,
|
|
210
|
+
num_specific_experts: int,
|
|
211
|
+
num_levels: int,
|
|
212
|
+
tower_params_list: list[dict],
|
|
213
|
+
target: list[str],
|
|
214
|
+
task: str | list[str] | None = None,
|
|
215
|
+
optimizer: str = "adam",
|
|
216
|
+
optimizer_params: dict | None = None,
|
|
217
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
218
|
+
loss_params: dict | list[dict] | None = None,
|
|
219
|
+
device: str = "cpu",
|
|
220
|
+
embedding_l1_reg=1e-6,
|
|
221
|
+
dense_l1_reg=1e-5,
|
|
222
|
+
embedding_l2_reg=1e-5,
|
|
223
|
+
dense_l2_reg=1e-4,
|
|
224
|
+
**kwargs,
|
|
225
|
+
):
|
|
226
|
+
|
|
191
227
|
self.num_tasks = len(target)
|
|
192
228
|
|
|
229
|
+
resolved_task = task
|
|
230
|
+
if resolved_task is None:
|
|
231
|
+
resolved_task = self.default_task
|
|
232
|
+
elif isinstance(resolved_task, str):
|
|
233
|
+
resolved_task = [resolved_task] * self.num_tasks
|
|
234
|
+
elif len(resolved_task) == 1 and self.num_tasks > 1:
|
|
235
|
+
resolved_task = resolved_task * self.num_tasks
|
|
236
|
+
elif len(resolved_task) != self.num_tasks:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
|
|
239
|
+
)
|
|
240
|
+
|
|
193
241
|
super(PLE, self).__init__(
|
|
194
242
|
dense_features=dense_features,
|
|
195
243
|
sparse_features=sparse_features,
|
|
196
244
|
sequence_features=sequence_features,
|
|
197
245
|
target=target,
|
|
198
|
-
task=
|
|
246
|
+
task=resolved_task,
|
|
199
247
|
device=device,
|
|
200
248
|
embedding_l1_reg=embedding_l1_reg,
|
|
201
249
|
dense_l1_reg=dense_l1_reg,
|
|
202
250
|
embedding_l2_reg=embedding_l2_reg,
|
|
203
251
|
dense_l2_reg=dense_l2_reg,
|
|
204
|
-
**kwargs
|
|
252
|
+
**kwargs,
|
|
205
253
|
)
|
|
206
254
|
|
|
207
255
|
self.loss = loss
|
|
@@ -215,7 +263,9 @@ class PLE(BaseModel):
|
|
|
215
263
|
if optimizer_params is None:
|
|
216
264
|
optimizer_params = {}
|
|
217
265
|
if len(tower_params_list) != self.num_tasks:
|
|
218
|
-
raise ValueError(
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
268
|
+
)
|
|
219
269
|
# Embedding layer
|
|
220
270
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
221
271
|
|
|
@@ -224,13 +274,13 @@ class PLE(BaseModel):
|
|
|
224
274
|
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
225
275
|
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
226
276
|
# input_dim = emb_dim_total + dense_input_dim
|
|
227
|
-
|
|
277
|
+
|
|
228
278
|
# Get expert output dimension
|
|
229
|
-
if
|
|
230
|
-
expert_output_dim = shared_expert_params[
|
|
279
|
+
if "dims" in shared_expert_params and len(shared_expert_params["dims"]) > 0:
|
|
280
|
+
expert_output_dim = shared_expert_params["dims"][-1]
|
|
231
281
|
else:
|
|
232
282
|
expert_output_dim = input_dim
|
|
233
|
-
|
|
283
|
+
|
|
234
284
|
# Build CGC layers
|
|
235
285
|
self.cgc_layers = nn.ModuleList()
|
|
236
286
|
for level in range(num_levels):
|
|
@@ -245,16 +295,25 @@ class PLE(BaseModel):
|
|
|
245
295
|
)
|
|
246
296
|
self.cgc_layers.append(cgc_layer)
|
|
247
297
|
expert_output_dim = cgc_layer.output_dim
|
|
248
|
-
|
|
298
|
+
|
|
249
299
|
# Task-specific towers
|
|
250
300
|
self.towers = nn.ModuleList()
|
|
251
301
|
for tower_params in tower_params_list:
|
|
252
302
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
253
303
|
self.towers.append(tower)
|
|
254
|
-
self.prediction_layer = PredictionLayer(
|
|
304
|
+
self.prediction_layer = PredictionLayer(
|
|
305
|
+
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
306
|
+
)
|
|
255
307
|
# Register regularization weights
|
|
256
|
-
self.register_regularization_weights(
|
|
257
|
-
|
|
308
|
+
self.register_regularization_weights(
|
|
309
|
+
embedding_attr="embedding", include_modules=["cgc_layers", "towers"]
|
|
310
|
+
)
|
|
311
|
+
self.compile(
|
|
312
|
+
optimizer=optimizer,
|
|
313
|
+
optimizer_params=optimizer_params,
|
|
314
|
+
loss=self.loss,
|
|
315
|
+
loss_params=loss_params,
|
|
316
|
+
)
|
|
258
317
|
|
|
259
318
|
def forward(self, x):
|
|
260
319
|
# Get all embeddings and flatten
|