nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/models/ranking/din.py
CHANGED
|
@@ -1,18 +1,62 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Checkpoint: edit on 09/12/2025
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
6
|
+
[1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate
|
|
7
|
+
prediction[C] //Proceedings of the 24th ACM SIGKDD international conference on
|
|
8
|
+
knowledge discovery & data mining. 2018: 1059-1068.
|
|
9
|
+
(https://arxiv.org/abs/1706.06978)
|
|
10
|
+
|
|
11
|
+
Deep Interest Network (DIN) is a CTR model that builds a target-aware user
|
|
12
|
+
representation by attending over the historical behavior sequence. Instead of
|
|
13
|
+
compressing all behaviors into one static vector, DIN highlights the behaviors
|
|
14
|
+
most relevant to the current candidate item, enabling adaptive interest
|
|
15
|
+
modeling for each request.
|
|
16
|
+
|
|
17
|
+
Pipeline:
|
|
18
|
+
(1) Embed candidate item, user behavior sequence, and other sparse/dense fields
|
|
19
|
+
(2) Use a small attention MLP to score each historical behavior against the
|
|
20
|
+
candidate embedding
|
|
21
|
+
(3) Apply masked weighted pooling to obtain a target-specific interest vector
|
|
22
|
+
(4) Concatenate candidate, interest vector, other sparse embeddings, and dense
|
|
23
|
+
features
|
|
24
|
+
(5) Feed the combined representation into an MLP for final prediction
|
|
25
|
+
|
|
26
|
+
Key Advantages:
|
|
27
|
+
- Target-aware attention captures fine-grained interests per candidate item
|
|
28
|
+
- Adaptive pooling handles diverse behavior patterns without heavy feature crafting
|
|
29
|
+
- Masked weighting reduces noise from padded sequence positions
|
|
30
|
+
- Easily incorporates additional sparse/dense context features alongside behavior
|
|
31
|
+
|
|
32
|
+
DIN 是一个 CTR 预估模型,通过对用户历史行为序列进行目标感知的注意力加权,
|
|
33
|
+
构建针对当前候选物品的兴趣表示。它不是将全部行为压缩为固定向量,而是突出
|
|
34
|
+
与候选物品最相关的行为,实现请求级的自适应兴趣建模。
|
|
35
|
+
|
|
36
|
+
处理流程:
|
|
37
|
+
(1) 对候选物品、用户行为序列及其他稀疏/稠密特征做 embedding
|
|
38
|
+
(2) 使用小型注意力 MLP 计算每个历史行为与候选 embedding 的相关性
|
|
39
|
+
(3) 通过掩码加权池化得到目标特定的兴趣向量
|
|
40
|
+
(4) 拼接候选、兴趣向量、其他稀疏 embedding 与稠密特征
|
|
41
|
+
(5) 输入 MLP 完成最终点击率预测
|
|
42
|
+
|
|
43
|
+
主要优点:
|
|
44
|
+
- 目标感知注意力捕捉候选级的细粒度兴趣
|
|
45
|
+
- 自适应池化应对多样化行为模式,减少手工特征工程
|
|
46
|
+
- 掩码加权降低序列填充位置的噪声
|
|
47
|
+
- 便捷融合行为与额外稀疏/稠密上下文信息
|
|
9
48
|
"""
|
|
10
49
|
|
|
11
50
|
import torch
|
|
12
51
|
import torch.nn as nn
|
|
13
52
|
|
|
14
53
|
from nextrec.basic.model import BaseModel
|
|
15
|
-
from nextrec.basic.layers import
|
|
54
|
+
from nextrec.basic.layers import (
|
|
55
|
+
EmbeddingLayer,
|
|
56
|
+
MLP,
|
|
57
|
+
AttentionPoolingLayer,
|
|
58
|
+
PredictionLayer,
|
|
59
|
+
)
|
|
16
60
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
17
61
|
|
|
18
62
|
|
|
@@ -24,28 +68,41 @@ class DIN(BaseModel):
|
|
|
24
68
|
@property
|
|
25
69
|
def default_task(self):
|
|
26
70
|
return "binary"
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
dense_features: list[DenseFeature] | None = None,
|
|
75
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
76
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
77
|
+
behavior_feature_name: str | None = None,
|
|
78
|
+
candidate_feature_name: str | None = None,
|
|
79
|
+
mlp_params: dict | None = None,
|
|
80
|
+
attention_hidden_units: list[int] | None = None,
|
|
81
|
+
attention_activation: str = "dice",
|
|
82
|
+
attention_use_softmax: bool = True,
|
|
83
|
+
target: list[str] | str | None = None,
|
|
84
|
+
task: str | list[str] | None = None,
|
|
85
|
+
optimizer: str = "adam",
|
|
86
|
+
optimizer_params: dict | None = None,
|
|
87
|
+
loss: str | nn.Module | None = "bce",
|
|
88
|
+
loss_params: dict | list[dict] | None = None,
|
|
89
|
+
device: str = "cpu",
|
|
90
|
+
embedding_l1_reg=1e-6,
|
|
91
|
+
dense_l1_reg=1e-5,
|
|
92
|
+
embedding_l2_reg=1e-5,
|
|
93
|
+
dense_l2_reg=1e-4,
|
|
94
|
+
**kwargs,
|
|
95
|
+
):
|
|
96
|
+
|
|
97
|
+
dense_features = dense_features or []
|
|
98
|
+
sparse_features = sparse_features or []
|
|
99
|
+
sequence_features = sequence_features or []
|
|
100
|
+
mlp_params = mlp_params or {}
|
|
101
|
+
attention_hidden_units = attention_hidden_units or [80, 40]
|
|
102
|
+
optimizer_params = optimizer_params or {}
|
|
103
|
+
if loss is None:
|
|
104
|
+
loss = "bce"
|
|
105
|
+
|
|
49
106
|
super(DIN, self).__init__(
|
|
50
107
|
dense_features=dense_features,
|
|
51
108
|
sparse_features=sparse_features,
|
|
@@ -57,43 +114,52 @@ class DIN(BaseModel):
|
|
|
57
114
|
dense_l1_reg=dense_l1_reg,
|
|
58
115
|
embedding_l2_reg=embedding_l2_reg,
|
|
59
116
|
dense_l2_reg=dense_l2_reg,
|
|
60
|
-
**kwargs
|
|
117
|
+
**kwargs,
|
|
61
118
|
)
|
|
62
119
|
|
|
63
|
-
|
|
64
|
-
if self.loss is None:
|
|
65
|
-
self.loss = "bce"
|
|
66
|
-
|
|
67
|
-
# Features classification
|
|
68
|
-
# DIN requires: candidate item + user behavior sequence + other features
|
|
120
|
+
# DIN requires: user behavior sequence + candidate item + other features
|
|
69
121
|
if len(sequence_features) == 0:
|
|
70
|
-
raise ValueError(
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"DIN requires at least one sequence feature for user behavior history"
|
|
124
|
+
)
|
|
125
|
+
if behavior_feature_name is None:
|
|
126
|
+
raise ValueError("DIN requires an explicit behavior_feature_name")
|
|
127
|
+
|
|
128
|
+
if candidate_feature_name is None:
|
|
129
|
+
raise ValueError("DIN requires an explicit candidate_feature_name")
|
|
130
|
+
|
|
131
|
+
self.behavior_feature = [
|
|
132
|
+
f for f in sequence_features if f.name == behavior_feature_name
|
|
133
|
+
][0]
|
|
134
|
+
self.candidate_feature = [
|
|
135
|
+
f for f in sparse_features if f.name == candidate_feature_name
|
|
136
|
+
][0]
|
|
137
|
+
|
|
138
|
+
# Other sparse features
|
|
139
|
+
self.other_sparse_features = [
|
|
140
|
+
f for f in sparse_features if f.name != self.candidate_feature.name
|
|
141
|
+
]
|
|
81
142
|
|
|
82
143
|
# Embedding layer
|
|
83
144
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
84
|
-
|
|
145
|
+
|
|
85
146
|
# Attention layer for behavior sequence
|
|
86
147
|
behavior_emb_dim = self.behavior_feature.embedding_dim
|
|
87
148
|
self.candidate_attention_proj = None
|
|
88
|
-
if
|
|
89
|
-
self.
|
|
149
|
+
if (
|
|
150
|
+
self.candidate_feature is not None
|
|
151
|
+
and self.candidate_feature.embedding_dim != behavior_emb_dim
|
|
152
|
+
):
|
|
153
|
+
self.candidate_attention_proj = nn.Linear(
|
|
154
|
+
self.candidate_feature.embedding_dim, behavior_emb_dim
|
|
155
|
+
)
|
|
90
156
|
self.attention = AttentionPoolingLayer(
|
|
91
157
|
embedding_dim=behavior_emb_dim,
|
|
92
158
|
hidden_units=attention_hidden_units,
|
|
93
159
|
activation=attention_activation,
|
|
94
|
-
use_softmax=attention_use_softmax
|
|
160
|
+
use_softmax=attention_use_softmax,
|
|
95
161
|
)
|
|
96
|
-
|
|
162
|
+
|
|
97
163
|
# Calculate MLP input dimension
|
|
98
164
|
# candidate + attention_pooled_behavior + other_sparse + dense
|
|
99
165
|
mlp_input_dim = 0
|
|
@@ -101,16 +167,18 @@ class DIN(BaseModel):
|
|
|
101
167
|
mlp_input_dim += self.candidate_feature.embedding_dim
|
|
102
168
|
mlp_input_dim += behavior_emb_dim # attention pooled
|
|
103
169
|
mlp_input_dim += sum([f.embedding_dim for f in self.other_sparse_features])
|
|
104
|
-
mlp_input_dim += sum(
|
|
105
|
-
|
|
170
|
+
mlp_input_dim += sum(
|
|
171
|
+
[getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
|
|
172
|
+
)
|
|
173
|
+
|
|
106
174
|
# MLP for final prediction
|
|
107
175
|
self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
|
|
108
176
|
self.prediction_layer = PredictionLayer(task_type=self.task)
|
|
109
177
|
|
|
110
178
|
# Register regularization weights
|
|
111
179
|
self.register_regularization_weights(
|
|
112
|
-
embedding_attr=
|
|
113
|
-
include_modules=[
|
|
180
|
+
embedding_attr="embedding",
|
|
181
|
+
include_modules=["attention", "mlp", "candidate_attention_proj"],
|
|
114
182
|
)
|
|
115
183
|
|
|
116
184
|
self.compile(
|
|
@@ -122,62 +190,60 @@ class DIN(BaseModel):
|
|
|
122
190
|
|
|
123
191
|
def forward(self, x):
|
|
124
192
|
# Get candidate item embedding
|
|
125
|
-
if self.candidate_feature:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
193
|
+
if self.candidate_feature is None:
|
|
194
|
+
raise ValueError("DIN requires a candidate item feature")
|
|
195
|
+
candidate_emb = self.embedding.embed_dict[
|
|
196
|
+
self.candidate_feature.embedding_name
|
|
197
|
+
](
|
|
198
|
+
x[self.candidate_feature.name].long()
|
|
199
|
+
) # [B, emb_dim]
|
|
200
|
+
|
|
132
201
|
# Get behavior sequence embedding
|
|
133
202
|
behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
|
|
134
203
|
behavior_emb = self.embedding.embed_dict[self.behavior_feature.embedding_name](
|
|
135
204
|
behavior_seq
|
|
136
205
|
) # [B, seq_len, emb_dim]
|
|
137
|
-
|
|
206
|
+
|
|
138
207
|
# Create mask for padding
|
|
139
208
|
if self.behavior_feature.padding_idx is not None:
|
|
140
|
-
mask = (
|
|
209
|
+
mask = (
|
|
210
|
+
(behavior_seq != self.behavior_feature.padding_idx)
|
|
211
|
+
.unsqueeze(-1)
|
|
212
|
+
.float()
|
|
213
|
+
)
|
|
141
214
|
else:
|
|
142
215
|
mask = (behavior_seq != 0).unsqueeze(-1).float()
|
|
143
|
-
|
|
216
|
+
|
|
144
217
|
# Apply attention pooling
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
mask=mask
|
|
153
|
-
) # [B, emb_dim]
|
|
154
|
-
else:
|
|
155
|
-
# If no candidate, use mean pooling
|
|
156
|
-
pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (mask.sum(dim=1) + 1e-9)
|
|
157
|
-
|
|
218
|
+
candidate_query = candidate_emb
|
|
219
|
+
if self.candidate_attention_proj is not None:
|
|
220
|
+
candidate_query = self.candidate_attention_proj(candidate_query)
|
|
221
|
+
pooled_behavior = self.attention(
|
|
222
|
+
query=candidate_query, keys=behavior_emb, mask=mask
|
|
223
|
+
) # [B, emb_dim]
|
|
224
|
+
|
|
158
225
|
# Get other features
|
|
159
226
|
other_embeddings = []
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
227
|
+
|
|
228
|
+
other_embeddings.append(candidate_emb)
|
|
229
|
+
|
|
164
230
|
other_embeddings.append(pooled_behavior)
|
|
165
|
-
|
|
231
|
+
|
|
166
232
|
# Other sparse features
|
|
167
233
|
for feat in self.other_sparse_features:
|
|
168
|
-
feat_emb = self.embedding.embed_dict[feat.embedding_name](
|
|
234
|
+
feat_emb = self.embedding.embed_dict[feat.embedding_name](
|
|
235
|
+
x[feat.name].long()
|
|
236
|
+
)
|
|
169
237
|
other_embeddings.append(feat_emb)
|
|
170
|
-
|
|
238
|
+
|
|
171
239
|
# Dense features
|
|
172
|
-
for feat in self.
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
other_embeddings.append(val)
|
|
177
|
-
|
|
240
|
+
for feat in self.dense_features:
|
|
241
|
+
dense_val = self.embedding.project_dense(feat, x)
|
|
242
|
+
other_embeddings.append(dense_val)
|
|
243
|
+
|
|
178
244
|
# Concatenate all features
|
|
179
245
|
concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
|
|
180
|
-
|
|
246
|
+
|
|
181
247
|
# MLP prediction
|
|
182
248
|
y = self.mlp(concat_input) # [B, 1]
|
|
183
249
|
return self.prediction_layer(y)
|
|
@@ -1,10 +1,43 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Checkpoint: edit on 09/12/2025
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
[1] Huang T, Zhang Z, Zhang B, et al. FiBiNET: Combining feature importance and bilinear
|
|
7
|
+
feature interaction for click-through rate prediction[C]//RecSys. 2019: 169-177.
|
|
8
|
+
(https://arxiv.org/abs/1905.09433)
|
|
9
|
+
|
|
10
|
+
FiBiNET (Feature Importance and Bilinear Interaction Network) is a CTR model that
|
|
11
|
+
jointly learns which fields matter most and how they interact. It first uses SENET
|
|
12
|
+
to produce field-wise importance weights and recalibrate embeddings, then applies a
|
|
13
|
+
bilinear interaction layer to capture pairwise feature relationships with enhanced
|
|
14
|
+
expressiveness.
|
|
15
|
+
|
|
16
|
+
Pipeline:
|
|
17
|
+
(1) Embed sparse and sequence features that share a common embedding dimension
|
|
18
|
+
(2) SENET squeezes and excites across fields to generate importance scores
|
|
19
|
+
(3) Reweight embeddings with SENET scores to highlight informative fields
|
|
20
|
+
(4) Compute bilinear interactions on both the original and SENET-reweighted
|
|
21
|
+
embeddings to model pairwise relations
|
|
22
|
+
(5) Concatenate interaction outputs and feed them into an MLP alongside a linear
|
|
23
|
+
term for final prediction
|
|
24
|
+
|
|
25
|
+
Key Advantages:
|
|
26
|
+
- SENET recalibration emphasizes the most informative feature fields
|
|
27
|
+
- Bilinear interactions explicitly model pairwise relationships beyond simple dot
|
|
28
|
+
products
|
|
29
|
+
- Dual-path (standard + SENET-reweighted) interactions enrich representation power
|
|
30
|
+
- Combines linear and deep components for both memorization and generalization
|
|
31
|
+
|
|
32
|
+
FiBiNET 是一个 CTR 预估模型,通过 SENET 重新分配特征字段的重要性,再用双线性
|
|
33
|
+
交互层捕捉成对特征关系。模型先对稀疏/序列特征做 embedding,SENET 生成字段权重并
|
|
34
|
+
重标定 embedding,随后在原始和重标定的 embedding 上分别计算双线性交互,最后将
|
|
35
|
+
交互结果与线性部分一起输入 MLP 得到预测。
|
|
36
|
+
主要优点:
|
|
37
|
+
- SENET 让模型聚焦最重要的特征字段
|
|
38
|
+
- 双线性交互显式建模特征对关系,表达力强于简单点积
|
|
39
|
+
- 标准与重标定两路交互结合,丰富特征表示
|
|
40
|
+
- 线性与深层结构并行,兼顾记忆与泛化
|
|
8
41
|
"""
|
|
9
42
|
|
|
10
43
|
import torch
|
|
@@ -13,6 +46,7 @@ import torch.nn as nn
|
|
|
13
46
|
from nextrec.basic.model import BaseModel
|
|
14
47
|
from nextrec.basic.layers import (
|
|
15
48
|
BiLinearInteractionLayer,
|
|
49
|
+
HadamardInteractionLayer,
|
|
16
50
|
EmbeddingLayer,
|
|
17
51
|
LR,
|
|
18
52
|
MLP,
|
|
@@ -30,27 +64,38 @@ class FiBiNET(BaseModel):
|
|
|
30
64
|
@property
|
|
31
65
|
def default_task(self):
|
|
32
66
|
return "binary"
|
|
33
|
-
|
|
34
|
-
def __init__(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
dense_features: list[DenseFeature] | None = None,
|
|
71
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
72
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
73
|
+
mlp_params: dict | None = None,
|
|
74
|
+
interaction_combo: str = "11", # "0": Hadamard, "1": Bilinear
|
|
75
|
+
bilinear_type: str = "field_interaction",
|
|
76
|
+
senet_reduction: int = 3,
|
|
77
|
+
target: list[str] | str | None = None,
|
|
78
|
+
task: str | list[str] | None = None,
|
|
79
|
+
optimizer: str = "adam",
|
|
80
|
+
optimizer_params: dict | None = None,
|
|
81
|
+
loss: str | nn.Module | None = "bce",
|
|
82
|
+
loss_params: dict | list[dict] | None = None,
|
|
83
|
+
device: str = "cpu",
|
|
84
|
+
embedding_l1_reg=1e-6,
|
|
85
|
+
dense_l1_reg=1e-5,
|
|
86
|
+
embedding_l2_reg=1e-5,
|
|
87
|
+
dense_l2_reg=1e-4,
|
|
88
|
+
**kwargs,
|
|
89
|
+
):
|
|
90
|
+
|
|
91
|
+
dense_features = dense_features or []
|
|
92
|
+
sparse_features = sparse_features or []
|
|
93
|
+
sequence_features = sequence_features or []
|
|
94
|
+
mlp_params = mlp_params or {}
|
|
95
|
+
optimizer_params = optimizer_params or {}
|
|
96
|
+
if loss is None:
|
|
97
|
+
loss = "bce"
|
|
98
|
+
|
|
54
99
|
super(FiBiNET, self).__init__(
|
|
55
100
|
dense_features=dense_features,
|
|
56
101
|
sparse_features=sparse_features,
|
|
@@ -62,39 +107,61 @@ class FiBiNET(BaseModel):
|
|
|
62
107
|
dense_l1_reg=dense_l1_reg,
|
|
63
108
|
embedding_l2_reg=embedding_l2_reg,
|
|
64
109
|
dense_l2_reg=dense_l2_reg,
|
|
65
|
-
**kwargs
|
|
110
|
+
**kwargs,
|
|
66
111
|
)
|
|
67
112
|
|
|
68
113
|
self.loss = loss
|
|
69
|
-
if self.loss is None:
|
|
70
|
-
self.loss = "bce"
|
|
71
|
-
|
|
72
114
|
self.linear_features = sparse_features + sequence_features
|
|
73
|
-
self.deep_features = dense_features + sparse_features + sequence_features
|
|
74
115
|
self.interaction_features = sparse_features + sequence_features
|
|
75
116
|
|
|
76
117
|
if len(self.interaction_features) < 2:
|
|
77
|
-
raise ValueError(
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"FiBiNET requires at least two sparse/sequence features for interactions."
|
|
120
|
+
)
|
|
78
121
|
|
|
79
|
-
self.embedding = EmbeddingLayer(features=self.
|
|
122
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
80
123
|
|
|
81
124
|
self.num_fields = len(self.interaction_features)
|
|
82
125
|
self.embedding_dim = self.interaction_features[0].embedding_dim
|
|
83
|
-
if any(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
self.bilinear_senet = BiLinearInteractionLayer(
|
|
93
|
-
input_dim=self.embedding_dim,
|
|
94
|
-
num_fields=self.num_fields,
|
|
95
|
-
bilinear_type=bilinear_type,
|
|
126
|
+
if any(
|
|
127
|
+
f.embedding_dim != self.embedding_dim for f in self.interaction_features
|
|
128
|
+
):
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"All interaction features must share the same embedding_dim in FiBiNET."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.senet = SENETLayer(
|
|
134
|
+
num_fields=self.num_fields, reduction_ratio=senet_reduction
|
|
96
135
|
)
|
|
97
136
|
|
|
137
|
+
self.interaction_combo = interaction_combo
|
|
138
|
+
|
|
139
|
+
# E interaction layers: original embeddings
|
|
140
|
+
if interaction_combo[0] == "0": # Hadamard
|
|
141
|
+
self.interaction_E = HadamardInteractionLayer(
|
|
142
|
+
num_fields=self.num_fields
|
|
143
|
+
) # [B, num_pairs, D]
|
|
144
|
+
elif interaction_combo[0] == "1": # Bilinear
|
|
145
|
+
self.interaction_E = BiLinearInteractionLayer(
|
|
146
|
+
input_dim=self.embedding_dim,
|
|
147
|
+
num_fields=self.num_fields,
|
|
148
|
+
bilinear_type=bilinear_type,
|
|
149
|
+
) # [B, num_pairs, D]
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError("interaction_combo must be '01' or '11'")
|
|
152
|
+
|
|
153
|
+
# V interaction layers: SENET reweighted embeddings
|
|
154
|
+
if interaction_combo[1] == "0":
|
|
155
|
+
self.interaction_V = HadamardInteractionLayer(num_fields=self.num_fields)
|
|
156
|
+
elif interaction_combo[1] == "1":
|
|
157
|
+
self.interaction_V = BiLinearInteractionLayer(
|
|
158
|
+
input_dim=self.embedding_dim,
|
|
159
|
+
num_fields=self.num_fields,
|
|
160
|
+
bilinear_type=bilinear_type,
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError("Deep-FiBiNET SENET side must be '01' or '11'")
|
|
164
|
+
|
|
98
165
|
linear_dim = sum([f.embedding_dim for f in self.linear_features])
|
|
99
166
|
self.linear = LR(linear_dim)
|
|
100
167
|
|
|
@@ -105,8 +172,14 @@ class FiBiNET(BaseModel):
|
|
|
105
172
|
|
|
106
173
|
# Register regularization weights
|
|
107
174
|
self.register_regularization_weights(
|
|
108
|
-
embedding_attr=
|
|
109
|
-
include_modules=[
|
|
175
|
+
embedding_attr="embedding",
|
|
176
|
+
include_modules=[
|
|
177
|
+
"linear",
|
|
178
|
+
"senet",
|
|
179
|
+
"mlp",
|
|
180
|
+
"interaction_E",
|
|
181
|
+
"interaction_V",
|
|
182
|
+
],
|
|
110
183
|
)
|
|
111
184
|
|
|
112
185
|
self.compile(
|
|
@@ -117,15 +190,24 @@ class FiBiNET(BaseModel):
|
|
|
117
190
|
)
|
|
118
191
|
|
|
119
192
|
def forward(self, x):
|
|
120
|
-
input_linear = self.embedding(
|
|
193
|
+
input_linear = self.embedding(
|
|
194
|
+
x=x, features=self.linear_features, squeeze_dim=True
|
|
195
|
+
)
|
|
121
196
|
y_linear = self.linear(input_linear)
|
|
122
197
|
|
|
123
|
-
field_emb = self.embedding(
|
|
198
|
+
field_emb = self.embedding(
|
|
199
|
+
x=x, features=self.interaction_features, squeeze_dim=False
|
|
200
|
+
)
|
|
124
201
|
senet_emb = self.senet(field_emb)
|
|
125
202
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
203
|
+
out_E = self.interaction_E(field_emb) # [B, num_pairs, D]
|
|
204
|
+
|
|
205
|
+
out_V = self.interaction_V(senet_emb) # [B, num_pairs, D]
|
|
206
|
+
|
|
207
|
+
deep_input = torch.cat(
|
|
208
|
+
[out_E.flatten(start_dim=1), out_V.flatten(start_dim=1)], dim=1
|
|
209
|
+
)
|
|
210
|
+
|
|
129
211
|
y_deep = self.mlp(deep_input)
|
|
130
212
|
|
|
131
213
|
y = y_linear + y_deep
|