nextrec 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -10
- nextrec/basic/callback.py +0 -1
- nextrec/basic/dataloader.py +127 -168
- nextrec/basic/features.py +27 -24
- nextrec/basic/layers.py +159 -328
- nextrec/basic/loggers.py +37 -50
- nextrec/basic/metrics.py +147 -255
- nextrec/basic/model.py +462 -817
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +12 -16
- nextrec/data/preprocessor.py +252 -276
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +22 -30
- nextrec/loss/match_losses.py +83 -116
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +61 -70
- nextrec/models/match/dssm_v2.py +51 -61
- nextrec/models/match/mind.py +71 -89
- nextrec/models/match/sdm.py +81 -93
- nextrec/models/match/youtube_dnn.py +53 -62
- nextrec/models/multi_task/esmm.py +43 -49
- nextrec/models/multi_task/mmoe.py +56 -65
- nextrec/models/multi_task/ple.py +65 -92
- nextrec/models/multi_task/share_bottom.py +42 -48
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +30 -39
- nextrec/models/ranking/autoint.py +57 -70
- nextrec/models/ranking/dcn.py +35 -43
- nextrec/models/ranking/deepfm.py +28 -34
- nextrec/models/ranking/dien.py +79 -115
- nextrec/models/ranking/din.py +60 -84
- nextrec/models/ranking/fibinet.py +35 -51
- nextrec/models/ranking/fm.py +26 -28
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +31 -30
- nextrec/models/ranking/widedeep.py +31 -36
- nextrec/models/ranking/xdeepfm.py +39 -46
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +15 -23
- nextrec/utils/optimizer.py +10 -14
- {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/METADATA +16 -7
- nextrec-0.1.8.dist-info/RECORD +51 -0
- nextrec-0.1.4.dist-info/RECORD +0 -51
- {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/WHEEL +0 -0
- {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,6 @@ Reference:
|
|
|
6
6
|
[1] Covington P, Adams J, Sargin E. Deep neural networks for youtube recommendations[C]
|
|
7
7
|
//Proceedings of the 10th ACM conference on recommender systems. 2016: 191-198.
|
|
8
8
|
"""
|
|
9
|
-
|
|
10
9
|
import torch
|
|
11
10
|
import torch.nn as nn
|
|
12
11
|
from typing import Literal
|
|
@@ -19,42 +18,40 @@ from nextrec.basic.layers import MLP, EmbeddingLayer, AveragePooling
|
|
|
19
18
|
class YoutubeDNN(BaseMatchModel):
|
|
20
19
|
"""
|
|
21
20
|
YouTube Deep Neural Network for Recommendations
|
|
22
|
-
|
|
21
|
+
|
|
23
22
|
用户塔:历史行为序列 + 用户特征 -> 用户embedding
|
|
24
23
|
物品塔:物品特征 -> 物品embedding
|
|
25
24
|
训练:sampled softmax loss (listwise)
|
|
26
25
|
"""
|
|
27
|
-
|
|
26
|
+
|
|
28
27
|
@property
|
|
29
28
|
def model_name(self) -> str:
|
|
30
29
|
return "YouTubeDNN"
|
|
31
|
-
|
|
32
|
-
def __init__(
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
):
|
|
57
|
-
|
|
30
|
+
|
|
31
|
+
def __init__(self,
|
|
32
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
33
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
34
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
35
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
36
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
37
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
38
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
39
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
40
|
+
embedding_dim: int = 64,
|
|
41
|
+
dnn_activation: str = 'relu',
|
|
42
|
+
dnn_dropout: float = 0.0,
|
|
43
|
+
training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'listwise',
|
|
44
|
+
num_negative_samples: int = 100,
|
|
45
|
+
temperature: float = 1.0,
|
|
46
|
+
similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
|
|
47
|
+
device: str = 'cpu',
|
|
48
|
+
embedding_l1_reg: float = 0.0,
|
|
49
|
+
dense_l1_reg: float = 0.0,
|
|
50
|
+
embedding_l2_reg: float = 0.0,
|
|
51
|
+
dense_l2_reg: float = 0.0,
|
|
52
|
+
early_stop_patience: int = 20,
|
|
53
|
+
model_id: str = 'youtube_dnn'):
|
|
54
|
+
|
|
58
55
|
super(YoutubeDNN, self).__init__(
|
|
59
56
|
user_dense_features=user_dense_features,
|
|
60
57
|
user_sparse_features=user_sparse_features,
|
|
@@ -72,13 +69,13 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
72
69
|
embedding_l2_reg=embedding_l2_reg,
|
|
73
70
|
dense_l2_reg=dense_l2_reg,
|
|
74
71
|
early_stop_patience=early_stop_patience,
|
|
75
|
-
model_id=model_id
|
|
72
|
+
model_id=model_id
|
|
76
73
|
)
|
|
77
|
-
|
|
74
|
+
|
|
78
75
|
self.embedding_dim = embedding_dim
|
|
79
76
|
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
80
77
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
81
|
-
|
|
78
|
+
|
|
82
79
|
# User tower
|
|
83
80
|
user_features = []
|
|
84
81
|
if user_dense_features:
|
|
@@ -87,10 +84,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
87
84
|
user_features.extend(user_sparse_features)
|
|
88
85
|
if user_sequence_features:
|
|
89
86
|
user_features.extend(user_sequence_features)
|
|
90
|
-
|
|
87
|
+
|
|
91
88
|
if len(user_features) > 0:
|
|
92
89
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
93
|
-
|
|
90
|
+
|
|
94
91
|
user_input_dim = 0
|
|
95
92
|
for feat in user_dense_features or []:
|
|
96
93
|
user_input_dim += 1
|
|
@@ -99,16 +96,16 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
99
96
|
for feat in user_sequence_features or []:
|
|
100
97
|
# 序列特征通过平均池化聚合
|
|
101
98
|
user_input_dim += feat.embedding_dim
|
|
102
|
-
|
|
99
|
+
|
|
103
100
|
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
104
101
|
self.user_dnn = MLP(
|
|
105
102
|
input_dim=user_input_dim,
|
|
106
103
|
dims=user_dnn_units,
|
|
107
104
|
output_layer=False,
|
|
108
105
|
dropout=dnn_dropout,
|
|
109
|
-
activation=dnn_activation
|
|
106
|
+
activation=dnn_activation
|
|
110
107
|
)
|
|
111
|
-
|
|
108
|
+
|
|
112
109
|
# Item tower
|
|
113
110
|
item_features = []
|
|
114
111
|
if item_dense_features:
|
|
@@ -117,10 +114,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
117
114
|
item_features.extend(item_sparse_features)
|
|
118
115
|
if item_sequence_features:
|
|
119
116
|
item_features.extend(item_sequence_features)
|
|
120
|
-
|
|
117
|
+
|
|
121
118
|
if len(item_features) > 0:
|
|
122
119
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
123
|
-
|
|
120
|
+
|
|
124
121
|
item_input_dim = 0
|
|
125
122
|
for feat in item_dense_features or []:
|
|
126
123
|
item_input_dim += 1
|
|
@@ -128,54 +125,48 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
128
125
|
item_input_dim += feat.embedding_dim
|
|
129
126
|
for feat in item_sequence_features or []:
|
|
130
127
|
item_input_dim += feat.embedding_dim
|
|
131
|
-
|
|
128
|
+
|
|
132
129
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
133
130
|
self.item_dnn = MLP(
|
|
134
131
|
input_dim=item_input_dim,
|
|
135
132
|
dims=item_dnn_units,
|
|
136
133
|
output_layer=False,
|
|
137
134
|
dropout=dnn_dropout,
|
|
138
|
-
activation=dnn_activation
|
|
135
|
+
activation=dnn_activation
|
|
139
136
|
)
|
|
140
|
-
|
|
137
|
+
|
|
141
138
|
self._register_regularization_weights(
|
|
142
|
-
embedding_attr=
|
|
139
|
+
embedding_attr='user_embedding',
|
|
140
|
+
include_modules=['user_dnn']
|
|
143
141
|
)
|
|
144
142
|
self._register_regularization_weights(
|
|
145
|
-
embedding_attr=
|
|
143
|
+
embedding_attr='item_embedding',
|
|
144
|
+
include_modules=['item_dnn']
|
|
146
145
|
)
|
|
147
|
-
|
|
146
|
+
|
|
148
147
|
self.to(device)
|
|
149
|
-
|
|
148
|
+
|
|
150
149
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
151
150
|
"""
|
|
152
151
|
User tower
|
|
153
152
|
处理用户历史行为序列和其他用户特征
|
|
154
153
|
"""
|
|
155
|
-
all_user_features =
|
|
156
|
-
self.user_dense_features
|
|
157
|
-
+ self.user_sparse_features
|
|
158
|
-
+ self.user_sequence_features
|
|
159
|
-
)
|
|
154
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
160
155
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
161
156
|
user_emb = self.user_dnn(user_emb)
|
|
162
|
-
|
|
157
|
+
|
|
163
158
|
# L2 normalization
|
|
164
159
|
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
165
|
-
|
|
160
|
+
|
|
166
161
|
return user_emb
|
|
167
|
-
|
|
162
|
+
|
|
168
163
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
169
164
|
"""Item tower"""
|
|
170
|
-
all_item_features =
|
|
171
|
-
self.item_dense_features
|
|
172
|
-
+ self.item_sparse_features
|
|
173
|
-
+ self.item_sequence_features
|
|
174
|
-
)
|
|
165
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
175
166
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
176
167
|
item_emb = self.item_dnn(item_emb)
|
|
177
|
-
|
|
168
|
+
|
|
178
169
|
# L2 normalization
|
|
179
170
|
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
180
|
-
|
|
171
|
+
|
|
181
172
|
return item_emb
|
|
@@ -17,15 +17,15 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
17
17
|
class ESMM(BaseModel):
|
|
18
18
|
"""
|
|
19
19
|
Entire Space Multi-Task Model
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
|
|
22
22
|
- CTR task: P(click | impression)
|
|
23
23
|
- CVR task: P(conversion | click)
|
|
24
24
|
- CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
25
|
-
|
|
25
|
+
|
|
26
26
|
This design addresses the sample selection bias and data sparsity issues in CVR modeling.
|
|
27
27
|
"""
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
@property
|
|
30
30
|
def model_name(self):
|
|
31
31
|
return "ESMM"
|
|
@@ -33,34 +33,30 @@ class ESMM(BaseModel):
|
|
|
33
33
|
@property
|
|
34
34
|
def task_type(self):
|
|
35
35
|
# ESMM has fixed task types: CTR (binary) and CVR (binary)
|
|
36
|
-
return [
|
|
37
|
-
|
|
38
|
-
def __init__(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
):
|
|
57
|
-
|
|
36
|
+
return ['binary', 'binary']
|
|
37
|
+
|
|
38
|
+
def __init__(self,
|
|
39
|
+
dense_features: list[DenseFeature],
|
|
40
|
+
sparse_features: list[SparseFeature],
|
|
41
|
+
sequence_features: list[SequenceFeature],
|
|
42
|
+
ctr_params: dict,
|
|
43
|
+
cvr_params: dict,
|
|
44
|
+
target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
|
|
45
|
+
task: str | list[str] = 'binary',
|
|
46
|
+
optimizer: str = "adam",
|
|
47
|
+
optimizer_params: dict = {},
|
|
48
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
49
|
+
device: str = 'cpu',
|
|
50
|
+
model_id: str = "baseline",
|
|
51
|
+
embedding_l1_reg=1e-6,
|
|
52
|
+
dense_l1_reg=1e-5,
|
|
53
|
+
embedding_l2_reg=1e-5,
|
|
54
|
+
dense_l2_reg=1e-4):
|
|
55
|
+
|
|
58
56
|
# ESMM requires exactly 2 targets: ctr and ctcvr
|
|
59
57
|
if len(target) != 2:
|
|
60
|
-
raise ValueError(
|
|
61
|
-
|
|
62
|
-
)
|
|
63
|
-
|
|
58
|
+
raise ValueError(f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}")
|
|
59
|
+
|
|
64
60
|
super(ESMM, self).__init__(
|
|
65
61
|
dense_features=dense_features,
|
|
66
62
|
sparse_features=sparse_features,
|
|
@@ -73,13 +69,13 @@ class ESMM(BaseModel):
|
|
|
73
69
|
embedding_l2_reg=embedding_l2_reg,
|
|
74
70
|
dense_l2_reg=dense_l2_reg,
|
|
75
71
|
early_stop_patience=20,
|
|
76
|
-
model_id=model_id
|
|
72
|
+
model_id=model_id
|
|
77
73
|
)
|
|
78
74
|
|
|
79
75
|
self.loss = loss
|
|
80
76
|
if self.loss is None:
|
|
81
77
|
self.loss = "bce"
|
|
82
|
-
|
|
78
|
+
|
|
83
79
|
# All features
|
|
84
80
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
85
81
|
|
|
@@ -87,48 +83,46 @@ class ESMM(BaseModel):
|
|
|
87
83
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
88
84
|
|
|
89
85
|
# Calculate input dimension
|
|
90
|
-
emb_dim_total = sum(
|
|
91
|
-
|
|
92
|
-
f.embedding_dim
|
|
93
|
-
for f in self.all_features
|
|
94
|
-
if not isinstance(f, DenseFeature)
|
|
95
|
-
]
|
|
96
|
-
)
|
|
97
|
-
dense_input_dim = sum(
|
|
98
|
-
[getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
|
|
99
|
-
)
|
|
86
|
+
emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
87
|
+
dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
100
88
|
input_dim = emb_dim_total + dense_input_dim
|
|
101
|
-
|
|
89
|
+
|
|
102
90
|
# CTR tower
|
|
103
91
|
self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
|
|
104
|
-
|
|
92
|
+
|
|
105
93
|
# CVR tower
|
|
106
94
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
107
95
|
self.prediction_layer = PredictionLayer(
|
|
108
|
-
task_type=self.task_type,
|
|
96
|
+
task_type=self.task_type,
|
|
97
|
+
task_dims=[1, 1]
|
|
109
98
|
)
|
|
110
99
|
|
|
111
100
|
# Register regularization weights
|
|
112
101
|
self._register_regularization_weights(
|
|
113
|
-
embedding_attr=
|
|
102
|
+
embedding_attr='embedding',
|
|
103
|
+
include_modules=['ctr_tower', 'cvr_tower']
|
|
114
104
|
)
|
|
115
105
|
|
|
116
|
-
self.compile(
|
|
106
|
+
self.compile(
|
|
107
|
+
optimizer=optimizer,
|
|
108
|
+
optimizer_params=optimizer_params,
|
|
109
|
+
loss=loss
|
|
110
|
+
)
|
|
117
111
|
|
|
118
112
|
def forward(self, x):
|
|
119
113
|
# Get all embeddings and flatten
|
|
120
114
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
121
|
-
|
|
115
|
+
|
|
122
116
|
# CTR prediction: P(click | impression)
|
|
123
117
|
ctr_logit = self.ctr_tower(input_flat) # [B, 1]
|
|
124
118
|
cvr_logit = self.cvr_tower(input_flat) # [B, 1]
|
|
125
119
|
logits = torch.cat([ctr_logit, cvr_logit], dim=1)
|
|
126
120
|
preds = self.prediction_layer(logits)
|
|
127
121
|
ctr, cvr = preds.chunk(2, dim=1)
|
|
128
|
-
|
|
122
|
+
|
|
129
123
|
# CTCVR prediction: P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
130
124
|
ctcvr = ctr * cvr # [B, 1]
|
|
131
|
-
|
|
125
|
+
|
|
132
126
|
# Output: [CTR, CTCVR]
|
|
133
127
|
# Note: We supervise CTR with click labels and CTCVR with conversion labels
|
|
134
128
|
y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
|
|
@@ -17,13 +17,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
|
17
17
|
class MMOE(BaseModel):
|
|
18
18
|
"""
|
|
19
19
|
Multi-gate Mixture-of-Experts
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
MMOE improves upon shared-bottom architecture by using multiple expert networks
|
|
22
22
|
and task-specific gating networks. Each task has its own gate that learns to
|
|
23
23
|
weight the contributions of different experts, allowing for both task-specific
|
|
24
24
|
and shared representations.
|
|
25
25
|
"""
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
@property
|
|
28
28
|
def model_name(self):
|
|
29
29
|
return "MMOE"
|
|
@@ -31,28 +31,26 @@ class MMOE(BaseModel):
|
|
|
31
31
|
@property
|
|
32
32
|
def task_type(self):
|
|
33
33
|
return self.task if isinstance(self.task, list) else [self.task]
|
|
34
|
-
|
|
35
|
-
def __init__(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
):
|
|
55
|
-
|
|
34
|
+
|
|
35
|
+
def __init__(self,
|
|
36
|
+
dense_features: list[DenseFeature]=[],
|
|
37
|
+
sparse_features: list[SparseFeature]=[],
|
|
38
|
+
sequence_features: list[SequenceFeature]=[],
|
|
39
|
+
expert_params: dict={},
|
|
40
|
+
num_experts: int=3,
|
|
41
|
+
tower_params_list: list[dict]=[],
|
|
42
|
+
target: list[str]=[],
|
|
43
|
+
task: str | list[str] = 'binary',
|
|
44
|
+
optimizer: str = "adam",
|
|
45
|
+
optimizer_params: dict = {},
|
|
46
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
47
|
+
device: str = 'cpu',
|
|
48
|
+
model_id: str = "baseline",
|
|
49
|
+
embedding_l1_reg=1e-6,
|
|
50
|
+
dense_l1_reg=1e-5,
|
|
51
|
+
embedding_l2_reg=1e-5,
|
|
52
|
+
dense_l2_reg=1e-4):
|
|
53
|
+
|
|
56
54
|
super(MMOE, self).__init__(
|
|
57
55
|
dense_features=dense_features,
|
|
58
56
|
sparse_features=sparse_features,
|
|
@@ -65,22 +63,20 @@ class MMOE(BaseModel):
|
|
|
65
63
|
embedding_l2_reg=embedding_l2_reg,
|
|
66
64
|
dense_l2_reg=dense_l2_reg,
|
|
67
65
|
early_stop_patience=20,
|
|
68
|
-
model_id=model_id
|
|
66
|
+
model_id=model_id
|
|
69
67
|
)
|
|
70
68
|
|
|
71
69
|
self.loss = loss
|
|
72
70
|
if self.loss is None:
|
|
73
71
|
self.loss = "bce"
|
|
74
|
-
|
|
72
|
+
|
|
75
73
|
# Number of tasks and experts
|
|
76
74
|
self.num_tasks = len(target)
|
|
77
75
|
self.num_experts = num_experts
|
|
78
|
-
|
|
76
|
+
|
|
79
77
|
if len(tower_params_list) != self.num_tasks:
|
|
80
|
-
raise ValueError(
|
|
81
|
-
|
|
82
|
-
)
|
|
83
|
-
|
|
78
|
+
raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
|
|
79
|
+
|
|
84
80
|
# All features
|
|
85
81
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
86
82
|
|
|
@@ -88,83 +84,78 @@ class MMOE(BaseModel):
|
|
|
88
84
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
89
85
|
|
|
90
86
|
# Calculate input dimension
|
|
91
|
-
emb_dim_total = sum(
|
|
92
|
-
|
|
93
|
-
f.embedding_dim
|
|
94
|
-
for f in self.all_features
|
|
95
|
-
if not isinstance(f, DenseFeature)
|
|
96
|
-
]
|
|
97
|
-
)
|
|
98
|
-
dense_input_dim = sum(
|
|
99
|
-
[getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
|
|
100
|
-
)
|
|
87
|
+
emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
88
|
+
dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
101
89
|
input_dim = emb_dim_total + dense_input_dim
|
|
102
|
-
|
|
90
|
+
|
|
103
91
|
# Expert networks (shared by all tasks)
|
|
104
92
|
self.experts = nn.ModuleList()
|
|
105
93
|
for _ in range(num_experts):
|
|
106
94
|
expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
|
|
107
95
|
self.experts.append(expert)
|
|
108
|
-
|
|
96
|
+
|
|
109
97
|
# Get expert output dimension
|
|
110
|
-
if
|
|
111
|
-
expert_output_dim = expert_params[
|
|
98
|
+
if 'dims' in expert_params and len(expert_params['dims']) > 0:
|
|
99
|
+
expert_output_dim = expert_params['dims'][-1]
|
|
112
100
|
else:
|
|
113
101
|
expert_output_dim = input_dim
|
|
114
|
-
|
|
102
|
+
|
|
115
103
|
# Task-specific gates
|
|
116
104
|
self.gates = nn.ModuleList()
|
|
117
105
|
for _ in range(self.num_tasks):
|
|
118
|
-
gate = nn.Sequential(
|
|
106
|
+
gate = nn.Sequential(
|
|
107
|
+
nn.Linear(input_dim, num_experts),
|
|
108
|
+
nn.Softmax(dim=1)
|
|
109
|
+
)
|
|
119
110
|
self.gates.append(gate)
|
|
120
|
-
|
|
111
|
+
|
|
121
112
|
# Task-specific towers
|
|
122
113
|
self.towers = nn.ModuleList()
|
|
123
114
|
for tower_params in tower_params_list:
|
|
124
115
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
125
116
|
self.towers.append(tower)
|
|
126
117
|
self.prediction_layer = PredictionLayer(
|
|
127
|
-
task_type=self.task_type,
|
|
118
|
+
task_type=self.task_type,
|
|
119
|
+
task_dims=[1] * self.num_tasks
|
|
128
120
|
)
|
|
129
121
|
|
|
130
122
|
# Register regularization weights
|
|
131
123
|
self._register_regularization_weights(
|
|
132
|
-
embedding_attr=
|
|
124
|
+
embedding_attr='embedding',
|
|
125
|
+
include_modules=['experts', 'gates', 'towers']
|
|
133
126
|
)
|
|
134
127
|
|
|
135
|
-
self.compile(
|
|
128
|
+
self.compile(
|
|
129
|
+
optimizer=optimizer,
|
|
130
|
+
optimizer_params=optimizer_params,
|
|
131
|
+
loss=loss
|
|
132
|
+
)
|
|
136
133
|
|
|
137
134
|
def forward(self, x):
|
|
138
135
|
# Get all embeddings and flatten
|
|
139
136
|
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
140
|
-
|
|
137
|
+
|
|
141
138
|
# Expert outputs: [num_experts, B, expert_dim]
|
|
142
139
|
expert_outputs = [expert(input_flat) for expert in self.experts]
|
|
143
|
-
expert_outputs = torch.stack(
|
|
144
|
-
|
|
145
|
-
) # [num_experts, B, expert_dim]
|
|
146
|
-
|
|
140
|
+
expert_outputs = torch.stack(expert_outputs, dim=0) # [num_experts, B, expert_dim]
|
|
141
|
+
|
|
147
142
|
# Task-specific processing
|
|
148
143
|
task_outputs = []
|
|
149
144
|
for task_idx in range(self.num_tasks):
|
|
150
145
|
# Gate weights for this task: [B, num_experts]
|
|
151
146
|
gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
|
|
152
|
-
|
|
147
|
+
|
|
153
148
|
# Weighted sum of expert outputs
|
|
154
149
|
# gate_weights: [B, num_experts, 1]
|
|
155
150
|
# expert_outputs: [num_experts, B, expert_dim]
|
|
156
151
|
gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
|
|
157
|
-
expert_outputs_t = expert_outputs.permute(
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
gated_output = torch.sum(
|
|
161
|
-
gate_weights * expert_outputs_t, dim=1
|
|
162
|
-
) # [B, expert_dim]
|
|
163
|
-
|
|
152
|
+
expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, num_experts, expert_dim]
|
|
153
|
+
gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1) # [B, expert_dim]
|
|
154
|
+
|
|
164
155
|
# Tower output
|
|
165
156
|
tower_output = self.towers[task_idx](gated_output) # [B, 1]
|
|
166
157
|
task_outputs.append(tower_output)
|
|
167
|
-
|
|
158
|
+
|
|
168
159
|
# Stack outputs: [B, num_tasks]
|
|
169
160
|
y = torch.cat(task_outputs, dim=1)
|
|
170
161
|
return self.prediction_layer(y)
|