nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -9
- nextrec/basic/callback.py +1 -0
- nextrec/basic/dataloader.py +168 -127
- nextrec/basic/features.py +24 -27
- nextrec/basic/layers.py +328 -159
- nextrec/basic/loggers.py +50 -37
- nextrec/basic/metrics.py +255 -147
- nextrec/basic/model.py +817 -462
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +16 -12
- nextrec/data/preprocessor.py +276 -252
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +30 -22
- nextrec/loss/match_losses.py +116 -83
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +70 -61
- nextrec/models/match/dssm_v2.py +61 -51
- nextrec/models/match/mind.py +89 -71
- nextrec/models/match/sdm.py +93 -81
- nextrec/models/match/youtube_dnn.py +62 -53
- nextrec/models/multi_task/esmm.py +49 -43
- nextrec/models/multi_task/mmoe.py +65 -56
- nextrec/models/multi_task/ple.py +92 -65
- nextrec/models/multi_task/share_bottom.py +48 -42
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +39 -30
- nextrec/models/ranking/autoint.py +70 -57
- nextrec/models/ranking/dcn.py +43 -35
- nextrec/models/ranking/deepfm.py +34 -28
- nextrec/models/ranking/dien.py +115 -79
- nextrec/models/ranking/din.py +84 -60
- nextrec/models/ranking/fibinet.py +51 -35
- nextrec/models/ranking/fm.py +28 -26
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +30 -31
- nextrec/models/ranking/widedeep.py +36 -31
- nextrec/models/ranking/xdeepfm.py +46 -39
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +23 -15
- nextrec/utils/optimizer.py +14 -10
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
- nextrec-0.1.2.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/RECORD +0 -51
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,6 +6,7 @@ Reference:
|
|
|
6
6
|
[1] Covington P, Adams J, Sargin E. Deep neural networks for youtube recommendations[C]
|
|
7
7
|
//Proceedings of the 10th ACM conference on recommender systems. 2016: 191-198.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
11
12
|
from typing import Literal
|
|
@@ -18,40 +19,42 @@ from nextrec.basic.layers import MLP, EmbeddingLayer, AveragePooling
|
|
|
18
19
|
class YoutubeDNN(BaseMatchModel):
|
|
19
20
|
"""
|
|
20
21
|
YouTube Deep Neural Network for Recommendations
|
|
21
|
-
|
|
22
|
+
|
|
22
23
|
用户塔:历史行为序列 + 用户特征 -> 用户embedding
|
|
23
24
|
物品塔:物品特征 -> 物品embedding
|
|
24
25
|
训练:sampled softmax loss (listwise)
|
|
25
26
|
"""
|
|
26
|
-
|
|
27
|
+
|
|
27
28
|
@property
|
|
28
29
|
def model_name(self) -> str:
|
|
29
30
|
return "YouTubeDNN"
|
|
30
|
-
|
|
31
|
-
def __init__(
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
35
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
36
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
37
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
38
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
39
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
40
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
41
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
42
|
+
embedding_dim: int = 64,
|
|
43
|
+
dnn_activation: str = "relu",
|
|
44
|
+
dnn_dropout: float = 0.0,
|
|
45
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "listwise",
|
|
46
|
+
num_negative_samples: int = 100,
|
|
47
|
+
temperature: float = 1.0,
|
|
48
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
49
|
+
device: str = "cpu",
|
|
50
|
+
embedding_l1_reg: float = 0.0,
|
|
51
|
+
dense_l1_reg: float = 0.0,
|
|
52
|
+
embedding_l2_reg: float = 0.0,
|
|
53
|
+
dense_l2_reg: float = 0.0,
|
|
54
|
+
early_stop_patience: int = 20,
|
|
55
|
+
model_id: str = "youtube_dnn",
|
|
56
|
+
):
|
|
57
|
+
|
|
55
58
|
super(YoutubeDNN, self).__init__(
|
|
56
59
|
user_dense_features=user_dense_features,
|
|
57
60
|
user_sparse_features=user_sparse_features,
|
|
@@ -69,13 +72,13 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
69
72
|
embedding_l2_reg=embedding_l2_reg,
|
|
70
73
|
dense_l2_reg=dense_l2_reg,
|
|
71
74
|
early_stop_patience=early_stop_patience,
|
|
72
|
-
model_id=model_id
|
|
75
|
+
model_id=model_id,
|
|
73
76
|
)
|
|
74
|
-
|
|
77
|
+
|
|
75
78
|
self.embedding_dim = embedding_dim
|
|
76
79
|
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
77
80
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
78
|
-
|
|
81
|
+
|
|
79
82
|
# User tower
|
|
80
83
|
user_features = []
|
|
81
84
|
if user_dense_features:
|
|
@@ -84,10 +87,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
84
87
|
user_features.extend(user_sparse_features)
|
|
85
88
|
if user_sequence_features:
|
|
86
89
|
user_features.extend(user_sequence_features)
|
|
87
|
-
|
|
90
|
+
|
|
88
91
|
if len(user_features) > 0:
|
|
89
92
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
90
|
-
|
|
93
|
+
|
|
91
94
|
user_input_dim = 0
|
|
92
95
|
for feat in user_dense_features or []:
|
|
93
96
|
user_input_dim += 1
|
|
@@ -96,16 +99,16 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
96
99
|
for feat in user_sequence_features or []:
|
|
97
100
|
# 序列特征通过平均池化聚合
|
|
98
101
|
user_input_dim += feat.embedding_dim
|
|
99
|
-
|
|
102
|
+
|
|
100
103
|
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
101
104
|
self.user_dnn = MLP(
|
|
102
105
|
input_dim=user_input_dim,
|
|
103
106
|
dims=user_dnn_units,
|
|
104
107
|
output_layer=False,
|
|
105
108
|
dropout=dnn_dropout,
|
|
106
|
-
activation=dnn_activation
|
|
109
|
+
activation=dnn_activation,
|
|
107
110
|
)
|
|
108
|
-
|
|
111
|
+
|
|
109
112
|
# Item tower
|
|
110
113
|
item_features = []
|
|
111
114
|
if item_dense_features:
|
|
@@ -114,10 +117,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
114
117
|
item_features.extend(item_sparse_features)
|
|
115
118
|
if item_sequence_features:
|
|
116
119
|
item_features.extend(item_sequence_features)
|
|
117
|
-
|
|
120
|
+
|
|
118
121
|
if len(item_features) > 0:
|
|
119
122
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
120
|
-
|
|
123
|
+
|
|
121
124
|
item_input_dim = 0
|
|
122
125
|
for feat in item_dense_features or []:
|
|
123
126
|
item_input_dim += 1
|
|
@@ -125,48 +128,54 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
125
128
|
item_input_dim += feat.embedding_dim
|
|
126
129
|
for feat in item_sequence_features or []:
|
|
127
130
|
item_input_dim += feat.embedding_dim
|
|
128
|
-
|
|
131
|
+
|
|
129
132
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
130
133
|
self.item_dnn = MLP(
|
|
131
134
|
input_dim=item_input_dim,
|
|
132
135
|
dims=item_dnn_units,
|
|
133
136
|
output_layer=False,
|
|
134
137
|
dropout=dnn_dropout,
|
|
135
|
-
activation=dnn_activation
|
|
138
|
+
activation=dnn_activation,
|
|
136
139
|
)
|
|
137
|
-
|
|
140
|
+
|
|
138
141
|
self._register_regularization_weights(
|
|
139
|
-
embedding_attr=
|
|
140
|
-
include_modules=['user_dnn']
|
|
142
|
+
embedding_attr="user_embedding", include_modules=["user_dnn"]
|
|
141
143
|
)
|
|
142
144
|
self._register_regularization_weights(
|
|
143
|
-
embedding_attr=
|
|
144
|
-
include_modules=['item_dnn']
|
|
145
|
+
embedding_attr="item_embedding", include_modules=["item_dnn"]
|
|
145
146
|
)
|
|
146
|
-
|
|
147
|
+
|
|
147
148
|
self.to(device)
|
|
148
|
-
|
|
149
|
+
|
|
149
150
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
150
151
|
"""
|
|
151
152
|
User tower
|
|
152
153
|
处理用户历史行为序列和其他用户特征
|
|
153
154
|
"""
|
|
154
|
-
all_user_features =
|
|
155
|
+
all_user_features = (
|
|
156
|
+
self.user_dense_features
|
|
157
|
+
+ self.user_sparse_features
|
|
158
|
+
+ self.user_sequence_features
|
|
159
|
+
)
|
|
155
160
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
156
161
|
user_emb = self.user_dnn(user_emb)
|
|
157
|
-
|
|
162
|
+
|
|
158
163
|
# L2 normalization
|
|
159
164
|
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
160
|
-
|
|
165
|
+
|
|
161
166
|
return user_emb
|
|
162
|
-
|
|
167
|
+
|
|
163
168
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
164
169
|
"""Item tower"""
|
|
165
|
-
all_item_features =
|
|
170
|
+
all_item_features = (
|
|
171
|
+
self.item_dense_features
|
|
172
|
+
+ self.item_sparse_features
|
|
173
|
+
+ self.item_sequence_features
|
|
174
|
+
)
|
|
166
175
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
167
176
|
item_emb = self.item_dnn(item_emb)
|
|
168
|
-
|
|
177
|
+
|
|
169
178
|
# L2 normalization
|
|
170
179
|
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
171
|
-
|
|
180
|
+
|
|
172
181
|
return item_emb
|
|
@@ -17,15 +17,15 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
17
17
|
class ESMM(BaseModel):
|
|
18
18
|
"""
|
|
19
19
|
Entire Space Multi-Task Model
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
|
|
22
22
|
- CTR task: P(click | impression)
|
|
23
23
|
- CVR task: P(conversion | click)
|
|
24
24
|
- CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
25
|
-
|
|
25
|
+
|
|
26
26
|
This design addresses the sample selection bias and data sparsity issues in CVR modeling.
|
|
27
27
|
"""
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
@property
|
|
30
30
|
def model_name(self):
|
|
31
31
|
return "ESMM"
|
|
@@ -33,30 +33,34 @@ class ESMM(BaseModel):
|
|
|
33
33
|
@property
|
|
34
34
|
def task_type(self):
|
|
35
35
|
# ESMM has fixed task types: CTR (binary) and CVR (binary)
|
|
36
|
-
return [
|
|
37
|
-
|
|
38
|
-
def __init__(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
36
|
+
return ["binary", "binary"]
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
dense_features: list[DenseFeature],
|
|
41
|
+
sparse_features: list[SparseFeature],
|
|
42
|
+
sequence_features: list[SequenceFeature],
|
|
43
|
+
ctr_params: dict,
|
|
44
|
+
cvr_params: dict,
|
|
45
|
+
target: list[str] = ["ctr", "ctcvr"], # Note: ctcvr = ctr * cvr
|
|
46
|
+
task: str | list[str] = "binary",
|
|
47
|
+
optimizer: str = "adam",
|
|
48
|
+
optimizer_params: dict = {},
|
|
49
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
50
|
+
device: str = "cpu",
|
|
51
|
+
model_id: str = "baseline",
|
|
52
|
+
embedding_l1_reg=1e-6,
|
|
53
|
+
dense_l1_reg=1e-5,
|
|
54
|
+
embedding_l2_reg=1e-5,
|
|
55
|
+
dense_l2_reg=1e-4,
|
|
56
|
+
):
|
|
57
|
+
|
|
56
58
|
# ESMM requires exactly 2 targets: ctr and ctcvr
|
|
57
59
|
if len(target) != 2:
|
|
58
|
-
raise ValueError(
|
|
59
|
-
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
|
|
62
|
+
)
|
|
63
|
+
|
|
60
64
|
super(ESMM, self).__init__(
|
|
61
65
|
dense_features=dense_features,
|
|
62
66
|
sparse_features=sparse_features,
|
|
@@ -69,13 +73,13 @@ class ESMM(BaseModel):
|
|
|
69
73
|
embedding_l2_reg=embedding_l2_reg,
|
|
70
74
|
dense_l2_reg=dense_l2_reg,
|
|
71
75
|
early_stop_patience=20,
|
|
72
|
-
model_id=model_id
|
|
76
|
+
model_id=model_id,
|
|
73
77
|
)
|
|
74
78
|
|
|
75
79
|
self.loss = loss
|
|
76
80
|
if self.loss is None:
|
|
77
81
|
self.loss = "bce"
|
|
78
|
-
|
|
82
|
+
|
|
79
83
|
# All features
|
|
80
84
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
81
85
|
|
|
@@ -83,46 +87,48 @@ class ESMM(BaseModel):
|
|
|
83
87
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
84
88
|
|
|
85
89
|
# Calculate input dimension
|
|
86
|
-
emb_dim_total = sum(
|
|
87
|
-
|
|
90
|
+
emb_dim_total = sum(
|
|
91
|
+
[
|
|
92
|
+
f.embedding_dim
|
|
93
|
+
for f in self.all_features
|
|
94
|
+
if not isinstance(f, DenseFeature)
|
|
95
|
+
]
|
|
96
|
+
)
|
|
97
|
+
dense_input_dim = sum(
|
|
98
|
+
[getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
|
|
99
|
+
)
|
|
88
100
|
input_dim = emb_dim_total + dense_input_dim
|
|
89
|
-
|
|
101
|
+
|
|
90
102
|
# CTR tower
|
|
91
103
|
self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
|
|
92
|
-
|
|
104
|
+
|
|
93
105
|
# CVR tower
|
|
94
106
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
95
107
|
self.prediction_layer = PredictionLayer(
|
|
96
|
-
task_type=self.task_type,
|
|
97
|
-
task_dims=[1, 1]
|
|
108
|
+
task_type=self.task_type, task_dims=[1, 1]
|
|
98
109
|
)
|
|
99
110
|
|
|
100
111
|
# Register regularization weights
|
|
101
112
|
self._register_regularization_weights(
|
|
102
|
-
embedding_attr=
|
|
103
|
-
include_modules=['ctr_tower', 'cvr_tower']
|
|
113
|
+
embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
|
|
104
114
|
)
|
|
105
115
|
|
|
106
|
-
self.compile(
|
|
107
|
-
optimizer=optimizer,
|
|
108
|
-
optimizer_params=optimizer_params,
|
|
109
|
-
loss=loss
|
|
110
|
-
)
|
|
116
|
+
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
|
|
111
117
|
|
|
112
118
|
def forward(self, x):
|
|
113
119
|
# Get all embeddings and flatten
|
|
114
120
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
115
|
-
|
|
121
|
+
|
|
116
122
|
# CTR prediction: P(click | impression)
|
|
117
123
|
ctr_logit = self.ctr_tower(input_flat) # [B, 1]
|
|
118
124
|
cvr_logit = self.cvr_tower(input_flat) # [B, 1]
|
|
119
125
|
logits = torch.cat([ctr_logit, cvr_logit], dim=1)
|
|
120
126
|
preds = self.prediction_layer(logits)
|
|
121
127
|
ctr, cvr = preds.chunk(2, dim=1)
|
|
122
|
-
|
|
128
|
+
|
|
123
129
|
# CTCVR prediction: P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
124
130
|
ctcvr = ctr * cvr # [B, 1]
|
|
125
|
-
|
|
131
|
+
|
|
126
132
|
# Output: [CTR, CTCVR]
|
|
127
133
|
# Note: We supervise CTR with click labels and CTCVR with conversion labels
|
|
128
134
|
y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
|
|
@@ -17,13 +17,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
17
17
|
class MMOE(BaseModel):
|
|
18
18
|
"""
|
|
19
19
|
Multi-gate Mixture-of-Experts
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
MMOE improves upon shared-bottom architecture by using multiple expert networks
|
|
22
22
|
and task-specific gating networks. Each task has its own gate that learns to
|
|
23
23
|
weight the contributions of different experts, allowing for both task-specific
|
|
24
24
|
and shared representations.
|
|
25
25
|
"""
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
@property
|
|
28
28
|
def model_name(self):
|
|
29
29
|
return "MMOE"
|
|
@@ -31,26 +31,28 @@ class MMOE(BaseModel):
|
|
|
31
31
|
@property
|
|
32
32
|
def task_type(self):
|
|
33
33
|
return self.task if isinstance(self.task, list) else [self.task]
|
|
34
|
-
|
|
35
|
-
def __init__(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
dense_features: list[DenseFeature] = [],
|
|
38
|
+
sparse_features: list[SparseFeature] = [],
|
|
39
|
+
sequence_features: list[SequenceFeature] = [],
|
|
40
|
+
expert_params: dict = {},
|
|
41
|
+
num_experts: int = 3,
|
|
42
|
+
tower_params_list: list[dict] = [],
|
|
43
|
+
target: list[str] = [],
|
|
44
|
+
task: str | list[str] = "binary",
|
|
45
|
+
optimizer: str = "adam",
|
|
46
|
+
optimizer_params: dict = {},
|
|
47
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
48
|
+
device: str = "cpu",
|
|
49
|
+
model_id: str = "baseline",
|
|
50
|
+
embedding_l1_reg=1e-6,
|
|
51
|
+
dense_l1_reg=1e-5,
|
|
52
|
+
embedding_l2_reg=1e-5,
|
|
53
|
+
dense_l2_reg=1e-4,
|
|
54
|
+
):
|
|
55
|
+
|
|
54
56
|
super(MMOE, self).__init__(
|
|
55
57
|
dense_features=dense_features,
|
|
56
58
|
sparse_features=sparse_features,
|
|
@@ -63,20 +65,22 @@ class MMOE(BaseModel):
|
|
|
63
65
|
embedding_l2_reg=embedding_l2_reg,
|
|
64
66
|
dense_l2_reg=dense_l2_reg,
|
|
65
67
|
early_stop_patience=20,
|
|
66
|
-
model_id=model_id
|
|
68
|
+
model_id=model_id,
|
|
67
69
|
)
|
|
68
70
|
|
|
69
71
|
self.loss = loss
|
|
70
72
|
if self.loss is None:
|
|
71
73
|
self.loss = "bce"
|
|
72
|
-
|
|
74
|
+
|
|
73
75
|
# Number of tasks and experts
|
|
74
76
|
self.num_tasks = len(target)
|
|
75
77
|
self.num_experts = num_experts
|
|
76
|
-
|
|
78
|
+
|
|
77
79
|
if len(tower_params_list) != self.num_tasks:
|
|
78
|
-
raise ValueError(
|
|
79
|
-
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
82
|
+
)
|
|
83
|
+
|
|
80
84
|
# All features
|
|
81
85
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
82
86
|
|
|
@@ -84,78 +88,83 @@ class MMOE(BaseModel):
|
|
|
84
88
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
85
89
|
|
|
86
90
|
# Calculate input dimension
|
|
87
|
-
emb_dim_total = sum(
|
|
88
|
-
|
|
91
|
+
emb_dim_total = sum(
|
|
92
|
+
[
|
|
93
|
+
f.embedding_dim
|
|
94
|
+
for f in self.all_features
|
|
95
|
+
if not isinstance(f, DenseFeature)
|
|
96
|
+
]
|
|
97
|
+
)
|
|
98
|
+
dense_input_dim = sum(
|
|
99
|
+
[getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
|
|
100
|
+
)
|
|
89
101
|
input_dim = emb_dim_total + dense_input_dim
|
|
90
|
-
|
|
102
|
+
|
|
91
103
|
# Expert networks (shared by all tasks)
|
|
92
104
|
self.experts = nn.ModuleList()
|
|
93
105
|
for _ in range(num_experts):
|
|
94
106
|
expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
|
|
95
107
|
self.experts.append(expert)
|
|
96
|
-
|
|
108
|
+
|
|
97
109
|
# Get expert output dimension
|
|
98
|
-
if
|
|
99
|
-
expert_output_dim = expert_params[
|
|
110
|
+
if "dims" in expert_params and len(expert_params["dims"]) > 0:
|
|
111
|
+
expert_output_dim = expert_params["dims"][-1]
|
|
100
112
|
else:
|
|
101
113
|
expert_output_dim = input_dim
|
|
102
|
-
|
|
114
|
+
|
|
103
115
|
# Task-specific gates
|
|
104
116
|
self.gates = nn.ModuleList()
|
|
105
117
|
for _ in range(self.num_tasks):
|
|
106
|
-
gate = nn.Sequential(
|
|
107
|
-
nn.Linear(input_dim, num_experts),
|
|
108
|
-
nn.Softmax(dim=1)
|
|
109
|
-
)
|
|
118
|
+
gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
|
|
110
119
|
self.gates.append(gate)
|
|
111
|
-
|
|
120
|
+
|
|
112
121
|
# Task-specific towers
|
|
113
122
|
self.towers = nn.ModuleList()
|
|
114
123
|
for tower_params in tower_params_list:
|
|
115
124
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
116
125
|
self.towers.append(tower)
|
|
117
126
|
self.prediction_layer = PredictionLayer(
|
|
118
|
-
task_type=self.task_type,
|
|
119
|
-
task_dims=[1] * self.num_tasks
|
|
127
|
+
task_type=self.task_type, task_dims=[1] * self.num_tasks
|
|
120
128
|
)
|
|
121
129
|
|
|
122
130
|
# Register regularization weights
|
|
123
131
|
self._register_regularization_weights(
|
|
124
|
-
embedding_attr=
|
|
125
|
-
include_modules=['experts', 'gates', 'towers']
|
|
132
|
+
embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
|
|
126
133
|
)
|
|
127
134
|
|
|
128
|
-
self.compile(
|
|
129
|
-
optimizer=optimizer,
|
|
130
|
-
optimizer_params=optimizer_params,
|
|
131
|
-
loss=loss
|
|
132
|
-
)
|
|
135
|
+
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
|
|
133
136
|
|
|
134
137
|
def forward(self, x):
|
|
135
138
|
# Get all embeddings and flatten
|
|
136
139
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
137
|
-
|
|
140
|
+
|
|
138
141
|
# Expert outputs: [num_experts, B, expert_dim]
|
|
139
142
|
expert_outputs = [expert(input_flat) for expert in self.experts]
|
|
140
|
-
expert_outputs = torch.stack(
|
|
141
|
-
|
|
143
|
+
expert_outputs = torch.stack(
|
|
144
|
+
expert_outputs, dim=0
|
|
145
|
+
) # [num_experts, B, expert_dim]
|
|
146
|
+
|
|
142
147
|
# Task-specific processing
|
|
143
148
|
task_outputs = []
|
|
144
149
|
for task_idx in range(self.num_tasks):
|
|
145
150
|
# Gate weights for this task: [B, num_experts]
|
|
146
151
|
gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
|
|
147
|
-
|
|
152
|
+
|
|
148
153
|
# Weighted sum of expert outputs
|
|
149
154
|
# gate_weights: [B, num_experts, 1]
|
|
150
155
|
# expert_outputs: [num_experts, B, expert_dim]
|
|
151
156
|
gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
|
|
152
|
-
expert_outputs_t = expert_outputs.permute(
|
|
153
|
-
|
|
154
|
-
|
|
157
|
+
expert_outputs_t = expert_outputs.permute(
|
|
158
|
+
1, 0, 2
|
|
159
|
+
) # [B, num_experts, expert_dim]
|
|
160
|
+
gated_output = torch.sum(
|
|
161
|
+
gate_weights * expert_outputs_t, dim=1
|
|
162
|
+
) # [B, expert_dim]
|
|
163
|
+
|
|
155
164
|
# Tower output
|
|
156
165
|
tower_output = self.towers[task_idx](gated_output) # [B, 1]
|
|
157
166
|
task_outputs.append(tower_output)
|
|
158
|
-
|
|
167
|
+
|
|
159
168
|
# Stack outputs: [B, num_tasks]
|
|
160
169
|
y = torch.cat(task_outputs, dim=1)
|
|
161
170
|
return self.prediction_layer(y)
|