nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -9
- nextrec/basic/callback.py +1 -0
- nextrec/basic/dataloader.py +168 -127
- nextrec/basic/features.py +24 -27
- nextrec/basic/layers.py +328 -159
- nextrec/basic/loggers.py +50 -37
- nextrec/basic/metrics.py +255 -147
- nextrec/basic/model.py +817 -462
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +16 -12
- nextrec/data/preprocessor.py +276 -252
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +30 -22
- nextrec/loss/match_losses.py +116 -83
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +70 -61
- nextrec/models/match/dssm_v2.py +61 -51
- nextrec/models/match/mind.py +89 -71
- nextrec/models/match/sdm.py +93 -81
- nextrec/models/match/youtube_dnn.py +62 -53
- nextrec/models/multi_task/esmm.py +49 -43
- nextrec/models/multi_task/mmoe.py +65 -56
- nextrec/models/multi_task/ple.py +92 -65
- nextrec/models/multi_task/share_bottom.py +48 -42
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +39 -30
- nextrec/models/ranking/autoint.py +70 -57
- nextrec/models/ranking/dcn.py +43 -35
- nextrec/models/ranking/deepfm.py +34 -28
- nextrec/models/ranking/dien.py +115 -79
- nextrec/models/ranking/din.py +84 -60
- nextrec/models/ranking/fibinet.py +51 -35
- nextrec/models/ranking/fm.py +28 -26
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +30 -31
- nextrec/models/ranking/widedeep.py +36 -31
- nextrec/models/ranking/xdeepfm.py +46 -39
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +23 -15
- nextrec/utils/optimizer.py +14 -10
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
- nextrec-0.1.2.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/RECORD +0 -51
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py
CHANGED
|
@@ -18,18 +18,18 @@ from nextrec.loss.loss_utils import (
|
|
|
18
18
|
|
|
19
19
|
__all__ = [
|
|
20
20
|
# Match losses
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
"BPRLoss",
|
|
22
|
+
"HingeLoss",
|
|
23
|
+
"TripletLoss",
|
|
24
|
+
"SampledSoftmaxLoss",
|
|
25
|
+
"CosineContrastiveLoss",
|
|
26
|
+
"InfoNCELoss",
|
|
27
27
|
# Listwise losses
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
28
|
+
"ListNetLoss",
|
|
29
|
+
"ListMLELoss",
|
|
30
|
+
"ApproxNDCGLoss",
|
|
31
31
|
# Utilities
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
32
|
+
"get_loss_fn",
|
|
33
|
+
"validate_training_mode",
|
|
34
|
+
"VALID_TASK_TYPES",
|
|
35
35
|
]
|
nextrec/loss/loss_utils.py
CHANGED
|
@@ -5,42 +5,52 @@ Date: create on 09/11/2025
|
|
|
5
5
|
Author:
|
|
6
6
|
Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
+
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
10
11
|
from typing import Literal
|
|
11
12
|
|
|
12
13
|
from nextrec.loss.match_losses import (
|
|
13
|
-
BPRLoss,
|
|
14
|
-
HingeLoss,
|
|
15
|
-
TripletLoss,
|
|
14
|
+
BPRLoss,
|
|
15
|
+
HingeLoss,
|
|
16
|
+
TripletLoss,
|
|
16
17
|
SampledSoftmaxLoss,
|
|
17
|
-
CosineContrastiveLoss,
|
|
18
|
-
InfoNCELoss
|
|
18
|
+
CosineContrastiveLoss,
|
|
19
|
+
InfoNCELoss,
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
# Valid task types for validation
|
|
22
|
-
VALID_TASK_TYPES = [
|
|
23
|
+
VALID_TASK_TYPES = [
|
|
24
|
+
"binary",
|
|
25
|
+
"multiclass",
|
|
26
|
+
"regression",
|
|
27
|
+
"multivariate_regression",
|
|
28
|
+
"match",
|
|
29
|
+
"ranking",
|
|
30
|
+
"multitask",
|
|
31
|
+
"multilabel",
|
|
32
|
+
]
|
|
23
33
|
|
|
24
34
|
|
|
25
35
|
def get_loss_fn(
|
|
26
36
|
task_type: str = "binary",
|
|
27
37
|
training_mode: str | None = None,
|
|
28
38
|
loss: str | nn.Module | None = None,
|
|
29
|
-
**loss_kwargs
|
|
39
|
+
**loss_kwargs,
|
|
30
40
|
) -> nn.Module:
|
|
31
41
|
"""
|
|
32
42
|
Get loss function based on task type and training mode.
|
|
33
|
-
|
|
43
|
+
|
|
34
44
|
Examples:
|
|
35
45
|
# Ranking task (binary classification)
|
|
36
46
|
>>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
|
|
37
|
-
|
|
47
|
+
|
|
38
48
|
# Match task with pointwise training
|
|
39
49
|
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
|
|
40
|
-
|
|
50
|
+
|
|
41
51
|
# Match task with pairwise training
|
|
42
52
|
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
|
|
43
|
-
|
|
53
|
+
|
|
44
54
|
# Match task with listwise training
|
|
45
55
|
>>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
|
|
46
56
|
"""
|
|
@@ -57,7 +67,7 @@ def get_loss_fn(
|
|
|
57
67
|
return CosineContrastiveLoss(**loss_kwargs)
|
|
58
68
|
elif isinstance(loss, str):
|
|
59
69
|
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
60
|
-
|
|
70
|
+
|
|
61
71
|
elif training_mode == "pairwise":
|
|
62
72
|
if loss is None or loss == "bpr":
|
|
63
73
|
return BPRLoss(**loss_kwargs)
|
|
@@ -67,7 +77,7 @@ def get_loss_fn(
|
|
|
67
77
|
return TripletLoss(**loss_kwargs)
|
|
68
78
|
elif isinstance(loss, str):
|
|
69
79
|
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
70
|
-
|
|
80
|
+
|
|
71
81
|
elif training_mode == "listwise":
|
|
72
82
|
if loss is None or loss == "sampled_softmax" or loss == "softmax":
|
|
73
83
|
return SampledSoftmaxLoss(**loss_kwargs)
|
|
@@ -77,7 +87,7 @@ def get_loss_fn(
|
|
|
77
87
|
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
78
88
|
elif isinstance(loss, str):
|
|
79
89
|
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
80
|
-
|
|
90
|
+
|
|
81
91
|
else:
|
|
82
92
|
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
83
93
|
|
|
@@ -98,7 +108,7 @@ def get_loss_fn(
|
|
|
98
108
|
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
99
109
|
elif isinstance(loss, str):
|
|
100
110
|
raise ValueError(f"Unsupported multiclass loss: {loss}")
|
|
101
|
-
|
|
111
|
+
|
|
102
112
|
elif task_type == "regression":
|
|
103
113
|
if loss is None or loss == "mse":
|
|
104
114
|
return nn.MSELoss(**loss_kwargs)
|
|
@@ -106,26 +116,24 @@ def get_loss_fn(
|
|
|
106
116
|
return nn.L1Loss(**loss_kwargs)
|
|
107
117
|
elif isinstance(loss, str):
|
|
108
118
|
raise ValueError(f"Unsupported regression loss: {loss}")
|
|
109
|
-
|
|
119
|
+
|
|
110
120
|
else:
|
|
111
121
|
raise ValueError(f"Unsupported task_type: {task_type}")
|
|
112
|
-
|
|
122
|
+
|
|
113
123
|
return loss
|
|
114
124
|
|
|
115
125
|
|
|
116
126
|
def validate_training_mode(
|
|
117
|
-
training_mode: str,
|
|
118
|
-
support_training_modes: list[str],
|
|
119
|
-
model_name: str = "Model"
|
|
127
|
+
training_mode: str, support_training_modes: list[str], model_name: str = "Model"
|
|
120
128
|
) -> None:
|
|
121
129
|
"""
|
|
122
130
|
Validate that the requested training mode is supported by the model.
|
|
123
|
-
|
|
131
|
+
|
|
124
132
|
Args:
|
|
125
133
|
training_mode: Requested training mode
|
|
126
134
|
support_training_modes: List of supported training modes
|
|
127
135
|
model_name: Name of the model (for error messages)
|
|
128
|
-
|
|
136
|
+
|
|
129
137
|
Raises:
|
|
130
138
|
ValueError: If training mode is not supported
|
|
131
139
|
"""
|
nextrec/loss/match_losses.py
CHANGED
|
@@ -13,149 +13,167 @@ from typing import Optional
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class BPRLoss(nn.Module):
|
|
16
|
-
def __init__(self, reduction: str =
|
|
16
|
+
def __init__(self, reduction: str = "mean"):
|
|
17
17
|
super(BPRLoss, self).__init__()
|
|
18
18
|
self.reduction = reduction
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
21
21
|
if neg_score.dim() == 2:
|
|
22
22
|
pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
|
|
23
23
|
diff = pos_score - neg_score # [batch_size, num_neg]
|
|
24
24
|
loss = -torch.log(torch.sigmoid(diff) + 1e-8)
|
|
25
|
-
if self.reduction ==
|
|
25
|
+
if self.reduction == "mean":
|
|
26
26
|
return loss.mean()
|
|
27
|
-
elif self.reduction ==
|
|
27
|
+
elif self.reduction == "sum":
|
|
28
28
|
return loss.sum()
|
|
29
29
|
else:
|
|
30
30
|
return loss
|
|
31
31
|
else:
|
|
32
32
|
diff = pos_score - neg_score
|
|
33
33
|
loss = -torch.log(torch.sigmoid(diff) + 1e-8)
|
|
34
|
-
if self.reduction ==
|
|
34
|
+
if self.reduction == "mean":
|
|
35
35
|
return loss.mean()
|
|
36
|
-
elif self.reduction ==
|
|
36
|
+
elif self.reduction == "sum":
|
|
37
37
|
return loss.sum()
|
|
38
38
|
else:
|
|
39
39
|
return loss
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
class HingeLoss(nn.Module):
|
|
43
|
-
def __init__(self, margin: float = 1.0, reduction: str =
|
|
42
|
+
class HingeLoss(nn.Module):
|
|
43
|
+
def __init__(self, margin: float = 1.0, reduction: str = "mean"):
|
|
44
44
|
super(HingeLoss, self).__init__()
|
|
45
45
|
self.margin = margin
|
|
46
46
|
self.reduction = reduction
|
|
47
|
-
|
|
47
|
+
|
|
48
48
|
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
49
49
|
if neg_score.dim() == 2:
|
|
50
50
|
pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
|
|
51
|
-
|
|
51
|
+
|
|
52
52
|
diff = pos_score - neg_score
|
|
53
53
|
loss = torch.clamp(self.margin - diff, min=0)
|
|
54
|
-
|
|
55
|
-
if self.reduction ==
|
|
54
|
+
|
|
55
|
+
if self.reduction == "mean":
|
|
56
56
|
return loss.mean()
|
|
57
|
-
elif self.reduction ==
|
|
57
|
+
elif self.reduction == "sum":
|
|
58
58
|
return loss.sum()
|
|
59
59
|
else:
|
|
60
60
|
return loss
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
class TripletLoss(nn.Module):
|
|
64
|
-
def __init__(
|
|
64
|
+
def __init__(
|
|
65
|
+
self, margin: float = 1.0, reduction: str = "mean", distance: str = "euclidean"
|
|
66
|
+
):
|
|
65
67
|
super(TripletLoss, self).__init__()
|
|
66
68
|
self.margin = margin
|
|
67
69
|
self.reduction = reduction
|
|
68
70
|
self.distance = distance
|
|
69
|
-
|
|
70
|
-
def forward(
|
|
71
|
-
|
|
71
|
+
|
|
72
|
+
def forward(
|
|
73
|
+
self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
|
|
74
|
+
) -> torch.Tensor:
|
|
75
|
+
if self.distance == "euclidean":
|
|
72
76
|
pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
|
|
73
|
-
|
|
77
|
+
|
|
74
78
|
if negative.dim() == 3:
|
|
75
79
|
anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
|
|
76
|
-
neg_dist = torch.sum(
|
|
80
|
+
neg_dist = torch.sum(
|
|
81
|
+
(anchor_expanded - negative) ** 2, dim=-1
|
|
82
|
+
) # [batch_size, num_neg]
|
|
77
83
|
else:
|
|
78
84
|
neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
|
|
79
|
-
|
|
85
|
+
|
|
80
86
|
if neg_dist.dim() == 2:
|
|
81
87
|
pos_dist = pos_dist.unsqueeze(1) # [batch_size, 1]
|
|
82
|
-
|
|
83
|
-
elif self.distance ==
|
|
88
|
+
|
|
89
|
+
elif self.distance == "cosine":
|
|
84
90
|
pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
|
|
85
|
-
|
|
91
|
+
|
|
86
92
|
if negative.dim() == 3:
|
|
87
93
|
anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
|
|
88
94
|
neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
|
|
89
95
|
else:
|
|
90
96
|
neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
|
|
91
|
-
|
|
97
|
+
|
|
92
98
|
if neg_dist.dim() == 2:
|
|
93
99
|
pos_dist = pos_dist.unsqueeze(1)
|
|
94
100
|
else:
|
|
95
101
|
raise ValueError(f"Unsupported distance: {self.distance}")
|
|
96
|
-
|
|
102
|
+
|
|
97
103
|
loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
|
|
98
|
-
|
|
99
|
-
if self.reduction ==
|
|
104
|
+
|
|
105
|
+
if self.reduction == "mean":
|
|
100
106
|
return loss.mean()
|
|
101
|
-
elif self.reduction ==
|
|
107
|
+
elif self.reduction == "sum":
|
|
102
108
|
return loss.sum()
|
|
103
109
|
else:
|
|
104
110
|
return loss
|
|
105
111
|
|
|
106
112
|
|
|
107
113
|
class SampledSoftmaxLoss(nn.Module):
|
|
108
|
-
def __init__(self, reduction: str =
|
|
114
|
+
def __init__(self, reduction: str = "mean"):
|
|
109
115
|
super(SampledSoftmaxLoss, self).__init__()
|
|
110
116
|
self.reduction = reduction
|
|
111
|
-
|
|
112
|
-
def forward(
|
|
117
|
+
|
|
118
|
+
def forward(
|
|
119
|
+
self, pos_logits: torch.Tensor, neg_logits: torch.Tensor
|
|
120
|
+
) -> torch.Tensor:
|
|
113
121
|
pos_logits = pos_logits.unsqueeze(1) # [batch_size, 1]
|
|
114
|
-
all_logits = torch.cat(
|
|
115
|
-
|
|
122
|
+
all_logits = torch.cat(
|
|
123
|
+
[pos_logits, neg_logits], dim=1
|
|
124
|
+
) # [batch_size, 1 + num_neg]
|
|
125
|
+
targets = torch.zeros(
|
|
126
|
+
all_logits.size(0), dtype=torch.long, device=all_logits.device
|
|
127
|
+
)
|
|
116
128
|
loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
|
|
117
|
-
|
|
129
|
+
|
|
118
130
|
return loss
|
|
119
131
|
|
|
120
132
|
|
|
121
133
|
class CosineContrastiveLoss(nn.Module):
|
|
122
|
-
def __init__(self, margin: float = 0.5, reduction: str =
|
|
134
|
+
def __init__(self, margin: float = 0.5, reduction: str = "mean"):
|
|
123
135
|
super(CosineContrastiveLoss, self).__init__()
|
|
124
136
|
self.margin = margin
|
|
125
137
|
self.reduction = reduction
|
|
126
|
-
|
|
127
|
-
def forward(
|
|
138
|
+
|
|
139
|
+
def forward(
|
|
140
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor
|
|
141
|
+
) -> torch.Tensor:
|
|
128
142
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
129
143
|
pos_loss = (1 - similarity) * labels
|
|
130
144
|
|
|
131
145
|
neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
|
|
132
|
-
|
|
146
|
+
|
|
133
147
|
loss = pos_loss + neg_loss
|
|
134
|
-
|
|
135
|
-
if self.reduction ==
|
|
148
|
+
|
|
149
|
+
if self.reduction == "mean":
|
|
136
150
|
return loss.mean()
|
|
137
|
-
elif self.reduction ==
|
|
151
|
+
elif self.reduction == "sum":
|
|
138
152
|
return loss.sum()
|
|
139
153
|
else:
|
|
140
154
|
return loss
|
|
141
155
|
|
|
142
156
|
|
|
143
157
|
class InfoNCELoss(nn.Module):
|
|
144
|
-
def __init__(self, temperature: float = 0.07, reduction: str =
|
|
158
|
+
def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
|
|
145
159
|
super(InfoNCELoss, self).__init__()
|
|
146
160
|
self.temperature = temperature
|
|
147
161
|
self.reduction = reduction
|
|
148
|
-
|
|
149
|
-
def forward(
|
|
162
|
+
|
|
163
|
+
def forward(
|
|
164
|
+
self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor
|
|
165
|
+
) -> torch.Tensor:
|
|
150
166
|
pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature # [batch_size]
|
|
151
167
|
pos_sim = pos_sim.unsqueeze(1) # [batch_size, 1]
|
|
152
168
|
query_expanded = query.unsqueeze(1) # [batch_size, 1, dim]
|
|
153
|
-
neg_sim =
|
|
169
|
+
neg_sim = (
|
|
170
|
+
torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature
|
|
171
|
+
) # [batch_size, num_neg]
|
|
154
172
|
logits = torch.cat([pos_sim, neg_sim], dim=1) # [batch_size, 1 + num_neg]
|
|
155
173
|
labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
|
|
156
174
|
|
|
157
175
|
loss = F.cross_entropy(logits, labels, reduction=self.reduction)
|
|
158
|
-
|
|
176
|
+
|
|
159
177
|
return loss
|
|
160
178
|
|
|
161
179
|
|
|
@@ -164,22 +182,23 @@ class ListNetLoss(nn.Module):
|
|
|
164
182
|
ListNet loss using top-1 probability distribution
|
|
165
183
|
Reference: Cao et al. Learning to Rank: From Pairwise Approach to Listwise Approach (ICML 2007)
|
|
166
184
|
"""
|
|
167
|
-
|
|
185
|
+
|
|
186
|
+
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
|
|
168
187
|
super(ListNetLoss, self).__init__()
|
|
169
188
|
self.temperature = temperature
|
|
170
189
|
self.reduction = reduction
|
|
171
|
-
|
|
190
|
+
|
|
172
191
|
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
173
192
|
# Convert scores and labels to probability distributions
|
|
174
193
|
pred_probs = F.softmax(scores / self.temperature, dim=1)
|
|
175
194
|
true_probs = F.softmax(labels / self.temperature, dim=1)
|
|
176
|
-
|
|
195
|
+
|
|
177
196
|
# Cross entropy between two distributions
|
|
178
197
|
loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
|
|
179
|
-
|
|
180
|
-
if self.reduction ==
|
|
198
|
+
|
|
199
|
+
if self.reduction == "mean":
|
|
181
200
|
return loss.mean()
|
|
182
|
-
elif self.reduction ==
|
|
201
|
+
elif self.reduction == "sum":
|
|
183
202
|
return loss.sum()
|
|
184
203
|
else:
|
|
185
204
|
return loss
|
|
@@ -190,19 +209,24 @@ class ListMLELoss(nn.Module):
|
|
|
190
209
|
ListMLE (Maximum Likelihood Estimation) loss
|
|
191
210
|
Reference: Xia et al. Listwise approach to learning to rank: theory and algorithm (ICML 2008)
|
|
192
211
|
"""
|
|
193
|
-
|
|
212
|
+
|
|
213
|
+
def __init__(self, reduction: str = "mean"):
|
|
194
214
|
super(ListMLELoss, self).__init__()
|
|
195
215
|
self.reduction = reduction
|
|
196
|
-
|
|
216
|
+
|
|
197
217
|
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
198
218
|
# Sort by labels in descending order to get ground truth ranking
|
|
199
219
|
sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
|
|
200
|
-
|
|
220
|
+
|
|
201
221
|
# Reorder scores according to ground truth ranking
|
|
202
222
|
batch_size, list_size = scores.shape
|
|
203
|
-
batch_indices =
|
|
223
|
+
batch_indices = (
|
|
224
|
+
torch.arange(batch_size, device=scores.device)
|
|
225
|
+
.unsqueeze(1)
|
|
226
|
+
.expand(-1, list_size)
|
|
227
|
+
)
|
|
204
228
|
sorted_scores = scores[batch_indices, sorted_indices]
|
|
205
|
-
|
|
229
|
+
|
|
206
230
|
# Compute log likelihood
|
|
207
231
|
# For each position, compute log(exp(score_i) / sum(exp(score_j) for j >= i))
|
|
208
232
|
loss = torch.tensor(0.0, device=scores.device)
|
|
@@ -211,10 +235,10 @@ class ListMLELoss(nn.Module):
|
|
|
211
235
|
remaining_scores = sorted_scores[:, i:]
|
|
212
236
|
log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
|
|
213
237
|
loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
|
|
214
|
-
|
|
215
|
-
if self.reduction ==
|
|
238
|
+
|
|
239
|
+
if self.reduction == "mean":
|
|
216
240
|
return loss / batch_size
|
|
217
|
-
elif self.reduction ==
|
|
241
|
+
elif self.reduction == "sum":
|
|
218
242
|
return loss
|
|
219
243
|
else:
|
|
220
244
|
return loss / batch_size
|
|
@@ -223,72 +247,81 @@ class ListMLELoss(nn.Module):
|
|
|
223
247
|
class ApproxNDCGLoss(nn.Module):
|
|
224
248
|
"""
|
|
225
249
|
Approximate NDCG loss for learning to rank
|
|
226
|
-
Reference: Qin et al. A General Approximation Framework for Direct Optimization of
|
|
250
|
+
Reference: Qin et al. A General Approximation Framework for Direct Optimization of
|
|
227
251
|
Information Retrieval Measures (Information Retrieval 2010)
|
|
228
252
|
"""
|
|
229
|
-
|
|
253
|
+
|
|
254
|
+
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
|
|
230
255
|
super(ApproxNDCGLoss, self).__init__()
|
|
231
256
|
self.temperature = temperature
|
|
232
257
|
self.reduction = reduction
|
|
233
|
-
|
|
258
|
+
|
|
234
259
|
def _dcg(self, relevance: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
|
|
235
260
|
if k is not None:
|
|
236
261
|
relevance = relevance[:, :k]
|
|
237
|
-
|
|
262
|
+
|
|
238
263
|
# DCG = sum(rel_i / log2(i + 2)) for i in range(list_size)
|
|
239
|
-
positions = torch.arange(
|
|
264
|
+
positions = torch.arange(
|
|
265
|
+
1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32
|
|
266
|
+
)
|
|
240
267
|
discounts = torch.log2(positions + 1.0)
|
|
241
268
|
dcg = torch.sum(relevance / discounts, dim=1)
|
|
242
|
-
|
|
269
|
+
|
|
243
270
|
return dcg
|
|
244
|
-
|
|
245
|
-
def forward(
|
|
271
|
+
|
|
272
|
+
def forward(
|
|
273
|
+
self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
|
|
274
|
+
) -> torch.Tensor:
|
|
246
275
|
"""
|
|
247
276
|
Args:
|
|
248
277
|
scores: Predicted scores [batch_size, list_size]
|
|
249
278
|
labels: Ground truth relevance labels [batch_size, list_size]
|
|
250
279
|
k: Top-k items for NDCG@k (if None, use all items)
|
|
251
|
-
|
|
280
|
+
|
|
252
281
|
Returns:
|
|
253
282
|
Approximate NDCG loss (1 - NDCG)
|
|
254
283
|
"""
|
|
255
284
|
batch_size = scores.size(0)
|
|
256
|
-
|
|
285
|
+
|
|
257
286
|
# Use differentiable sorting approximation with softmax
|
|
258
287
|
# Create pairwise comparison matrix
|
|
259
288
|
scores_expanded = scores.unsqueeze(2) # [batch_size, list_size, 1]
|
|
260
|
-
scores_tiled = scores.unsqueeze(1)
|
|
261
|
-
|
|
289
|
+
scores_tiled = scores.unsqueeze(1) # [batch_size, 1, list_size]
|
|
290
|
+
|
|
262
291
|
# Compute pairwise probabilities using sigmoid
|
|
263
292
|
pairwise_diff = (scores_expanded - scores_tiled) / self.temperature
|
|
264
|
-
pairwise_probs = torch.sigmoid(
|
|
265
|
-
|
|
293
|
+
pairwise_probs = torch.sigmoid(
|
|
294
|
+
pairwise_diff
|
|
295
|
+
) # [batch_size, list_size, list_size]
|
|
296
|
+
|
|
266
297
|
# Approximate ranking positions
|
|
267
298
|
# ranking_probs[i, j] ≈ probability that item i is ranked at position j
|
|
268
299
|
# We use softmax approximation for differentiable ranking
|
|
269
300
|
ranking_weights = F.softmax(scores / self.temperature, dim=1)
|
|
270
|
-
|
|
301
|
+
|
|
271
302
|
# Sort labels to get ideal DCG
|
|
272
303
|
ideal_labels, _ = torch.sort(labels, descending=True, dim=1)
|
|
273
304
|
ideal_dcg = self._dcg(ideal_labels, k)
|
|
274
|
-
|
|
305
|
+
|
|
275
306
|
# Compute approximate DCG using soft ranking
|
|
276
307
|
# Weight each item's relevance by its soft ranking position
|
|
277
|
-
positions = torch.arange(
|
|
308
|
+
positions = torch.arange(
|
|
309
|
+
1, scores.size(1) + 1, device=scores.device, dtype=torch.float32
|
|
310
|
+
)
|
|
278
311
|
discounts = 1.0 / torch.log2(positions + 1.0)
|
|
279
|
-
|
|
312
|
+
|
|
280
313
|
# Approximate DCG by weighting relevance with ranking probabilities
|
|
281
314
|
approx_dcg = torch.sum(labels * ranking_weights * discounts, dim=1)
|
|
282
|
-
|
|
315
|
+
|
|
283
316
|
# Normalize by ideal DCG to get NDCG
|
|
284
317
|
ndcg = approx_dcg / (ideal_dcg + 1e-10)
|
|
285
|
-
|
|
318
|
+
|
|
286
319
|
# Loss is 1 - NDCG (we want to maximize NDCG, so minimize 1 - NDCG)
|
|
287
320
|
loss = 1.0 - ndcg
|
|
288
|
-
|
|
289
|
-
if self.reduction ==
|
|
321
|
+
|
|
322
|
+
if self.reduction == "mean":
|
|
290
323
|
return loss.mean()
|
|
291
|
-
elif self.reduction ==
|
|
324
|
+
elif self.reduction == "sum":
|
|
292
325
|
return loss.sum()
|
|
293
326
|
else:
|
|
294
327
|
return loss
|
nextrec/models/match/__init__.py
CHANGED