nextrec 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/loss/__init__.py CHANGED
@@ -37,6 +37,5 @@ __all__ = [
37
37
  # Utilities
38
38
  "get_loss_fn",
39
39
  "get_loss_kwargs",
40
- "validate_training_mode",
41
40
  "VALID_TASK_TYPES",
42
41
  ]
@@ -21,138 +21,69 @@ from nextrec.loss.pointwise import (
21
21
  WeightedBCELoss,
22
22
  )
23
23
 
24
- # Valid task types for validation
24
+
25
25
  VALID_TASK_TYPES = [
26
- "binary",
27
- "multiclass",
28
- "regression",
29
- "multivariate_regression",
30
- "match",
31
- "ranking",
32
- "multitask",
33
- "multilabel",
26
+ "binary",
27
+ "multiclass",
28
+ "multilabel",
29
+ "regression",
34
30
  ]
35
31
 
32
+ def _build_cb_focal(kw):
33
+ if "class_counts" not in kw:
34
+ raise ValueError("class_balanced_focal requires class_counts")
35
+ return ClassBalancedFocalLoss(**kw)
36
36
 
37
- def get_loss_fn(
38
- task_type: str = "binary",
39
- training_mode: str | None = None,
40
- loss: str | nn.Module | None = None,
41
- **loss_kwargs,
42
- ) -> nn.Module:
43
- """
44
- Get loss function based on task type and training mode.
45
- """
46
37
 
38
+ def get_loss_fn(loss=None, **kw):
47
39
  if isinstance(loss, nn.Module):
48
40
  return loss
49
-
50
- # Common mappings
51
- if task_type == "match":
52
- return _get_match_loss(training_mode, loss, **loss_kwargs)
53
-
54
- if task_type in ["ranking", "multitask", "binary", "multilabel"]:
55
- return _get_classification_loss(loss, **loss_kwargs)
56
-
57
- if task_type == "multiclass":
58
- return _get_multiclass_loss(loss, **loss_kwargs)
59
-
60
- if task_type == "regression":
61
- if loss is None or loss == "mse":
62
- return nn.MSELoss(**loss_kwargs)
63
- if loss == "mae":
64
- return nn.L1Loss(**loss_kwargs)
65
- if isinstance(loss, str):
66
- raise ValueError(f"Unsupported regression loss: {loss}")
67
-
68
- raise ValueError(f"Unsupported task_type: {task_type}")
69
-
70
-
71
- def _get_match_loss(training_mode: str | None, loss: str | None, **loss_kwargs) -> nn.Module:
72
- if training_mode == "pointwise":
73
- if loss is None or loss in {"bce", "binary_crossentropy"}:
74
- return nn.BCELoss(**loss_kwargs)
75
- if loss == "weighted_bce":
76
- return WeightedBCELoss(**loss_kwargs)
77
- if loss == "focal":
78
- return FocalLoss(**loss_kwargs)
79
- if loss == "class_balanced_focal":
80
- return _build_cb_focal(loss_kwargs)
81
- if loss == "cosine_contrastive":
82
- return CosineContrastiveLoss(**loss_kwargs)
83
- if isinstance(loss, str):
84
- raise ValueError(f"Unsupported pointwise loss: {loss}")
85
-
86
- if training_mode == "pairwise":
87
- if loss is None or loss == "bpr":
88
- return BPRLoss(**loss_kwargs)
89
- if loss == "hinge":
90
- return HingeLoss(**loss_kwargs)
91
- if loss == "triplet":
92
- return TripletLoss(**loss_kwargs)
93
- if isinstance(loss, str):
94
- raise ValueError(f"Unsupported pairwise loss: {loss}")
95
-
96
- if training_mode == "listwise":
97
- if loss is None or loss in {"sampled_softmax", "softmax"}:
98
- return SampledSoftmaxLoss(**loss_kwargs)
99
- if loss == "infonce":
100
- return InfoNCELoss(**loss_kwargs)
101
- if loss == "listnet":
102
- return ListNetLoss(**loss_kwargs)
103
- if loss == "listmle":
104
- return ListMLELoss(**loss_kwargs)
105
- if loss == "approx_ndcg":
106
- return ApproxNDCGLoss(**loss_kwargs)
107
- if loss in {"crossentropy", "ce"}:
108
- return nn.CrossEntropyLoss(**loss_kwargs)
109
- if isinstance(loss, str):
110
- raise ValueError(f"Unsupported listwise loss: {loss}")
111
-
112
- raise ValueError(f"Unknown training_mode: {training_mode}")
113
-
114
-
115
- def _get_classification_loss(loss: str | None, **loss_kwargs) -> nn.Module:
116
- if loss is None or loss in {"bce", "binary_crossentropy"}:
117
- return nn.BCELoss(**loss_kwargs)
41
+ if loss is None:
42
+ raise ValueError("loss must be provided explicitly")
43
+ if loss in ["bce", "binary_crossentropy"]:
44
+ return nn.BCELoss(**kw)
118
45
  if loss == "weighted_bce":
