nextrec 0.1.11__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +1 -2
- nextrec/basic/callback.py +1 -2
- nextrec/basic/features.py +39 -8
- nextrec/basic/layers.py +3 -4
- nextrec/basic/loggers.py +15 -10
- nextrec/basic/metrics.py +1 -2
- nextrec/basic/model.py +160 -125
- nextrec/basic/session.py +150 -0
- nextrec/data/__init__.py +13 -2
- nextrec/data/data_utils.py +74 -22
- nextrec/data/dataloader.py +513 -0
- nextrec/data/preprocessor.py +494 -134
- nextrec/loss/__init__.py +31 -24
- nextrec/loss/listwise.py +164 -0
- nextrec/loss/loss_utils.py +133 -106
- nextrec/loss/pairwise.py +105 -0
- nextrec/loss/pointwise.py +198 -0
- nextrec/models/match/dssm.py +26 -17
- nextrec/models/match/dssm_v2.py +20 -2
- nextrec/models/match/mind.py +18 -3
- nextrec/models/match/sdm.py +17 -2
- nextrec/models/match/youtube_dnn.py +23 -10
- nextrec/models/multi_task/esmm.py +8 -8
- nextrec/models/multi_task/mmoe.py +8 -8
- nextrec/models/multi_task/ple.py +8 -8
- nextrec/models/multi_task/share_bottom.py +8 -8
- nextrec/models/ranking/__init__.py +8 -0
- nextrec/models/ranking/afm.py +5 -4
- nextrec/models/ranking/autoint.py +6 -4
- nextrec/models/ranking/dcn.py +6 -4
- nextrec/models/ranking/deepfm.py +5 -4
- nextrec/models/ranking/dien.py +6 -4
- nextrec/models/ranking/din.py +6 -4
- nextrec/models/ranking/fibinet.py +6 -4
- nextrec/models/ranking/fm.py +6 -4
- nextrec/models/ranking/masknet.py +6 -4
- nextrec/models/ranking/pnn.py +6 -4
- nextrec/models/ranking/widedeep.py +6 -4
- nextrec/models/ranking/xdeepfm.py +6 -4
- nextrec/utils/__init__.py +7 -11
- nextrec/utils/embedding.py +2 -4
- nextrec/utils/initializer.py +4 -5
- nextrec/utils/optimizer.py +7 -8
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/METADATA +3 -3
- nextrec-0.2.2.dist-info/RECORD +53 -0
- nextrec/basic/dataloader.py +0 -447
- nextrec/loss/match_losses.py +0 -294
- nextrec/utils/common.py +0 -14
- nextrec-0.1.11.dist-info/RECORD +0 -51
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py
CHANGED
|
@@ -1,35 +1,42 @@
|
|
|
1
|
-
from nextrec.loss.
|
|
2
|
-
|
|
3
|
-
HingeLoss,
|
|
4
|
-
TripletLoss,
|
|
5
|
-
SampledSoftmaxLoss,
|
|
6
|
-
CosineContrastiveLoss,
|
|
1
|
+
from nextrec.loss.listwise import (
|
|
2
|
+
ApproxNDCGLoss,
|
|
7
3
|
InfoNCELoss,
|
|
8
|
-
ListNetLoss,
|
|
9
4
|
ListMLELoss,
|
|
10
|
-
|
|
5
|
+
ListNetLoss,
|
|
6
|
+
SampledSoftmaxLoss,
|
|
7
|
+
)
|
|
8
|
+
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
9
|
+
from nextrec.loss.pointwise import (
|
|
10
|
+
ClassBalancedFocalLoss,
|
|
11
|
+
CosineContrastiveLoss,
|
|
12
|
+
FocalLoss,
|
|
13
|
+
WeightedBCELoss,
|
|
11
14
|
)
|
|
12
|
-
|
|
13
15
|
from nextrec.loss.loss_utils import (
|
|
14
16
|
get_loss_fn,
|
|
15
|
-
|
|
17
|
+
get_loss_kwargs,
|
|
16
18
|
VALID_TASK_TYPES,
|
|
17
19
|
)
|
|
18
20
|
|
|
19
21
|
__all__ = [
|
|
20
|
-
#
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
22
|
+
# Pointwise
|
|
23
|
+
"CosineContrastiveLoss",
|
|
24
|
+
"WeightedBCELoss",
|
|
25
|
+
"FocalLoss",
|
|
26
|
+
"ClassBalancedFocalLoss",
|
|
27
|
+
# Pairwise
|
|
28
|
+
"BPRLoss",
|
|
29
|
+
"HingeLoss",
|
|
30
|
+
"TripletLoss",
|
|
31
|
+
# Listwise
|
|
32
|
+
"SampledSoftmaxLoss",
|
|
33
|
+
"InfoNCELoss",
|
|
34
|
+
"ListNetLoss",
|
|
35
|
+
"ListMLELoss",
|
|
36
|
+
"ApproxNDCGLoss",
|
|
31
37
|
# Utilities
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
38
|
+
"get_loss_fn",
|
|
39
|
+
"get_loss_kwargs",
|
|
40
|
+
"validate_training_mode",
|
|
41
|
+
"VALID_TASK_TYPES",
|
|
35
42
|
]
|
nextrec/loss/listwise.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Listwise loss functions for ranking and contrastive training.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SampledSoftmaxLoss(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Softmax over one positive and multiple sampled negatives.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, reduction: str = "mean"):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.reduction = reduction
|
|
20
|
+
|
|
21
|
+
def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
pos_logits = pos_logits.unsqueeze(1)
|
|
23
|
+
all_logits = torch.cat([pos_logits, neg_logits], dim=1)
|
|
24
|
+
targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
|
|
25
|
+
loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
|
|
26
|
+
return loss
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class InfoNCELoss(nn.Module):
|
|
30
|
+
"""
|
|
31
|
+
InfoNCE loss for contrastive learning with one positive and many negatives.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.temperature = temperature
|
|
37
|
+
self.reduction = reduction
|
|
38
|
+
|
|
39
|
+
def forward(
|
|
40
|
+
self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor
|
|
41
|
+
) -> torch.Tensor:
|
|
42
|
+
pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature
|
|
43
|
+
pos_sim = pos_sim.unsqueeze(1)
|
|
44
|
+
query_expanded = query.unsqueeze(1)
|
|
45
|
+
neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature
|
|
46
|
+
logits = torch.cat([pos_sim, neg_sim], dim=1)
|
|
47
|
+
labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
|
|
48
|
+
loss = F.cross_entropy(logits, labels, reduction=self.reduction)
|
|
49
|
+
return loss
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ListNetLoss(nn.Module):
|
|
53
|
+
"""
|
|
54
|
+
ListNet loss using top-1 probability distribution.
|
|
55
|
+
Reference: Cao et al. (ICML 2007)
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.temperature = temperature
|
|
61
|
+
self.reduction = reduction
|
|
62
|
+
|
|
63
|
+
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
pred_probs = F.softmax(scores / self.temperature, dim=1)
|
|
65
|
+
true_probs = F.softmax(labels / self.temperature, dim=1)
|
|
66
|
+
loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
|
|
67
|
+
|
|
68
|
+
if self.reduction == "mean":
|
|
69
|
+
return loss.mean()
|
|
70
|
+
if self.reduction == "sum":
|
|
71
|
+
return loss.sum()
|
|
72
|
+
return loss
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ListMLELoss(nn.Module):
|
|
76
|
+
"""
|
|
77
|
+
ListMLE (Maximum Likelihood Estimation) loss.
|
|
78
|
+
Reference: Xia et al. (ICML 2008)
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self, reduction: str = "mean"):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.reduction = reduction
|
|
84
|
+
|
|
85
|
+
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
|
|
87
|
+
batch_size, list_size = scores.shape
|
|
88
|
+
batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
|
|
89
|
+
sorted_scores = scores[batch_indices, sorted_indices]
|
|
90
|
+
|
|
91
|
+
loss = torch.tensor(0.0, device=scores.device)
|
|
92
|
+
for i in range(list_size):
|
|
93
|
+
remaining_scores = sorted_scores[:, i:]
|
|
94
|
+
log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
|
|
95
|
+
loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
|
|
96
|
+
|
|
97
|
+
if self.reduction == "mean":
|
|
98
|
+
return loss / batch_size
|
|
99
|
+
if self.reduction == "sum":
|
|
100
|
+
return loss
|
|
101
|
+
return loss / batch_size
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ApproxNDCGLoss(nn.Module):
|
|
105
|
+
"""
|
|
106
|
+
Approximate NDCG loss for learning to rank.
|
|
107
|
+
Reference: Qin et al. (2010)
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.temperature = temperature
|
|
113
|
+
self.reduction = reduction
|
|
114
|
+
|
|
115
|
+
def _ideal_dcg(self, labels: torch.Tensor, k: Optional[int]) -> torch.Tensor:
|
|
116
|
+
# labels: [B, L]
|
|
117
|
+
sorted_labels, _ = torch.sort(labels, dim=1, descending=True)
|
|
118
|
+
if k is not None:
|
|
119
|
+
sorted_labels = sorted_labels[:, :k]
|
|
120
|
+
|
|
121
|
+
gains = torch.pow(2.0, sorted_labels) - 1.0 # [B, K]
|
|
122
|
+
positions = torch.arange(
|
|
123
|
+
1, gains.size(1) + 1, device=gains.device, dtype=torch.float32
|
|
124
|
+
) # [K]
|
|
125
|
+
discounts = 1.0 / torch.log2(positions + 1.0) # [K]
|
|
126
|
+
ideal_dcg = torch.sum(gains * discounts, dim=1) # [B]
|
|
127
|
+
return ideal_dcg
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
|
|
131
|
+
) -> torch.Tensor:
|
|
132
|
+
"""
|
|
133
|
+
scores: [B, L]
|
|
134
|
+
labels: [B, L]
|
|
135
|
+
"""
|
|
136
|
+
batch_size, list_size = scores.shape
|
|
137
|
+
device = scores.device
|
|
138
|
+
|
|
139
|
+
# diff[b, i, j] = (s_j - s_i) / T
|
|
140
|
+
scores_i = scores.unsqueeze(2) # [B, L, 1]
|
|
141
|
+
scores_j = scores.unsqueeze(1) # [B, 1, L]
|
|
142
|
+
diff = (scores_j - scores_i) / self.temperature # [B, L, L]
|
|
143
|
+
|
|
144
|
+
P_ji = torch.sigmoid(diff) # [B, L, L]
|
|
145
|
+
eye = torch.eye(list_size, device=device).unsqueeze(0) # [1, L, L]
|
|
146
|
+
P_ji = P_ji * (1.0 - eye)
|
|
147
|
+
|
|
148
|
+
exp_rank = 1.0 + P_ji.sum(dim=-1) # [B, L]
|
|
149
|
+
|
|
150
|
+
discounts = 1.0 / torch.log2(exp_rank + 1.0) # [B, L]
|
|
151
|
+
|
|
152
|
+
gains = torch.pow(2.0, labels) - 1.0 # [B, L]
|
|
153
|
+
approx_dcg = torch.sum(gains * discounts, dim=1) # [B]
|
|
154
|
+
|
|
155
|
+
ideal_dcg = self._ideal_dcg(labels, k) # [B]
|
|
156
|
+
|
|
157
|
+
ndcg = approx_dcg / (ideal_dcg + 1e-10) # [B]
|
|
158
|
+
loss = 1.0 - ndcg
|
|
159
|
+
|
|
160
|
+
if self.reduction == "mean":
|
|
161
|
+
return loss.mean()
|
|
162
|
+
if self.reduction == "sum":
|
|
163
|
+
return loss.sum()
|
|
164
|
+
return loss
|
nextrec/loss/loss_utils.py
CHANGED
|
@@ -1,136 +1,163 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Loss utilities for NextRec
|
|
3
|
-
|
|
4
|
-
Date: create on 09/11/2025
|
|
5
|
-
Author:
|
|
6
|
-
Yang Zhou,zyaztec@gmail.com
|
|
2
|
+
Loss utilities for NextRec.
|
|
7
3
|
"""
|
|
8
|
-
|
|
9
|
-
import torch.nn as nn
|
|
4
|
+
|
|
10
5
|
from typing import Literal
|
|
11
6
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from nextrec.loss.listwise import (
|
|
10
|
+
ApproxNDCGLoss,
|
|
11
|
+
InfoNCELoss,
|
|
12
|
+
ListMLELoss,
|
|
13
|
+
ListNetLoss,
|
|
16
14
|
SampledSoftmaxLoss,
|
|
17
|
-
|
|
18
|
-
|
|
15
|
+
)
|
|
16
|
+
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
17
|
+
from nextrec.loss.pointwise import (
|
|
18
|
+
ClassBalancedFocalLoss,
|
|
19
|
+
CosineContrastiveLoss,
|
|
20
|
+
FocalLoss,
|
|
21
|
+
WeightedBCELoss,
|
|
19
22
|
)
|
|
20
23
|
|
|
21
24
|
# Valid task types for validation
|
|
22
|
-
VALID_TASK_TYPES = [
|
|
25
|
+
VALID_TASK_TYPES = [
|
|
26
|
+
"binary",
|
|
27
|
+
"multiclass",
|
|
28
|
+
"regression",
|
|
29
|
+
"multivariate_regression",
|
|
30
|
+
"match",
|
|
31
|
+
"ranking",
|
|
32
|
+
"multitask",
|
|
33
|
+
"multilabel",
|
|
34
|
+
]
|
|
23
35
|
|
|
24
36
|
|
|
25
37
|
def get_loss_fn(
|
|
26
38
|
task_type: str = "binary",
|
|
27
39
|
training_mode: str | None = None,
|
|
28
40
|
loss: str | nn.Module | None = None,
|
|
29
|
-
**loss_kwargs
|
|
41
|
+
**loss_kwargs,
|
|
30
42
|
) -> nn.Module:
|
|
31
43
|
"""
|
|
32
44
|
Get loss function based on task type and training mode.
|
|
33
|
-
|
|
34
|
-
Examples:
|
|
35
|
-
# Ranking task (binary classification)
|
|
36
|
-
>>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
|
|
37
|
-
|
|
38
|
-
# Match task with pointwise training
|
|
39
|
-
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
|
|
40
|
-
|
|
41
|
-
# Match task with pairwise training
|
|
42
|
-
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
|
|
43
|
-
|
|
44
|
-
# Match task with listwise training
|
|
45
|
-
>>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
|
|
46
45
|
"""
|
|
47
46
|
|
|
48
47
|
if isinstance(loss, nn.Module):
|
|
49
48
|
return loss
|
|
50
49
|
|
|
50
|
+
# Common mappings
|
|
51
51
|
if task_type == "match":
|
|
52
|
-
|
|
53
|
-
# Pointwise training uses binary cross entropy
|
|
54
|
-
if loss is None or loss == "bce" or loss == "binary_crossentropy":
|
|
55
|
-
return nn.BCELoss(**loss_kwargs)
|
|
56
|
-
elif loss == "cosine_contrastive":
|
|
57
|
-
return CosineContrastiveLoss(**loss_kwargs)
|
|
58
|
-
elif isinstance(loss, str):
|
|
59
|
-
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
60
|
-
|
|
61
|
-
elif training_mode == "pairwise":
|
|
62
|
-
if loss is None or loss == "bpr":
|
|
63
|
-
return BPRLoss(**loss_kwargs)
|
|
64
|
-
elif loss == "hinge":
|
|
65
|
-
return HingeLoss(**loss_kwargs)
|
|
66
|
-
elif loss == "triplet":
|
|
67
|
-
return TripletLoss(**loss_kwargs)
|
|
68
|
-
elif isinstance(loss, str):
|
|
69
|
-
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
70
|
-
|
|
71
|
-
elif training_mode == "listwise":
|
|
72
|
-
if loss is None or loss == "sampled_softmax" or loss == "softmax":
|
|
73
|
-
return SampledSoftmaxLoss(**loss_kwargs)
|
|
74
|
-
elif loss == "infonce":
|
|
75
|
-
return InfoNCELoss(**loss_kwargs)
|
|
76
|
-
elif loss == "crossentropy" or loss == "ce":
|
|
77
|
-
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
78
|
-
elif isinstance(loss, str):
|
|
79
|
-
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
80
|
-
|
|
81
|
-
else:
|
|
82
|
-
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
83
|
-
|
|
84
|
-
elif task_type in ["ranking", "multitask", "binary"]:
|
|
85
|
-
if loss is None or loss == "bce" or loss == "binary_crossentropy":
|
|
86
|
-
return nn.BCELoss(**loss_kwargs)
|
|
87
|
-
elif loss == "mse":
|
|
88
|
-
return nn.MSELoss(**loss_kwargs)
|
|
89
|
-
elif loss == "mae":
|
|
90
|
-
return nn.L1Loss(**loss_kwargs)
|
|
91
|
-
elif loss == "crossentropy" or loss == "ce":
|
|
92
|
-
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
93
|
-
elif isinstance(loss, str):
|
|
94
|
-
raise ValueError(f"Unsupported loss function: {loss}")
|
|
52
|
+
return _get_match_loss(training_mode, loss, **loss_kwargs)
|
|
95
53
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
54
|
+
if task_type in ["ranking", "multitask", "binary", "multilabel"]:
|
|
55
|
+
return _get_classification_loss(loss, **loss_kwargs)
|
|
56
|
+
|
|
57
|
+
if task_type == "multiclass":
|
|
58
|
+
return _get_multiclass_loss(loss, **loss_kwargs)
|
|
59
|
+
|
|
60
|
+
if task_type == "regression":
|
|
103
61
|
if loss is None or loss == "mse":
|
|
104
62
|
return nn.MSELoss(**loss_kwargs)
|
|
105
|
-
|
|
63
|
+
if loss == "mae":
|
|
106
64
|
return nn.L1Loss(**loss_kwargs)
|
|
107
|
-
|
|
65
|
+
if isinstance(loss, str):
|
|
108
66
|
raise ValueError(f"Unsupported regression loss: {loss}")
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
)
|
|
67
|
+
|
|
68
|
+
raise ValueError(f"Unsupported task_type: {task_type}")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_match_loss(training_mode: str | None, loss: str | None, **loss_kwargs) -> nn.Module:
|
|
72
|
+
if training_mode == "pointwise":
|
|
73
|
+
if loss is None or loss in {"bce", "binary_crossentropy"}:
|
|
74
|
+
return nn.BCELoss(**loss_kwargs)
|
|
75
|
+
if loss == "weighted_bce":
|
|
76
|
+
return WeightedBCELoss(**loss_kwargs)
|
|
77
|
+
if loss == "focal":
|
|
78
|
+
return FocalLoss(**loss_kwargs)
|
|
79
|
+
if loss == "class_balanced_focal":
|
|
80
|
+
return _build_cb_focal(loss_kwargs)
|
|
81
|
+
if loss == "cosine_contrastive":
|
|
82
|
+
return CosineContrastiveLoss(**loss_kwargs)
|
|
83
|
+
if isinstance(loss, str):
|
|
84
|
+
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
85
|
+
|
|
86
|
+
if training_mode == "pairwise":
|
|
87
|
+
if loss is None or loss == "bpr":
|
|
88
|
+
return BPRLoss(**loss_kwargs)
|
|
89
|
+
if loss == "hinge":
|
|
90
|
+
return HingeLoss(**loss_kwargs)
|
|
91
|
+
if loss == "triplet":
|
|
92
|
+
return TripletLoss(**loss_kwargs)
|
|
93
|
+
if isinstance(loss, str):
|
|
94
|
+
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
95
|
+
|
|
96
|
+
if training_mode == "listwise":
|
|
97
|
+
if loss is None or loss in {"sampled_softmax", "softmax"}:
|
|
98
|
+
return SampledSoftmaxLoss(**loss_kwargs)
|
|
99
|
+
if loss == "infonce":
|
|
100
|
+
return InfoNCELoss(**loss_kwargs)
|
|
101
|
+
if loss == "listnet":
|
|
102
|
+
return ListNetLoss(**loss_kwargs)
|
|
103
|
+
if loss == "listmle":
|
|
104
|
+
return ListMLELoss(**loss_kwargs)
|
|
105
|
+
if loss == "approx_ndcg":
|
|
106
|
+
return ApproxNDCGLoss(**loss_kwargs)
|
|
107
|
+
if loss in {"crossentropy", "ce"}:
|
|
108
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
109
|
+
if isinstance(loss, str):
|
|
110
|
+
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
111
|
+
|
|
112
|
+
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _get_classification_loss(loss: str | None, **loss_kwargs) -> nn.Module:
|
|
116
|
+
if loss is None or loss in {"bce", "binary_crossentropy"}:
|
|
117
|
+
return nn.BCELoss(**loss_kwargs)
|
|
118
|
+
if loss == "weighted_bce":
|
|
119
|
+
return WeightedBCELoss(**loss_kwargs)
|
|
120
|
+
if loss == "focal":
|
|
121
|
+
return FocalLoss(**loss_kwargs)
|
|
122
|
+
if loss == "class_balanced_focal":
|
|
123
|
+
return _build_cb_focal(loss_kwargs)
|
|
124
|
+
if loss == "mse":
|
|
125
|
+
return nn.MSELoss(**loss_kwargs)
|
|
126
|
+
if loss == "mae":
|
|
127
|
+
return nn.L1Loss(**loss_kwargs)
|
|
128
|
+
if loss in {"crossentropy", "ce"}:
|
|
129
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
130
|
+
if isinstance(loss, str):
|
|
131
|
+
raise ValueError(f"Unsupported loss function: {loss}")
|
|
132
|
+
raise ValueError("Loss must be specified for classification task.")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _get_multiclass_loss(loss: str | None, **loss_kwargs) -> nn.Module:
|
|
136
|
+
if loss is None or loss in {"crossentropy", "ce"}:
|
|
137
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
138
|
+
if loss == "focal":
|
|
139
|
+
return FocalLoss(**loss_kwargs)
|
|
140
|
+
if loss == "class_balanced_focal":
|
|
141
|
+
return _build_cb_focal(loss_kwargs)
|
|
142
|
+
if isinstance(loss, str):
|
|
143
|
+
raise ValueError(f"Unsupported multiclass loss: {loss}")
|
|
144
|
+
raise ValueError("Loss must be specified for multiclass task.")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _build_cb_focal(loss_kwargs: dict) -> ClassBalancedFocalLoss:
|
|
148
|
+
if "class_counts" not in loss_kwargs:
|
|
149
|
+
raise ValueError("class_balanced_focal requires `class_counts` argument.")
|
|
150
|
+
return ClassBalancedFocalLoss(**loss_kwargs)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
|
|
121
154
|
"""
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
training_mode: Requested training mode
|
|
126
|
-
support_training_modes: List of supported training modes
|
|
127
|
-
model_name: Name of the model (for error messages)
|
|
128
|
-
|
|
129
|
-
Raises:
|
|
130
|
-
ValueError: If training mode is not supported
|
|
155
|
+
Resolve per-task loss kwargs from a dict or list of dicts.
|
|
131
156
|
"""
|
|
132
|
-
if
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
157
|
+
if loss_params is None:
|
|
158
|
+
return {}
|
|
159
|
+
if isinstance(loss_params, list):
|
|
160
|
+
if index < len(loss_params) and loss_params[index] is not None:
|
|
161
|
+
return loss_params[index]
|
|
162
|
+
return {}
|
|
163
|
+
return loss_params
|
nextrec/loss/pairwise.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pairwise loss functions for learning-to-rank and matching tasks.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BPRLoss(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Bayesian Personalized Ranking loss with support for multiple negatives.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, reduction: str = "mean"):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.reduction = reduction
|
|
20
|
+
|
|
21
|
+
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
if neg_score.dim() == 2:
|
|
23
|
+
pos_score = pos_score.unsqueeze(1)
|
|
24
|
+
diff = pos_score - neg_score
|
|
25
|
+
else:
|
|
26
|
+
diff = pos_score - neg_score
|
|
27
|
+
|
|
28
|
+
loss = -torch.log(torch.sigmoid(diff) + 1e-8)
|
|
29
|
+
if self.reduction == "mean":
|
|
30
|
+
return loss.mean()
|
|
31
|
+
if self.reduction == "sum":
|
|
32
|
+
return loss.sum()
|
|
33
|
+
return loss
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class HingeLoss(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Hinge loss for pairwise ranking.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, margin: float = 1.0, reduction: str = "mean"):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.margin = margin
|
|
44
|
+
self.reduction = reduction
|
|
45
|
+
|
|
46
|
+
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
if neg_score.dim() == 2:
|
|
48
|
+
pos_score = pos_score.unsqueeze(1)
|
|
49
|
+
|
|
50
|
+
diff = pos_score - neg_score
|
|
51
|
+
loss = torch.clamp(self.margin - diff, min=0)
|
|
52
|
+
|
|
53
|
+
if self.reduction == "mean":
|
|
54
|
+
return loss.mean()
|
|
55
|
+
if self.reduction == "sum":
|
|
56
|
+
return loss.sum()
|
|
57
|
+
return loss
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TripletLoss(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Triplet margin loss with cosine or euclidean distance.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
margin: float = 1.0,
|
|
68
|
+
reduction: str = "mean",
|
|
69
|
+
distance: Literal["euclidean", "cosine"] = "euclidean",
|
|
70
|
+
):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.margin = margin
|
|
73
|
+
self.reduction = reduction
|
|
74
|
+
self.distance = distance
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
|
|
78
|
+
) -> torch.Tensor:
|
|
79
|
+
if self.distance == "euclidean":
|
|
80
|
+
pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
|
|
81
|
+
if negative.dim() == 3:
|
|
82
|
+
anchor_expanded = anchor.unsqueeze(1)
|
|
83
|
+
neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1)
|
|
84
|
+
else:
|
|
85
|
+
neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
|
|
86
|
+
if neg_dist.dim() == 2:
|
|
87
|
+
pos_dist = pos_dist.unsqueeze(1)
|
|
88
|
+
elif self.distance == "cosine":
|
|
89
|
+
pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
|
|
90
|
+
if negative.dim() == 3:
|
|
91
|
+
anchor_expanded = anchor.unsqueeze(1)
|
|
92
|
+
neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
|
|
93
|
+
else:
|
|
94
|
+
neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
|
|
95
|
+
if neg_dist.dim() == 2:
|
|
96
|
+
pos_dist = pos_dist.unsqueeze(1)
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"Unsupported distance: {self.distance}")
|
|
99
|
+
|
|
100
|
+
loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
|
|
101
|
+
if self.reduction == "mean":
|
|
102
|
+
return loss.mean()
|
|
103
|
+
if self.reduction == "sum":
|
|
104
|
+
return loss.sum()
|
|
105
|
+
return loss
|