nextrec 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/layers.py +32 -8
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/metrics.py +2 -1
- nextrec/basic/model.py +3 -3
- nextrec/cli.py +41 -47
- nextrec/data/dataloader.py +1 -1
- nextrec/models/multi_task/esmm.py +23 -16
- nextrec/models/multi_task/mmoe.py +36 -17
- nextrec/models/multi_task/ple.py +18 -12
- nextrec/models/multi_task/poso.py +68 -37
- nextrec/models/multi_task/share_bottom.py +16 -2
- nextrec/models/ranking/afm.py +14 -14
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +61 -19
- nextrec/models/ranking/dcn_v2.py +224 -45
- nextrec/models/ranking/deepfm.py +14 -9
- nextrec/models/ranking/dien.py +215 -82
- nextrec/models/ranking/din.py +95 -57
- nextrec/models/ranking/fibinet.py +92 -30
- nextrec/models/ranking/fm.py +44 -8
- nextrec/models/ranking/masknet.py +7 -7
- nextrec/models/ranking/pnn.py +105 -38
- nextrec/models/ranking/widedeep.py +8 -4
- nextrec/models/ranking/xdeepfm.py +10 -5
- nextrec/utils/config.py +9 -3
- nextrec/utils/file.py +2 -1
- nextrec/utils/model.py +22 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/RECORD +33 -33
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/models/ranking/din.py
CHANGED
|
@@ -1,11 +1,50 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Checkpoint: edit on 09/12/2025
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
6
|
+
[1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate
|
|
7
|
+
prediction[C] //Proceedings of the 24th ACM SIGKDD international conference on
|
|
8
|
+
knowledge discovery & data mining. 2018: 1059-1068.
|
|
9
|
+
(https://arxiv.org/abs/1706.06978)
|
|
10
|
+
|
|
11
|
+
Deep Interest Network (DIN) is a CTR model that builds a target-aware user
|
|
12
|
+
representation by attending over the historical behavior sequence. Instead of
|
|
13
|
+
compressing all behaviors into one static vector, DIN highlights the behaviors
|
|
14
|
+
most relevant to the current candidate item, enabling adaptive interest
|
|
15
|
+
modeling for each request.
|
|
16
|
+
|
|
17
|
+
Pipeline:
|
|
18
|
+
(1) Embed candidate item, user behavior sequence, and other sparse/dense fields
|
|
19
|
+
(2) Use a small attention MLP to score each historical behavior against the
|
|
20
|
+
candidate embedding
|
|
21
|
+
(3) Apply masked weighted pooling to obtain a target-specific interest vector
|
|
22
|
+
(4) Concatenate candidate, interest vector, other sparse embeddings, and dense
|
|
23
|
+
features
|
|
24
|
+
(5) Feed the combined representation into an MLP for final prediction
|
|
25
|
+
|
|
26
|
+
Key Advantages:
|
|
27
|
+
- Target-aware attention captures fine-grained interests per candidate item
|
|
28
|
+
- Adaptive pooling handles diverse behavior patterns without heavy feature crafting
|
|
29
|
+
- Masked weighting reduces noise from padded sequence positions
|
|
30
|
+
- Easily incorporates additional sparse/dense context features alongside behavior
|
|
31
|
+
|
|
32
|
+
DIN 是一个 CTR 预估模型,通过对用户历史行为序列进行目标感知的注意力加权,
|
|
33
|
+
构建针对当前候选物品的兴趣表示。它不是将全部行为压缩为固定向量,而是突出
|
|
34
|
+
与候选物品最相关的行为,实现请求级的自适应兴趣建模。
|
|
35
|
+
|
|
36
|
+
处理流程:
|
|
37
|
+
(1) 对候选物品、用户行为序列及其他稀疏/稠密特征做 embedding
|
|
38
|
+
(2) 使用小型注意力 MLP 计算每个历史行为与候选 embedding 的相关性
|
|
39
|
+
(3) 通过掩码加权池化得到目标特定的兴趣向量
|
|
40
|
+
(4) 拼接候选、兴趣向量、其他稀疏 embedding 与稠密特征
|
|
41
|
+
(5) 输入 MLP 完成最终点击率预测
|
|
42
|
+
|
|
43
|
+
主要优点:
|
|
44
|
+
- 目标感知注意力捕捉候选级的细粒度兴趣
|
|
45
|
+
- 自适应池化应对多样化行为模式,减少手工特征工程
|
|
46
|
+
- 掩码加权降低序列填充位置的噪声
|
|
47
|
+
- 便捷融合行为与额外稀疏/稠密上下文信息
|
|
9
48
|
"""
|
|
10
49
|
|
|
11
50
|
import torch
|
|
@@ -32,17 +71,19 @@ class DIN(BaseModel):
|
|
|
32
71
|
|
|
33
72
|
def __init__(
|
|
34
73
|
self,
|
|
35
|
-
dense_features: list[DenseFeature],
|
|
36
|
-
sparse_features: list[SparseFeature],
|
|
37
|
-
sequence_features: list[SequenceFeature],
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
74
|
+
dense_features: list[DenseFeature] | None = None,
|
|
75
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
76
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
77
|
+
behavior_feature_name: str | None = None,
|
|
78
|
+
candidate_feature_name: str | None = None,
|
|
79
|
+
mlp_params: dict | None = None,
|
|
80
|
+
attention_hidden_units: list[int] | None = None,
|
|
81
|
+
attention_activation: str = "dice",
|
|
41
82
|
attention_use_softmax: bool = True,
|
|
42
|
-
target: list[str] =
|
|
83
|
+
target: list[str] | str | None = None,
|
|
43
84
|
task: str | list[str] | None = None,
|
|
44
85
|
optimizer: str = "adam",
|
|
45
|
-
optimizer_params: dict =
|
|
86
|
+
optimizer_params: dict | None = None,
|
|
46
87
|
loss: str | nn.Module | None = "bce",
|
|
47
88
|
loss_params: dict | list[dict] | None = None,
|
|
48
89
|
device: str = "cpu",
|
|
@@ -53,6 +94,15 @@ class DIN(BaseModel):
|
|
|
53
94
|
**kwargs,
|
|
54
95
|
):
|
|
55
96
|
|
|
97
|
+
dense_features = dense_features or []
|
|
98
|
+
sparse_features = sparse_features or []
|
|
99
|
+
sequence_features = sequence_features or []
|
|
100
|
+
mlp_params = mlp_params or {}
|
|
101
|
+
attention_hidden_units = attention_hidden_units or [80, 40]
|
|
102
|
+
optimizer_params = optimizer_params or {}
|
|
103
|
+
if loss is None:
|
|
104
|
+
loss = "bce"
|
|
105
|
+
|
|
56
106
|
super(DIN, self).__init__(
|
|
57
107
|
dense_features=dense_features,
|
|
58
108
|
sparse_features=sparse_features,
|
|
@@ -67,30 +117,28 @@ class DIN(BaseModel):
|
|
|
67
117
|
**kwargs,
|
|
68
118
|
)
|
|
69
119
|
|
|
70
|
-
|
|
71
|
-
if self.loss is None:
|
|
72
|
-
self.loss = "bce"
|
|
73
|
-
|
|
74
|
-
# Features classification
|
|
75
|
-
# DIN requires: candidate item + user behavior sequence + other features
|
|
120
|
+
# DIN requires: user behavior sequence + candidate item + other features
|
|
76
121
|
if len(sequence_features) == 0:
|
|
77
122
|
raise ValueError(
|
|
78
123
|
"DIN requires at least one sequence feature for user behavior history"
|
|
79
124
|
)
|
|
125
|
+
if behavior_feature_name is None:
|
|
126
|
+
raise ValueError("DIN requires an explicit behavior_feature_name")
|
|
80
127
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
sparse_features[-1] if sparse_features else None
|
|
84
|
-
) # Candidate item
|
|
128
|
+
if candidate_feature_name is None:
|
|
129
|
+
raise ValueError("DIN requires an explicit candidate_feature_name")
|
|
85
130
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
131
|
+
self.behavior_feature = [
|
|
132
|
+
f for f in sequence_features if f.name == behavior_feature_name
|
|
133
|
+
][0]
|
|
134
|
+
self.candidate_feature = [
|
|
135
|
+
f for f in sparse_features if f.name == candidate_feature_name
|
|
136
|
+
][0]
|
|
91
137
|
|
|
92
|
-
#
|
|
93
|
-
self.
|
|
138
|
+
# Other sparse features
|
|
139
|
+
self.other_sparse_features = [
|
|
140
|
+
f for f in sparse_features if f.name != self.candidate_feature.name
|
|
141
|
+
]
|
|
94
142
|
|
|
95
143
|
# Embedding layer
|
|
96
144
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
@@ -142,14 +190,13 @@ class DIN(BaseModel):
|
|
|
142
190
|
|
|
143
191
|
def forward(self, x):
|
|
144
192
|
# Get candidate item embedding
|
|
145
|
-
if self.candidate_feature:
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
candidate_emb = None
|
|
193
|
+
if self.candidate_feature is None:
|
|
194
|
+
raise ValueError("DIN requires a candidate item feature")
|
|
195
|
+
candidate_emb = self.embedding.embed_dict[
|
|
196
|
+
self.candidate_feature.embedding_name
|
|
197
|
+
](
|
|
198
|
+
x[self.candidate_feature.name].long()
|
|
199
|
+
) # [B, emb_dim]
|
|
153
200
|
|
|
154
201
|
# Get behavior sequence embedding
|
|
155
202
|
behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
|
|
@@ -168,24 +215,17 @@ class DIN(BaseModel):
|
|
|
168
215
|
mask = (behavior_seq != 0).unsqueeze(-1).float()
|
|
169
216
|
|
|
170
217
|
# Apply attention pooling
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
) # [B, emb_dim]
|
|
178
|
-
else:
|
|
179
|
-
# If no candidate, use mean pooling
|
|
180
|
-
pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (
|
|
181
|
-
mask.sum(dim=1) + 1e-9
|
|
182
|
-
)
|
|
218
|
+
candidate_query = candidate_emb
|
|
219
|
+
if self.candidate_attention_proj is not None:
|
|
220
|
+
candidate_query = self.candidate_attention_proj(candidate_query)
|
|
221
|
+
pooled_behavior = self.attention(
|
|
222
|
+
query=candidate_query, keys=behavior_emb, mask=mask
|
|
223
|
+
) # [B, emb_dim]
|
|
183
224
|
|
|
184
225
|
# Get other features
|
|
185
226
|
other_embeddings = []
|
|
186
227
|
|
|
187
|
-
|
|
188
|
-
other_embeddings.append(candidate_emb)
|
|
228
|
+
other_embeddings.append(candidate_emb)
|
|
189
229
|
|
|
190
230
|
other_embeddings.append(pooled_behavior)
|
|
191
231
|
|
|
@@ -197,11 +237,9 @@ class DIN(BaseModel):
|
|
|
197
237
|
other_embeddings.append(feat_emb)
|
|
198
238
|
|
|
199
239
|
# Dense features
|
|
200
|
-
for feat in self.
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
val = val.unsqueeze(1)
|
|
204
|
-
other_embeddings.append(val)
|
|
240
|
+
for feat in self.dense_features:
|
|
241
|
+
dense_val = self.embedding.project_dense(feat, x)
|
|
242
|
+
other_embeddings.append(dense_val)
|
|
205
243
|
|
|
206
244
|
# Concatenate all features
|
|
207
245
|
concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
|
|
@@ -1,10 +1,43 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Checkpoint: edit on 09/12/2025
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
[1] Huang T, Zhang Z, Zhang B, et al. FiBiNET: Combining feature importance and bilinear
|
|
7
|
+
feature interaction for click-through rate prediction[C]//RecSys. 2019: 169-177.
|
|
8
|
+
(https://arxiv.org/abs/1905.09433)
|
|
9
|
+
|
|
10
|
+
FiBiNET (Feature Importance and Bilinear Interaction Network) is a CTR model that
|
|
11
|
+
jointly learns which fields matter most and how they interact. It first uses SENET
|
|
12
|
+
to produce field-wise importance weights and recalibrate embeddings, then applies a
|
|
13
|
+
bilinear interaction layer to capture pairwise feature relationships with enhanced
|
|
14
|
+
expressiveness.
|
|
15
|
+
|
|
16
|
+
Pipeline:
|
|
17
|
+
(1) Embed sparse and sequence features that share a common embedding dimension
|
|
18
|
+
(2) SENET squeezes and excites across fields to generate importance scores
|
|
19
|
+
(3) Reweight embeddings with SENET scores to highlight informative fields
|
|
20
|
+
(4) Compute bilinear interactions on both the original and SENET-reweighted
|
|
21
|
+
embeddings to model pairwise relations
|
|
22
|
+
(5) Concatenate interaction outputs and feed them into an MLP alongside a linear
|
|
23
|
+
term for final prediction
|
|
24
|
+
|
|
25
|
+
Key Advantages:
|
|
26
|
+
- SENET recalibration emphasizes the most informative feature fields
|
|
27
|
+
- Bilinear interactions explicitly model pairwise relationships beyond simple dot
|
|
28
|
+
products
|
|
29
|
+
- Dual-path (standard + SENET-reweighted) interactions enrich representation power
|
|
30
|
+
- Combines linear and deep components for both memorization and generalization
|
|
31
|
+
|
|
32
|
+
FiBiNET 是一个 CTR 预估模型,通过 SENET 重新分配特征字段的重要性,再用双线性
|
|
33
|
+
交互层捕捉成对特征关系。模型先对稀疏/序列特征做 embedding,SENET 生成字段权重并
|
|
34
|
+
重标定 embedding,随后在原始和重标定的 embedding 上分别计算双线性交互,最后将
|
|
35
|
+
交互结果与线性部分一起输入 MLP 得到预测。
|
|
36
|
+
主要优点:
|
|
37
|
+
- SENET 让模型聚焦最重要的特征字段
|
|
38
|
+
- 双线性交互显式建模特征对关系,表达力强于简单点积
|
|
39
|
+
- 标准与重标定两路交互结合,丰富特征表示
|
|
40
|
+
- 线性与深层结构并行,兼顾记忆与泛化
|
|
8
41
|
"""
|
|
9
42
|
|
|
10
43
|
import torch
|
|
@@ -13,6 +46,7 @@ import torch.nn as nn
|
|
|
13
46
|
from nextrec.basic.model import BaseModel
|
|
14
47
|
from nextrec.basic.layers import (
|
|
15
48
|
BiLinearInteractionLayer,
|
|
49
|
+
HadamardInteractionLayer,
|
|
16
50
|
EmbeddingLayer,
|
|
17
51
|
LR,
|
|
18
52
|
MLP,
|
|
@@ -33,16 +67,17 @@ class FiBiNET(BaseModel):
|
|
|
33
67
|
|
|
34
68
|
def __init__(
|
|
35
69
|
self,
|
|
36
|
-
dense_features: list[DenseFeature] |
|
|
37
|
-
sparse_features: list[SparseFeature] |
|
|
38
|
-
sequence_features: list[SequenceFeature] |
|
|
39
|
-
mlp_params: dict =
|
|
70
|
+
dense_features: list[DenseFeature] | None = None,
|
|
71
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
72
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
73
|
+
mlp_params: dict | None = None,
|
|
74
|
+
interaction_combo: str = "11", # "0": Hadamard, "1": Bilinear
|
|
40
75
|
bilinear_type: str = "field_interaction",
|
|
41
76
|
senet_reduction: int = 3,
|
|
42
|
-
target: list[str] |
|
|
77
|
+
target: list[str] | str | None = None,
|
|
43
78
|
task: str | list[str] | None = None,
|
|
44
79
|
optimizer: str = "adam",
|
|
45
|
-
optimizer_params: dict =
|
|
80
|
+
optimizer_params: dict | None = None,
|
|
46
81
|
loss: str | nn.Module | None = "bce",
|
|
47
82
|
loss_params: dict | list[dict] | None = None,
|
|
48
83
|
device: str = "cpu",
|
|
@@ -53,6 +88,14 @@ class FiBiNET(BaseModel):
|
|
|
53
88
|
**kwargs,
|
|
54
89
|
):
|
|
55
90
|
|
|
91
|
+
dense_features = dense_features or []
|
|
92
|
+
sparse_features = sparse_features or []
|
|
93
|
+
sequence_features = sequence_features or []
|
|
94
|
+
mlp_params = mlp_params or {}
|
|
95
|
+
optimizer_params = optimizer_params or {}
|
|
96
|
+
if loss is None:
|
|
97
|
+
loss = "bce"
|
|
98
|
+
|
|
56
99
|
super(FiBiNET, self).__init__(
|
|
57
100
|
dense_features=dense_features,
|
|
58
101
|
sparse_features=sparse_features,
|
|
@@ -68,11 +111,7 @@ class FiBiNET(BaseModel):
|
|
|
68
111
|
)
|
|
69
112
|
|
|
70
113
|
self.loss = loss
|
|
71
|
-
if self.loss is None:
|
|
72
|
-
self.loss = "bce"
|
|
73
|
-
|
|
74
114
|
self.linear_features = sparse_features + sequence_features
|
|
75
|
-
self.deep_features = dense_features + sparse_features + sequence_features
|
|
76
115
|
self.interaction_features = sparse_features + sequence_features
|
|
77
116
|
|
|
78
117
|
if len(self.interaction_features) < 2:
|
|
@@ -80,7 +119,7 @@ class FiBiNET(BaseModel):
|
|
|
80
119
|
"FiBiNET requires at least two sparse/sequence features for interactions."
|
|
81
120
|
)
|
|
82
121
|
|
|
83
|
-
self.embedding = EmbeddingLayer(features=self.
|
|
122
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
84
123
|
|
|
85
124
|
self.num_fields = len(self.interaction_features)
|
|
86
125
|
self.embedding_dim = self.interaction_features[0].embedding_dim
|
|
@@ -94,16 +133,34 @@ class FiBiNET(BaseModel):
|
|
|
94
133
|
self.senet = SENETLayer(
|
|
95
134
|
num_fields=self.num_fields, reduction_ratio=senet_reduction
|
|
96
135
|
)
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
136
|
+
|
|
137
|
+
self.interaction_combo = interaction_combo
|
|
138
|
+
|
|
139
|
+
# E interaction layers: original embeddings
|
|
140
|
+
if interaction_combo[0] == "0": # Hadamard
|
|
141
|
+
self.interaction_E = HadamardInteractionLayer(
|
|
142
|
+
num_fields=self.num_fields
|
|
143
|
+
) # [B, num_pairs, D]
|
|
144
|
+
elif interaction_combo[0] == "1": # Bilinear
|
|
145
|
+
self.interaction_E = BiLinearInteractionLayer(
|
|
146
|
+
input_dim=self.embedding_dim,
|
|
147
|
+
num_fields=self.num_fields,
|
|
148
|
+
bilinear_type=bilinear_type,
|
|
149
|
+
) # [B, num_pairs, D]
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError("interaction_combo must be '01' or '11'")
|
|
152
|
+
|
|
153
|
+
# V interaction layers: SENET reweighted embeddings
|
|
154
|
+
if interaction_combo[1] == "0":
|
|
155
|
+
self.interaction_V = HadamardInteractionLayer(num_fields=self.num_fields)
|
|
156
|
+
elif interaction_combo[1] == "1":
|
|
157
|
+
self.interaction_V = BiLinearInteractionLayer(
|
|
158
|
+
input_dim=self.embedding_dim,
|
|
159
|
+
num_fields=self.num_fields,
|
|
160
|
+
bilinear_type=bilinear_type,
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError("Deep-FiBiNET SENET side must be '01' or '11'")
|
|
107
164
|
|
|
108
165
|
linear_dim = sum([f.embedding_dim for f in self.linear_features])
|
|
109
166
|
self.linear = LR(linear_dim)
|
|
@@ -119,9 +176,9 @@ class FiBiNET(BaseModel):
|
|
|
119
176
|
include_modules=[
|
|
120
177
|
"linear",
|
|
121
178
|
"senet",
|
|
122
|
-
"bilinear_standard",
|
|
123
|
-
"bilinear_senet",
|
|
124
179
|
"mlp",
|
|
180
|
+
"interaction_E",
|
|
181
|
+
"interaction_V",
|
|
125
182
|
],
|
|
126
183
|
)
|
|
127
184
|
|
|
@@ -143,9 +200,14 @@ class FiBiNET(BaseModel):
|
|
|
143
200
|
)
|
|
144
201
|
senet_emb = self.senet(field_emb)
|
|
145
202
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
203
|
+
out_E = self.interaction_E(field_emb) # [B, num_pairs, D]
|
|
204
|
+
|
|
205
|
+
out_V = self.interaction_V(senet_emb) # [B, num_pairs, D]
|
|
206
|
+
|
|
207
|
+
deep_input = torch.cat(
|
|
208
|
+
[out_E.flatten(start_dim=1), out_V.flatten(start_dim=1)], dim=1
|
|
209
|
+
)
|
|
210
|
+
|
|
149
211
|
y_deep = self.mlp(deep_input)
|
|
150
212
|
|
|
151
213
|
y = y_linear + y_deep
|
nextrec/models/ranking/fm.py
CHANGED
|
@@ -1,9 +1,41 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Checkpoint: edit on 09/12/2025
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
|
-
|
|
6
|
+
[1] Rendle S. Factorization machines[C]//ICDM. 2010: 995-1000.
|
|
7
|
+
|
|
8
|
+
Factorization Machines (FM) capture second-order feature interactions with
|
|
9
|
+
linear complexity by factorizing the pairwise interaction matrix. Each field
|
|
10
|
+
is embedded into a latent vector; FM models the dot product of every pair of
|
|
11
|
+
embeddings and sums them along with a linear term, enabling strong performance
|
|
12
|
+
with sparse high-dimensional data and minimal feature engineering.
|
|
13
|
+
|
|
14
|
+
Pipeline:
|
|
15
|
+
(1) Embed sparse and sequence fields into low-dimensional vectors
|
|
16
|
+
(2) Compute linear logit over concatenated embeddings
|
|
17
|
+
(3) Compute pairwise interaction logit via factorized dot products
|
|
18
|
+
(4) Sum linear + interaction terms and apply prediction layer
|
|
19
|
+
|
|
20
|
+
Key Advantages:
|
|
21
|
+
- Models pairwise interactions efficiently (O(nk) vs. O(n^2))
|
|
22
|
+
- Works well on sparse inputs without handcrafted crosses
|
|
23
|
+
- Simple architecture with strong baseline performance
|
|
24
|
+
|
|
25
|
+
FM 是一种通过分解二阶特征交互矩阵、以线性复杂度建模特征对的 CTR 模型。
|
|
26
|
+
每个特征映射为低维向量,FM 对任意特征对进行内积求和并叠加线性项,
|
|
27
|
+
无需复杂特征工程即可在稀疏高维场景取得稳健效果。
|
|
28
|
+
|
|
29
|
+
处理流程:
|
|
30
|
+
(1) 对稀疏/序列特征做 embedding
|
|
31
|
+
(2) 计算线性部分的 logit
|
|
32
|
+
(3) 计算嵌入对之间的二阶交互 logit
|
|
33
|
+
(4) 线性项与交互项求和,再通过预测层输出
|
|
34
|
+
|
|
35
|
+
主要优点:
|
|
36
|
+
- 线性复杂度建模二阶交互,效率高
|
|
37
|
+
- 对稀疏特征友好,减少人工特征交叉
|
|
38
|
+
- 结构简单、表现强健,常作 CTR 基线
|
|
7
39
|
"""
|
|
8
40
|
|
|
9
41
|
import torch.nn as nn
|
|
@@ -29,13 +61,13 @@ class FM(BaseModel):
|
|
|
29
61
|
|
|
30
62
|
def __init__(
|
|
31
63
|
self,
|
|
32
|
-
dense_features: list[DenseFeature] |
|
|
33
|
-
sparse_features: list[SparseFeature] |
|
|
34
|
-
sequence_features: list[SequenceFeature] |
|
|
35
|
-
target: list[str] |
|
|
64
|
+
dense_features: list[DenseFeature] | None = None,
|
|
65
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
66
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
67
|
+
target: list[str] | str | None = None,
|
|
36
68
|
task: str | list[str] | None = None,
|
|
37
69
|
optimizer: str = "adam",
|
|
38
|
-
optimizer_params: dict =
|
|
70
|
+
optimizer_params: dict | None = None,
|
|
39
71
|
loss: str | nn.Module | None = "bce",
|
|
40
72
|
loss_params: dict | list[dict] | None = None,
|
|
41
73
|
device: str = "cpu",
|
|
@@ -46,6 +78,10 @@ class FM(BaseModel):
|
|
|
46
78
|
**kwargs,
|
|
47
79
|
):
|
|
48
80
|
|
|
81
|
+
dense_features = dense_features or []
|
|
82
|
+
sparse_features = sparse_features or []
|
|
83
|
+
sequence_features = sequence_features or []
|
|
84
|
+
|
|
49
85
|
super(FM, self).__init__(
|
|
50
86
|
dense_features=dense_features,
|
|
51
87
|
sparse_features=sparse_features,
|
|
@@ -166,7 +166,7 @@ class MaskNet(BaseModel):
|
|
|
166
166
|
dense_features: list[DenseFeature] | None = None,
|
|
167
167
|
sparse_features: list[SparseFeature] | None = None,
|
|
168
168
|
sequence_features: list[SequenceFeature] | None = None,
|
|
169
|
-
|
|
169
|
+
architecture: str = "parallel", # "serial" or "parallel"
|
|
170
170
|
num_blocks: int = 3,
|
|
171
171
|
mask_hidden_dim: int = 64,
|
|
172
172
|
block_hidden_dim: int = 256,
|
|
@@ -232,11 +232,11 @@ class MaskNet(BaseModel):
|
|
|
232
232
|
)
|
|
233
233
|
|
|
234
234
|
self.v_emb_dim = self.num_fields * self.embedding_dim
|
|
235
|
-
self.
|
|
236
|
-
assert self.
|
|
235
|
+
self.architecture = architecture.lower()
|
|
236
|
+
assert self.architecture in (
|
|
237
237
|
"serial",
|
|
238
238
|
"parallel",
|
|
239
|
-
), "
|
|
239
|
+
), "architecture must be either 'serial' or 'parallel'."
|
|
240
240
|
|
|
241
241
|
self.num_blocks = max(1, num_blocks)
|
|
242
242
|
self.block_hidden_dim = block_hidden_dim
|
|
@@ -244,7 +244,7 @@ class MaskNet(BaseModel):
|
|
|
244
244
|
nn.Dropout(block_dropout) if block_dropout > 0 else nn.Identity()
|
|
245
245
|
)
|
|
246
246
|
|
|
247
|
-
if self.
|
|
247
|
+
if self.architecture == "serial":
|
|
248
248
|
self.first_block = MaskBlockOnEmbedding(
|
|
249
249
|
num_fields=self.num_fields,
|
|
250
250
|
embedding_dim=self.embedding_dim,
|
|
@@ -284,7 +284,7 @@ class MaskNet(BaseModel):
|
|
|
284
284
|
self.output_layer = None
|
|
285
285
|
self.prediction_layer = PredictionLayer(task_type=self.task)
|
|
286
286
|
|
|
287
|
-
if self.
|
|
287
|
+
if self.architecture == "serial":
|
|
288
288
|
self.register_regularization_weights(
|
|
289
289
|
embedding_attr="embedding",
|
|
290
290
|
include_modules=["mask_blocks", "output_layer"],
|
|
@@ -306,7 +306,7 @@ class MaskNet(BaseModel):
|
|
|
306
306
|
B = field_emb.size(0)
|
|
307
307
|
v_emb_flat = field_emb.view(B, -1) # flattened embeddings
|
|
308
308
|
|
|
309
|
-
if self.
|
|
309
|
+
if self.architecture == "parallel":
|
|
310
310
|
block_outputs = []
|
|
311
311
|
for block in self.mask_blocks:
|
|
312
312
|
h = block(field_emb, v_emb_flat) # [B, block_hidden_dim]
|