nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
  45. nextrec-0.1.2.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py CHANGED
@@ -18,18 +18,18 @@ from nextrec.loss.loss_utils import (
18
18
 
19
19
  __all__ = [
20
20
  # Match losses
21
- 'BPRLoss',
22
- 'HingeLoss',
23
- 'TripletLoss',
24
- 'SampledSoftmaxLoss',
25
- 'CosineContrastiveLoss',
26
- 'InfoNCELoss',
21
+ "BPRLoss",
22
+ "HingeLoss",
23
+ "TripletLoss",
24
+ "SampledSoftmaxLoss",
25
+ "CosineContrastiveLoss",
26
+ "InfoNCELoss",
27
27
  # Listwise losses
28
- 'ListNetLoss',
29
- 'ListMLELoss',
30
- 'ApproxNDCGLoss',
28
+ "ListNetLoss",
29
+ "ListMLELoss",
30
+ "ApproxNDCGLoss",
31
31
  # Utilities
32
- 'get_loss_fn',
33
- 'validate_training_mode',
34
- 'VALID_TASK_TYPES',
32
+ "get_loss_fn",
33
+ "validate_training_mode",
34
+ "VALID_TASK_TYPES",
35
35
  ]
@@ -5,42 +5,52 @@ Date: create on 09/11/2025
5
5
  Author:
6
6
  Yang Zhou,zyaztec@gmail.com
7
7
  """
8
+
8
9
  import torch
9
10
  import torch.nn as nn
10
11
  from typing import Literal
11
12
 
12
13
  from nextrec.loss.match_losses import (
13
- BPRLoss,
14
- HingeLoss,
15
- TripletLoss,
14
+ BPRLoss,
15
+ HingeLoss,
16
+ TripletLoss,
16
17
  SampledSoftmaxLoss,
17
- CosineContrastiveLoss,
18
- InfoNCELoss
18
+ CosineContrastiveLoss,
19
+ InfoNCELoss,
19
20
  )
20
21
 
21
22
  # Valid task types for validation
22
- VALID_TASK_TYPES = ['binary', 'multiclass', 'regression', 'multivariate_regression', 'match', 'ranking', 'multitask', 'multilabel']
23
+ VALID_TASK_TYPES = [
24
+ "binary",
25
+ "multiclass",
26
+ "regression",
27
+ "multivariate_regression",
28
+ "match",
29
+ "ranking",
30
+ "multitask",
31
+ "multilabel",
32
+ ]
23
33
 
24
34
 
25
35
  def get_loss_fn(
26
36
  task_type: str = "binary",
27
37
  training_mode: str | None = None,
28
38
  loss: str | nn.Module | None = None,
29
- **loss_kwargs
39
+ **loss_kwargs,
30
40
  ) -> nn.Module:
31
41
  """
32
42
  Get loss function based on task type and training mode.
33
-
43
+
34
44
  Examples:
35
45
  # Ranking task (binary classification)
36
46
  >>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
37
-
47
+
38
48
  # Match task with pointwise training
39
49
  >>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
40
-
50
+
41
51
  # Match task with pairwise training
42
52
  >>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
43
-
53
+
44
54
  # Match task with listwise training
45
55
  >>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
46
56
  """
@@ -57,7 +67,7 @@ def get_loss_fn(
57
67
  return CosineContrastiveLoss(**loss_kwargs)
58
68
  elif isinstance(loss, str):
59
69
  raise ValueError(f"Unsupported pointwise loss: {loss}")
60
-
70
+
61
71
  elif training_mode == "pairwise":
62
72
  if loss is None or loss == "bpr":
63
73
  return BPRLoss(**loss_kwargs)
@@ -67,7 +77,7 @@ def get_loss_fn(
67
77
  return TripletLoss(**loss_kwargs)
68
78
  elif isinstance(loss, str):
69
79
  raise ValueError(f"Unsupported pairwise loss: {loss}")
70
-
80
+
71
81
  elif training_mode == "listwise":
72
82
  if loss is None or loss == "sampled_softmax" or loss == "softmax":
73
83
  return SampledSoftmaxLoss(**loss_kwargs)
@@ -77,7 +87,7 @@ def get_loss_fn(
77
87
  return nn.CrossEntropyLoss(**loss_kwargs)
78
88
  elif isinstance(loss, str):
79
89
  raise ValueError(f"Unsupported listwise loss: {loss}")
80
-
90
+
81
91
  else:
82
92
  raise ValueError(f"Unknown training_mode: {training_mode}")
83
93
 
@@ -98,7 +108,7 @@ def get_loss_fn(
98
108
  return nn.CrossEntropyLoss(**loss_kwargs)
99
109
  elif isinstance(loss, str):
100
110
  raise ValueError(f"Unsupported multiclass loss: {loss}")
101
-
111
+
102
112
  elif task_type == "regression":
103
113
  if loss is None or loss == "mse":
104
114
  return nn.MSELoss(**loss_kwargs)
@@ -106,26 +116,24 @@ def get_loss_fn(
106
116
  return nn.L1Loss(**loss_kwargs)
107
117
  elif isinstance(loss, str):
108
118
  raise ValueError(f"Unsupported regression loss: {loss}")
109
-
119
+
110
120
  else:
111
121
  raise ValueError(f"Unsupported task_type: {task_type}")
112
-
122
+
113
123
  return loss
114
124
 
115
125
 
116
126
  def validate_training_mode(
117
- training_mode: str,
118
- support_training_modes: list[str],
119
- model_name: str = "Model"
127
+ training_mode: str, support_training_modes: list[str], model_name: str = "Model"
120
128
  ) -> None:
121
129
  """
122
130
  Validate that the requested training mode is supported by the model.
123
-
131
+
124
132
  Args:
125
133
  training_mode: Requested training mode
126
134
  support_training_modes: List of supported training modes
127
135
  model_name: Name of the model (for error messages)
128
-
136
+
129
137
  Raises:
130
138
  ValueError: If training mode is not supported
131
139
  """
@@ -13,149 +13,167 @@ from typing import Optional
13
13
 
14
14
 
15
15
  class BPRLoss(nn.Module):
16
- def __init__(self, reduction: str = 'mean'):
16
+ def __init__(self, reduction: str = "mean"):
17
17
  super(BPRLoss, self).__init__()
18
18
  self.reduction = reduction
19
-
19
+
20
20
  def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
21
21
  if neg_score.dim() == 2:
22
22
  pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
23
23
  diff = pos_score - neg_score # [batch_size, num_neg]
24
24
  loss = -torch.log(torch.sigmoid(diff) + 1e-8)
25
- if self.reduction == 'mean':
25
+ if self.reduction == "mean":
26
26
  return loss.mean()
27
- elif self.reduction == 'sum':
27
+ elif self.reduction == "sum":
28
28
  return loss.sum()
29
29
  else:
30
30
  return loss
31
31
  else:
32
32
  diff = pos_score - neg_score
33
33
  loss = -torch.log(torch.sigmoid(diff) + 1e-8)
34
- if self.reduction == 'mean':
34
+ if self.reduction == "mean":
35
35
  return loss.mean()
36
- elif self.reduction == 'sum':
36
+ elif self.reduction == "sum":
37
37
  return loss.sum()
38
38
  else:
39
39
  return loss
40
40
 
41
41
 
42
- class HingeLoss(nn.Module):
43
- def __init__(self, margin: float = 1.0, reduction: str = 'mean'):
42
+ class HingeLoss(nn.Module):
43
+ def __init__(self, margin: float = 1.0, reduction: str = "mean"):
44
44
  super(HingeLoss, self).__init__()
45
45
  self.margin = margin
46
46
  self.reduction = reduction
47
-
47
+
48
48
  def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
49
49
  if neg_score.dim() == 2:
50
50
  pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
51
-
51
+
52
52
  diff = pos_score - neg_score
53
53
  loss = torch.clamp(self.margin - diff, min=0)
54
-
55
- if self.reduction == 'mean':
54
+
55
+ if self.reduction == "mean":
56
56
  return loss.mean()
57
- elif self.reduction == 'sum':
57
+ elif self.reduction == "sum":
58
58
  return loss.sum()
59
59
  else:
60
60
  return loss
61
61
 
62
62
 
63
63
  class TripletLoss(nn.Module):
64
- def __init__(self, margin: float = 1.0, reduction: str = 'mean', distance: str = 'euclidean'):
64
+ def __init__(
65
+ self, margin: float = 1.0, reduction: str = "mean", distance: str = "euclidean"
66
+ ):
65
67
  super(TripletLoss, self).__init__()
66
68
  self.margin = margin
67
69
  self.reduction = reduction
68
70
  self.distance = distance
69
-
70
- def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
71
- if self.distance == 'euclidean':
71
+
72
+ def forward(
73
+ self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
74
+ ) -> torch.Tensor:
75
+ if self.distance == "euclidean":
72
76
  pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
73
-
77
+
74
78
  if negative.dim() == 3:
75
79
  anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
76
- neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1) # [batch_size, num_neg]
80
+ neg_dist = torch.sum(
81
+ (anchor_expanded - negative) ** 2, dim=-1
82
+ ) # [batch_size, num_neg]
77
83
  else:
78
84
  neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
79
-
85
+
80
86
  if neg_dist.dim() == 2:
81
87
  pos_dist = pos_dist.unsqueeze(1) # [batch_size, 1]
82
-
83
- elif self.distance == 'cosine':
88
+
89
+ elif self.distance == "cosine":
84
90
  pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
85
-
91
+
86
92
  if negative.dim() == 3:
87
93
  anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
88
94
  neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
89
95
  else:
90
96
  neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
91
-
97
+
92
98
  if neg_dist.dim() == 2:
93
99
  pos_dist = pos_dist.unsqueeze(1)
94
100
  else:
95
101
  raise ValueError(f"Unsupported distance: {self.distance}")
96
-
102
+
97
103
  loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
98
-
99
- if self.reduction == 'mean':
104
+
105
+ if self.reduction == "mean":
100
106
  return loss.mean()
101
- elif self.reduction == 'sum':
107
+ elif self.reduction == "sum":
102
108
  return loss.sum()
103
109
  else:
104
110
  return loss
105
111
 
106
112
 
107
113
  class SampledSoftmaxLoss(nn.Module):
108
- def __init__(self, reduction: str = 'mean'):
114
+ def __init__(self, reduction: str = "mean"):
109
115
  super(SampledSoftmaxLoss, self).__init__()
110
116
  self.reduction = reduction
111
-
112
- def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
117
+
118
+ def forward(
119
+ self, pos_logits: torch.Tensor, neg_logits: torch.Tensor
120
+ ) -> torch.Tensor:
113
121
  pos_logits = pos_logits.unsqueeze(1) # [batch_size, 1]
114
- all_logits = torch.cat([pos_logits, neg_logits], dim=1) # [batch_size, 1 + num_neg]
115
- targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
122
+ all_logits = torch.cat(
123
+ [pos_logits, neg_logits], dim=1
124
+ ) # [batch_size, 1 + num_neg]
125
+ targets = torch.zeros(
126
+ all_logits.size(0), dtype=torch.long, device=all_logits.device
127
+ )
116
128
  loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
117
-
129
+
118
130
  return loss
119
131
 
120
132
 
121
133
  class CosineContrastiveLoss(nn.Module):
122
- def __init__(self, margin: float = 0.5, reduction: str = 'mean'):
134
+ def __init__(self, margin: float = 0.5, reduction: str = "mean"):
123
135
  super(CosineContrastiveLoss, self).__init__()
124
136
  self.margin = margin
125
137
  self.reduction = reduction
126
-
127
- def forward(self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
138
+
139
+ def forward(
140
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor
141
+ ) -> torch.Tensor:
128
142
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
129
143
  pos_loss = (1 - similarity) * labels
130
144
 
131
145
  neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
132
-
146
+
133
147
  loss = pos_loss + neg_loss
134
-
135
- if self.reduction == 'mean':
148
+
149
+ if self.reduction == "mean":
136
150
  return loss.mean()
137
- elif self.reduction == 'sum':
151
+ elif self.reduction == "sum":
138
152
  return loss.sum()
139
153
  else:
140
154
  return loss
141
155
 
142
156
 
143
157
  class InfoNCELoss(nn.Module):
144
- def __init__(self, temperature: float = 0.07, reduction: str = 'mean'):
158
+ def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
145
159
  super(InfoNCELoss, self).__init__()
146
160
  self.temperature = temperature
147
161
  self.reduction = reduction
148
-
149
- def forward(self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor) -> torch.Tensor:
162
+
163
+ def forward(
164
+ self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor
165
+ ) -> torch.Tensor:
150
166
  pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature # [batch_size]
151
167
  pos_sim = pos_sim.unsqueeze(1) # [batch_size, 1]
152
168
  query_expanded = query.unsqueeze(1) # [batch_size, 1, dim]
153
- neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature # [batch_size, num_neg]
169
+ neg_sim = (
170
+ torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature
171
+ ) # [batch_size, num_neg]
154
172
  logits = torch.cat([pos_sim, neg_sim], dim=1) # [batch_size, 1 + num_neg]
155
173
  labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
156
174
 
157
175
  loss = F.cross_entropy(logits, labels, reduction=self.reduction)
158
-
176
+
159
177
  return loss
160
178
 
161
179
 
@@ -164,22 +182,23 @@ class ListNetLoss(nn.Module):
164
182
  ListNet loss using top-1 probability distribution
165
183
  Reference: Cao et al. Learning to Rank: From Pairwise Approach to Listwise Approach (ICML 2007)
166
184
  """
167
- def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
185
+
186
+ def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
168
187
  super(ListNetLoss, self).__init__()
169
188
  self.temperature = temperature
170
189
  self.reduction = reduction
171
-
190
+
172
191
  def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
173
192
  # Convert scores and labels to probability distributions
174
193
  pred_probs = F.softmax(scores / self.temperature, dim=1)
175
194
  true_probs = F.softmax(labels / self.temperature, dim=1)
176
-
195
+
177
196
  # Cross entropy between two distributions
178
197
  loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
179
-
180
- if self.reduction == 'mean':
198
+
199
+ if self.reduction == "mean":
181
200
  return loss.mean()
182
- elif self.reduction == 'sum':
201
+ elif self.reduction == "sum":
183
202
  return loss.sum()
184
203
  else:
185
204
  return loss
@@ -190,19 +209,24 @@ class ListMLELoss(nn.Module):
190
209
  ListMLE (Maximum Likelihood Estimation) loss
191
210
  Reference: Xia et al. Listwise approach to learning to rank: theory and algorithm (ICML 2008)
192
211
  """
193
- def __init__(self, reduction: str = 'mean'):
212
+
213
+ def __init__(self, reduction: str = "mean"):
194
214
  super(ListMLELoss, self).__init__()
195
215
  self.reduction = reduction
196
-
216
+
197
217
  def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
198
218
  # Sort by labels in descending order to get ground truth ranking
199
219
  sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
200
-
220
+
201
221
  # Reorder scores according to ground truth ranking
202
222
  batch_size, list_size = scores.shape
203
- batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
223
+ batch_indices = (
224
+ torch.arange(batch_size, device=scores.device)
225
+ .unsqueeze(1)
226
+ .expand(-1, list_size)
227
+ )
204
228
  sorted_scores = scores[batch_indices, sorted_indices]
205
-
229
+
206
230
  # Compute log likelihood
207
231
  # For each position, compute log(exp(score_i) / sum(exp(score_j) for j >= i))
208
232
  loss = torch.tensor(0.0, device=scores.device)
@@ -211,10 +235,10 @@ class ListMLELoss(nn.Module):
211
235
  remaining_scores = sorted_scores[:, i:]
212
236
  log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
213
237
  loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
214
-
215
- if self.reduction == 'mean':
238
+
239
+ if self.reduction == "mean":
216
240
  return loss / batch_size
217
- elif self.reduction == 'sum':
241
+ elif self.reduction == "sum":
218
242
  return loss
219
243
  else:
220
244
  return loss / batch_size
@@ -223,72 +247,81 @@ class ListMLELoss(nn.Module):
223
247
  class ApproxNDCGLoss(nn.Module):
224
248
  """
225
249
  Approximate NDCG loss for learning to rank
226
- Reference: Qin et al. A General Approximation Framework for Direct Optimization of
250
+ Reference: Qin et al. A General Approximation Framework for Direct Optimization of
227
251
  Information Retrieval Measures (Information Retrieval 2010)
228
252
  """
229
- def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
253
+
254
+ def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
230
255
  super(ApproxNDCGLoss, self).__init__()
231
256
  self.temperature = temperature
232
257
  self.reduction = reduction
233
-
258
+
234
259
  def _dcg(self, relevance: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
235
260
  if k is not None:
236
261
  relevance = relevance[:, :k]
237
-
262
+
238
263
  # DCG = sum(rel_i / log2(i + 2)) for i in range(list_size)
239
- positions = torch.arange(1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32)
264
+ positions = torch.arange(
265
+ 1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32
266
+ )
240
267
  discounts = torch.log2(positions + 1.0)
241
268
  dcg = torch.sum(relevance / discounts, dim=1)
242
-
269
+
243
270
  return dcg
244
-
245
- def forward(self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
271
+
272
+ def forward(
273
+ self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
274
+ ) -> torch.Tensor:
246
275
  """
247
276
  Args:
248
277
  scores: Predicted scores [batch_size, list_size]
249
278
  labels: Ground truth relevance labels [batch_size, list_size]
250
279
  k: Top-k items for NDCG@k (if None, use all items)
251
-
280
+
252
281
  Returns:
253
282
  Approximate NDCG loss (1 - NDCG)
254
283
  """
255
284
  batch_size = scores.size(0)
256
-
285
+
257
286
  # Use differentiable sorting approximation with softmax
258
287
  # Create pairwise comparison matrix
259
288
  scores_expanded = scores.unsqueeze(2) # [batch_size, list_size, 1]
260
- scores_tiled = scores.unsqueeze(1) # [batch_size, 1, list_size]
261
-
289
+ scores_tiled = scores.unsqueeze(1) # [batch_size, 1, list_size]
290
+
262
291
  # Compute pairwise probabilities using sigmoid
263
292
  pairwise_diff = (scores_expanded - scores_tiled) / self.temperature
264
- pairwise_probs = torch.sigmoid(pairwise_diff) # [batch_size, list_size, list_size]
265
-
293
+ pairwise_probs = torch.sigmoid(
294
+ pairwise_diff
295
+ ) # [batch_size, list_size, list_size]
296
+
266
297
  # Approximate ranking positions
267
298
  # ranking_probs[i, j] ≈ probability that item i is ranked at position j
268
299
  # We use softmax approximation for differentiable ranking
269
300
  ranking_weights = F.softmax(scores / self.temperature, dim=1)
270
-
301
+
271
302
  # Sort labels to get ideal DCG
272
303
  ideal_labels, _ = torch.sort(labels, descending=True, dim=1)
273
304
  ideal_dcg = self._dcg(ideal_labels, k)
274
-
305
+
275
306
  # Compute approximate DCG using soft ranking
276
307
  # Weight each item's relevance by its soft ranking position
277
- positions = torch.arange(1, scores.size(1) + 1, device=scores.device, dtype=torch.float32)
308
+ positions = torch.arange(
309
+ 1, scores.size(1) + 1, device=scores.device, dtype=torch.float32
310
+ )
278
311
  discounts = 1.0 / torch.log2(positions + 1.0)
279
-
312
+
280
313
  # Approximate DCG by weighting relevance with ranking probabilities
281
314
  approx_dcg = torch.sum(labels * ranking_weights * discounts, dim=1)
282
-
315
+
283
316
  # Normalize by ideal DCG to get NDCG
284
317
  ndcg = approx_dcg / (ideal_dcg + 1e-10)
285
-
318
+
286
319
  # Loss is 1 - NDCG (we want to maximize NDCG, so minimize 1 - NDCG)
287
320
  loss = 1.0 - ndcg
288
-
289
- if self.reduction == 'mean':
321
+
322
+ if self.reduction == "mean":
290
323
  return loss.mean()
291
- elif self.reduction == 'sum':
324
+ elif self.reduction == "sum":
292
325
  return loss.sum()
293
326
  else:
294
327
  return loss
@@ -5,9 +5,9 @@ from .mind import MIND
5
5
  from .sdm import SDM
6
6
 
7
7
  __all__ = [
8
- 'DSSM',
9
- 'DSSM_v2',
10
- 'YoutubeDNN',
11
- 'MIND',
12
- 'SDM',
8
+ "DSSM",
9
+ "DSSM_v2",
10
+ "YoutubeDNN",
11
+ "MIND",
12
+ "SDM",
13
13
  ]