nextrec 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +5 -1
- nextrec/basic/layers.py +3 -7
- nextrec/basic/model.py +495 -664
- nextrec/data/data_utils.py +44 -12
- nextrec/data/dataloader.py +84 -285
- nextrec/data/preprocessor.py +91 -213
- nextrec/loss/__init__.py +0 -1
- nextrec/loss/loss_utils.py +51 -120
- nextrec/models/multi_task/esmm.py +1 -1
- nextrec/models/ranking/autoint.py +51 -7
- nextrec/models/ranking/masknet.py +268 -76
- nextrec/utils/__init__.py +4 -1
- nextrec/utils/common.py +16 -0
- {nextrec-0.2.3.dist-info → nextrec-0.2.5.dist-info}/METADATA +2 -2
- {nextrec-0.2.3.dist-info → nextrec-0.2.5.dist-info}/RECORD +18 -17
- {nextrec-0.2.3.dist-info → nextrec-0.2.5.dist-info}/WHEEL +0 -0
- {nextrec-0.2.3.dist-info → nextrec-0.2.5.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py
CHANGED
nextrec/loss/loss_utils.py
CHANGED
|
@@ -21,138 +21,69 @@ from nextrec.loss.pointwise import (
|
|
|
21
21
|
WeightedBCELoss,
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
VALID_TASK_TYPES = [
|
|
26
|
-
"binary",
|
|
27
|
-
"multiclass",
|
|
28
|
-
"
|
|
29
|
-
"
|
|
30
|
-
"match",
|
|
31
|
-
"ranking",
|
|
32
|
-
"multitask",
|
|
33
|
-
"multilabel",
|
|
26
|
+
"binary",
|
|
27
|
+
"multiclass",
|
|
28
|
+
"multilabel",
|
|
29
|
+
"regression",
|
|
34
30
|
]
|
|
35
31
|
|
|
32
|
+
def _build_cb_focal(kw):
|
|
33
|
+
if "class_counts" not in kw:
|
|
34
|
+
raise ValueError("class_balanced_focal requires class_counts")
|
|
35
|
+
return ClassBalancedFocalLoss(**kw)
|
|
36
36
|
|
|
37
|
-
def get_loss_fn(
|
|
38
|
-
task_type: str = "binary",
|
|
39
|
-
training_mode: str | None = None,
|
|
40
|
-
loss: str | nn.Module | None = None,
|
|
41
|
-
**loss_kwargs,
|
|
42
|
-
) -> nn.Module:
|
|
43
|
-
"""
|
|
44
|
-
Get loss function based on task type and training mode.
|
|
45
|
-
"""
|
|
46
37
|
|
|
38
|
+
def get_loss_fn(loss=None, **kw):
|
|
47
39
|
if isinstance(loss, nn.Module):
|
|
48
40
|
return loss
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
if
|
|
52
|
-
return
|
|
53
|
-
|
|
54
|
-
if task_type in ["ranking", "multitask", "binary", "multilabel"]:
|
|
55
|
-
return _get_classification_loss(loss, **loss_kwargs)
|
|
56
|
-
|
|
57
|
-
if task_type == "multiclass":
|
|
58
|
-
return _get_multiclass_loss(loss, **loss_kwargs)
|
|
59
|
-
|
|
60
|
-
if task_type == "regression":
|
|
61
|
-
if loss is None or loss == "mse":
|
|
62
|
-
return nn.MSELoss(**loss_kwargs)
|
|
63
|
-
if loss == "mae":
|
|
64
|
-
return nn.L1Loss(**loss_kwargs)
|
|
65
|
-
if isinstance(loss, str):
|
|
66
|
-
raise ValueError(f"Unsupported regression loss: {loss}")
|
|
67
|
-
|
|
68
|
-
raise ValueError(f"Unsupported task_type: {task_type}")
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _get_match_loss(training_mode: str | None, loss: str | None, **loss_kwargs) -> nn.Module:
|
|
72
|
-
if training_mode == "pointwise":
|
|
73
|
-
if loss is None or loss in {"bce", "binary_crossentropy"}:
|
|
74
|
-
return nn.BCELoss(**loss_kwargs)
|
|
75
|
-
if loss == "weighted_bce":
|
|
76
|
-
return WeightedBCELoss(**loss_kwargs)
|
|
77
|
-
if loss == "focal":
|
|
78
|
-
return FocalLoss(**loss_kwargs)
|
|
79
|
-
if loss == "class_balanced_focal":
|
|
80
|
-
return _build_cb_focal(loss_kwargs)
|
|
81
|
-
if loss == "cosine_contrastive":
|
|
82
|
-
return CosineContrastiveLoss(**loss_kwargs)
|
|
83
|
-
if isinstance(loss, str):
|
|
84
|
-
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
85
|
-
|
|
86
|
-
if training_mode == "pairwise":
|
|
87
|
-
if loss is None or loss == "bpr":
|
|
88
|
-
return BPRLoss(**loss_kwargs)
|
|
89
|
-
if loss == "hinge":
|
|
90
|
-
return HingeLoss(**loss_kwargs)
|
|
91
|
-
if loss == "triplet":
|
|
92
|
-
return TripletLoss(**loss_kwargs)
|
|
93
|
-
if isinstance(loss, str):
|
|
94
|
-
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
95
|
-
|
|
96
|
-
if training_mode == "listwise":
|
|
97
|
-
if loss is None or loss in {"sampled_softmax", "softmax"}:
|
|
98
|
-
return SampledSoftmaxLoss(**loss_kwargs)
|
|
99
|
-
if loss == "infonce":
|
|
100
|
-
return InfoNCELoss(**loss_kwargs)
|
|
101
|
-
if loss == "listnet":
|
|
102
|
-
return ListNetLoss(**loss_kwargs)
|
|
103
|
-
if loss == "listmle":
|
|
104
|
-
return ListMLELoss(**loss_kwargs)
|
|
105
|
-
if loss == "approx_ndcg":
|
|
106
|
-
return ApproxNDCGLoss(**loss_kwargs)
|
|
107
|
-
if loss in {"crossentropy", "ce"}:
|
|
108
|
-
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
109
|
-
if isinstance(loss, str):
|
|
110
|
-
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
111
|
-
|
|
112
|
-
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def _get_classification_loss(loss: str | None, **loss_kwargs) -> nn.Module:
|
|
116
|
-
if loss is None or loss in {"bce", "binary_crossentropy"}:
|
|
117
|
-
return nn.BCELoss(**loss_kwargs)
|
|
41
|
+
if loss is None:
|
|
42
|
+
raise ValueError("loss must be provided explicitly")
|
|
43
|
+
if loss in ["bce", "binary_crossentropy"]:
|
|
44
|
+
return nn.BCELoss(**kw)
|
|
118
45
|
if loss == "weighted_bce":
|
|
119
|
-
return WeightedBCELoss(**
|
|
120
|
-
if loss
|
|
121
|
-
return FocalLoss(**
|
|
122
|
-
if loss
|
|
123
|
-
return _build_cb_focal(
|
|
46
|
+
return WeightedBCELoss(**kw)
|
|
47
|
+
if loss in ["focal", "focal_loss"]:
|
|
48
|
+
return FocalLoss(**kw)
|
|
49
|
+
if loss in ["cb_focal", "class_balanced_focal"]:
|
|
50
|
+
return _build_cb_focal(kw)
|
|
51
|
+
if loss in ["crossentropy", "ce"]:
|
|
52
|
+
return nn.CrossEntropyLoss(**kw)
|
|
124
53
|
if loss == "mse":
|
|
125
|
-
return nn.MSELoss(**
|
|
54
|
+
return nn.MSELoss(**kw)
|
|
126
55
|
if loss == "mae":
|
|
127
|
-
return nn.L1Loss(**
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
if
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
if loss
|
|
139
|
-
return
|
|
140
|
-
if loss == "
|
|
141
|
-
return
|
|
142
|
-
if
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
return ClassBalancedFocalLoss(**loss_kwargs)
|
|
151
|
-
|
|
56
|
+
return nn.L1Loss(**kw)
|
|
57
|
+
|
|
58
|
+
# Pairwise ranking Loss
|
|
59
|
+
if loss == "bpr":
|
|
60
|
+
return BPRLoss(**kw)
|
|
61
|
+
if loss == "hinge":
|
|
62
|
+
return HingeLoss(**kw)
|
|
63
|
+
if loss == "triplet":
|
|
64
|
+
return TripletLoss(**kw)
|
|
65
|
+
|
|
66
|
+
# Listwise ranking Loss
|
|
67
|
+
if loss in ["sampled_softmax", "softmax"]:
|
|
68
|
+
return SampledSoftmaxLoss(**kw)
|
|
69
|
+
if loss == "infonce":
|
|
70
|
+
return InfoNCELoss(**kw)
|
|
71
|
+
if loss == "listnet":
|
|
72
|
+
return ListNetLoss(**kw)
|
|
73
|
+
if loss == "listmle":
|
|
74
|
+
return ListMLELoss(**kw)
|
|
75
|
+
if loss == "approx_ndcg":
|
|
76
|
+
return ApproxNDCGLoss(**kw)
|
|
77
|
+
|
|
78
|
+
raise ValueError(f"Unsupported loss: {loss}")
|
|
152
79
|
|
|
153
80
|
def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
|
|
154
81
|
"""
|
|
155
|
-
|
|
82
|
+
解析每个 head 对应的 loss_kwargs。
|
|
83
|
+
|
|
84
|
+
- loss_params 为 None -> {}
|
|
85
|
+
- loss_params 为 dict -> 所有 head 共用
|
|
86
|
+
- loss_params 为 list[dict] -> 用 loss_params[index](若存在且非 None),否则 {}
|
|
156
87
|
"""
|
|
157
88
|
if loss_params is None:
|
|
158
89
|
return {}
|
|
@@ -160,4 +91,4 @@ def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> di
|
|
|
160
91
|
if index < len(loss_params) and loss_params[index] is not None:
|
|
161
92
|
return loss_params[index]
|
|
162
93
|
return {}
|
|
163
|
-
return loss_params
|
|
94
|
+
return loss_params
|
|
@@ -40,7 +40,7 @@ class ESMM(BaseModel):
|
|
|
40
40
|
ctr_params: dict,
|
|
41
41
|
cvr_params: dict,
|
|
42
42
|
target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
|
|
43
|
-
task:
|
|
43
|
+
task: list[str] = ['binary', 'binary'],
|
|
44
44
|
optimizer: str = "adam",
|
|
45
45
|
optimizer_params: dict = {},
|
|
46
46
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -1,12 +1,57 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Checkpoint: edit on 24/11/2025
|
|
4
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
6
|
+
[1] Song W, Shi C, Xiao Z, et al. Autoint: Automatic feature interaction learning via
|
|
7
|
+
self-attentive neural networks[C]//Proceedings of the 28th ACM international conference
|
|
8
|
+
on information and knowledge management. 2019: 1161-1170.
|
|
9
|
+
(https://arxiv.org/abs/1810.11921)
|
|
10
|
+
|
|
11
|
+
AutoInt is a CTR prediction model that leverages multi-head self-attention
|
|
12
|
+
to automatically learn high-order feature interactions in an explicit and
|
|
13
|
+
interpretable way. Instead of relying on manual feature engineering or
|
|
14
|
+
implicit MLP-based transformations, AutoInt models feature dependencies
|
|
15
|
+
by attending over all embedded fields and capturing their contextual
|
|
16
|
+
relationships.
|
|
17
|
+
|
|
18
|
+
In each Interacting Layer:
|
|
19
|
+
(1) Each field embedding is projected into multiple attention heads
|
|
20
|
+
(2) Scaled dot-product attention computes feature-to-feature interactions
|
|
21
|
+
(3) Outputs are aggregated and passed through residual connections
|
|
22
|
+
(4) Layer Normalization ensures stable optimization
|
|
23
|
+
|
|
24
|
+
By stacking multiple Interacting Layers, AutoInt progressively discovers
|
|
25
|
+
higher-order feature interactions, while maintaining transparency since
|
|
26
|
+
attention weights explicitly show which features interact.
|
|
27
|
+
|
|
28
|
+
Key Advantages:
|
|
29
|
+
- Explicit modeling of high-order feature interactions
|
|
30
|
+
- Multi-head attention enhances representation diversity
|
|
31
|
+
- Residual structure facilitates deep interaction learning
|
|
32
|
+
- Attention weights provide interpretability of feature relations
|
|
33
|
+
- Eliminates heavy manual feature engineering
|
|
34
|
+
|
|
35
|
+
AutoInt 是一个 CTR 预估模型,通过多头自注意力机制显式学习高阶特征交互,
|
|
36
|
+
并具有良好的可解释性。不同于依赖人工特征工程或 MLP 隐式建模的方法,
|
|
37
|
+
AutoInt 通过对所有特征 embedding 进行注意力计算,捕捉特征之间的上下文依赖关系。
|
|
38
|
+
|
|
39
|
+
在每个 Interacting Layer(交互层)中:
|
|
40
|
+
(1) 每个特征 embedding 通过投影分成多个注意力头
|
|
41
|
+
(2) 使用缩放点积注意力计算特征间交互权重
|
|
42
|
+
(3) 将多头输出进行聚合,并使用残差连接
|
|
43
|
+
(4) Layer Normalization 确保训练稳定性
|
|
44
|
+
|
|
45
|
+
通过堆叠多个交互层,AutoInt 能逐步学习更高阶的特征交互;
|
|
46
|
+
同时由于注意力权重可视化,模型具有明确的可解释能力,
|
|
47
|
+
能展示哪些特征之间的关系最重要。
|
|
48
|
+
|
|
49
|
+
主要优点:
|
|
50
|
+
- 显式建模高阶特征交互
|
|
51
|
+
- 多头机制增强表示能力
|
|
52
|
+
- 残差结构支持深层交互学习
|
|
53
|
+
- 注意力权重天然具备可解释性
|
|
54
|
+
- 减少繁重的人工特征工程工作
|
|
10
55
|
"""
|
|
11
56
|
|
|
12
57
|
import torch
|
|
@@ -80,7 +125,6 @@ class AutoInt(BaseModel):
|
|
|
80
125
|
|
|
81
126
|
# Project embeddings to attention embedding dimension
|
|
82
127
|
num_fields = len(self.interaction_features)
|
|
83
|
-
total_embedding_dim = sum([f.embedding_dim for f in self.interaction_features])
|
|
84
128
|
|
|
85
129
|
# If embeddings have different dimensions, project them to att_embedding_dim
|
|
86
130
|
self.need_projection = not all(f.embedding_dim == att_embedding_dim for f in self.interaction_features)
|