nextrec 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +41 -0
- nextrec/__version__.py +1 -0
- nextrec/basic/__init__.py +0 -0
- nextrec/basic/activation.py +92 -0
- nextrec/basic/callback.py +35 -0
- nextrec/basic/dataloader.py +447 -0
- nextrec/basic/features.py +87 -0
- nextrec/basic/layers.py +985 -0
- nextrec/basic/loggers.py +124 -0
- nextrec/basic/metrics.py +557 -0
- nextrec/basic/model.py +1438 -0
- nextrec/data/__init__.py +27 -0
- nextrec/data/data_utils.py +132 -0
- nextrec/data/preprocessor.py +662 -0
- nextrec/loss/__init__.py +35 -0
- nextrec/loss/loss_utils.py +136 -0
- nextrec/loss/match_losses.py +294 -0
- nextrec/models/generative/hstu.py +0 -0
- nextrec/models/generative/tiger.py +0 -0
- nextrec/models/match/__init__.py +13 -0
- nextrec/models/match/dssm.py +200 -0
- nextrec/models/match/dssm_v2.py +162 -0
- nextrec/models/match/mind.py +210 -0
- nextrec/models/match/sdm.py +253 -0
- nextrec/models/match/youtube_dnn.py +172 -0
- nextrec/models/multi_task/esmm.py +129 -0
- nextrec/models/multi_task/mmoe.py +161 -0
- nextrec/models/multi_task/ple.py +260 -0
- nextrec/models/multi_task/share_bottom.py +126 -0
- nextrec/models/ranking/__init__.py +17 -0
- nextrec/models/ranking/afm.py +118 -0
- nextrec/models/ranking/autoint.py +140 -0
- nextrec/models/ranking/dcn.py +120 -0
- nextrec/models/ranking/deepfm.py +95 -0
- nextrec/models/ranking/dien.py +214 -0
- nextrec/models/ranking/din.py +181 -0
- nextrec/models/ranking/fibinet.py +130 -0
- nextrec/models/ranking/fm.py +87 -0
- nextrec/models/ranking/masknet.py +125 -0
- nextrec/models/ranking/pnn.py +128 -0
- nextrec/models/ranking/widedeep.py +105 -0
- nextrec/models/ranking/xdeepfm.py +117 -0
- nextrec/utils/__init__.py +18 -0
- nextrec/utils/common.py +14 -0
- nextrec/utils/embedding.py +19 -0
- nextrec/utils/initializer.py +47 -0
- nextrec/utils/optimizer.py +75 -0
- nextrec-0.1.1.dist-info/METADATA +302 -0
- nextrec-0.1.1.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/WHEEL +4 -0
- nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss utilities for NextRec
|
|
3
|
+
|
|
4
|
+
Date: create on 09/11/2025
|
|
5
|
+
Author:
|
|
6
|
+
Yang Zhou,zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from typing import Literal
|
|
11
|
+
|
|
12
|
+
from nextrec.loss.match_losses import (
|
|
13
|
+
BPRLoss,
|
|
14
|
+
HingeLoss,
|
|
15
|
+
TripletLoss,
|
|
16
|
+
SampledSoftmaxLoss,
|
|
17
|
+
CosineContrastiveLoss,
|
|
18
|
+
InfoNCELoss
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Valid task types for validation
|
|
22
|
+
VALID_TASK_TYPES = ['binary', 'multiclass', 'regression', 'multivariate_regression', 'match', 'ranking', 'multitask', 'multilabel']
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_loss_fn(
|
|
26
|
+
task_type: str = "binary",
|
|
27
|
+
training_mode: str | None = None,
|
|
28
|
+
loss: str | nn.Module | None = None,
|
|
29
|
+
**loss_kwargs
|
|
30
|
+
) -> nn.Module:
|
|
31
|
+
"""
|
|
32
|
+
Get loss function based on task type and training mode.
|
|
33
|
+
|
|
34
|
+
Examples:
|
|
35
|
+
# Ranking task (binary classification)
|
|
36
|
+
>>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
|
|
37
|
+
|
|
38
|
+
# Match task with pointwise training
|
|
39
|
+
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
|
|
40
|
+
|
|
41
|
+
# Match task with pairwise training
|
|
42
|
+
>>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
|
|
43
|
+
|
|
44
|
+
# Match task with listwise training
|
|
45
|
+
>>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
if isinstance(loss, nn.Module):
|
|
49
|
+
return loss
|
|
50
|
+
|
|
51
|
+
if task_type == "match":
|
|
52
|
+
if training_mode == "pointwise":
|
|
53
|
+
# Pointwise training uses binary cross entropy
|
|
54
|
+
if loss is None or loss == "bce" or loss == "binary_crossentropy":
|
|
55
|
+
return nn.BCELoss(**loss_kwargs)
|
|
56
|
+
elif loss == "cosine_contrastive":
|
|
57
|
+
return CosineContrastiveLoss(**loss_kwargs)
|
|
58
|
+
elif isinstance(loss, str):
|
|
59
|
+
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
60
|
+
|
|
61
|
+
elif training_mode == "pairwise":
|
|
62
|
+
if loss is None or loss == "bpr":
|
|
63
|
+
return BPRLoss(**loss_kwargs)
|
|
64
|
+
elif loss == "hinge":
|
|
65
|
+
return HingeLoss(**loss_kwargs)
|
|
66
|
+
elif loss == "triplet":
|
|
67
|
+
return TripletLoss(**loss_kwargs)
|
|
68
|
+
elif isinstance(loss, str):
|
|
69
|
+
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
70
|
+
|
|
71
|
+
elif training_mode == "listwise":
|
|
72
|
+
if loss is None or loss == "sampled_softmax" or loss == "softmax":
|
|
73
|
+
return SampledSoftmaxLoss(**loss_kwargs)
|
|
74
|
+
elif loss == "infonce":
|
|
75
|
+
return InfoNCELoss(**loss_kwargs)
|
|
76
|
+
elif loss == "crossentropy" or loss == "ce":
|
|
77
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
78
|
+
elif isinstance(loss, str):
|
|
79
|
+
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
80
|
+
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
83
|
+
|
|
84
|
+
elif task_type in ["ranking", "multitask", "binary"]:
|
|
85
|
+
if loss is None or loss == "bce" or loss == "binary_crossentropy":
|
|
86
|
+
return nn.BCELoss(**loss_kwargs)
|
|
87
|
+
elif loss == "mse":
|
|
88
|
+
return nn.MSELoss(**loss_kwargs)
|
|
89
|
+
elif loss == "mae":
|
|
90
|
+
return nn.L1Loss(**loss_kwargs)
|
|
91
|
+
elif loss == "crossentropy" or loss == "ce":
|
|
92
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
93
|
+
elif isinstance(loss, str):
|
|
94
|
+
raise ValueError(f"Unsupported loss function: {loss}")
|
|
95
|
+
|
|
96
|
+
elif task_type == "multiclass":
|
|
97
|
+
if loss is None or loss == "crossentropy" or loss == "ce":
|
|
98
|
+
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
99
|
+
elif isinstance(loss, str):
|
|
100
|
+
raise ValueError(f"Unsupported multiclass loss: {loss}")
|
|
101
|
+
|
|
102
|
+
elif task_type == "regression":
|
|
103
|
+
if loss is None or loss == "mse":
|
|
104
|
+
return nn.MSELoss(**loss_kwargs)
|
|
105
|
+
elif loss == "mae":
|
|
106
|
+
return nn.L1Loss(**loss_kwargs)
|
|
107
|
+
elif isinstance(loss, str):
|
|
108
|
+
raise ValueError(f"Unsupported regression loss: {loss}")
|
|
109
|
+
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError(f"Unsupported task_type: {task_type}")
|
|
112
|
+
|
|
113
|
+
return loss
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def validate_training_mode(
|
|
117
|
+
training_mode: str,
|
|
118
|
+
support_training_modes: list[str],
|
|
119
|
+
model_name: str = "Model"
|
|
120
|
+
) -> None:
|
|
121
|
+
"""
|
|
122
|
+
Validate that the requested training mode is supported by the model.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
training_mode: Requested training mode
|
|
126
|
+
support_training_modes: List of supported training modes
|
|
127
|
+
model_name: Name of the model (for error messages)
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ValueError: If training mode is not supported
|
|
131
|
+
"""
|
|
132
|
+
if training_mode not in support_training_modes:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"{model_name} does not support training_mode='{training_mode}'. "
|
|
135
|
+
f"Supported modes: {support_training_modes}"
|
|
136
|
+
)
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss functions for matching tasks
|
|
3
|
+
|
|
4
|
+
Date: create on 13/11/2025
|
|
5
|
+
Author:
|
|
6
|
+
Yang Zhou,zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BPRLoss(nn.Module):
|
|
16
|
+
def __init__(self, reduction: str = 'mean'):
|
|
17
|
+
super(BPRLoss, self).__init__()
|
|
18
|
+
self.reduction = reduction
|
|
19
|
+
|
|
20
|
+
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
21
|
+
if neg_score.dim() == 2:
|
|
22
|
+
pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
|
|
23
|
+
diff = pos_score - neg_score # [batch_size, num_neg]
|
|
24
|
+
loss = -torch.log(torch.sigmoid(diff) + 1e-8)
|
|
25
|
+
if self.reduction == 'mean':
|
|
26
|
+
return loss.mean()
|
|
27
|
+
elif self.reduction == 'sum':
|
|
28
|
+
return loss.sum()
|
|
29
|
+
else:
|
|
30
|
+
return loss
|
|
31
|
+
else:
|
|
32
|
+
diff = pos_score - neg_score
|
|
33
|
+
loss = -torch.log(torch.sigmoid(diff) + 1e-8)
|
|
34
|
+
if self.reduction == 'mean':
|
|
35
|
+
return loss.mean()
|
|
36
|
+
elif self.reduction == 'sum':
|
|
37
|
+
return loss.sum()
|
|
38
|
+
else:
|
|
39
|
+
return loss
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class HingeLoss(nn.Module):
|
|
43
|
+
def __init__(self, margin: float = 1.0, reduction: str = 'mean'):
|
|
44
|
+
super(HingeLoss, self).__init__()
|
|
45
|
+
self.margin = margin
|
|
46
|
+
self.reduction = reduction
|
|
47
|
+
|
|
48
|
+
def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
if neg_score.dim() == 2:
|
|
50
|
+
pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
|
|
51
|
+
|
|
52
|
+
diff = pos_score - neg_score
|
|
53
|
+
loss = torch.clamp(self.margin - diff, min=0)
|
|
54
|
+
|
|
55
|
+
if self.reduction == 'mean':
|
|
56
|
+
return loss.mean()
|
|
57
|
+
elif self.reduction == 'sum':
|
|
58
|
+
return loss.sum()
|
|
59
|
+
else:
|
|
60
|
+
return loss
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TripletLoss(nn.Module):
|
|
64
|
+
def __init__(self, margin: float = 1.0, reduction: str = 'mean', distance: str = 'euclidean'):
|
|
65
|
+
super(TripletLoss, self).__init__()
|
|
66
|
+
self.margin = margin
|
|
67
|
+
self.reduction = reduction
|
|
68
|
+
self.distance = distance
|
|
69
|
+
|
|
70
|
+
def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
if self.distance == 'euclidean':
|
|
72
|
+
pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
|
|
73
|
+
|
|
74
|
+
if negative.dim() == 3:
|
|
75
|
+
anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
|
|
76
|
+
neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1) # [batch_size, num_neg]
|
|
77
|
+
else:
|
|
78
|
+
neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
|
|
79
|
+
|
|
80
|
+
if neg_dist.dim() == 2:
|
|
81
|
+
pos_dist = pos_dist.unsqueeze(1) # [batch_size, 1]
|
|
82
|
+
|
|
83
|
+
elif self.distance == 'cosine':
|
|
84
|
+
pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
|
|
85
|
+
|
|
86
|
+
if negative.dim() == 3:
|
|
87
|
+
anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
|
|
88
|
+
neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
|
|
89
|
+
else:
|
|
90
|
+
neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
|
|
91
|
+
|
|
92
|
+
if neg_dist.dim() == 2:
|
|
93
|
+
pos_dist = pos_dist.unsqueeze(1)
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unsupported distance: {self.distance}")
|
|
96
|
+
|
|
97
|
+
loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
|
|
98
|
+
|
|
99
|
+
if self.reduction == 'mean':
|
|
100
|
+
return loss.mean()
|
|
101
|
+
elif self.reduction == 'sum':
|
|
102
|
+
return loss.sum()
|
|
103
|
+
else:
|
|
104
|
+
return loss
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SampledSoftmaxLoss(nn.Module):
|
|
108
|
+
def __init__(self, reduction: str = 'mean'):
|
|
109
|
+
super(SampledSoftmaxLoss, self).__init__()
|
|
110
|
+
self.reduction = reduction
|
|
111
|
+
|
|
112
|
+
def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
|
|
113
|
+
pos_logits = pos_logits.unsqueeze(1) # [batch_size, 1]
|
|
114
|
+
all_logits = torch.cat([pos_logits, neg_logits], dim=1) # [batch_size, 1 + num_neg]
|
|
115
|
+
targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
|
|
116
|
+
loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
|
|
117
|
+
|
|
118
|
+
return loss
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class CosineContrastiveLoss(nn.Module):
|
|
122
|
+
def __init__(self, margin: float = 0.5, reduction: str = 'mean'):
|
|
123
|
+
super(CosineContrastiveLoss, self).__init__()
|
|
124
|
+
self.margin = margin
|
|
125
|
+
self.reduction = reduction
|
|
126
|
+
|
|
127
|
+
def forward(self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
128
|
+
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
129
|
+
pos_loss = (1 - similarity) * labels
|
|
130
|
+
|
|
131
|
+
neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
|
|
132
|
+
|
|
133
|
+
loss = pos_loss + neg_loss
|
|
134
|
+
|
|
135
|
+
if self.reduction == 'mean':
|
|
136
|
+
return loss.mean()
|
|
137
|
+
elif self.reduction == 'sum':
|
|
138
|
+
return loss.sum()
|
|
139
|
+
else:
|
|
140
|
+
return loss
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class InfoNCELoss(nn.Module):
|
|
144
|
+
def __init__(self, temperature: float = 0.07, reduction: str = 'mean'):
|
|
145
|
+
super(InfoNCELoss, self).__init__()
|
|
146
|
+
self.temperature = temperature
|
|
147
|
+
self.reduction = reduction
|
|
148
|
+
|
|
149
|
+
def forward(self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor) -> torch.Tensor:
|
|
150
|
+
pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature # [batch_size]
|
|
151
|
+
pos_sim = pos_sim.unsqueeze(1) # [batch_size, 1]
|
|
152
|
+
query_expanded = query.unsqueeze(1) # [batch_size, 1, dim]
|
|
153
|
+
neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature # [batch_size, num_neg]
|
|
154
|
+
logits = torch.cat([pos_sim, neg_sim], dim=1) # [batch_size, 1 + num_neg]
|
|
155
|
+
labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
|
|
156
|
+
|
|
157
|
+
loss = F.cross_entropy(logits, labels, reduction=self.reduction)
|
|
158
|
+
|
|
159
|
+
return loss
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ListNetLoss(nn.Module):
|
|
163
|
+
"""
|
|
164
|
+
ListNet loss using top-1 probability distribution
|
|
165
|
+
Reference: Cao et al. Learning to Rank: From Pairwise Approach to Listwise Approach (ICML 2007)
|
|
166
|
+
"""
|
|
167
|
+
def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
|
|
168
|
+
super(ListNetLoss, self).__init__()
|
|
169
|
+
self.temperature = temperature
|
|
170
|
+
self.reduction = reduction
|
|
171
|
+
|
|
172
|
+
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
173
|
+
# Convert scores and labels to probability distributions
|
|
174
|
+
pred_probs = F.softmax(scores / self.temperature, dim=1)
|
|
175
|
+
true_probs = F.softmax(labels / self.temperature, dim=1)
|
|
176
|
+
|
|
177
|
+
# Cross entropy between two distributions
|
|
178
|
+
loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
|
|
179
|
+
|
|
180
|
+
if self.reduction == 'mean':
|
|
181
|
+
return loss.mean()
|
|
182
|
+
elif self.reduction == 'sum':
|
|
183
|
+
return loss.sum()
|
|
184
|
+
else:
|
|
185
|
+
return loss
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class ListMLELoss(nn.Module):
|
|
189
|
+
"""
|
|
190
|
+
ListMLE (Maximum Likelihood Estimation) loss
|
|
191
|
+
Reference: Xia et al. Listwise approach to learning to rank: theory and algorithm (ICML 2008)
|
|
192
|
+
"""
|
|
193
|
+
def __init__(self, reduction: str = 'mean'):
|
|
194
|
+
super(ListMLELoss, self).__init__()
|
|
195
|
+
self.reduction = reduction
|
|
196
|
+
|
|
197
|
+
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
198
|
+
# Sort by labels in descending order to get ground truth ranking
|
|
199
|
+
sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
|
|
200
|
+
|
|
201
|
+
# Reorder scores according to ground truth ranking
|
|
202
|
+
batch_size, list_size = scores.shape
|
|
203
|
+
batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
|
|
204
|
+
sorted_scores = scores[batch_indices, sorted_indices]
|
|
205
|
+
|
|
206
|
+
# Compute log likelihood
|
|
207
|
+
# For each position, compute log(exp(score_i) / sum(exp(score_j) for j >= i))
|
|
208
|
+
loss = torch.tensor(0.0, device=scores.device)
|
|
209
|
+
for i in range(list_size):
|
|
210
|
+
# Log-sum-exp trick for numerical stability
|
|
211
|
+
remaining_scores = sorted_scores[:, i:]
|
|
212
|
+
log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
|
|
213
|
+
loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
|
|
214
|
+
|
|
215
|
+
if self.reduction == 'mean':
|
|
216
|
+
return loss / batch_size
|
|
217
|
+
elif self.reduction == 'sum':
|
|
218
|
+
return loss
|
|
219
|
+
else:
|
|
220
|
+
return loss / batch_size
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class ApproxNDCGLoss(nn.Module):
|
|
224
|
+
"""
|
|
225
|
+
Approximate NDCG loss for learning to rank
|
|
226
|
+
Reference: Qin et al. A General Approximation Framework for Direct Optimization of
|
|
227
|
+
Information Retrieval Measures (Information Retrieval 2010)
|
|
228
|
+
"""
|
|
229
|
+
def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
|
|
230
|
+
super(ApproxNDCGLoss, self).__init__()
|
|
231
|
+
self.temperature = temperature
|
|
232
|
+
self.reduction = reduction
|
|
233
|
+
|
|
234
|
+
def _dcg(self, relevance: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
|
|
235
|
+
if k is not None:
|
|
236
|
+
relevance = relevance[:, :k]
|
|
237
|
+
|
|
238
|
+
# DCG = sum(rel_i / log2(i + 2)) for i in range(list_size)
|
|
239
|
+
positions = torch.arange(1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32)
|
|
240
|
+
discounts = torch.log2(positions + 1.0)
|
|
241
|
+
dcg = torch.sum(relevance / discounts, dim=1)
|
|
242
|
+
|
|
243
|
+
return dcg
|
|
244
|
+
|
|
245
|
+
def forward(self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
|
|
246
|
+
"""
|
|
247
|
+
Args:
|
|
248
|
+
scores: Predicted scores [batch_size, list_size]
|
|
249
|
+
labels: Ground truth relevance labels [batch_size, list_size]
|
|
250
|
+
k: Top-k items for NDCG@k (if None, use all items)
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Approximate NDCG loss (1 - NDCG)
|
|
254
|
+
"""
|
|
255
|
+
batch_size = scores.size(0)
|
|
256
|
+
|
|
257
|
+
# Use differentiable sorting approximation with softmax
|
|
258
|
+
# Create pairwise comparison matrix
|
|
259
|
+
scores_expanded = scores.unsqueeze(2) # [batch_size, list_size, 1]
|
|
260
|
+
scores_tiled = scores.unsqueeze(1) # [batch_size, 1, list_size]
|
|
261
|
+
|
|
262
|
+
# Compute pairwise probabilities using sigmoid
|
|
263
|
+
pairwise_diff = (scores_expanded - scores_tiled) / self.temperature
|
|
264
|
+
pairwise_probs = torch.sigmoid(pairwise_diff) # [batch_size, list_size, list_size]
|
|
265
|
+
|
|
266
|
+
# Approximate ranking positions
|
|
267
|
+
# ranking_probs[i, j] ≈ probability that item i is ranked at position j
|
|
268
|
+
# We use softmax approximation for differentiable ranking
|
|
269
|
+
ranking_weights = F.softmax(scores / self.temperature, dim=1)
|
|
270
|
+
|
|
271
|
+
# Sort labels to get ideal DCG
|
|
272
|
+
ideal_labels, _ = torch.sort(labels, descending=True, dim=1)
|
|
273
|
+
ideal_dcg = self._dcg(ideal_labels, k)
|
|
274
|
+
|
|
275
|
+
# Compute approximate DCG using soft ranking
|
|
276
|
+
# Weight each item's relevance by its soft ranking position
|
|
277
|
+
positions = torch.arange(1, scores.size(1) + 1, device=scores.device, dtype=torch.float32)
|
|
278
|
+
discounts = 1.0 / torch.log2(positions + 1.0)
|
|
279
|
+
|
|
280
|
+
# Approximate DCG by weighting relevance with ranking probabilities
|
|
281
|
+
approx_dcg = torch.sum(labels * ranking_weights * discounts, dim=1)
|
|
282
|
+
|
|
283
|
+
# Normalize by ideal DCG to get NDCG
|
|
284
|
+
ndcg = approx_dcg / (ideal_dcg + 1e-10)
|
|
285
|
+
|
|
286
|
+
# Loss is 1 - NDCG (we want to maximize NDCG, so minimize 1 - NDCG)
|
|
287
|
+
loss = 1.0 - ndcg
|
|
288
|
+
|
|
289
|
+
if self.reduction == 'mean':
|
|
290
|
+
return loss.mean()
|
|
291
|
+
elif self.reduction == 'sum':
|
|
292
|
+
return loss.sum()
|
|
293
|
+
else:
|
|
294
|
+
return loss
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 09/11/2025
|
|
3
|
+
Author:
|
|
4
|
+
Yang Zhou,zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
[1] Huang P S, He X, Gao J, et al. Learning deep structured semantic models for web search using clickthrough data[C]
|
|
7
|
+
//Proceedings of the 22nd ACM international conference on Information & Knowledge Management. 2013: 2333-2338.
|
|
8
|
+
"""
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from typing import Optional, Literal
|
|
12
|
+
|
|
13
|
+
from nextrec.basic.model import BaseMatchModel
|
|
14
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
15
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DSSM(BaseMatchModel):
|
|
19
|
+
"""
|
|
20
|
+
Deep Structured Semantic Model
|
|
21
|
+
|
|
22
|
+
双塔模型,分别对user和item特征编码为embedding,通过余弦相似度或点积计算匹配分数
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def model_name(self) -> str:
|
|
27
|
+
return "DSSM"
|
|
28
|
+
|
|
29
|
+
def __init__(self,
|
|
30
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
31
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
32
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
33
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
34
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
35
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
36
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
37
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
38
|
+
embedding_dim: int = 64,
|
|
39
|
+
dnn_activation: str = 'relu',
|
|
40
|
+
dnn_dropout: float = 0.0,
|
|
41
|
+
training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
|
|
42
|
+
num_negative_samples: int = 4,
|
|
43
|
+
temperature: float = 1.0,
|
|
44
|
+
similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'cosine',
|
|
45
|
+
device: str = 'cpu',
|
|
46
|
+
embedding_l1_reg: float = 0.0,
|
|
47
|
+
dense_l1_reg: float = 0.0,
|
|
48
|
+
embedding_l2_reg: float = 0.0,
|
|
49
|
+
dense_l2_reg: float = 0.0,
|
|
50
|
+
early_stop_patience: int = 20,
|
|
51
|
+
model_id: str = 'dssm'):
|
|
52
|
+
|
|
53
|
+
super(DSSM, self).__init__(
|
|
54
|
+
user_dense_features=user_dense_features,
|
|
55
|
+
user_sparse_features=user_sparse_features,
|
|
56
|
+
user_sequence_features=user_sequence_features,
|
|
57
|
+
item_dense_features=item_dense_features,
|
|
58
|
+
item_sparse_features=item_sparse_features,
|
|
59
|
+
item_sequence_features=item_sequence_features,
|
|
60
|
+
training_mode=training_mode,
|
|
61
|
+
num_negative_samples=num_negative_samples,
|
|
62
|
+
temperature=temperature,
|
|
63
|
+
similarity_metric=similarity_metric,
|
|
64
|
+
device=device,
|
|
65
|
+
embedding_l1_reg=embedding_l1_reg,
|
|
66
|
+
dense_l1_reg=dense_l1_reg,
|
|
67
|
+
embedding_l2_reg=embedding_l2_reg,
|
|
68
|
+
dense_l2_reg=dense_l2_reg,
|
|
69
|
+
early_stop_patience=early_stop_patience,
|
|
70
|
+
model_id=model_id
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.embedding_dim = embedding_dim
|
|
74
|
+
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
75
|
+
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
76
|
+
|
|
77
|
+
# User tower embedding layer
|
|
78
|
+
user_features = []
|
|
79
|
+
if user_dense_features:
|
|
80
|
+
user_features.extend(user_dense_features)
|
|
81
|
+
if user_sparse_features:
|
|
82
|
+
user_features.extend(user_sparse_features)
|
|
83
|
+
if user_sequence_features:
|
|
84
|
+
user_features.extend(user_sequence_features)
|
|
85
|
+
|
|
86
|
+
if len(user_features) > 0:
|
|
87
|
+
self.user_embedding = EmbeddingLayer(user_features)
|
|
88
|
+
|
|
89
|
+
# 计算user tower输入维度
|
|
90
|
+
user_input_dim = 0
|
|
91
|
+
for feat in user_dense_features or []:
|
|
92
|
+
user_input_dim += 1
|
|
93
|
+
for feat in user_sparse_features or []:
|
|
94
|
+
user_input_dim += feat.embedding_dim
|
|
95
|
+
for feat in user_sequence_features or []:
|
|
96
|
+
user_input_dim += feat.embedding_dim
|
|
97
|
+
|
|
98
|
+
# User DNN
|
|
99
|
+
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
100
|
+
self.user_dnn = MLP(
|
|
101
|
+
input_dim=user_input_dim,
|
|
102
|
+
dims=user_dnn_units,
|
|
103
|
+
output_layer=False,
|
|
104
|
+
dropout=dnn_dropout,
|
|
105
|
+
activation=dnn_activation
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Item tower embedding layer
|
|
109
|
+
item_features = []
|
|
110
|
+
if item_dense_features:
|
|
111
|
+
item_features.extend(item_dense_features)
|
|
112
|
+
if item_sparse_features:
|
|
113
|
+
item_features.extend(item_sparse_features)
|
|
114
|
+
if item_sequence_features:
|
|
115
|
+
item_features.extend(item_sequence_features)
|
|
116
|
+
|
|
117
|
+
if len(item_features) > 0:
|
|
118
|
+
self.item_embedding = EmbeddingLayer(item_features)
|
|
119
|
+
|
|
120
|
+
# 计算item tower输入维度
|
|
121
|
+
item_input_dim = 0
|
|
122
|
+
for feat in item_dense_features or []:
|
|
123
|
+
item_input_dim += 1
|
|
124
|
+
for feat in item_sparse_features or []:
|
|
125
|
+
item_input_dim += feat.embedding_dim
|
|
126
|
+
for feat in item_sequence_features or []:
|
|
127
|
+
item_input_dim += feat.embedding_dim
|
|
128
|
+
|
|
129
|
+
# Item DNN
|
|
130
|
+
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
131
|
+
self.item_dnn = MLP(
|
|
132
|
+
input_dim=item_input_dim,
|
|
133
|
+
dims=item_dnn_units,
|
|
134
|
+
output_layer=False,
|
|
135
|
+
dropout=dnn_dropout,
|
|
136
|
+
activation=dnn_activation
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# 注册正则化权重
|
|
140
|
+
self._register_regularization_weights(
|
|
141
|
+
embedding_attr='user_embedding',
|
|
142
|
+
include_modules=['user_dnn']
|
|
143
|
+
)
|
|
144
|
+
self._register_regularization_weights(
|
|
145
|
+
embedding_attr='item_embedding',
|
|
146
|
+
include_modules=['item_dnn']
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
self.compile(
|
|
150
|
+
optimizer="adam",
|
|
151
|
+
optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
self.to(device)
|
|
155
|
+
|
|
156
|
+
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
157
|
+
"""
|
|
158
|
+
User tower: 将user特征编码为embedding
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
user_input: user特征字典
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
user_emb: [batch_size, embedding_dim]
|
|
165
|
+
"""
|
|
166
|
+
# 获取user特征的embedding
|
|
167
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
168
|
+
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
169
|
+
|
|
170
|
+
# 通过user DNN
|
|
171
|
+
user_emb = self.user_dnn(user_emb)
|
|
172
|
+
|
|
173
|
+
# L2 normalize for cosine similarity
|
|
174
|
+
if self.similarity_metric == 'cosine':
|
|
175
|
+
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
176
|
+
|
|
177
|
+
return user_emb
|
|
178
|
+
|
|
179
|
+
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
180
|
+
"""
|
|
181
|
+
Item tower: 将item特征编码为embedding
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
item_input: item特征字典
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
item_emb: [batch_size, embedding_dim] 或 [batch_size, num_items, embedding_dim]
|
|
188
|
+
"""
|
|
189
|
+
# 获取item特征的embedding
|
|
190
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
191
|
+
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
192
|
+
|
|
193
|
+
# 通过item DNN
|
|
194
|
+
item_emb = self.item_dnn(item_emb)
|
|
195
|
+
|
|
196
|
+
# L2 normalize for cosine similarity
|
|
197
|
+
if self.similarity_metric == 'cosine':
|
|
198
|
+
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
199
|
+
|
|
200
|
+
return item_emb
|