119
- return WeightedBCELoss(**loss_kwargs)
120
- if loss == "focal":
121
- return FocalLoss(**loss_kwargs)
122
- if loss == "class_balanced_focal":
123
- return _build_cb_focal(loss_kwargs)
46
+ return WeightedBCELoss(**kw)
47
+ if loss in ["focal", "focal_loss"]:
48
+ return FocalLoss(**kw)
49
+ if loss in ["cb_focal", "class_balanced_focal"]:
50
+ return _build_cb_focal(kw)
51
+ if loss in ["crossentropy", "ce"]:
52
+ return nn.CrossEntropyLoss(**kw)
124
53
  if loss == "mse":
125
- return nn.MSELoss(**loss_kwargs)
54
+ return nn.MSELoss(**kw)
126
55
  if loss == "mae":
127
- return nn.L1Loss(**loss_kwargs)
128
- if loss in {"crossentropy", "ce"}:
129
- return nn.CrossEntropyLoss(**loss_kwargs)
130
- if isinstance(loss, str):
131
- raise ValueError(f"Unsupported loss function: {loss}")
132
- raise ValueError("Loss must be specified for classification task.")
133
-
134
-
135
- def _get_multiclass_loss(loss: str | None, **loss_kwargs) -> nn.Module:
136
- if loss is None or loss in {"crossentropy", "ce"}:
137
- return nn.CrossEntropyLoss(**loss_kwargs)
138
- if loss == "focal":
139
- return FocalLoss(**loss_kwargs)
140
- if loss == "class_balanced_focal":
141
- return _build_cb_focal(loss_kwargs)
142
- if isinstance(loss, str):
143
- raise ValueError(f"Unsupported multiclass loss: {loss}")
144
- raise ValueError("Loss must be specified for multiclass task.")
145
-
146
-
147
- def _build_cb_focal(loss_kwargs: dict) -> ClassBalancedFocalLoss:
148
- if "class_counts" not in loss_kwargs:
149
- raise ValueError("class_balanced_focal requires `class_counts` argument.")
150
- return ClassBalancedFocalLoss(**loss_kwargs)
151
-
56
+ return nn.L1Loss(**kw)
57
+
58
+ # Pairwise ranking Loss
59
+ if loss == "bpr":
60
+ return BPRLoss(**kw)
61
+ if loss == "hinge":
62
+ return HingeLoss(**kw)
63
+ if loss == "triplet":
64
+ return TripletLoss(**kw)
65
+
66
+ # Listwise ranking Loss
67
+ if loss in ["sampled_softmax", "softmax"]:
68
+ return SampledSoftmaxLoss(**kw)
69
+ if loss == "infonce":
70
+ return InfoNCELoss(**kw)
71
+ if loss == "listnet":
72
+ return ListNetLoss(**kw)
73
+ if loss == "listmle":
74
+ return ListMLELoss(**kw)
75
+ if loss == "approx_ndcg":
76
+ return ApproxNDCGLoss(**kw)
77
+
78
+ raise ValueError(f"Unsupported loss: {loss}")
152
79
 
153
80
  def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
154
81
  """
155
- Resolve per-task loss kwargs from a dict or list of dicts.
82
+ 解析每个 head 对应的 loss_kwargs。
83
+
84
+ - loss_params 为 None -> {}
85
+ - loss_params 为 dict -> 所有 head 共用
86
+ - loss_params 为 list[dict] -> 用 loss_params[index](若存在且非 None),否则 {}
156
87
  """
157
88
  if loss_params is None:
