nextrec 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +2 -1
- nextrec/basic/layers.py +2 -2
- nextrec/basic/model.py +82 -49
- nextrec/data/__init__.py +2 -4
- nextrec/data/dataloader.py +3 -3
- nextrec/data/preprocessor.py +2 -2
- nextrec/loss/__init__.py +31 -24
- nextrec/loss/listwise.py +162 -4
- nextrec/loss/loss_utils.py +133 -105
- nextrec/loss/pairwise.py +103 -4
- nextrec/loss/pointwise.py +196 -4
- nextrec/models/match/dssm.py +24 -15
- nextrec/models/match/dssm_v2.py +18 -0
- nextrec/models/match/mind.py +16 -1
- nextrec/models/match/sdm.py +15 -0
- nextrec/models/match/youtube_dnn.py +21 -8
- nextrec/models/multi_task/esmm.py +5 -5
- nextrec/models/multi_task/mmoe.py +5 -5
- nextrec/models/multi_task/ple.py +5 -5
- nextrec/models/multi_task/share_bottom.py +5 -5
- nextrec/models/ranking/__init__.py +8 -0
- nextrec/models/ranking/afm.py +3 -1
- nextrec/models/ranking/autoint.py +3 -1
- nextrec/models/ranking/dcn.py +3 -1
- nextrec/models/ranking/deepfm.py +3 -1
- nextrec/models/ranking/dien.py +3 -1
- nextrec/models/ranking/din.py +3 -1
- nextrec/models/ranking/fibinet.py +3 -1
- nextrec/models/ranking/fm.py +3 -1
- nextrec/models/ranking/masknet.py +3 -1
- nextrec/models/ranking/pnn.py +3 -1
- nextrec/models/ranking/widedeep.py +3 -1
- nextrec/models/ranking/xdeepfm.py +3 -1
- nextrec/utils/__init__.py +5 -5
- nextrec/utils/initializer.py +3 -3
- nextrec/utils/optimizer.py +6 -6
- {nextrec-0.2.1.dist-info → nextrec-0.2.3.dist-info}/METADATA +2 -2
- nextrec-0.2.3.dist-info/RECORD +53 -0
- nextrec/loss/match_losses.py +0 -293
- nextrec-0.2.1.dist-info/RECORD +0 -54
- {nextrec-0.2.1.dist-info → nextrec-0.2.3.dist-info}/WHEEL +0 -0
- {nextrec-0.2.1.dist-info → nextrec-0.2.3.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/loss_utils.py
CHANGED
|
@@ -1,135 +1,163 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Loss utilities for NextRec
|
|
3
|
-
|
|
4
|
-
Date: create on 09/11/2025
|
|
5
|
-
Author: Yang Zhou,zyaztec@gmail.com
|
|
2
|
+
Loss utilities for NextRec.
|
|
6
3
|
"""
|
|
7
|
-
|
|
8
|
-
import torch.nn as nn
|
|
4
|
+
|
|
9
5
|
from typing import Literal
|
|
10
6
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from nextrec.loss.listwise import (
|
|
10
|
+
ApproxNDCGLoss,
|
|
11
|
+
InfoNCELoss,
|
|
12
|
+
ListMLELoss,
|
|
13
|
+
ListNetLoss,
|
|
15
14
|
SampledSoftmaxLoss,
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
)
|
|
16
|
+
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
17
|
+
from nextrec.loss.pointwise import (
|
|
18
|
+
ClassBalancedFocalLoss,
|
|
19
|
+
CosineContrastiveLoss,
|
|
20
|
+
FocalLoss,
|
|
21
|
+
WeightedBCELoss,
|
|
18
22
|
)
|
|
19
23
|
|
|
20
24
|
# Valid task types for validation
|
|
21
|
-
VALID_TASK_TYPES = [
|
|
25
|
+
VALID_TASK_TYPES = [
|
|
26
|
+
"binary",
|
|
27
|
+
"multiclass",
|
|
28
|
+
"regression",
|
|
29
|
+
"multivariate_regression",
|
|
30
|
+
"match",
|
|
31
|
+
"ranking",
|
|
32
|
+
"multitask",
|
|
33
|
+
"multilabel",
|
|
34
|
+
]
|
|
22
35
|
|
|
23
36
|
|
|
24
37
|
def get_loss_fn(
|
|
25
38
|
task_type: str = "binary",
|
|
26
39
|
training_mode: str | None = None,
|
|
27
40
|
loss: str | nn.Module | None = None,
|
|
28
|
-
**loss_kwargs
|
|
41
|
+
**loss_kwargs,
|
|
29
42
|
) -> nn.Module:
|
|
30
43
|
"""
|
|
31
44
|
Get loss function based on task type and training mode.
|
|
32
|
-
|
|
33
|
-
Examples:
|
|
34
|
-
# Ranking task (binary classification)
|
|
35
|
-
>>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
|
|
36
|
-
|
|
37
|
-
# Match task with pointwise training
|
|
38
|
-
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
|
|
39
|
-
|
|
40
|
-
# Match task with pairwise training
|
|
41
|
-
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
|
|
42
|
-
|
|
43
|
-
# Match task with listwise training
|
|
44
|
-
>>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
47
|
if isinstance(loss, nn.Module):
|
|
48
48
|
return loss
|
|
49
49
|
|
|
50
|
+
# Common mappings
|
|
50
51
|
if task_type == "match":
|
|
51
|
-
|
|
52
|
-
# Pointwise training uses binary cross entropy
|
|
53
|
-
if loss is None or loss == "bce" or loss == "binary_crossentropy":
|
|
54
|
-
return nn.BCELoss(**loss_kwargs)
|
|
55
|
-
elif loss == "cosine_contrastive":
|
|
56
|
-
return CosineContrastiveLoss(**loss_kwargs)
|
|
57
|
-
elif isinstance(loss, str):
|
|
58
|
-
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
59
|
-
|
|
60
|
-
elif training_mode == "pairwise":
|
|
61
|
-
if loss is None or loss == "bpr":
|
|
62
|
-
return BPRLoss(**loss_kwargs)
|
|
63
|
-
elif loss == "hinge":
|
|
64
|
-
return HingeLoss(**loss_kwargs)
|
|
65
|
-
elif loss == "triplet":
|
|
66
|
-
return TripletLoss(**loss_kwargs)
|
|
67
|
-
elif isinstance(loss, str):
|
|
68
|
-
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
69
|
-
|
|
70
|
-
elif training_mode == "listwise":
|
|
71
|
-
if loss is None or loss == "sampled_softmax" or loss == "softmax":
|
|
72
|
-
return SampledSoftmaxLoss(**loss_kwargs)
|
|
73
|
-
elif loss == "infonce":
|
|
74
|
-
return InfoNCELoss(**loss_kwargs)
|
|
75
|
-
elif loss == "crossentropy" or loss == "ce":
|
|
76
|
-
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
77
|
-
elif isinstance(loss, str):
|
|
78
|
-
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
79
|
-
|
|
80
|
-
else:
|
|
81
|
-
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
82
|
-
|
|
83
|
-
elif task_type in ["ranking", "multitask", "binary"]:
|
|
84
|
-
if loss is None or loss == "bce" or loss == "binary_crossentropy":
|
|
85
|
-
return nn.BCELoss(**loss_kwargs)
|
|
86
|
-
elif loss == "mse":
|
|
87
|
-
return nn.MSELoss(**loss_kwargs)
|
|
88
|
-
elif loss == "mae":
|
|
89
|
-
return nn.L1Loss(**loss_kwargs)
|
|
90
|
-
elif loss == "crossentropy" or loss == "ce":
|
|
91
|
-
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
92
|
-
elif isinstance(loss, str):
|
|
93
|
-
raise ValueError(f"Unsupported loss function: {loss}")
|
|
52
|
+
return _get_match_loss(training_mode, loss, **loss_kwargs)
|
|
94
53
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
54
|
+
if task_type in ["ranking", "multitask", "binary", "multilabel"]:
|
|
55
|
+
return _get_classification_loss(loss, **loss_kwargs)
|
|
56
|
+
|
|
57
|
+
if task_type == "multiclass":
|
|
58
|
+
return _get_multiclass_loss(loss, **loss_kwargs)
|
|
59
|
+
|
|
60
|
+
if task_type == "regression":
|
|
102
61
|
if loss is None or loss == "mse":
|
|
103
62
|
return nn.MSELoss(**loss_kwargs)
|
|
104
|
-
|
|
63
|
+
if loss == "mae":
|
|
105
64
|
return nn.L1Loss(**loss_kwargs)
|
|
106
|
-
|
|
65
|
+
if isinstance(loss, str):
|
|
107
66
|
raise ValueError(f"Unsupported regression loss: {loss}")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
)
|
|
67
|
+
|
|
68
|
+
raise ValueError(f"Unsupported task_type: {task_type}")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_match_loss(training_mode: str | None, loss: str | None, **loss_kwargs) -> nn.Module:
|
|
72
|
+
if training_mode == "pointwise":
|
|
73
|
+
if loss is None or loss in {"bce", "binary_crossentropy"}:
|
|
74
|
+
return nn.BCELoss(**loss_kwargs)
|
|
75
|
+
if loss == "weighted_bce":
|
|
76
|
+
return WeightedBCELoss(**loss_kwargs)
|
|
77
|
+
if loss == "focal":
|
|
78
|
+
return FocalLoss(**loss_kwargs)
|
|
79
|
+
if loss == "class_balanced_focal":
|
|
80
|
+
return _build_cb_focal(loss_kwargs)
|
|
81
|
+
if loss == "cosine_contrastive":
|
|
82
|
+
return CosineContrastiveLoss(**loss_kwargs)
|
|
83
|
+
if isinstance(loss, str):
|
|
84
|
+
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
85
|
+
|
|
86
|
+
if training_mode == "pairwise":
|
|
87
|
+
if loss is None or loss == "bpr":
|
|
88
|
+
return BPRLoss(**loss_kwargs)
|
|
89
|
+
if loss == "hinge":
|
|
90
|
+
return HingeLoss(**loss_kwargs)
|
|
91
|
+
if loss == "triplet":
|
|
92
|
+
return TripletLoss(**loss_kwargs)
|
|
93
|
+
if isinstance(loss, str):
|
|
94
|
+
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
95
|
+
|
|
96
|
+
if training_mode == "listwise":
|
|
97
|
+
if loss is None or loss in {"sampled_softmax", "softmax"}:
|
|
98
|
+
return SampledSoftmaxLoss(**loss_kwargs)
|
|
99
|
+
if loss == "infonce":
|
|
100
|
+
return InfoNCELoss(**loss_kwargs)
|
|
101
|
+
if loss == "listnet":
|
|
102
|
+
return ListNetLoss(**loss_kwargs)
|
|
103
|
+
if loss == "listmle":
|
|
104
|
+
return ListMLELoss(**loss_kwargs)
|
|
105
|
+
if loss == "approx_ndcg":
|
|
106
|
+
return ApproxNDCGLoss(**loss_kwargs)
|
|
107
|
+
if loss in {"crossentropy", "ce"}:
|
|
108
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
109
|
+
if isinstance(loss, str):
|
|
110
|
+
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
111
|
+
|
|
112
|
+
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _get_classification_loss(loss: str | None, **loss_kwargs) -> nn.Module:
|
|
116
|
+
if loss is None or loss in {"bce", "binary_crossentropy"}:
|
|
117
|
+
return nn.BCELoss(**loss_kwargs)
|
|
118
|
+
if loss == "weighted_bce":
|
|
119
|
+
return WeightedBCELoss(**loss_kwargs)
|
|
120
|
+
if loss == "focal":
|
|
121
|
+
return FocalLoss(**loss_kwargs)
|
|
122
|
+
if loss == "class_balanced_focal":
|
|
123
|
+
return _build_cb_focal(loss_kwargs)
|
|
124
|
+
if loss == "mse":
|
|
125
|
+
return nn.MSELoss(**loss_kwargs)
|
|
126
|
+
if loss == "mae":
|
|
127
|
+
return nn.L1Loss(**loss_kwargs)
|
|
128
|
+
if loss in {"crossentropy", "ce"}:
|
|
129
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
130
|
+
if isinstance(loss, str):
|
|
131
|
+
raise ValueError(f"Unsupported loss function: {loss}")
|
|
132
|
+
raise ValueError("Loss must be specified for classification task.")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _get_multiclass_loss(loss: str | None, **loss_kwargs) -> nn.Module:
|
|
136
|
+
if loss is None or loss in {"crossentropy", "ce"}:
|
|
137
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
138
|
+
if loss == "focal":
|
|
139
|
+
return FocalLoss(**loss_kwargs)
|
|
140
|
+
if loss == "class_balanced_focal":
|
|
141
|
+
return _build_cb_focal(loss_kwargs)
|
|
142
|
+
if isinstance(loss, str):
|
|
143
|
+
raise ValueError(f"Unsupported multiclass loss: {loss}")
|
|
144
|
+
raise ValueError("Loss must be specified for multiclass task.")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _build_cb_focal(loss_kwargs: dict) -> ClassBalancedFocalLoss:
|
|
148
|
+
if "class_counts" not in loss_kwargs:
|
|
149
|
+
raise ValueError("class_balanced_focal requires `class_counts` argument.")
|
|
150
|
+
return ClassBalancedFocalLoss(**loss_kwargs)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
|
|
120
154
|
"""
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
training_mode: Requested training mode
|
|
125
|
-
support_training_modes: List of supported training modes
|
|
126
|
-
model_name: Name of the model (for error messages)
|
|
127
|
-
|
|
128
|
-
Raises:
|
|
129
|
-
ValueError: If training mode is not supported
|
|
155
|
+
Resolve per-task loss kwargs from a dict or list of dicts.
|
|
130
156
|
"""
|
|
131
|
-
if
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
157
|
+
if loss_params is None:
|
|
158
|
+
return {}
|
|
159
|
+
if isinstance(loss_params, list):
|
|
160
|
+
if index < len(loss_params) and loss_params[index] is not None:
|
|
161
|
+
return loss_params[index]
|
|
162
|
+
return {}
|
|
163
|
+
return loss_params
|
nextrec/loss/pairwise.py
CHANGED
|
@@ -1,6 +1,105 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
Date: create on 22/11/2025
|
|
5
|
-
Author: Yang Zhou,zyaztec@gmail.com
|
|
2
|
+
Pairwise loss functions for learning-to-rank and matching tasks.
|
|
6
3
|
"""
|
|
4
|
+
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BPRLoss(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Bayesian Personalized Ranking loss with support for multiple negatives.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, reduction: str = "mean"):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.reduction = reduction
|
|
20
|
+
|
|
21
|
+
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
if neg_score.dim() == 2:
|
|
23
|
+
pos_score = pos_score.unsqueeze(1)
|
|
24
|
+
diff = pos_score - neg_score
|
|
25
|
+
else:
|
|
26
|
+
diff = pos_score - neg_score
|
|
27
|
+
|
|
28
|
+
loss = -torch.log(torch.sigmoid(diff) + 1e-8)
|
|
29
|
+
if self.reduction == "mean":
|
|
30
|
+
return loss.mean()
|
|
31
|
+
if self.reduction == "sum":
|
|
32
|
+
return loss.sum()
|
|
33
|
+
return loss
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class HingeLoss(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Hinge loss for pairwise ranking.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, margin: float = 1.0, reduction: str = "mean"):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.margin = margin
|
|
44
|
+
self.reduction = reduction
|
|
45
|
+
|
|
46
|
+
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
if neg_score.dim() == 2:
|
|
48
|
+
pos_score = pos_score.unsqueeze(1)
|
|
49
|
+
|
|
50
|
+
diff = pos_score - neg_score
|
|
51
|
+
loss = torch.clamp(self.margin - diff, min=0)
|
|
52
|
+
|
|
53
|
+
if self.reduction == "mean":
|
|
54
|
+
return loss.mean()
|
|
55
|
+
if self.reduction == "sum":
|
|
56
|
+
return loss.sum()
|
|
57
|
+
return loss
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TripletLoss(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Triplet margin loss with cosine or euclidean distance.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
margin: float = 1.0,
|
|
68
|
+
reduction: str = "mean",
|
|
69
|
+
distance: Literal["euclidean", "cosine"] = "euclidean",
|
|
70
|
+
):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.margin = margin
|
|
73
|
+
self.reduction = reduction
|
|
74
|
+
self.distance = distance
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
|
|
78
|
+
) -> torch.Tensor:
|
|
79
|
+
if self.distance == "euclidean":
|
|
80
|
+
pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
|
|
81
|
+
if negative.dim() == 3:
|
|
82
|
+
anchor_expanded = anchor.unsqueeze(1)
|
|
83
|
+
neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1)
|
|
84
|
+
else:
|
|
85
|
+
neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
|
|
86
|
+
if neg_dist.dim() == 2:
|
|
87
|
+
pos_dist = pos_dist.unsqueeze(1)
|
|
88
|
+
elif self.distance == "cosine":
|
|
89
|
+
pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
|
|
90
|
+
if negative.dim() == 3:
|
|
91
|
+
anchor_expanded = anchor.unsqueeze(1)
|
|
92
|
+
neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
|
|
93
|
+
else:
|
|
94
|
+
neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
|
|
95
|
+
if neg_dist.dim() == 2:
|
|
96
|
+
pos_dist = pos_dist.unsqueeze(1)
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"Unsupported distance: {self.distance}")
|
|
99
|
+
|
|
100
|
+
loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
|
|
101
|
+
if self.reduction == "mean":
|
|
102
|
+
return loss.mean()
|
|
103
|
+
if self.reduction == "sum":
|
|
104
|
+
return loss.sum()
|
|
105
|
+
return loss
|
nextrec/loss/pointwise.py
CHANGED
|
@@ -1,6 +1,198 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
Date: create on 22/11/2025
|
|
5
|
-
Author: Yang Zhou,zyaztec@gmail.com
|
|
2
|
+
Pointwise loss functions, including imbalance-aware variants.
|
|
6
3
|
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional, Sequence
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CosineContrastiveLoss(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Contrastive loss using cosine similarity for positive/negative pairs.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, margin: float = 0.5, reduction: str = "mean"):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.margin = margin
|
|
20
|
+
self.reduction = reduction
|
|
21
|
+
|
|
22
|
+
def forward(
|
|
23
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor
|
|
24
|
+
) -> torch.Tensor:
|
|
25
|
+
labels = labels.float()
|
|
26
|
+
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
27
|
+
pos_loss = torch.clamp(self.margin - similarity, min=0) * labels
|
|
28
|
+
neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
|
|
29
|
+
loss = pos_loss + neg_loss
|
|
30
|
+
|
|
31
|
+
if self.reduction == "mean":
|
|
32
|
+
return loss.mean()
|
|
33
|
+
if self.reduction == "sum":
|
|
34
|
+
return loss.sum()
|
|
35
|
+
return loss
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class WeightedBCELoss(nn.Module):
|
|
39
|
+
"""
|
|
40
|
+
Binary cross entropy with controllable positive class weight.
|
|
41
|
+
Supports probability or logit inputs via `logits` flag.
|
|
42
|
+
If `auto_balance=True` and `pos_weight` is None, the positive weight is
|
|
43
|
+
computed from the batch as (#neg / #pos) for stable imbalance handling.
|
|
44
|
+
"""
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
pos_weight: float | torch.Tensor | None = None,
|
|
48
|
+
reduction: str = "mean",
|
|
49
|
+
logits: bool = False,
|
|
50
|
+
auto_balance: bool = False,
|
|
51
|
+
):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.reduction = reduction
|
|
54
|
+
self.logits = logits
|
|
55
|
+
self.auto_balance = auto_balance
|
|
56
|
+
|
|
57
|
+
if pos_weight is not None:
|
|
58
|
+
self.register_buffer(
|
|
59
|
+
"pos_weight",
|
|
60
|
+
torch.as_tensor(pos_weight, dtype=torch.float32),
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
self.pos_weight = None
|
|
64
|
+
|
|
65
|
+
def _resolve_pos_weight(self, labels: torch.Tensor) -> torch.Tensor:
|
|
66
|
+
if self.pos_weight is not None:
|
|
67
|
+
return self.pos_weight.to(device=labels.device)
|
|
68
|
+
|
|
69
|
+
if not self.auto_balance:
|
|
70
|
+
return torch.tensor(1.0, device=labels.device, dtype=labels.dtype)
|
|
71
|
+
|
|
72
|
+
labels_float = labels.float()
|
|
73
|
+
pos = torch.clamp(labels_float.sum(), min=1.0)
|
|
74
|
+
neg = torch.clamp(labels_float.numel() - labels_float.sum(), min=1.0)
|
|
75
|
+
return (neg / pos).to(device=labels.device, dtype=labels.dtype)
|
|
76
|
+
|
|
77
|
+
def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
labels = labels.float()
|
|
79
|
+
current_pos_weight = self._resolve_pos_weight(labels)
|
|
80
|
+
current_pos_weight = current_pos_weight.to(inputs.dtype)
|
|
81
|
+
|
|
82
|
+
if self.logits:
|
|
83
|
+
loss = F.binary_cross_entropy_with_logits(
|
|
84
|
+
inputs, labels, pos_weight=current_pos_weight, reduction="none"
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
|
|
88
|
+
base_loss = F.binary_cross_entropy(probs, labels, reduction="none")
|
|
89
|
+
loss = torch.where(labels == 1, base_loss * current_pos_weight, base_loss)
|
|
90
|
+
|
|
91
|
+
if self.reduction == "mean":
|
|
92
|
+
return loss.mean()
|
|
93
|
+
elif self.reduction == "sum":
|
|
94
|
+
return loss.sum()
|
|
95
|
+
else:
|
|
96
|
+
return loss
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class FocalLoss(nn.Module):
|
|
100
|
+
"""
|
|
101
|
+
Standard focal loss for binary or multi-class classification.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
gamma: float = 2.0,
|
|
107
|
+
alpha: Optional[float | Sequence[float] | torch.Tensor] = None,
|
|
108
|
+
reduction: str = "mean",
|
|
109
|
+
logits: bool = False,
|
|
110
|
+
):
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.gamma = gamma
|
|
113
|
+
self.reduction = reduction
|
|
114
|
+
self.logits = logits
|
|
115
|
+
self.alpha = alpha
|
|
116
|
+
|
|
117
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
118
|
+
# Multi-class case
|
|
119
|
+
if inputs.dim() > 1 and inputs.size(1) > 1:
|
|
120
|
+
log_probs = F.log_softmax(inputs, dim=1)
|
|
121
|
+
probs = log_probs.exp()
|
|
122
|
+
targets_one_hot = F.one_hot(targets.long(), num_classes=inputs.size(1)).float()
|
|
123
|
+
|
|
124
|
+
alpha = self._get_alpha(inputs)
|
|
125
|
+
alpha_factor = targets_one_hot * alpha
|
|
126
|
+
focal_weight = (1.0 - probs) ** self.gamma
|
|
127
|
+
loss = torch.sum(alpha_factor * focal_weight * (-log_probs), dim=1)
|
|
128
|
+
else:
|
|
129
|
+
targets = targets.float()
|
|
130
|
+
if self.logits:
|
|
131
|
+
ce_loss = F.binary_cross_entropy_with_logits(
|
|
132
|
+
inputs, targets, reduction="none"
|
|
133
|
+
)
|
|
134
|
+
probs = torch.sigmoid(inputs)
|
|
135
|
+
else:
|
|
136
|
+
ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
|
|
137
|
+
probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
|
|
138
|
+
|
|
139
|
+
p_t = probs * targets + (1 - probs) * (1 - targets)
|
|
140
|
+
alpha_factor = self._get_binary_alpha(targets, inputs.device)
|
|
141
|
+
focal_weight = (1.0 - p_t) ** self.gamma
|
|
142
|
+
loss = alpha_factor * focal_weight * ce_loss
|
|
143
|
+
|
|
144
|
+
if self.reduction == "mean":
|
|
145
|
+
return loss.mean()
|
|
146
|
+
if self.reduction == "sum":
|
|
147
|
+
return loss.sum()
|
|
148
|
+
return loss
|
|
149
|
+
|
|
150
|
+
def _get_alpha(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
151
|
+
if self.alpha is None:
|
|
152
|
+
return torch.ones_like(inputs)
|
|
153
|
+
if isinstance(self.alpha, torch.Tensor):
|
|
154
|
+
return self.alpha.to(inputs.device)
|
|
155
|
+
alpha_tensor = torch.tensor(self.alpha, device=inputs.device, dtype=inputs.dtype)
|
|
156
|
+
return alpha_tensor
|
|
157
|
+
|
|
158
|
+
def _get_binary_alpha(self, targets: torch.Tensor, device: torch.device) -> torch.Tensor:
|
|
159
|
+
if self.alpha is None:
|
|
160
|
+
return torch.ones_like(targets)
|
|
161
|
+
if isinstance(self.alpha, (float, int)):
|
|
162
|
+
return torch.where(targets == 1, self.alpha, 1 - float(self.alpha)).to(device)
|
|
163
|
+
alpha_tensor = torch.tensor(self.alpha, device=device, dtype=targets.dtype)
|
|
164
|
+
return torch.where(targets == 1, alpha_tensor, 1 - alpha_tensor)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class ClassBalancedFocalLoss(nn.Module):
|
|
168
|
+
"""
|
|
169
|
+
Focal loss weighted by effective number of samples per class.
|
|
170
|
+
Reference: "Class-Balanced Loss Based on Effective Number of Samples"
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
class_counts: Sequence[int] | torch.Tensor,
|
|
176
|
+
beta: float = 0.9999,
|
|
177
|
+
gamma: float = 2.0,
|
|
178
|
+
reduction: str = "mean",
|
|
179
|
+
):
|
|
180
|
+
super().__init__()
|
|
181
|
+
self.gamma = gamma
|
|
182
|
+
self.reduction = reduction
|
|
183
|
+
class_counts = torch.as_tensor(class_counts, dtype=torch.float32)
|
|
184
|
+
effective_num = 1.0 - torch.pow(beta, class_counts)
|
|
185
|
+
weights = (1.0 - beta) / (effective_num + 1e-12)
|
|
186
|
+
weights = weights / weights.sum() * len(weights)
|
|
187
|
+
self.register_buffer("class_weights", weights)
|
|
188
|
+
|
|
189
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
190
|
+
focal = FocalLoss(
|
|
191
|
+
gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True
|
|
192
|
+
)
|
|
193
|
+
loss = focal(inputs, targets)
|
|
194
|
+
if self.reduction == "mean":
|
|
195
|
+
return loss.mean()
|
|
196
|
+
if self.reduction == "sum":
|
|
197
|
+
return loss.sum()
|
|
198
|
+
return loss
|
nextrec/models/match/dssm.py
CHANGED
|
@@ -19,7 +19,8 @@ class DSSM(BaseMatchModel):
|
|
|
19
19
|
"""
|
|
20
20
|
Deep Structured Semantic Model
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
Dual-tower model that encodes user and item features separately and
|
|
23
|
+
computes similarity via cosine or dot product.
|
|
23
24
|
"""
|
|
24
25
|
|
|
25
26
|
@property
|
|
@@ -48,6 +49,12 @@ class DSSM(BaseMatchModel):
|
|
|
48
49
|
embedding_l2_reg: float = 0.0,
|
|
49
50
|
dense_l2_reg: float = 0.0,
|
|
50
51
|
early_stop_patience: int = 20,
|
|
52
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
53
|
+
optimizer_params: dict | None = None,
|
|
54
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
55
|
+
scheduler_params: dict | None = None,
|
|
56
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
57
|
+
loss_params: dict | list[dict] | None = None,
|
|
51
58
|
**kwargs):
|
|
52
59
|
|
|
53
60
|
super(DSSM, self).__init__(
|
|
@@ -86,7 +93,7 @@ class DSSM(BaseMatchModel):
|
|
|
86
93
|
if len(user_features) > 0:
|
|
87
94
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
88
95
|
|
|
89
|
-
#
|
|
96
|
+
# Compute user tower input dimension
|
|
90
97
|
user_input_dim = 0
|
|
91
98
|
for feat in user_dense_features or []:
|
|
92
99
|
user_input_dim += 1
|
|
@@ -117,7 +124,7 @@ class DSSM(BaseMatchModel):
|
|
|
117
124
|
if len(item_features) > 0:
|
|
118
125
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
119
126
|
|
|
120
|
-
#
|
|
127
|
+
# Compute item tower input dimension
|
|
121
128
|
item_input_dim = 0
|
|
122
129
|
for feat in item_dense_features or []:
|
|
123
130
|
item_input_dim += 1
|
|
@@ -136,7 +143,6 @@ class DSSM(BaseMatchModel):
|
|
|
136
143
|
activation=dnn_activation
|
|
137
144
|
)
|
|
138
145
|
|
|
139
|
-
# 注册正则化权重
|
|
140
146
|
self._register_regularization_weights(
|
|
141
147
|
embedding_attr='user_embedding',
|
|
142
148
|
include_modules=['user_dnn']
|
|
@@ -146,28 +152,33 @@ class DSSM(BaseMatchModel):
|
|
|
146
152
|
include_modules=['item_dnn']
|
|
147
153
|
)
|
|
148
154
|
|
|
155
|
+
if optimizer_params is None:
|
|
156
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
157
|
+
|
|
149
158
|
self.compile(
|
|
150
|
-
optimizer=
|
|
151
|
-
optimizer_params=
|
|
159
|
+
optimizer=optimizer,
|
|
160
|
+
optimizer_params=optimizer_params,
|
|
161
|
+
scheduler=scheduler,
|
|
162
|
+
scheduler_params=scheduler_params,
|
|
163
|
+
loss=loss,
|
|
164
|
+
loss_params=loss_params,
|
|
152
165
|
)
|
|
153
166
|
|
|
154
167
|
self.to(device)
|
|
155
168
|
|
|
156
169
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
157
170
|
"""
|
|
158
|
-
User tower
|
|
171
|
+
User tower encodes user features into embeddings.
|
|
159
172
|
|
|
160
173
|
Args:
|
|
161
|
-
user_input: user
|
|
174
|
+
user_input: user feature dict
|
|
162
175
|
|
|
163
176
|
Returns:
|
|
164
177
|
user_emb: [batch_size, embedding_dim]
|
|
165
178
|
"""
|
|
166
|
-
# 获取user特征的embedding
|
|
167
179
|
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
168
180
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
169
181
|
|
|
170
|
-
# 通过user DNN
|
|
171
182
|
user_emb = self.user_dnn(user_emb)
|
|
172
183
|
|
|
173
184
|
# L2 normalize for cosine similarity
|
|
@@ -178,19 +189,17 @@ class DSSM(BaseMatchModel):
|
|
|
178
189
|
|
|
179
190
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
180
191
|
"""
|
|
181
|
-
Item tower
|
|
192
|
+
Item tower encodes item features into embeddings.
|
|
182
193
|
|
|
183
194
|
Args:
|
|
184
|
-
item_input: item
|
|
195
|
+
item_input: item feature dict
|
|
185
196
|
|
|
186
197
|
Returns:
|
|
187
|
-
item_emb: [batch_size, embedding_dim]
|
|
198
|
+
item_emb: [batch_size, embedding_dim] or [batch_size, num_items, embedding_dim]
|
|
188
199
|
"""
|
|
189
|
-
# 获取item特征的embedding
|
|
190
200
|
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
191
201
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
192
202
|
|
|
193
|
-
# 通过item DNN
|
|
194
203
|
item_emb = self.item_dnn(item_emb)
|
|
195
204
|
|
|
196
205
|
# L2 normalize for cosine similarity
|