nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +220 -106
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1082 -400
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +272 -95
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +53 -37
- nextrec/models/multi_task/mmoe.py +64 -45
- nextrec/models/multi_task/ple.py +101 -48
- nextrec/models/multi_task/poso.py +113 -36
- nextrec/models/multi_task/share_bottom.py +48 -35
- nextrec/models/ranking/afm.py +72 -37
- nextrec/models/ranking/autoint.py +72 -55
- nextrec/models/ranking/dcn.py +55 -35
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +32 -22
- nextrec/models/ranking/dien.py +155 -99
- nextrec/models/ranking/din.py +85 -57
- nextrec/models/ranking/fibinet.py +52 -32
- nextrec/models/ranking/fm.py +29 -23
- nextrec/models/ranking/masknet.py +91 -29
- nextrec/models/ranking/pnn.py +31 -28
- nextrec/models/ranking/widedeep.py +34 -26
- nextrec/models/ranking/xdeepfm.py +60 -38
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -52,87 +52,103 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
52
52
|
class ESMM(BaseModel):
|
|
53
53
|
"""
|
|
54
54
|
Entire Space Multi-Task Model
|
|
55
|
-
|
|
55
|
+
|
|
56
56
|
ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
|
|
57
57
|
- CTR task: P(click | impression)
|
|
58
58
|
- CVR task: P(conversion | click)
|
|
59
59
|
- CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
This design addresses the sample selection bias and data sparsity issues in CVR modeling.
|
|
62
62
|
"""
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
@property
|
|
65
65
|
def model_name(self):
|
|
66
66
|
return "ESMM"
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
@property
|
|
69
69
|
def default_task(self):
|
|
70
|
-
return [
|
|
71
|
-
|
|
72
|
-
def __init__(
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
70
|
+
return ["binary", "binary"]
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
dense_features: list[DenseFeature],
|
|
75
|
+
sparse_features: list[SparseFeature],
|
|
76
|
+
sequence_features: list[SequenceFeature],
|
|
77
|
+
ctr_params: dict,
|
|
78
|
+
cvr_params: dict,
|
|
79
|
+
target: list[str] = ["ctr", "ctcvr"], # Note: ctcvr = ctr * cvr
|
|
80
|
+
task: list[str] | None = None,
|
|
81
|
+
optimizer: str = "adam",
|
|
82
|
+
optimizer_params: dict = {},
|
|
83
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
84
|
+
loss_params: dict | list[dict] | None = None,
|
|
85
|
+
device: str = "cpu",
|
|
86
|
+
embedding_l1_reg=1e-6,
|
|
87
|
+
dense_l1_reg=1e-5,
|
|
88
|
+
embedding_l2_reg=1e-5,
|
|
89
|
+
dense_l2_reg=1e-4,
|
|
90
|
+
**kwargs,
|
|
91
|
+
):
|
|
92
|
+
|
|
91
93
|
# ESMM requires exactly 2 targets: ctr and ctcvr
|
|
92
94
|
if len(target) != 2:
|
|
93
|
-
raise ValueError(
|
|
94
|
-
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
|
|
97
|
+
)
|
|
98
|
+
|
|
95
99
|
super(ESMM, self).__init__(
|
|
96
100
|
dense_features=dense_features,
|
|
97
101
|
sparse_features=sparse_features,
|
|
98
102
|
sequence_features=sequence_features,
|
|
99
103
|
target=target,
|
|
100
|
-
task=task
|
|
104
|
+
task=task
|
|
105
|
+
or self.default_task, # Both CTR and CTCVR are binary classification
|
|
101
106
|
device=device,
|
|
102
107
|
embedding_l1_reg=embedding_l1_reg,
|
|
103
108
|
dense_l1_reg=dense_l1_reg,
|
|
104
109
|
embedding_l2_reg=embedding_l2_reg,
|
|
105
110
|
dense_l2_reg=dense_l2_reg,
|
|
106
|
-
**kwargs
|
|
111
|
+
**kwargs,
|
|
107
112
|
)
|
|
108
113
|
|
|
109
114
|
self.loss = loss
|
|
110
115
|
if self.loss is None:
|
|
111
116
|
self.loss = "bce"
|
|
112
|
-
|
|
117
|
+
|
|
113
118
|
# All features
|
|
114
119
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
115
120
|
# Shared embedding layer
|
|
116
121
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
117
|
-
input_dim =
|
|
122
|
+
input_dim = (
|
|
123
|
+
self.embedding.input_dim
|
|
124
|
+
) # Calculate input dimension, better way than below
|
|
118
125
|
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
119
126
|
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
120
127
|
# input_dim = emb_dim_total + dense_input_dim
|
|
121
128
|
|
|
122
129
|
# CTR tower
|
|
123
130
|
self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
|
|
124
|
-
|
|
131
|
+
|
|
125
132
|
# CVR tower
|
|
126
133
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
127
|
-
self.prediction_layer = PredictionLayer(
|
|
134
|
+
self.prediction_layer = PredictionLayer(
|
|
135
|
+
task_type=self.default_task, task_dims=[1, 1]
|
|
136
|
+
)
|
|
128
137
|
# Register regularization weights
|
|
129
|
-
self.register_regularization_weights(
|
|
130
|
-
|
|
138
|
+
self.register_regularization_weights(
|
|
139
|
+
embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
|
|
140
|
+
)
|
|
141
|
+
self.compile(
|
|
142
|
+
optimizer=optimizer,
|
|
143
|
+
optimizer_params=optimizer_params,
|
|
144
|
+
loss=loss,
|
|
145
|
+
loss_params=loss_params,
|
|
146
|
+
)
|
|
131
147
|
|
|
132
148
|
def forward(self, x):
|
|
133
149
|
# Get all embeddings and flatten
|
|
134
150
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
135
|
-
|
|
151
|
+
|
|
136
152
|
# CTR prediction: P(click | impression)
|
|
137
153
|
ctr_logit = self.ctr_tower(input_flat) # [B, 1]
|
|
138
154
|
cvr_logit = self.cvr_tower(input_flat) # [B, 1]
|
|
@@ -140,7 +156,7 @@ class ESMM(BaseModel):
|
|
|
140
156
|
preds = self.prediction_layer(logits)
|
|
141
157
|
ctr, cvr = preds.chunk(2, dim=1)
|
|
142
158
|
ctcvr = ctr * cvr # [B, 1]
|
|
143
|
-
|
|
159
|
+
|
|
144
160
|
# Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
|
|
145
161
|
y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
|
|
146
162
|
return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
|
|
@@ -53,13 +53,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
53
53
|
class MMOE(BaseModel):
|
|
54
54
|
"""
|
|
55
55
|
Multi-gate Mixture-of-Experts
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
MMOE improves upon shared-bottom architecture by using multiple expert networks
|
|
58
58
|
and task-specific gating networks. Each task has its own gate that learns to
|
|
59
59
|
weight the contributions of different experts, allowing for both task-specific
|
|
60
60
|
and shared representations.
|
|
61
61
|
"""
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
@property
|
|
64
64
|
def model_name(self):
|
|
65
65
|
return "MMOE"
|
|
@@ -68,29 +68,31 @@ class MMOE(BaseModel):
|
|
|
68
68
|
def default_task(self):
|
|
69
69
|
num_tasks = getattr(self, "num_tasks", None)
|
|
70
70
|
if num_tasks is not None and num_tasks > 0:
|
|
71
|
-
return [
|
|
72
|
-
return [
|
|
73
|
-
|
|
74
|
-
def __init__(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
71
|
+
return ["binary"] * num_tasks
|
|
72
|
+
return ["binary"]
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
dense_features: list[DenseFeature] = [],
|
|
77
|
+
sparse_features: list[SparseFeature] = [],
|
|
78
|
+
sequence_features: list[SequenceFeature] = [],
|
|
79
|
+
expert_params: dict = {},
|
|
80
|
+
num_experts: int = 3,
|
|
81
|
+
tower_params_list: list[dict] = [],
|
|
82
|
+
target: list[str] = [],
|
|
83
|
+
task: str | list[str] | None = None,
|
|
84
|
+
optimizer: str = "adam",
|
|
85
|
+
optimizer_params: dict = {},
|
|
86
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
87
|
+
loss_params: dict | list[dict] | None = None,
|
|
88
|
+
device: str = "cpu",
|
|
89
|
+
embedding_l1_reg=1e-6,
|
|
90
|
+
dense_l1_reg=1e-5,
|
|
91
|
+
embedding_l2_reg=1e-5,
|
|
92
|
+
dense_l2_reg=1e-4,
|
|
93
|
+
**kwargs,
|
|
94
|
+
):
|
|
95
|
+
|
|
94
96
|
self.num_tasks = len(target)
|
|
95
97
|
|
|
96
98
|
super(MMOE, self).__init__(
|
|
@@ -104,19 +106,21 @@ class MMOE(BaseModel):
|
|
|
104
106
|
dense_l1_reg=dense_l1_reg,
|
|
105
107
|
embedding_l2_reg=embedding_l2_reg,
|
|
106
108
|
dense_l2_reg=dense_l2_reg,
|
|
107
|
-
**kwargs
|
|
109
|
+
**kwargs,
|
|
108
110
|
)
|
|
109
111
|
|
|
110
112
|
self.loss = loss
|
|
111
113
|
if self.loss is None:
|
|
112
114
|
self.loss = "bce"
|
|
113
|
-
|
|
115
|
+
|
|
114
116
|
# Number of tasks and experts
|
|
115
117
|
self.num_tasks = len(target)
|
|
116
118
|
self.num_experts = num_experts
|
|
117
|
-
|
|
119
|
+
|
|
118
120
|
if len(tower_params_list) != self.num_tasks:
|
|
119
|
-
raise ValueError(
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
123
|
+
)
|
|
120
124
|
|
|
121
125
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
122
126
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
@@ -130,54 +134,69 @@ class MMOE(BaseModel):
|
|
|
130
134
|
for _ in range(num_experts):
|
|
131
135
|
expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
|
|
132
136
|
self.experts.append(expert)
|
|
133
|
-
|
|
137
|
+
|
|
134
138
|
# Get expert output dimension
|
|
135
|
-
if
|
|
136
|
-
expert_output_dim = expert_params[
|
|
139
|
+
if "dims" in expert_params and len(expert_params["dims"]) > 0:
|
|
140
|
+
expert_output_dim = expert_params["dims"][-1]
|
|
137
141
|
else:
|
|
138
142
|
expert_output_dim = input_dim
|
|
139
|
-
|
|
143
|
+
|
|
140
144
|
# Task-specific gates
|
|
141
145
|
self.gates = nn.ModuleList()
|
|
142
146
|
for _ in range(self.num_tasks):
|
|
143
147
|
gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
|
|
144
148
|
self.gates.append(gate)
|
|
145
|
-
|
|
149
|
+
|
|
146
150
|
# Task-specific towers
|
|
147
151
|
self.towers = nn.ModuleList()
|
|
148
152
|
for tower_params in tower_params_list:
|
|
149
153
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
150
154
|
self.towers.append(tower)
|
|
151
|
-
self.prediction_layer = PredictionLayer(
|
|
155
|
+
self.prediction_layer = PredictionLayer(
|
|
156
|
+
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
157
|
+
)
|
|
152
158
|
# Register regularization weights
|
|
153
|
-
self.register_regularization_weights(
|
|
154
|
-
|
|
159
|
+
self.register_regularization_weights(
|
|
160
|
+
embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
|
|
161
|
+
)
|
|
162
|
+
self.compile(
|
|
163
|
+
optimizer=optimizer,
|
|
164
|
+
optimizer_params=optimizer_params,
|
|
165
|
+
loss=loss,
|
|
166
|
+
loss_params=loss_params,
|
|
167
|
+
)
|
|
155
168
|
|
|
156
169
|
def forward(self, x):
|
|
157
170
|
# Get all embeddings and flatten
|
|
158
171
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
159
|
-
|
|
172
|
+
|
|
160
173
|
# Expert outputs: [num_experts, B, expert_dim]
|
|
161
174
|
expert_outputs = [expert(input_flat) for expert in self.experts]
|
|
162
|
-
expert_outputs = torch.stack(
|
|
163
|
-
|
|
175
|
+
expert_outputs = torch.stack(
|
|
176
|
+
expert_outputs, dim=0
|
|
177
|
+
) # [num_experts, B, expert_dim]
|
|
178
|
+
|
|
164
179
|
# Task-specific processing
|
|
165
180
|
task_outputs = []
|
|
166
181
|
for task_idx in range(self.num_tasks):
|
|
167
182
|
# Gate weights for this task: [B, num_experts]
|
|
168
183
|
gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
|
|
169
|
-
|
|
184
|
+
|
|
170
185
|
# Weighted sum of expert outputs
|
|
171
186
|
# gate_weights: [B, num_experts, 1]
|
|
172
187
|
# expert_outputs: [num_experts, B, expert_dim]
|
|
173
188
|
gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
|
|
174
|
-
expert_outputs_t = expert_outputs.permute(
|
|
175
|
-
|
|
176
|
-
|
|
189
|
+
expert_outputs_t = expert_outputs.permute(
|
|
190
|
+
1, 0, 2
|
|
191
|
+
) # [B, num_experts, expert_dim]
|
|
192
|
+
gated_output = torch.sum(
|
|
193
|
+
gate_weights * expert_outputs_t, dim=1
|
|
194
|
+
) # [B, expert_dim]
|
|
195
|
+
|
|
177
196
|
# Tower output
|
|
178
197
|
tower_output = self.towers[task_idx](gated_output) # [B, 1]
|
|
179
198
|
task_outputs.append(tower_output)
|
|
180
|
-
|
|
199
|
+
|
|
181
200
|
# Stack outputs: [B, num_tasks]
|
|
182
201
|
y = torch.cat(task_outputs, dim=1)
|
|
183
202
|
return self.prediction_layer(y)
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -52,6 +52,7 @@ from nextrec.basic.model import BaseModel
|
|
|
52
52
|
from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
|
|
53
53
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
54
54
|
|
|
55
|
+
|
|
55
56
|
class CGCLayer(nn.Module):
|
|
56
57
|
"""
|
|
57
58
|
CGC (Customized Gate Control) block used by PLE.
|
|
@@ -71,26 +72,61 @@ class CGCLayer(nn.Module):
|
|
|
71
72
|
if num_tasks < 1:
|
|
72
73
|
raise ValueError("num_tasks must be >= 1")
|
|
73
74
|
|
|
74
|
-
specific_params_list = self._normalize_specific_params(
|
|
75
|
+
specific_params_list = self._normalize_specific_params(
|
|
76
|
+
specific_expert_params, num_tasks
|
|
77
|
+
)
|
|
75
78
|
|
|
76
79
|
self.output_dim = self._get_output_dim(shared_expert_params, input_dim)
|
|
77
|
-
specific_dims = [
|
|
80
|
+
specific_dims = [
|
|
81
|
+
self._get_output_dim(params, input_dim) for params in specific_params_list
|
|
82
|
+
]
|
|
78
83
|
dims_set = set(specific_dims + [self.output_dim])
|
|
79
84
|
if len(dims_set) != 1:
|
|
80
|
-
raise ValueError(
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Shared/specific expert output dims must match, got {dims_set}"
|
|
87
|
+
)
|
|
81
88
|
|
|
82
89
|
# experts
|
|
83
|
-
self.shared_experts = nn.ModuleList(
|
|
90
|
+
self.shared_experts = nn.ModuleList(
|
|
91
|
+
[
|
|
92
|
+
MLP(
|
|
93
|
+
input_dim=input_dim,
|
|
94
|
+
output_layer=False,
|
|
95
|
+
**shared_expert_params,
|
|
96
|
+
)
|
|
97
|
+
for _ in range(num_shared_experts)
|
|
98
|
+
]
|
|
99
|
+
)
|
|
84
100
|
self.specific_experts = nn.ModuleList()
|
|
85
101
|
for params in specific_params_list:
|
|
86
|
-
task_experts = nn.ModuleList(
|
|
102
|
+
task_experts = nn.ModuleList(
|
|
103
|
+
[
|
|
104
|
+
MLP(
|
|
105
|
+
input_dim=input_dim,
|
|
106
|
+
output_layer=False,
|
|
107
|
+
**params,
|
|
108
|
+
)
|
|
109
|
+
for _ in range(num_specific_experts)
|
|
110
|
+
]
|
|
111
|
+
)
|
|
87
112
|
self.specific_experts.append(task_experts)
|
|
88
113
|
|
|
89
114
|
# gates
|
|
90
115
|
task_gate_expert_num = num_shared_experts + num_specific_experts
|
|
91
|
-
self.task_gates = nn.ModuleList(
|
|
116
|
+
self.task_gates = nn.ModuleList(
|
|
117
|
+
[
|
|
118
|
+
nn.Sequential(
|
|
119
|
+
nn.Linear(input_dim, task_gate_expert_num),
|
|
120
|
+
nn.Softmax(dim=1),
|
|
121
|
+
)
|
|
122
|
+
for _ in range(num_tasks)
|
|
123
|
+
]
|
|
124
|
+
)
|
|
92
125
|
shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
|
|
93
|
-
self.shared_gate = nn.Sequential(
|
|
126
|
+
self.shared_gate = nn.Sequential(
|
|
127
|
+
nn.Linear(input_dim, shared_gate_expert_num),
|
|
128
|
+
nn.Softmax(dim=1),
|
|
129
|
+
)
|
|
94
130
|
|
|
95
131
|
self.num_tasks = num_tasks
|
|
96
132
|
|
|
@@ -98,7 +134,9 @@ class CGCLayer(nn.Module):
|
|
|
98
134
|
self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
|
|
99
135
|
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
100
136
|
if len(task_inputs) != self.num_tasks:
|
|
101
|
-
raise ValueError(
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}"
|
|
139
|
+
)
|
|
102
140
|
|
|
103
141
|
shared_outputs = [expert(shared_input) for expert in self.shared_experts]
|
|
104
142
|
shared_stack = torch.stack(shared_outputs, dim=0) # [num_shared, B, D]
|
|
@@ -108,7 +146,7 @@ class CGCLayer(nn.Module):
|
|
|
108
146
|
|
|
109
147
|
for task_idx in range(self.num_tasks):
|
|
110
148
|
task_input = task_inputs[task_idx]
|
|
111
|
-
task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]]
|
|
149
|
+
task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
|
|
112
150
|
all_specific_for_shared.extend(task_specific_outputs)
|
|
113
151
|
specific_stack = torch.stack(task_specific_outputs, dim=0)
|
|
114
152
|
|
|
@@ -139,7 +177,9 @@ class CGCLayer(nn.Module):
|
|
|
139
177
|
) -> list[dict]:
|
|
140
178
|
if isinstance(params, list):
|
|
141
179
|
if len(params) != num_tasks:
|
|
142
|
-
raise ValueError(
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks})."
|
|
182
|
+
)
|
|
143
183
|
return [p.copy() for p in params]
|
|
144
184
|
return [params.copy() for _ in range(num_tasks)]
|
|
145
185
|
|
|
@@ -147,13 +187,13 @@ class CGCLayer(nn.Module):
|
|
|
147
187
|
class PLE(BaseModel):
|
|
148
188
|
"""
|
|
149
189
|
Progressive Layered Extraction
|
|
150
|
-
|
|
190
|
+
|
|
151
191
|
PLE is an advanced multi-task learning model that extends MMOE by introducing
|
|
152
192
|
both task-specific experts and shared experts at each level. It uses a progressive
|
|
153
193
|
routing mechanism where experts from level k feed into gates at level k+1.
|
|
154
194
|
This design better captures task-specific and shared information progressively.
|
|
155
195
|
"""
|
|
156
|
-
|
|
196
|
+
|
|
157
197
|
@property
|
|
158
198
|
def model_name(self):
|
|
159
199
|
return "PLE"
|
|
@@ -162,32 +202,34 @@ class PLE(BaseModel):
|
|
|
162
202
|
def default_task(self):
|
|
163
203
|
num_tasks = getattr(self, "num_tasks", None)
|
|
164
204
|
if num_tasks is not None and num_tasks > 0:
|
|
165
|
-
return [
|
|
166
|
-
return [
|
|
167
|
-
|
|
168
|
-
def __init__(
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
205
|
+
return ["binary"] * num_tasks
|
|
206
|
+
return ["binary"]
|
|
207
|
+
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
dense_features: list[DenseFeature],
|
|
211
|
+
sparse_features: list[SparseFeature],
|
|
212
|
+
sequence_features: list[SequenceFeature],
|
|
213
|
+
shared_expert_params: dict,
|
|
214
|
+
specific_expert_params: dict | list[dict],
|
|
215
|
+
num_shared_experts: int,
|
|
216
|
+
num_specific_experts: int,
|
|
217
|
+
num_levels: int,
|
|
218
|
+
tower_params_list: list[dict],
|
|
219
|
+
target: list[str],
|
|
220
|
+
task: str | list[str] | None = None,
|
|
221
|
+
optimizer: str = "adam",
|
|
222
|
+
optimizer_params: dict | None = None,
|
|
223
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
224
|
+
loss_params: dict | list[dict] | None = None,
|
|
225
|
+
device: str = "cpu",
|
|
226
|
+
embedding_l1_reg=1e-6,
|
|
227
|
+
dense_l1_reg=1e-5,
|
|
228
|
+
embedding_l2_reg=1e-5,
|
|
229
|
+
dense_l2_reg=1e-4,
|
|
230
|
+
**kwargs,
|
|
231
|
+
):
|
|
232
|
+
|
|
191
233
|
self.num_tasks = len(target)
|
|
192
234
|
|
|
193
235
|
super(PLE, self).__init__(
|
|
@@ -201,7 +243,7 @@ class PLE(BaseModel):
|
|
|
201
243
|
dense_l1_reg=dense_l1_reg,
|
|
202
244
|
embedding_l2_reg=embedding_l2_reg,
|
|
203
245
|
dense_l2_reg=dense_l2_reg,
|
|
204
|
-
**kwargs
|
|
246
|
+
**kwargs,
|
|
205
247
|
)
|
|
206
248
|
|
|
207
249
|
self.loss = loss
|
|
@@ -215,7 +257,9 @@ class PLE(BaseModel):
|
|
|
215
257
|
if optimizer_params is None:
|
|
216
258
|
optimizer_params = {}
|
|
217
259
|
if len(tower_params_list) != self.num_tasks:
|
|
218
|
-
raise ValueError(
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
262
|
+
)
|
|
219
263
|
# Embedding layer
|
|
220
264
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
221
265
|
|
|
@@ -224,13 +268,13 @@ class PLE(BaseModel):
|
|
|
224
268
|
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
225
269
|
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
226
270
|
# input_dim = emb_dim_total + dense_input_dim
|
|
227
|
-
|
|
271
|
+
|
|
228
272
|
# Get expert output dimension
|
|
229
|
-
if
|
|
230
|
-
expert_output_dim = shared_expert_params[
|
|
273
|
+
if "dims" in shared_expert_params and len(shared_expert_params["dims"]) > 0:
|
|
274
|
+
expert_output_dim = shared_expert_params["dims"][-1]
|
|
231
275
|
else:
|
|
232
276
|
expert_output_dim = input_dim
|
|
233
|
-
|
|
277
|
+
|
|
234
278
|
# Build CGC layers
|
|
235
279
|
self.cgc_layers = nn.ModuleList()
|
|
236
280
|
for level in range(num_levels):
|
|
@@ -245,16 +289,25 @@ class PLE(BaseModel):
|
|
|
245
289
|
)
|
|
246
290
|
self.cgc_layers.append(cgc_layer)
|
|
247
291
|
expert_output_dim = cgc_layer.output_dim
|
|
248
|
-
|
|
292
|
+
|
|
249
293
|
# Task-specific towers
|
|
250
294
|
self.towers = nn.ModuleList()
|
|
251
295
|
for tower_params in tower_params_list:
|
|
252
296
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
253
297
|
self.towers.append(tower)
|
|
254
|
-
self.prediction_layer = PredictionLayer(
|
|
298
|
+
self.prediction_layer = PredictionLayer(
|
|
299
|
+
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
300
|
+
)
|
|
255
301
|
# Register regularization weights
|
|
256
|
-
self.register_regularization_weights(
|
|
257
|
-
|
|
302
|
+
self.register_regularization_weights(
|
|
303
|
+
embedding_attr="embedding", include_modules=["cgc_layers", "towers"]
|
|
304
|
+
)
|
|
305
|
+
self.compile(
|
|
306
|
+
optimizer=optimizer,
|
|
307
|
+
optimizer_params=optimizer_params,
|
|
308
|
+
loss=self.loss,
|
|
309
|
+
loss_params=loss_params,
|
|
310
|
+
)
|
|
258
311
|
|
|
259
312
|
def forward(self, x):
|
|
260
313
|
# Get all embeddings and flatten
|