158
89
  return {}
@@ -160,4 +91,4 @@ def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> di
160
91
  if index < len(loss_params) and loss_params[index] is not None:
161
92
  return loss_params[index]
162
93
  return {}
163
- return loss_params
94
+ return loss_params
@@ -40,7 +40,7 @@ class ESMM(BaseModel):
40
40
  ctr_params: dict,
41
41
  cvr_params: dict,
42
42
  target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
43
- task: str | list[str] = 'binary',
43
+ task: list[str] = ['binary', 'binary'],
44
44
  optimizer: str = "adam",
45
45
  optimizer_params: dict = {},
46
46
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -1,12 +1,57 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 24/11/2025
4
+ Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
- [1] Song W, Shi C, Xiao Z, et al. Autoint: Automatic feature interaction learning via
7
- self-attentive neural networks[C]//Proceedings of the 28th ACM international conference
8
- on information and knowledge management. 2019: 1161-1170.
9
- (https://arxiv.org/abs/1810.11921)
6
+ [1] Song W, Shi C, Xiao Z, et al. Autoint: Automatic feature interaction learning via
7
+ self-attentive neural networks[C]//Proceedings of the 28th ACM international conference
8
+ on information and knowledge management. 2019: 1161-1170.
9
+ (https://arxiv.org/abs/1810.11921)
10
+
11
+ AutoInt is a CTR prediction model that leverages multi-head self-attention
12
+ to automatically learn high-order feature interactions in an explicit and
13
+ interpretable way. Instead of relying on manual feature engineering or
14
+ implicit MLP-based transformations, AutoInt models feature dependencies
15
+ by attending over all embedded fields and capturing their contextual
16
+ relationships.
17
+
18
+ In each Interacting Layer:
19
+ (1) Each field embedding is projected into multiple attention heads
20
+ (2) Scaled dot-product attention computes feature-to-feature interactions
21
+ (3) Outputs are aggregated and passed through residual connections
22
+ (4) Layer Normalization ensures stable optimization
23
+
24
+ By stacking multiple Interacting Layers, AutoInt progressively discovers
25
+ higher-order feature interactions, while maintaining transparency since
26
+ attention weights explicitly show which features interact.
27
+
28
+ Key Advantages:
29
+ - Explicit modeling of high-order feature interactions
30
+ - Multi-head attention enhances representation diversity
31
+ - Residual structure facilitates deep interaction learning
32
+ - Attention weights provide interpretability of feature relations
33
+ - Eliminates heavy manual feature engineering
34
+
35
+ AutoInt 是一个 CTR 预估模型,通过多头自注意力机制显式学习高阶特征交互,
36
+ 并具有良好的可解释性。不同于依赖人工特征工程或 MLP 隐式建模的方法,
37
+ AutoInt 通过对所有特征 embedding 进行注意力计算,捕捉特征之间的上下文依赖关系。
38
+
39
+ 在每个 Interacting Layer(交互层)中:
40
+ (1) 每个特征 embedding 通过投影分成多个注意力头
41
+ (2) 使用缩放点积注意力计算特征间交互权重
42
+ (3) 将多头输出进行聚合,并使用残差连接
43
+ (4) Layer Normalization 确保训练稳定性
44
+
45
+ 通过堆叠多个交互层,AutoInt 能逐步学习更高阶的特征交互;
46
+ 同时由于注意力权重可视化,模型具有明确的可解释能力,
47
+ 能展示哪些特征之间的关系最重要。
48
+
49
+ 主要优点:
50
+ - 显式建模高阶特征交互
51
+ - 多头机制增强表示能力
52
+ - 残差结构支持深层交互学习
53
+ - 注意力权重天然具备可解释性
54
+ - 减少繁重的人工特征工程工作
10
55
  """
11
56
 
12
57
  import torch
@@ -80,7 +125,6 @@ class AutoInt(BaseModel):
80
125
 
81
126
  # Project embeddings to attention embedding dimension
82
127
  num_fields = len(self.interaction_features)
83
- total_embedding_dim = sum([f.embedding_dim for f in self.interaction_features])
84
128
 
85
129
  # If embeddings have different dimensions, project them to att_embedding_dim
86
130
  self.need_projection = not all(f.embedding_dim == att_embedding_dim for f in self.interaction_features)