nextrec 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -10
- nextrec/basic/callback.py +0 -1
- nextrec/basic/dataloader.py +127 -168
- nextrec/basic/features.py +27 -24
- nextrec/basic/layers.py +159 -328
- nextrec/basic/loggers.py +37 -50
- nextrec/basic/metrics.py +147 -255
- nextrec/basic/model.py +462 -817
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +12 -16
- nextrec/data/preprocessor.py +252 -276
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +22 -30
- nextrec/loss/match_losses.py +83 -116
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +61 -70
- nextrec/models/match/dssm_v2.py +51 -61
- nextrec/models/match/mind.py +71 -89
- nextrec/models/match/sdm.py +81 -93
- nextrec/models/match/youtube_dnn.py +53 -62
- nextrec/models/multi_task/esmm.py +43 -49
- nextrec/models/multi_task/mmoe.py +56 -65
- nextrec/models/multi_task/ple.py +65 -92
- nextrec/models/multi_task/share_bottom.py +42 -48
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +30 -39
- nextrec/models/ranking/autoint.py +57 -70
- nextrec/models/ranking/dcn.py +35 -43
- nextrec/models/ranking/deepfm.py +28 -34
- nextrec/models/ranking/dien.py +79 -115
- nextrec/models/ranking/din.py +60 -84
- nextrec/models/ranking/fibinet.py +35 -51
- nextrec/models/ranking/fm.py +26 -28
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +31 -30
- nextrec/models/ranking/widedeep.py +31 -36
- nextrec/models/ranking/xdeepfm.py +39 -46
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +15 -23
- nextrec/utils/optimizer.py +10 -14
- {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
- nextrec-0.1.7.dist-info/RECORD +51 -0
- nextrec-0.1.3.dist-info/RECORD +0 -51
- {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
- {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py
CHANGED
|
@@ -18,18 +18,18 @@ from nextrec.loss.loss_utils import (
|
|
|
18
18
|
|
|
19
19
|
__all__ = [
|
|
20
20
|
# Match losses
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
'BPRLoss',
|
|
22
|
+
'HingeLoss',
|
|
23
|
+
'TripletLoss',
|
|
24
|
+
'SampledSoftmaxLoss',
|
|
25
|
+
'CosineContrastiveLoss',
|
|
26
|
+
'InfoNCELoss',
|
|
27
27
|
# Listwise losses
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
28
|
+
'ListNetLoss',
|
|
29
|
+
'ListMLELoss',
|
|
30
|
+
'ApproxNDCGLoss',
|
|
31
31
|
# Utilities
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
32
|
+
'get_loss_fn',
|
|
33
|
+
'validate_training_mode',
|
|
34
|
+
'VALID_TASK_TYPES',
|
|
35
35
|
]
|
nextrec/loss/loss_utils.py
CHANGED
|
@@ -5,52 +5,42 @@ Date: create on 09/11/2025
|
|
|
5
5
|
Author:
|
|
6
6
|
Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
-
|
|
9
8
|
import torch
|
|
10
9
|
import torch.nn as nn
|
|
11
10
|
from typing import Literal
|
|
12
11
|
|
|
13
12
|
from nextrec.loss.match_losses import (
|
|
14
|
-
BPRLoss,
|
|
15
|
-
HingeLoss,
|
|
16
|
-
TripletLoss,
|
|
13
|
+
BPRLoss,
|
|
14
|
+
HingeLoss,
|
|
15
|
+
TripletLoss,
|
|
17
16
|
SampledSoftmaxLoss,
|
|
18
|
-
CosineContrastiveLoss,
|
|
19
|
-
InfoNCELoss
|
|
17
|
+
CosineContrastiveLoss,
|
|
18
|
+
InfoNCELoss
|
|
20
19
|
)
|
|
21
20
|
|
|
22
21
|
# Valid task types for validation
|
|
23
|
-
VALID_TASK_TYPES = [
|
|
24
|
-
"binary",
|
|
25
|
-
"multiclass",
|
|
26
|
-
"regression",
|
|
27
|
-
"multivariate_regression",
|
|
28
|
-
"match",
|
|
29
|
-
"ranking",
|
|
30
|
-
"multitask",
|
|
31
|
-
"multilabel",
|
|
32
|
-
]
|
|
22
|
+
VALID_TASK_TYPES = ['binary', 'multiclass', 'regression', 'multivariate_regression', 'match', 'ranking', 'multitask', 'multilabel']
|
|
33
23
|
|
|
34
24
|
|
|
35
25
|
def get_loss_fn(
|
|
36
26
|
task_type: str = "binary",
|
|
37
27
|
training_mode: str | None = None,
|
|
38
28
|
loss: str | nn.Module | None = None,
|
|
39
|
-
**loss_kwargs
|
|
29
|
+
**loss_kwargs
|
|
40
30
|
) -> nn.Module:
|
|
41
31
|
"""
|
|
42
32
|
Get loss function based on task type and training mode.
|
|
43
|
-
|
|
33
|
+
|
|
44
34
|
Examples:
|
|
45
35
|
# Ranking task (binary classification)
|
|
46
36
|
>>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
|
|
47
|
-
|
|
37
|
+
|
|
48
38
|
# Match task with pointwise training
|
|
49
39
|
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
|
|
50
|
-
|
|
40
|
+
|
|
51
41
|
# Match task with pairwise training
|
|
52
42
|
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
|
|
53
|
-
|
|
43
|
+
|
|
54
44
|
# Match task with listwise training
|
|
55
45
|
>>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
|
|
56
46
|
"""
|
|
@@ -67,7 +57,7 @@ def get_loss_fn(
|
|
|
67
57
|
return CosineContrastiveLoss(**loss_kwargs)
|
|
68
58
|
elif isinstance(loss, str):
|
|
69
59
|
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
70
|
-
|
|
60
|
+
|
|
71
61
|
elif training_mode == "pairwise":
|
|
72
62
|
if loss is None or loss == "bpr":
|
|
73
63
|
return BPRLoss(**loss_kwargs)
|
|
@@ -77,7 +67,7 @@ def get_loss_fn(
|
|
|
77
67
|
return TripletLoss(**loss_kwargs)
|
|
78
68
|
elif isinstance(loss, str):
|
|
79
69
|
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
80
|
-
|
|
70
|
+
|
|
81
71
|
elif training_mode == "listwise":
|
|
82
72
|
if loss is None or loss == "sampled_softmax" or loss == "softmax":
|
|
83
73
|
return SampledSoftmaxLoss(**loss_kwargs)
|
|
@@ -87,7 +77,7 @@ def get_loss_fn(
|
|
|
87
77
|
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
88
78
|
elif isinstance(loss, str):
|
|
89
79
|
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
90
|
-
|
|
80
|
+
|
|
91
81
|
else:
|
|
92
82
|
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
93
83
|
|
|
@@ -108,7 +98,7 @@ def get_loss_fn(
|
|
|
108
98
|
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
109
99
|
elif isinstance(loss, str):
|
|
110
100
|
raise ValueError(f"Unsupported multiclass loss: {loss}")
|
|
111
|
-
|
|
101
|
+
|
|
112
102
|
elif task_type == "regression":
|
|
113
103
|
if loss is None or loss == "mse":
|
|
114
104
|
return nn.MSELoss(**loss_kwargs)
|
|
@@ -116,24 +106,26 @@ def get_loss_fn(
|
|
|
116
106
|
return nn.L1Loss(**loss_kwargs)
|
|
117
107
|
elif isinstance(loss, str):
|
|
118
108
|
raise ValueError(f"Unsupported regression loss: {loss}")
|
|
119
|
-
|
|
109
|
+
|
|
120
110
|
else:
|
|
121
111
|
raise ValueError(f"Unsupported task_type: {task_type}")
|
|
122
|
-
|
|
112
|
+
|
|
123
113
|
return loss
|
|
124
114
|
|
|
125
115
|
|
|
126
116
|
def validate_training_mode(
|
|
127
|
-
training_mode: str,
|
|
117
|
+
training_mode: str,
|
|
118
|
+
support_training_modes: list[str],
|
|
119
|
+
model_name: str = "Model"
|
|
128
120
|
) -> None:
|
|
129
121
|
"""
|
|
130
122
|
Validate that the requested training mode is supported by the model.
|
|
131
|
-
|
|
123
|
+
|
|
132
124
|
Args:
|
|
133
125
|
training_mode: Requested training mode
|
|
134
126
|
support_training_modes: List of supported training modes
|
|
135
127
|
model_name: Name of the model (for error messages)
|
|
136
|
-
|
|
128
|
+
|
|
137
129
|
Raises:
|
|
138
130
|
ValueError: If training mode is not supported
|
|
139
131
|
"""
|
nextrec/loss/match_losses.py
CHANGED
|
@@ -13,167 +13,149 @@ from typing import Optional
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class BPRLoss(nn.Module):
|
|
16
|
-
def __init__(self, reduction: str =
|
|
16
|
+
def __init__(self, reduction: str = 'mean'):
|
|
17
17
|
super(BPRLoss, self).__init__()
|
|
18
18
|
self.reduction = reduction
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
21
21
|
if neg_score.dim() == 2:
|
|
22
22
|
pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
|
|
23
23
|
diff = pos_score - neg_score # [batch_size, num_neg]
|
|
24
24
|
loss = -torch.log(torch.sigmoid(diff) + 1e-8)
|
|
25
|
-
if self.reduction ==
|
|
25
|
+
if self.reduction == 'mean':
|
|
26
26
|
return loss.mean()
|
|
27
|
-
elif self.reduction ==
|
|
27
|
+
elif self.reduction == 'sum':
|
|
28
28
|
return loss.sum()
|
|
29
29
|
else:
|
|
30
30
|
return loss
|
|
31
31
|
else:
|
|
32
32
|
diff = pos_score - neg_score
|
|
33
33
|
loss = -torch.log(torch.sigmoid(diff) + 1e-8)
|
|
34
|
-
if self.reduction ==
|
|
34
|
+
if self.reduction == 'mean':
|
|
35
35
|
return loss.mean()
|
|
36
|
-
elif self.reduction ==
|
|
36
|
+
elif self.reduction == 'sum':
|
|
37
37
|
return loss.sum()
|
|
38
38
|
else:
|
|
39
39
|
return loss
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
class HingeLoss(nn.Module):
|
|
43
|
-
def __init__(self, margin: float = 1.0, reduction: str =
|
|
42
|
+
class HingeLoss(nn.Module):
|
|
43
|
+
def __init__(self, margin: float = 1.0, reduction: str = 'mean'):
|
|
44
44
|
super(HingeLoss, self).__init__()
|
|
45
45
|
self.margin = margin
|
|
46
46
|
self.reduction = reduction
|
|
47
|
-
|
|
47
|
+
|
|
48
48
|
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
49
49
|
if neg_score.dim() == 2:
|
|
50
50
|
pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
|
|
51
|
-
|
|
51
|
+
|
|
52
52
|
diff = pos_score - neg_score
|
|
53
53
|
loss = torch.clamp(self.margin - diff, min=0)
|
|
54
|
-
|
|
55
|
-
if self.reduction ==
|
|
54
|
+
|
|
55
|
+
if self.reduction == 'mean':
|
|
56
56
|
return loss.mean()
|
|
57
|
-
elif self.reduction ==
|
|
57
|
+
elif self.reduction == 'sum':
|
|
58
58
|
return loss.sum()
|
|
59
59
|
else:
|
|
60
60
|
return loss
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
class TripletLoss(nn.Module):
|
|
64
|
-
def __init__(
|
|
65
|
-
self, margin: float = 1.0, reduction: str = "mean", distance: str = "euclidean"
|
|
66
|
-
):
|
|
64
|
+
def __init__(self, margin: float = 1.0, reduction: str = 'mean', distance: str = 'euclidean'):
|
|
67
65
|
super(TripletLoss, self).__init__()
|
|
68
66
|
self.margin = margin
|
|
69
67
|
self.reduction = reduction
|
|
70
68
|
self.distance = distance
|
|
71
|
-
|
|
72
|
-
def forward(
|
|
73
|
-
self
|
|
74
|
-
) -> torch.Tensor:
|
|
75
|
-
if self.distance == "euclidean":
|
|
69
|
+
|
|
70
|
+
def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
if self.distance == 'euclidean':
|
|
76
72
|
pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
|
|
77
|
-
|
|
73
|
+
|
|
78
74
|
if negative.dim() == 3:
|
|
79
75
|
anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
|
|
80
|
-
neg_dist = torch.sum(
|
|
81
|
-
(anchor_expanded - negative) ** 2, dim=-1
|
|
82
|
-
) # [batch_size, num_neg]
|
|
76
|
+
neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1) # [batch_size, num_neg]
|
|
83
77
|
else:
|
|
84
78
|
neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
|
|
85
|
-
|
|
79
|
+
|
|
86
80
|
if neg_dist.dim() == 2:
|
|
87
81
|
pos_dist = pos_dist.unsqueeze(1) # [batch_size, 1]
|
|
88
|
-
|
|
89
|
-
elif self.distance ==
|
|
82
|
+
|
|
83
|
+
elif self.distance == 'cosine':
|
|
90
84
|
pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
|
|
91
|
-
|
|
85
|
+
|
|
92
86
|
if negative.dim() == 3:
|
|
93
87
|
anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
|
|
94
88
|
neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
|
|
95
89
|
else:
|
|
96
90
|
neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
|
|
97
|
-
|
|
91
|
+
|
|
98
92
|
if neg_dist.dim() == 2:
|
|
99
93
|
pos_dist = pos_dist.unsqueeze(1)
|
|
100
94
|
else:
|
|
101
95
|
raise ValueError(f"Unsupported distance: {self.distance}")
|
|
102
|
-
|
|
96
|
+
|
|
103
97
|
loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
|
|
104
|
-
|
|
105
|
-
if self.reduction ==
|
|
98
|
+
|
|
99
|
+
if self.reduction == 'mean':
|
|
106
100
|
return loss.mean()
|
|
107
|
-
elif self.reduction ==
|
|
101
|
+
elif self.reduction == 'sum':
|
|
108
102
|
return loss.sum()
|
|
109
103
|
else:
|
|
110
104
|
return loss
|
|
111
105
|
|
|
112
106
|
|
|
113
107
|
class SampledSoftmaxLoss(nn.Module):
|
|
114
|
-
def __init__(self, reduction: str =
|
|
108
|
+
def __init__(self, reduction: str = 'mean'):
|
|
115
109
|
super(SampledSoftmaxLoss, self).__init__()
|
|
116
110
|
self.reduction = reduction
|
|
117
|
-
|
|
118
|
-
def forward(
|
|
119
|
-
self, pos_logits: torch.Tensor, neg_logits: torch.Tensor
|
|
120
|
-
) -> torch.Tensor:
|
|
111
|
+
|
|
112
|
+
def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
|
|
121
113
|
pos_logits = pos_logits.unsqueeze(1) # [batch_size, 1]
|
|
122
|
-
all_logits = torch.cat(
|
|
123
|
-
|
|
124
|
-
) # [batch_size, 1 + num_neg]
|
|
125
|
-
targets = torch.zeros(
|
|
126
|
-
all_logits.size(0), dtype=torch.long, device=all_logits.device
|
|
127
|
-
)
|
|
114
|
+
all_logits = torch.cat([pos_logits, neg_logits], dim=1) # [batch_size, 1 + num_neg]
|
|
115
|
+
targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
|
|
128
116
|
loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
|
|
129
|
-
|
|
117
|
+
|
|
130
118
|
return loss
|
|
131
119
|
|
|
132
120
|
|
|
133
121
|
class CosineContrastiveLoss(nn.Module):
|
|
134
|
-
def __init__(self, margin: float = 0.5, reduction: str =
|
|
122
|
+
def __init__(self, margin: float = 0.5, reduction: str = 'mean'):
|
|
135
123
|
super(CosineContrastiveLoss, self).__init__()
|
|
136
124
|
self.margin = margin
|
|
137
125
|
self.reduction = reduction
|
|
138
|
-
|
|
139
|
-
def forward(
|
|
140
|
-
self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor
|
|
141
|
-
) -> torch.Tensor:
|
|
126
|
+
|
|
127
|
+
def forward(self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
142
128
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
143
129
|
pos_loss = (1 - similarity) * labels
|
|
144
130
|
|
|
145
131
|
neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
|
|
146
|
-
|
|
132
|
+
|
|
147
133
|
loss = pos_loss + neg_loss
|
|
148
|
-
|
|
149
|
-
if self.reduction ==
|
|
134
|
+
|
|
135
|
+
if self.reduction == 'mean':
|
|
150
136
|
return loss.mean()
|
|
151
|
-
elif self.reduction ==
|
|
137
|
+
elif self.reduction == 'sum':
|
|
152
138
|
return loss.sum()
|
|
153
139
|
else:
|
|
154
140
|
return loss
|
|
155
141
|
|
|
156
142
|
|
|
157
143
|
class InfoNCELoss(nn.Module):
|
|
158
|
-
def __init__(self, temperature: float = 0.07, reduction: str =
|
|
144
|
+
def __init__(self, temperature: float = 0.07, reduction: str = 'mean'):
|
|
159
145
|
super(InfoNCELoss, self).__init__()
|
|
160
146
|
self.temperature = temperature
|
|
161
147
|
self.reduction = reduction
|
|
162
|
-
|
|
163
|
-
def forward(
|
|
164
|
-
self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor
|
|
165
|
-
) -> torch.Tensor:
|
|
148
|
+
|
|
149
|
+
def forward(self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor) -> torch.Tensor:
|
|
166
150
|
pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature # [batch_size]
|
|
167
151
|
pos_sim = pos_sim.unsqueeze(1) # [batch_size, 1]
|
|
168
152
|
query_expanded = query.unsqueeze(1) # [batch_size, 1, dim]
|
|
169
|
-
neg_sim = (
|
|
170
|
-
torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature
|
|
171
|
-
) # [batch_size, num_neg]
|
|
153
|
+
neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature # [batch_size, num_neg]
|
|
172
154
|
logits = torch.cat([pos_sim, neg_sim], dim=1) # [batch_size, 1 + num_neg]
|
|
173
155
|
labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
|
|
174
156
|
|
|
175
157
|
loss = F.cross_entropy(logits, labels, reduction=self.reduction)
|
|
176
|
-
|
|
158
|
+
|
|
177
159
|
return loss
|
|
178
160
|
|
|
179
161
|
|
|
@@ -182,23 +164,22 @@ class ListNetLoss(nn.Module):
|
|
|
182
164
|
ListNet loss using top-1 probability distribution
|
|
183
165
|
Reference: Cao et al. Learning to Rank: From Pairwise Approach to Listwise Approach (ICML 2007)
|
|
184
166
|
"""
|
|
185
|
-
|
|
186
|
-
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
|
|
167
|
+
def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
|
|
187
168
|
super(ListNetLoss, self).__init__()
|
|
188
169
|
self.temperature = temperature
|
|
189
170
|
self.reduction = reduction
|
|
190
|
-
|
|
171
|
+
|
|
191
172
|
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
192
173
|
# Convert scores and labels to probability distributions
|
|
193
174
|
pred_probs = F.softmax(scores / self.temperature, dim=1)
|
|
194
175
|
true_probs = F.softmax(labels / self.temperature, dim=1)
|
|
195
|
-
|
|
176
|
+
|
|
196
177
|
# Cross entropy between two distributions
|
|
197
178
|
loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
|
|
198
|
-
|
|
199
|
-
if self.reduction ==
|
|
179
|
+
|
|
180
|
+
if self.reduction == 'mean':
|
|
200
181
|
return loss.mean()
|
|
201
|
-
elif self.reduction ==
|
|
182
|
+
elif self.reduction == 'sum':
|
|
202
183
|
return loss.sum()
|
|
203
184
|
else:
|
|
204
185
|
return loss
|
|
@@ -209,24 +190,19 @@ class ListMLELoss(nn.Module):
|
|
|
209
190
|
ListMLE (Maximum Likelihood Estimation) loss
|
|
210
191
|
Reference: Xia et al. Listwise approach to learning to rank: theory and algorithm (ICML 2008)
|
|
211
192
|
"""
|
|
212
|
-
|
|
213
|
-
def __init__(self, reduction: str = "mean"):
|
|
193
|
+
def __init__(self, reduction: str = 'mean'):
|
|
214
194
|
super(ListMLELoss, self).__init__()
|
|
215
195
|
self.reduction = reduction
|
|
216
|
-
|
|
196
|
+
|
|
217
197
|
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
218
198
|
# Sort by labels in descending order to get ground truth ranking
|
|
219
199
|
sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
|
|
220
|
-
|
|
200
|
+
|
|
221
201
|
# Reorder scores according to ground truth ranking
|
|
222
202
|
batch_size, list_size = scores.shape
|
|
223
|
-
batch_indices = (
|
|
224
|
-
torch.arange(batch_size, device=scores.device)
|
|
225
|
-
.unsqueeze(1)
|
|
226
|
-
.expand(-1, list_size)
|
|
227
|
-
)
|
|
203
|
+
batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
|
|
228
204
|
sorted_scores = scores[batch_indices, sorted_indices]
|
|
229
|
-
|
|
205
|
+
|
|
230
206
|
# Compute log likelihood
|
|
231
207
|
# For each position, compute log(exp(score_i) / sum(exp(score_j) for j >= i))
|
|
232
208
|
loss = torch.tensor(0.0, device=scores.device)
|
|
@@ -235,10 +211,10 @@ class ListMLELoss(nn.Module):
|
|
|
235
211
|
remaining_scores = sorted_scores[:, i:]
|
|
236
212
|
log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
|
|
237
213
|
loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
|
|
238
|
-
|
|
239
|
-
if self.reduction ==
|
|
214
|
+
|
|
215
|
+
if self.reduction == 'mean':
|
|
240
216
|
return loss / batch_size
|
|
241
|
-
elif self.reduction ==
|
|
217
|
+
elif self.reduction == 'sum':
|
|
242
218
|
return loss
|
|
243
219
|
else:
|
|
244
220
|
return loss / batch_size
|
|
@@ -247,81 +223,72 @@ class ListMLELoss(nn.Module):
|
|
|
247
223
|
class ApproxNDCGLoss(nn.Module):
|
|
248
224
|
"""
|
|
249
225
|
Approximate NDCG loss for learning to rank
|
|
250
|
-
Reference: Qin et al. A General Approximation Framework for Direct Optimization of
|
|
226
|
+
Reference: Qin et al. A General Approximation Framework for Direct Optimization of
|
|
251
227
|
Information Retrieval Measures (Information Retrieval 2010)
|
|
252
228
|
"""
|
|
253
|
-
|
|
254
|
-
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
|
|
229
|
+
def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
|
|
255
230
|
super(ApproxNDCGLoss, self).__init__()
|
|
256
231
|
self.temperature = temperature
|
|
257
232
|
self.reduction = reduction
|
|
258
|
-
|
|
233
|
+
|
|
259
234
|
def _dcg(self, relevance: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
|
|
260
235
|
if k is not None:
|
|
261
236
|
relevance = relevance[:, :k]
|
|
262
|
-
|
|
237
|
+
|
|
263
238
|
# DCG = sum(rel_i / log2(i + 2)) for i in range(list_size)
|
|
264
|
-
positions = torch.arange(
|
|
265
|
-
1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32
|
|
266
|
-
)
|
|
239
|
+
positions = torch.arange(1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32)
|
|
267
240
|
discounts = torch.log2(positions + 1.0)
|
|
268
241
|
dcg = torch.sum(relevance / discounts, dim=1)
|
|
269
|
-
|
|
242
|
+
|
|
270
243
|
return dcg
|
|
271
|
-
|
|
272
|
-
def forward(
|
|
273
|
-
self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
|
|
274
|
-
) -> torch.Tensor:
|
|
244
|
+
|
|
245
|
+
def forward(self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
|
|
275
246
|
"""
|
|
276
247
|
Args:
|
|
277
248
|
scores: Predicted scores [batch_size, list_size]
|
|
278
249
|
labels: Ground truth relevance labels [batch_size, list_size]
|
|
279
250
|
k: Top-k items for NDCG@k (if None, use all items)
|
|
280
|
-
|
|
251
|
+
|
|
281
252
|
Returns:
|
|
282
253
|
Approximate NDCG loss (1 - NDCG)
|
|
283
254
|
"""
|
|
284
255
|
batch_size = scores.size(0)
|
|
285
|
-
|
|
256
|
+
|
|
286
257
|
# Use differentiable sorting approximation with softmax
|
|
287
258
|
# Create pairwise comparison matrix
|
|
288
259
|
scores_expanded = scores.unsqueeze(2) # [batch_size, list_size, 1]
|
|
289
|
-
scores_tiled = scores.unsqueeze(1)
|
|
290
|
-
|
|
260
|
+
scores_tiled = scores.unsqueeze(1) # [batch_size, 1, list_size]
|
|
261
|
+
|
|
291
262
|
# Compute pairwise probabilities using sigmoid
|
|
292
263
|
pairwise_diff = (scores_expanded - scores_tiled) / self.temperature
|
|
293
|
-
pairwise_probs = torch.sigmoid(
|
|
294
|
-
|
|
295
|
-
) # [batch_size, list_size, list_size]
|
|
296
|
-
|
|
264
|
+
pairwise_probs = torch.sigmoid(pairwise_diff) # [batch_size, list_size, list_size]
|
|
265
|
+
|
|
297
266
|
# Approximate ranking positions
|
|
298
267
|
# ranking_probs[i, j] ≈ probability that item i is ranked at position j
|
|
299
268
|
# We use softmax approximation for differentiable ranking
|
|
300
269
|
ranking_weights = F.softmax(scores / self.temperature, dim=1)
|
|
301
|
-
|
|
270
|
+
|
|
302
271
|
# Sort labels to get ideal DCG
|
|
303
272
|
ideal_labels, _ = torch.sort(labels, descending=True, dim=1)
|
|
304
273
|
ideal_dcg = self._dcg(ideal_labels, k)
|
|
305
|
-
|
|
274
|
+
|
|
306
275
|
# Compute approximate DCG using soft ranking
|
|
307
276
|
# Weight each item's relevance by its soft ranking position
|
|
308
|
-
positions = torch.arange(
|
|
309
|
-
1, scores.size(1) + 1, device=scores.device, dtype=torch.float32
|
|
310
|
-
)
|
|
277
|
+
positions = torch.arange(1, scores.size(1) + 1, device=scores.device, dtype=torch.float32)
|
|
311
278
|
discounts = 1.0 / torch.log2(positions + 1.0)
|
|
312
|
-
|
|
279
|
+
|
|
313
280
|
# Approximate DCG by weighting relevance with ranking probabilities
|
|
314
281
|
approx_dcg = torch.sum(labels * ranking_weights * discounts, dim=1)
|
|
315
|
-
|
|
282
|
+
|
|
316
283
|
# Normalize by ideal DCG to get NDCG
|
|
317
284
|
ndcg = approx_dcg / (ideal_dcg + 1e-10)
|
|
318
|
-
|
|
285
|
+
|
|
319
286
|
# Loss is 1 - NDCG (we want to maximize NDCG, so minimize 1 - NDCG)
|
|
320
287
|
loss = 1.0 - ndcg
|
|
321
|
-
|
|
322
|
-
if self.reduction ==
|
|
288
|
+
|
|
289
|
+
if self.reduction == 'mean':
|
|
323
290
|
return loss.mean()
|
|
324
|
-
elif self.reduction ==
|
|
291
|
+
elif self.reduction == 'sum':
|
|
325
292
|
return loss.sum()
|
|
326
293
|
else:
|
|
327
294
|
return loss
|
nextrec/models/match/__init__.py
CHANGED