nextrec 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +41 -0
- nextrec/__version__.py +1 -0
- nextrec/basic/__init__.py +0 -0
- nextrec/basic/activation.py +92 -0
- nextrec/basic/callback.py +35 -0
- nextrec/basic/dataloader.py +447 -0
- nextrec/basic/features.py +87 -0
- nextrec/basic/layers.py +985 -0
- nextrec/basic/loggers.py +124 -0
- nextrec/basic/metrics.py +557 -0
- nextrec/basic/model.py +1438 -0
- nextrec/data/__init__.py +27 -0
- nextrec/data/data_utils.py +132 -0
- nextrec/data/preprocessor.py +662 -0
- nextrec/loss/__init__.py +35 -0
- nextrec/loss/loss_utils.py +136 -0
- nextrec/loss/match_losses.py +294 -0
- nextrec/models/generative/hstu.py +0 -0
- nextrec/models/generative/tiger.py +0 -0
- nextrec/models/match/__init__.py +13 -0
- nextrec/models/match/dssm.py +200 -0
- nextrec/models/match/dssm_v2.py +162 -0
- nextrec/models/match/mind.py +210 -0
- nextrec/models/match/sdm.py +253 -0
- nextrec/models/match/youtube_dnn.py +172 -0
- nextrec/models/multi_task/esmm.py +129 -0
- nextrec/models/multi_task/mmoe.py +161 -0
- nextrec/models/multi_task/ple.py +260 -0
- nextrec/models/multi_task/share_bottom.py +126 -0
- nextrec/models/ranking/__init__.py +17 -0
- nextrec/models/ranking/afm.py +118 -0
- nextrec/models/ranking/autoint.py +140 -0
- nextrec/models/ranking/dcn.py +120 -0
- nextrec/models/ranking/deepfm.py +95 -0
- nextrec/models/ranking/dien.py +214 -0
- nextrec/models/ranking/din.py +181 -0
- nextrec/models/ranking/fibinet.py +130 -0
- nextrec/models/ranking/fm.py +87 -0
- nextrec/models/ranking/masknet.py +125 -0
- nextrec/models/ranking/pnn.py +128 -0
- nextrec/models/ranking/widedeep.py +105 -0
- nextrec/models/ranking/xdeepfm.py +117 -0
- nextrec/utils/__init__.py +18 -0
- nextrec/utils/common.py +14 -0
- nextrec/utils/embedding.py +19 -0
- nextrec/utils/initializer.py +47 -0
- nextrec/utils/optimizer.py +75 -0
- nextrec-0.1.1.dist-info/METADATA +302 -0
- nextrec-0.1.1.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/WHEEL +4 -0
- nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 09/11/2025
|
|
3
|
+
Author:
|
|
4
|
+
Yang Zhou,zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
[1] Covington P, Adams J, Sargin E. Deep neural networks for youtube recommendations[C]
|
|
7
|
+
//Proceedings of the 10th ACM conference on recommender systems. 2016: 191-198.
|
|
8
|
+
"""
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
13
|
+
from nextrec.basic.model import BaseMatchModel
|
|
14
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
15
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer, AveragePooling
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class YoutubeDNN(BaseMatchModel):
|
|
19
|
+
"""
|
|
20
|
+
YouTube Deep Neural Network for Recommendations
|
|
21
|
+
|
|
22
|
+
用户塔:历史行为序列 + 用户特征 -> 用户embedding
|
|
23
|
+
物品塔:物品特征 -> 物品embedding
|
|
24
|
+
训练:sampled softmax loss (listwise)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def model_name(self) -> str:
|
|
29
|
+
return "YouTubeDNN"
|
|
30
|
+
|
|
31
|
+
def __init__(self,
|
|
32
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
33
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
34
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
35
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
36
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
37
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
38
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
39
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
40
|
+
embedding_dim: int = 64,
|
|
41
|
+
dnn_activation: str = 'relu',
|
|
42
|
+
dnn_dropout: float = 0.0,
|
|
43
|
+
training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'listwise',
|
|
44
|
+
num_negative_samples: int = 100,
|
|
45
|
+
temperature: float = 1.0,
|
|
46
|
+
similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
|
|
47
|
+
device: str = 'cpu',
|
|
48
|
+
embedding_l1_reg: float = 0.0,
|
|
49
|
+
dense_l1_reg: float = 0.0,
|
|
50
|
+
embedding_l2_reg: float = 0.0,
|
|
51
|
+
dense_l2_reg: float = 0.0,
|
|
52
|
+
early_stop_patience: int = 20,
|
|
53
|
+
model_id: str = 'youtube_dnn'):
|
|
54
|
+
|
|
55
|
+
super(YoutubeDNN, self).__init__(
|
|
56
|
+
user_dense_features=user_dense_features,
|
|
57
|
+
user_sparse_features=user_sparse_features,
|
|
58
|
+
user_sequence_features=user_sequence_features,
|
|
59
|
+
item_dense_features=item_dense_features,
|
|
60
|
+
item_sparse_features=item_sparse_features,
|
|
61
|
+
item_sequence_features=item_sequence_features,
|
|
62
|
+
training_mode=training_mode,
|
|
63
|
+
num_negative_samples=num_negative_samples,
|
|
64
|
+
temperature=temperature,
|
|
65
|
+
similarity_metric=similarity_metric,
|
|
66
|
+
device=device,
|
|
67
|
+
embedding_l1_reg=embedding_l1_reg,
|
|
68
|
+
dense_l1_reg=dense_l1_reg,
|
|
69
|
+
embedding_l2_reg=embedding_l2_reg,
|
|
70
|
+
dense_l2_reg=dense_l2_reg,
|
|
71
|
+
early_stop_patience=early_stop_patience,
|
|
72
|
+
model_id=model_id
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self.embedding_dim = embedding_dim
|
|
76
|
+
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
77
|
+
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
78
|
+
|
|
79
|
+
# User tower
|
|
80
|
+
user_features = []
|
|
81
|
+
if user_dense_features:
|
|
82
|
+
user_features.extend(user_dense_features)
|
|
83
|
+
if user_sparse_features:
|
|
84
|
+
user_features.extend(user_sparse_features)
|
|
85
|
+
if user_sequence_features:
|
|
86
|
+
user_features.extend(user_sequence_features)
|
|
87
|
+
|
|
88
|
+
if len(user_features) > 0:
|
|
89
|
+
self.user_embedding = EmbeddingLayer(user_features)
|
|
90
|
+
|
|
91
|
+
user_input_dim = 0
|
|
92
|
+
for feat in user_dense_features or []:
|
|
93
|
+
user_input_dim += 1
|
|
94
|
+
for feat in user_sparse_features or []:
|
|
95
|
+
user_input_dim += feat.embedding_dim
|
|
96
|
+
for feat in user_sequence_features or []:
|
|
97
|
+
# 序列特征通过平均池化聚合
|
|
98
|
+
user_input_dim += feat.embedding_dim
|
|
99
|
+
|
|
100
|
+
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
101
|
+
self.user_dnn = MLP(
|
|
102
|
+
input_dim=user_input_dim,
|
|
103
|
+
dims=user_dnn_units,
|
|
104
|
+
output_layer=False,
|
|
105
|
+
dropout=dnn_dropout,
|
|
106
|
+
activation=dnn_activation
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Item tower
|
|
110
|
+
item_features = []
|
|
111
|
+
if item_dense_features:
|
|
112
|
+
item_features.extend(item_dense_features)
|
|
113
|
+
if item_sparse_features:
|
|
114
|
+
item_features.extend(item_sparse_features)
|
|
115
|
+
if item_sequence_features:
|
|
116
|
+
item_features.extend(item_sequence_features)
|
|
117
|
+
|
|
118
|
+
if len(item_features) > 0:
|
|
119
|
+
self.item_embedding = EmbeddingLayer(item_features)
|
|
120
|
+
|
|
121
|
+
item_input_dim = 0
|
|
122
|
+
for feat in item_dense_features or []:
|
|
123
|
+
item_input_dim += 1
|
|
124
|
+
for feat in item_sparse_features or []:
|
|
125
|
+
item_input_dim += feat.embedding_dim
|
|
126
|
+
for feat in item_sequence_features or []:
|
|
127
|
+
item_input_dim += feat.embedding_dim
|
|
128
|
+
|
|
129
|
+
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
130
|
+
self.item_dnn = MLP(
|
|
131
|
+
input_dim=item_input_dim,
|
|
132
|
+
dims=item_dnn_units,
|
|
133
|
+
output_layer=False,
|
|
134
|
+
dropout=dnn_dropout,
|
|
135
|
+
activation=dnn_activation
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self._register_regularization_weights(
|
|
139
|
+
embedding_attr='user_embedding',
|
|
140
|
+
include_modules=['user_dnn']
|
|
141
|
+
)
|
|
142
|
+
self._register_regularization_weights(
|
|
143
|
+
embedding_attr='item_embedding',
|
|
144
|
+
include_modules=['item_dnn']
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
self.to(device)
|
|
148
|
+
|
|
149
|
+
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
150
|
+
"""
|
|
151
|
+
User tower
|
|
152
|
+
处理用户历史行为序列和其他用户特征
|
|
153
|
+
"""
|
|
154
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
155
|
+
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
156
|
+
user_emb = self.user_dnn(user_emb)
|
|
157
|
+
|
|
158
|
+
# L2 normalization
|
|
159
|
+
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
160
|
+
|
|
161
|
+
return user_emb
|
|
162
|
+
|
|
163
|
+
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
164
|
+
"""Item tower"""
|
|
165
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
166
|
+
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
167
|
+
item_emb = self.item_dnn(item_emb)
|
|
168
|
+
|
|
169
|
+
# L2 normalization
|
|
170
|
+
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
171
|
+
|
|
172
|
+
return item_emb
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 09/11/2025
|
|
3
|
+
Author:
|
|
4
|
+
Yang Zhou,zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
[1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
from nextrec.basic.model import BaseModel
|
|
13
|
+
from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
|
|
14
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ESMM(BaseModel):
|
|
18
|
+
"""
|
|
19
|
+
Entire Space Multi-Task Model
|
|
20
|
+
|
|
21
|
+
ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
|
|
22
|
+
- CTR task: P(click | impression)
|
|
23
|
+
- CVR task: P(conversion | click)
|
|
24
|
+
- CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
25
|
+
|
|
26
|
+
This design addresses the sample selection bias and data sparsity issues in CVR modeling.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def model_name(self):
|
|
31
|
+
return "ESMM"
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def task_type(self):
|
|
35
|
+
# ESMM has fixed task types: CTR (binary) and CVR (binary)
|
|
36
|
+
return ['binary', 'binary']
|
|
37
|
+
|
|
38
|
+
def __init__(self,
|
|
39
|
+
dense_features: list[DenseFeature],
|
|
40
|
+
sparse_features: list[SparseFeature],
|
|
41
|
+
sequence_features: list[SequenceFeature],
|
|
42
|
+
ctr_params: dict,
|
|
43
|
+
cvr_params: dict,
|
|
44
|
+
target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
|
|
45
|
+
task: str | list[str] = 'binary',
|
|
46
|
+
optimizer: str = "adam",
|
|
47
|
+
optimizer_params: dict = {},
|
|
48
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
49
|
+
device: str = 'cpu',
|
|
50
|
+
model_id: str = "baseline",
|
|
51
|
+
embedding_l1_reg=1e-6,
|
|
52
|
+
dense_l1_reg=1e-5,
|
|
53
|
+
embedding_l2_reg=1e-5,
|
|
54
|
+
dense_l2_reg=1e-4):
|
|
55
|
+
|
|
56
|
+
# ESMM requires exactly 2 targets: ctr and ctcvr
|
|
57
|
+
if len(target) != 2:
|
|
58
|
+
raise ValueError(f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}")
|
|
59
|
+
|
|
60
|
+
super(ESMM, self).__init__(
|
|
61
|
+
dense_features=dense_features,
|
|
62
|
+
sparse_features=sparse_features,
|
|
63
|
+
sequence_features=sequence_features,
|
|
64
|
+
target=target,
|
|
65
|
+
task=task, # Both CTR and CTCVR are binary classification
|
|
66
|
+
device=device,
|
|
67
|
+
embedding_l1_reg=embedding_l1_reg,
|
|
68
|
+
dense_l1_reg=dense_l1_reg,
|
|
69
|
+
embedding_l2_reg=embedding_l2_reg,
|
|
70
|
+
dense_l2_reg=dense_l2_reg,
|
|
71
|
+
early_stop_patience=20,
|
|
72
|
+
model_id=model_id
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self.loss = loss
|
|
76
|
+
if self.loss is None:
|
|
77
|
+
self.loss = "bce"
|
|
78
|
+
|
|
79
|
+
# All features
|
|
80
|
+
self.all_features = dense_features + sparse_features + sequence_features
|
|
81
|
+
|
|
82
|
+
# Shared embedding layer
|
|
83
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
84
|
+
|
|
85
|
+
# Calculate input dimension
|
|
86
|
+
emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
87
|
+
dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
88
|
+
input_dim = emb_dim_total + dense_input_dim
|
|
89
|
+
|
|
90
|
+
# CTR tower
|
|
91
|
+
self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
|
|
92
|
+
|
|
93
|
+
# CVR tower
|
|
94
|
+
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
95
|
+
self.prediction_layer = PredictionLayer(
|
|
96
|
+
task_type=self.task_type,
|
|
97
|
+
task_dims=[1, 1]
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Register regularization weights
|
|
101
|
+
self._register_regularization_weights(
|
|
102
|
+
embedding_attr='embedding',
|
|
103
|
+
include_modules=['ctr_tower', 'cvr_tower']
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self.compile(
|
|
107
|
+
optimizer=optimizer,
|
|
108
|
+
optimizer_params=optimizer_params,
|
|
109
|
+
loss=loss
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def forward(self, x):
|
|
113
|
+
# Get all embeddings and flatten
|
|
114
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
115
|
+
|
|
116
|
+
# CTR prediction: P(click | impression)
|
|
117
|
+
ctr_logit = self.ctr_tower(input_flat) # [B, 1]
|
|
118
|
+
cvr_logit = self.cvr_tower(input_flat) # [B, 1]
|
|
119
|
+
logits = torch.cat([ctr_logit, cvr_logit], dim=1)
|
|
120
|
+
preds = self.prediction_layer(logits)
|
|
121
|
+
ctr, cvr = preds.chunk(2, dim=1)
|
|
122
|
+
|
|
123
|
+
# CTCVR prediction: P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
124
|
+
ctcvr = ctr * cvr # [B, 1]
|
|
125
|
+
|
|
126
|
+
# Output: [CTR, CTCVR]
|
|
127
|
+
# Note: We supervise CTR with click labels and CTCVR with conversion labels
|
|
128
|
+
y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
|
|
129
|
+
return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 09/11/2025
|
|
3
|
+
Author:
|
|
4
|
+
Yang Zhou,zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
[1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
from nextrec.basic.model import BaseModel
|
|
13
|
+
from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
|
|
14
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MMOE(BaseModel):
|
|
18
|
+
"""
|
|
19
|
+
Multi-gate Mixture-of-Experts
|
|
20
|
+
|
|
21
|
+
MMOE improves upon shared-bottom architecture by using multiple expert networks
|
|
22
|
+
and task-specific gating networks. Each task has its own gate that learns to
|
|
23
|
+
weight the contributions of different experts, allowing for both task-specific
|
|
24
|
+
and shared representations.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def model_name(self):
|
|
29
|
+
return "MMOE"
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def task_type(self):
|
|
33
|
+
return self.task if isinstance(self.task, list) else [self.task]
|
|
34
|
+
|
|
35
|
+
def __init__(self,
|
|
36
|
+
dense_features: list[DenseFeature]=[],
|
|
37
|
+
sparse_features: list[SparseFeature]=[],
|
|
38
|
+
sequence_features: list[SequenceFeature]=[],
|
|
39
|
+
expert_params: dict={},
|
|
40
|
+
num_experts: int=3,
|
|
41
|
+
tower_params_list: list[dict]=[],
|
|
42
|
+
target: list[str]=[],
|
|
43
|
+
task: str | list[str] = 'binary',
|
|
44
|
+
optimizer: str = "adam",
|
|
45
|
+
optimizer_params: dict = {},
|
|
46
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
47
|
+
device: str = 'cpu',
|
|
48
|
+
model_id: str = "baseline",
|
|
49
|
+
embedding_l1_reg=1e-6,
|
|
50
|
+
dense_l1_reg=1e-5,
|
|
51
|
+
embedding_l2_reg=1e-5,
|
|
52
|
+
dense_l2_reg=1e-4):
|
|
53
|
+
|
|
54
|
+
super(MMOE, self).__init__(
|
|
55
|
+
dense_features=dense_features,
|
|
56
|
+
sparse_features=sparse_features,
|
|
57
|
+
sequence_features=sequence_features,
|
|
58
|
+
target=target,
|
|
59
|
+
task=task,
|
|
60
|
+
device=device,
|
|
61
|
+
embedding_l1_reg=embedding_l1_reg,
|
|
62
|
+
dense_l1_reg=dense_l1_reg,
|
|
63
|
+
embedding_l2_reg=embedding_l2_reg,
|
|
64
|
+
dense_l2_reg=dense_l2_reg,
|
|
65
|
+
early_stop_patience=20,
|
|
66
|
+
model_id=model_id
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.loss = loss
|
|
70
|
+
if self.loss is None:
|
|
71
|
+
self.loss = "bce"
|
|
72
|
+
|
|
73
|
+
# Number of tasks and experts
|
|
74
|
+
self.num_tasks = len(target)
|
|
75
|
+
self.num_experts = num_experts
|
|
76
|
+
|
|
77
|
+
if len(tower_params_list) != self.num_tasks:
|
|
78
|
+
raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
|
|
79
|
+
|
|
80
|
+
# All features
|
|
81
|
+
self.all_features = dense_features + sparse_features + sequence_features
|
|
82
|
+
|
|
83
|
+
# Embedding layer
|
|
84
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
85
|
+
|
|
86
|
+
# Calculate input dimension
|
|
87
|
+
emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
88
|
+
dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
89
|
+
input_dim = emb_dim_total + dense_input_dim
|
|
90
|
+
|
|
91
|
+
# Expert networks (shared by all tasks)
|
|
92
|
+
self.experts = nn.ModuleList()
|
|
93
|
+
for _ in range(num_experts):
|
|
94
|
+
expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
|
|
95
|
+
self.experts.append(expert)
|
|
96
|
+
|
|
97
|
+
# Get expert output dimension
|
|
98
|
+
if 'dims' in expert_params and len(expert_params['dims']) > 0:
|
|
99
|
+
expert_output_dim = expert_params['dims'][-1]
|
|
100
|
+
else:
|
|
101
|
+
expert_output_dim = input_dim
|
|
102
|
+
|
|
103
|
+
# Task-specific gates
|
|
104
|
+
self.gates = nn.ModuleList()
|
|
105
|
+
for _ in range(self.num_tasks):
|
|
106
|
+
gate = nn.Sequential(
|
|
107
|
+
nn.Linear(input_dim, num_experts),
|
|
108
|
+
nn.Softmax(dim=1)
|
|
109
|
+
)
|
|
110
|
+
self.gates.append(gate)
|
|
111
|
+
|
|
112
|
+
# Task-specific towers
|
|
113
|
+
self.towers = nn.ModuleList()
|
|
114
|
+
for tower_params in tower_params_list:
|
|
115
|
+
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
116
|
+
self.towers.append(tower)
|
|
117
|
+
self.prediction_layer = PredictionLayer(
|
|
118
|
+
task_type=self.task_type,
|
|
119
|
+
task_dims=[1] * self.num_tasks
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Register regularization weights
|
|
123
|
+
self._register_regularization_weights(
|
|
124
|
+
embedding_attr='embedding',
|
|
125
|
+
include_modules=['experts', 'gates', 'towers']
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.compile(
|
|
129
|
+
optimizer=optimizer,
|
|
130
|
+
optimizer_params=optimizer_params,
|
|
131
|
+
loss=loss
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def forward(self, x):
|
|
135
|
+
# Get all embeddings and flatten
|
|
136
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
137
|
+
|
|
138
|
+
# Expert outputs: [num_experts, B, expert_dim]
|
|
139
|
+
expert_outputs = [expert(input_flat) for expert in self.experts]
|
|
140
|
+
expert_outputs = torch.stack(expert_outputs, dim=0) # [num_experts, B, expert_dim]
|
|
141
|
+
|
|
142
|
+
# Task-specific processing
|
|
143
|
+
task_outputs = []
|
|
144
|
+
for task_idx in range(self.num_tasks):
|
|
145
|
+
# Gate weights for this task: [B, num_experts]
|
|
146
|
+
gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
|
|
147
|
+
|
|
148
|
+
# Weighted sum of expert outputs
|
|
149
|
+
# gate_weights: [B, num_experts, 1]
|
|
150
|
+
# expert_outputs: [num_experts, B, expert_dim]
|
|
151
|
+
gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
|
|
152
|
+
expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, num_experts, expert_dim]
|
|
153
|
+
gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1) # [B, expert_dim]
|
|
154
|
+
|
|
155
|
+
# Tower output
|
|
156
|
+
tower_output = self.towers[task_idx](gated_output) # [B, 1]
|
|
157
|
+
task_outputs.append(tower_output)
|
|
158
|
+
|
|
159
|
+
# Stack outputs: [B, num_tasks]
|
|
160
|
+
y = torch.cat(task_outputs, dim=1)
|
|
161
|
+
return self.prediction_layer(y)
|