nextrec 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +9 -10
  4. nextrec/basic/callback.py +0 -1
  5. nextrec/basic/dataloader.py +127 -168
  6. nextrec/basic/features.py +27 -24
  7. nextrec/basic/layers.py +159 -328
  8. nextrec/basic/loggers.py +37 -50
  9. nextrec/basic/metrics.py +147 -255
  10. nextrec/basic/model.py +462 -817
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +12 -16
  13. nextrec/data/preprocessor.py +252 -276
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +22 -30
  16. nextrec/loss/match_losses.py +83 -116
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +61 -70
  19. nextrec/models/match/dssm_v2.py +51 -61
  20. nextrec/models/match/mind.py +71 -89
  21. nextrec/models/match/sdm.py +81 -93
  22. nextrec/models/match/youtube_dnn.py +53 -62
  23. nextrec/models/multi_task/esmm.py +43 -49
  24. nextrec/models/multi_task/mmoe.py +56 -65
  25. nextrec/models/multi_task/ple.py +65 -92
  26. nextrec/models/multi_task/share_bottom.py +42 -48
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +30 -39
  29. nextrec/models/ranking/autoint.py +57 -70
  30. nextrec/models/ranking/dcn.py +35 -43
  31. nextrec/models/ranking/deepfm.py +28 -34
  32. nextrec/models/ranking/dien.py +79 -115
  33. nextrec/models/ranking/din.py +60 -84
  34. nextrec/models/ranking/fibinet.py +35 -51
  35. nextrec/models/ranking/fm.py +26 -28
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +31 -30
  38. nextrec/models/ranking/widedeep.py +31 -36
  39. nextrec/models/ranking/xdeepfm.py +39 -46
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +15 -23
  43. nextrec/utils/optimizer.py +10 -14
  44. {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/METADATA +16 -7
  45. nextrec-0.1.8.dist-info/RECORD +51 -0
  46. nextrec-0.1.4.dist-info/RECORD +0 -51
  47. {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py CHANGED
@@ -18,18 +18,18 @@ from nextrec.loss.loss_utils import (
18
18
 
19
19
  __all__ = [
20
20
  # Match losses
21
- "BPRLoss",
22
- "HingeLoss",
23
- "TripletLoss",
24
- "SampledSoftmaxLoss",
25
- "CosineContrastiveLoss",
26
- "InfoNCELoss",
21
+ 'BPRLoss',
22
+ 'HingeLoss',
23
+ 'TripletLoss',
24
+ 'SampledSoftmaxLoss',
25
+ 'CosineContrastiveLoss',
26
+ 'InfoNCELoss',
27
27
  # Listwise losses
28
- "ListNetLoss",
29
- "ListMLELoss",
30
- "ApproxNDCGLoss",
28
+ 'ListNetLoss',
29
+ 'ListMLELoss',
30
+ 'ApproxNDCGLoss',
31
31
  # Utilities
32
- "get_loss_fn",
33
- "validate_training_mode",
34
- "VALID_TASK_TYPES",
32
+ 'get_loss_fn',
33
+ 'validate_training_mode',
34
+ 'VALID_TASK_TYPES',
35
35
  ]
@@ -5,52 +5,42 @@ Date: create on 09/11/2025
5
5
  Author:
6
6
  Yang Zhou,zyaztec@gmail.com
7
7
  """
8
-
9
8
  import torch
10
9
  import torch.nn as nn
11
10
  from typing import Literal
12
11
 
13
12
  from nextrec.loss.match_losses import (
14
- BPRLoss,
15
- HingeLoss,
16
- TripletLoss,
13
+ BPRLoss,
14
+ HingeLoss,
15
+ TripletLoss,
17
16
  SampledSoftmaxLoss,
18
- CosineContrastiveLoss,
19
- InfoNCELoss,
17
+ CosineContrastiveLoss,
18
+ InfoNCELoss
20
19
  )
21
20
 
22
21
  # Valid task types for validation
23
- VALID_TASK_TYPES = [
24
- "binary",
25
- "multiclass",
26
- "regression",
27
- "multivariate_regression",
28
- "match",
29
- "ranking",
30
- "multitask",
31
- "multilabel",
32
- ]
22
+ VALID_TASK_TYPES = ['binary', 'multiclass', 'regression', 'multivariate_regression', 'match', 'ranking', 'multitask', 'multilabel']
33
23
 
34
24
 
35
25
  def get_loss_fn(
36
26
  task_type: str = "binary",
37
27
  training_mode: str | None = None,
38
28
  loss: str | nn.Module | None = None,
39
- **loss_kwargs,
29
+ **loss_kwargs
40
30
  ) -> nn.Module:
41
31
  """
42
32
  Get loss function based on task type and training mode.
43
-
33
+
44
34
  Examples:
45
35
  # Ranking task (binary classification)
46
36
  >>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
47
-
37
+
48
38
  # Match task with pointwise training
49
39
  >>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
50
-
40
+
51
41
  # Match task with pairwise training
52
42
  >>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
53
-
43
+
54
44
  # Match task with listwise training
55
45
  >>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
56
46
  """
@@ -67,7 +57,7 @@ def get_loss_fn(
67
57
  return CosineContrastiveLoss(**loss_kwargs)
68
58
  elif isinstance(loss, str):
69
59
  raise ValueError(f"Unsupported pointwise loss: {loss}")
70
-
60
+
71
61
  elif training_mode == "pairwise":
72
62
  if loss is None or loss == "bpr":
73
63
  return BPRLoss(**loss_kwargs)
@@ -77,7 +67,7 @@ def get_loss_fn(
77
67
  return TripletLoss(**loss_kwargs)
78
68
  elif isinstance(loss, str):
79
69
  raise ValueError(f"Unsupported pairwise loss: {loss}")
80
-
70
+
81
71
  elif training_mode == "listwise":
82
72
  if loss is None or loss == "sampled_softmax" or loss == "softmax":
83
73
  return SampledSoftmaxLoss(**loss_kwargs)
@@ -87,7 +77,7 @@ def get_loss_fn(
87
77
  return nn.CrossEntropyLoss(**loss_kwargs)
88
78
  elif isinstance(loss, str):
89
79
  raise ValueError(f"Unsupported listwise loss: {loss}")
90
-
80
+
91
81
  else:
92
82
  raise ValueError(f"Unknown training_mode: {training_mode}")
93
83
 
@@ -108,7 +98,7 @@ def get_loss_fn(
108
98
  return nn.CrossEntropyLoss(**loss_kwargs)
109
99
  elif isinstance(loss, str):
110
100
  raise ValueError(f"Unsupported multiclass loss: {loss}")
111
-
101
+
112
102
  elif task_type == "regression":
113
103
  if loss is None or loss == "mse":
114
104
  return nn.MSELoss(**loss_kwargs)
@@ -116,24 +106,26 @@ def get_loss_fn(
116
106
  return nn.L1Loss(**loss_kwargs)
117
107
  elif isinstance(loss, str):
118
108
  raise ValueError(f"Unsupported regression loss: {loss}")
119
-
109
+
120
110
  else:
121
111
  raise ValueError(f"Unsupported task_type: {task_type}")
122
-
112
+
123
113
  return loss
124
114
 
125
115
 
126
116
  def validate_training_mode(
127
- training_mode: str, support_training_modes: list[str], model_name: str = "Model"
117
+ training_mode: str,
118
+ support_training_modes: list[str],
119
+ model_name: str = "Model"
128
120
  ) -> None:
129
121
  """
130
122
  Validate that the requested training mode is supported by the model.
131
-
123
+
132
124
  Args:
133
125
  training_mode: Requested training mode
134
126
  support_training_modes: List of supported training modes
135
127
  model_name: Name of the model (for error messages)
136
-
128
+
137
129
  Raises:
138
130
  ValueError: If training mode is not supported
139
131
  """
@@ -13,167 +13,149 @@ from typing import Optional
13
13
 
14
14
 
15
15
  class BPRLoss(nn.Module):
16
- def __init__(self, reduction: str = "mean"):
16
+ def __init__(self, reduction: str = 'mean'):
17
17
  super(BPRLoss, self).__init__()
18
18
  self.reduction = reduction
19
-
19
+
20
20
  def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
21
21
  if neg_score.dim() == 2:
22
22
  pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
23
23
  diff = pos_score - neg_score # [batch_size, num_neg]
24
24
  loss = -torch.log(torch.sigmoid(diff) + 1e-8)
25
- if self.reduction == "mean":
25
+ if self.reduction == 'mean':
26
26
  return loss.mean()
27
- elif self.reduction == "sum":
27
+ elif self.reduction == 'sum':
28
28
  return loss.sum()
29
29
  else:
30
30
  return loss
31
31
  else:
32
32
  diff = pos_score - neg_score
33
33
  loss = -torch.log(torch.sigmoid(diff) + 1e-8)
34
- if self.reduction == "mean":
34
+ if self.reduction == 'mean':
35
35
  return loss.mean()
36
- elif self.reduction == "sum":
36
+ elif self.reduction == 'sum':
37
37
  return loss.sum()
38
38
  else:
39
39
  return loss
40
40
 
41
41
 
42
- class HingeLoss(nn.Module):
43
- def __init__(self, margin: float = 1.0, reduction: str = "mean"):
42
+ class HingeLoss(nn.Module):
43
+ def __init__(self, margin: float = 1.0, reduction: str = 'mean'):
44
44
  super(HingeLoss, self).__init__()
45
45
  self.margin = margin
46
46
  self.reduction = reduction
47
-
47
+
48
48
  def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
49
49
  if neg_score.dim() == 2:
50
50
  pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
51
-
51
+
52
52
  diff = pos_score - neg_score
53
53
  loss = torch.clamp(self.margin - diff, min=0)
54
-
55
- if self.reduction == "mean":
54
+
55
+ if self.reduction == 'mean':
56
56
  return loss.mean()
57
- elif self.reduction == "sum":
57
+ elif self.reduction == 'sum':
58
58
  return loss.sum()
59
59
  else:
60
60
  return loss
61
61
 
62
62
 
63
63
  class TripletLoss(nn.Module):
64
- def __init__(
65
- self, margin: float = 1.0, reduction: str = "mean", distance: str = "euclidean"
66
- ):
64
+ def __init__(self, margin: float = 1.0, reduction: str = 'mean', distance: str = 'euclidean'):
67
65
  super(TripletLoss, self).__init__()
68
66
  self.margin = margin
69
67
  self.reduction = reduction
70
68
  self.distance = distance
71
-
72
- def forward(
73
- self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
74
- ) -> torch.Tensor:
75
- if self.distance == "euclidean":
69
+
70
+ def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
71
+ if self.distance == 'euclidean':
76
72
  pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
77
-
73
+
78
74
  if negative.dim() == 3:
79
75
  anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
80
- neg_dist = torch.sum(
81
- (anchor_expanded - negative) ** 2, dim=-1
82
- ) # [batch_size, num_neg]
76
+ neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1) # [batch_size, num_neg]
83
77
  else:
84
78
  neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
85
-
79
+
86
80
  if neg_dist.dim() == 2:
87
81
  pos_dist = pos_dist.unsqueeze(1) # [batch_size, 1]
88
-
89
- elif self.distance == "cosine":
82
+
83
+ elif self.distance == 'cosine':
90
84
  pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
91
-
85
+
92
86
  if negative.dim() == 3:
93
87
  anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
94
88
  neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
95
89
  else:
96
90
  neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
97
-
91
+
98
92
  if neg_dist.dim() == 2:
99
93
  pos_dist = pos_dist.unsqueeze(1)
100
94
  else:
101
95
  raise ValueError(f"Unsupported distance: {self.distance}")
102
-
96
+
103
97
  loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
104
-
105
- if self.reduction == "mean":
98
+
99
+ if self.reduction == 'mean':
106
100
  return loss.mean()
107
- elif self.reduction == "sum":
101
+ elif self.reduction == 'sum':
108
102
  return loss.sum()
109
103
  else:
110
104
  return loss
111
105
 
112
106
 
113
107
  class SampledSoftmaxLoss(nn.Module):
114
- def __init__(self, reduction: str = "mean"):
108
+ def __init__(self, reduction: str = 'mean'):
115
109
  super(SampledSoftmaxLoss, self).__init__()
116
110
  self.reduction = reduction
117
-
118
- def forward(
119
- self, pos_logits: torch.Tensor, neg_logits: torch.Tensor
120
- ) -> torch.Tensor:
111
+
112
+ def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
121
113
  pos_logits = pos_logits.unsqueeze(1) # [batch_size, 1]
122
- all_logits = torch.cat(
123
- [pos_logits, neg_logits], dim=1
124
- ) # [batch_size, 1 + num_neg]
125
- targets = torch.zeros(
126
- all_logits.size(0), dtype=torch.long, device=all_logits.device
127
- )
114
+ all_logits = torch.cat([pos_logits, neg_logits], dim=1) # [batch_size, 1 + num_neg]
115
+ targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
128
116
  loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
129
-
117
+
130
118
  return loss
131
119
 
132
120
 
133
121
  class CosineContrastiveLoss(nn.Module):
134
- def __init__(self, margin: float = 0.5, reduction: str = "mean"):
122
+ def __init__(self, margin: float = 0.5, reduction: str = 'mean'):
135
123
  super(CosineContrastiveLoss, self).__init__()
136
124
  self.margin = margin
137
125
  self.reduction = reduction
138
-
139
- def forward(
140
- self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor
141
- ) -> torch.Tensor:
126
+
127
+ def forward(self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
142
128
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
143
129
  pos_loss = (1 - similarity) * labels
144
130
 
145
131
  neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
146
-
132
+
147
133
  loss = pos_loss + neg_loss
148
-
149
- if self.reduction == "mean":
134
+
135
+ if self.reduction == 'mean':
150
136
  return loss.mean()
151
- elif self.reduction == "sum":
137
+ elif self.reduction == 'sum':
152
138
  return loss.sum()
153
139
  else:
154
140
  return loss
155
141
 
156
142
 
157
143
  class InfoNCELoss(nn.Module):
158
- def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
144
+ def __init__(self, temperature: float = 0.07, reduction: str = 'mean'):
159
145
  super(InfoNCELoss, self).__init__()
160
146
  self.temperature = temperature
161
147
  self.reduction = reduction
162
-
163
- def forward(
164
- self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor
165
- ) -> torch.Tensor:
148
+
149
+ def forward(self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor) -> torch.Tensor:
166
150
  pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature # [batch_size]
167
151
  pos_sim = pos_sim.unsqueeze(1) # [batch_size, 1]
168
152
  query_expanded = query.unsqueeze(1) # [batch_size, 1, dim]
169
- neg_sim = (
170
- torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature
171
- ) # [batch_size, num_neg]
153
+ neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature # [batch_size, num_neg]
172
154
  logits = torch.cat([pos_sim, neg_sim], dim=1) # [batch_size, 1 + num_neg]
173
155
  labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
174
156
 
175
157
  loss = F.cross_entropy(logits, labels, reduction=self.reduction)
176
-
158
+
177
159
  return loss
178
160
 
179
161
 
@@ -182,23 +164,22 @@ class ListNetLoss(nn.Module):
182
164
  ListNet loss using top-1 probability distribution
183
165
  Reference: Cao et al. Learning to Rank: From Pairwise Approach to Listwise Approach (ICML 2007)
184
166
  """
185
-
186
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
167
+ def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
187
168
  super(ListNetLoss, self).__init__()
188
169
  self.temperature = temperature
189
170
  self.reduction = reduction
190
-
171
+
191
172
  def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
192
173
  # Convert scores and labels to probability distributions
193
174
  pred_probs = F.softmax(scores / self.temperature, dim=1)
194
175
  true_probs = F.softmax(labels / self.temperature, dim=1)
195
-
176
+
196
177
  # Cross entropy between two distributions
197
178
  loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
198
-
199
- if self.reduction == "mean":
179
+
180
+ if self.reduction == 'mean':
200
181
  return loss.mean()
201
- elif self.reduction == "sum":
182
+ elif self.reduction == 'sum':
202
183
  return loss.sum()
203
184
  else:
204
185
  return loss
@@ -209,24 +190,19 @@ class ListMLELoss(nn.Module):
209
190
  ListMLE (Maximum Likelihood Estimation) loss
210
191
  Reference: Xia et al. Listwise approach to learning to rank: theory and algorithm (ICML 2008)
211
192
  """
212
-
213
- def __init__(self, reduction: str = "mean"):
193
+ def __init__(self, reduction: str = 'mean'):
214
194
  super(ListMLELoss, self).__init__()
215
195
  self.reduction = reduction
216
-
196
+
217
197
  def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
218
198
  # Sort by labels in descending order to get ground truth ranking
219
199
  sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
220
-
200
+
221
201
  # Reorder scores according to ground truth ranking
222
202
  batch_size, list_size = scores.shape
223
- batch_indices = (
224
- torch.arange(batch_size, device=scores.device)
225
- .unsqueeze(1)
226
- .expand(-1, list_size)
227
- )
203
+ batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
228
204
  sorted_scores = scores[batch_indices, sorted_indices]
229
-
205
+
230
206
  # Compute log likelihood
231
207
  # For each position, compute log(exp(score_i) / sum(exp(score_j) for j >= i))
232
208
  loss = torch.tensor(0.0, device=scores.device)
@@ -235,10 +211,10 @@ class ListMLELoss(nn.Module):
235
211
  remaining_scores = sorted_scores[:, i:]
236
212
  log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
237
213
  loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
238
-
239
- if self.reduction == "mean":
214
+
215
+ if self.reduction == 'mean':
240
216
  return loss / batch_size
241
- elif self.reduction == "sum":
217
+ elif self.reduction == 'sum':
242
218
  return loss
243
219
  else:
244
220
  return loss / batch_size
@@ -247,81 +223,72 @@ class ListMLELoss(nn.Module):
247
223
  class ApproxNDCGLoss(nn.Module):
248
224
  """
249
225
  Approximate NDCG loss for learning to rank
250
- Reference: Qin et al. A General Approximation Framework for Direct Optimization of
226
+ Reference: Qin et al. A General Approximation Framework for Direct Optimization of
251
227
  Information Retrieval Measures (Information Retrieval 2010)
252
228
  """
253
-
254
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
229
+ def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
255
230
  super(ApproxNDCGLoss, self).__init__()
256
231
  self.temperature = temperature
257
232
  self.reduction = reduction
258
-
233
+
259
234
  def _dcg(self, relevance: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
260
235
  if k is not None:
261
236
  relevance = relevance[:, :k]
262
-
237
+
263
238
  # DCG = sum(rel_i / log2(i + 2)) for i in range(list_size)
264
- positions = torch.arange(
265
- 1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32
266
- )
239
+ positions = torch.arange(1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32)
267
240
  discounts = torch.log2(positions + 1.0)
268
241
  dcg = torch.sum(relevance / discounts, dim=1)
269
-
242
+
270
243
  return dcg
271
-
272
- def forward(
273
- self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
274
- ) -> torch.Tensor:
244
+
245
+ def forward(self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
275
246
  """
276
247
  Args:
277
248
  scores: Predicted scores [batch_size, list_size]
278
249
  labels: Ground truth relevance labels [batch_size, list_size]
279
250
  k: Top-k items for NDCG@k (if None, use all items)
280
-
251
+
281
252
  Returns:
282
253
  Approximate NDCG loss (1 - NDCG)
283
254
  """
284
255
  batch_size = scores.size(0)
285
-
256
+
286
257
  # Use differentiable sorting approximation with softmax
287
258
  # Create pairwise comparison matrix
288
259
  scores_expanded = scores.unsqueeze(2) # [batch_size, list_size, 1]
289
- scores_tiled = scores.unsqueeze(1) # [batch_size, 1, list_size]
290
-
260
+ scores_tiled = scores.unsqueeze(1) # [batch_size, 1, list_size]
261
+
291
262
  # Compute pairwise probabilities using sigmoid
292
263
  pairwise_diff = (scores_expanded - scores_tiled) / self.temperature
293
- pairwise_probs = torch.sigmoid(
294
- pairwise_diff
295
- ) # [batch_size, list_size, list_size]
296
-
264
+ pairwise_probs = torch.sigmoid(pairwise_diff) # [batch_size, list_size, list_size]
265
+
297
266
  # Approximate ranking positions
298
267
  # ranking_probs[i, j] ≈ probability that item i is ranked at position j
299
268
  # We use softmax approximation for differentiable ranking
300
269
  ranking_weights = F.softmax(scores / self.temperature, dim=1)
301
-
270
+
302
271
  # Sort labels to get ideal DCG
303
272
  ideal_labels, _ = torch.sort(labels, descending=True, dim=1)
304
273
  ideal_dcg = self._dcg(ideal_labels, k)
305
-
274
+
306
275
  # Compute approximate DCG using soft ranking
307
276
  # Weight each item's relevance by its soft ranking position
308
- positions = torch.arange(
309
- 1, scores.size(1) + 1, device=scores.device, dtype=torch.float32
310
- )
277
+ positions = torch.arange(1, scores.size(1) + 1, device=scores.device, dtype=torch.float32)
311
278
  discounts = 1.0 / torch.log2(positions + 1.0)
312
-
279
+
313
280
  # Approximate DCG by weighting relevance with ranking probabilities
314
281
  approx_dcg = torch.sum(labels * ranking_weights * discounts, dim=1)
315
-
282
+
316
283
  # Normalize by ideal DCG to get NDCG
317
284
  ndcg = approx_dcg / (ideal_dcg + 1e-10)
318
-
285
+
319
286
  # Loss is 1 - NDCG (we want to maximize NDCG, so minimize 1 - NDCG)
320
287
  loss = 1.0 - ndcg
321
-
322
- if self.reduction == "mean":
288
+
289
+ if self.reduction == 'mean':
323
290
  return loss.mean()
324
- elif self.reduction == "sum":
291
+ elif self.reduction == 'sum':
325
292
  return loss.sum()
326
293
  else:
327
294
  return loss
@@ -5,9 +5,9 @@ from .mind import MIND
5
5
  from .sdm import SDM
6
6
 
7
7
  __all__ = [
8
- "DSSM",
9
- "DSSM_v2",
10
- "YoutubeDNN",
11
- "MIND",
12
- "SDM",
8
+ 'DSSM',
9
+ 'DSSM_v2',
10
+ 'YoutubeDNN',
11
+ 'MIND',
12
+ 'SDM',
13
13
  ]