nextrec 0.1.11__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +1 -2
  3. nextrec/basic/callback.py +1 -2
  4. nextrec/basic/features.py +39 -8
  5. nextrec/basic/layers.py +3 -4
  6. nextrec/basic/loggers.py +15 -10
  7. nextrec/basic/metrics.py +1 -2
  8. nextrec/basic/model.py +160 -125
  9. nextrec/basic/session.py +150 -0
  10. nextrec/data/__init__.py +13 -2
  11. nextrec/data/data_utils.py +74 -22
  12. nextrec/data/dataloader.py +513 -0
  13. nextrec/data/preprocessor.py +494 -134
  14. nextrec/loss/__init__.py +31 -24
  15. nextrec/loss/listwise.py +164 -0
  16. nextrec/loss/loss_utils.py +133 -106
  17. nextrec/loss/pairwise.py +105 -0
  18. nextrec/loss/pointwise.py +198 -0
  19. nextrec/models/match/dssm.py +26 -17
  20. nextrec/models/match/dssm_v2.py +20 -2
  21. nextrec/models/match/mind.py +18 -3
  22. nextrec/models/match/sdm.py +17 -2
  23. nextrec/models/match/youtube_dnn.py +23 -10
  24. nextrec/models/multi_task/esmm.py +8 -8
  25. nextrec/models/multi_task/mmoe.py +8 -8
  26. nextrec/models/multi_task/ple.py +8 -8
  27. nextrec/models/multi_task/share_bottom.py +8 -8
  28. nextrec/models/ranking/__init__.py +8 -0
  29. nextrec/models/ranking/afm.py +5 -4
  30. nextrec/models/ranking/autoint.py +6 -4
  31. nextrec/models/ranking/dcn.py +6 -4
  32. nextrec/models/ranking/deepfm.py +5 -4
  33. nextrec/models/ranking/dien.py +6 -4
  34. nextrec/models/ranking/din.py +6 -4
  35. nextrec/models/ranking/fibinet.py +6 -4
  36. nextrec/models/ranking/fm.py +6 -4
  37. nextrec/models/ranking/masknet.py +6 -4
  38. nextrec/models/ranking/pnn.py +6 -4
  39. nextrec/models/ranking/widedeep.py +6 -4
  40. nextrec/models/ranking/xdeepfm.py +6 -4
  41. nextrec/utils/__init__.py +7 -11
  42. nextrec/utils/embedding.py +2 -4
  43. nextrec/utils/initializer.py +4 -5
  44. nextrec/utils/optimizer.py +7 -8
  45. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/METADATA +3 -3
  46. nextrec-0.2.2.dist-info/RECORD +53 -0
  47. nextrec/basic/dataloader.py +0 -447
  48. nextrec/loss/match_losses.py +0 -294
  49. nextrec/utils/common.py +0 -14
  50. nextrec-0.1.11.dist-info/RECORD +0 -51
  51. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/WHEEL +0 -0
  52. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py CHANGED
@@ -1,35 +1,42 @@
1
- from nextrec.loss.match_losses import (
2
- BPRLoss,
3
- HingeLoss,
4
- TripletLoss,
5
- SampledSoftmaxLoss,
6
- CosineContrastiveLoss,
1
+ from nextrec.loss.listwise import (
2
+ ApproxNDCGLoss,
7
3
  InfoNCELoss,
8
- ListNetLoss,
9
4
  ListMLELoss,
10
- ApproxNDCGLoss,
5
+ ListNetLoss,
6
+ SampledSoftmaxLoss,
7
+ )
8
+ from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
9
+ from nextrec.loss.pointwise import (
10
+ ClassBalancedFocalLoss,
11
+ CosineContrastiveLoss,
12
+ FocalLoss,
13
+ WeightedBCELoss,
11
14
  )
12
-
13
15
  from nextrec.loss.loss_utils import (
14
16
  get_loss_fn,
15
- validate_training_mode,
17
+ get_loss_kwargs,
16
18
  VALID_TASK_TYPES,
17
19
  )
18
20
 
19
21
  __all__ = [
20
- # Match losses
21
- 'BPRLoss',
22
- 'HingeLoss',
23
- 'TripletLoss',
24
- 'SampledSoftmaxLoss',
25
- 'CosineContrastiveLoss',
26
- 'InfoNCELoss',
27
- # Listwise losses
28
- 'ListNetLoss',
29
- 'ListMLELoss',
30
- 'ApproxNDCGLoss',
22
+ # Pointwise
23
+ "CosineContrastiveLoss",
24
+ "WeightedBCELoss",
25
+ "FocalLoss",
26
+ "ClassBalancedFocalLoss",
27
+ # Pairwise
28
+ "BPRLoss",
29
+ "HingeLoss",
30
+ "TripletLoss",
31
+ # Listwise
32
+ "SampledSoftmaxLoss",
33
+ "InfoNCELoss",
34
+ "ListNetLoss",
35
+ "ListMLELoss",
36
+ "ApproxNDCGLoss",
31
37
  # Utilities
32
- 'get_loss_fn',
33
- 'validate_training_mode',
34
- 'VALID_TASK_TYPES',
38
+ "get_loss_fn",
39
+ "get_loss_kwargs",
40
+ "validate_training_mode",
41
+ "VALID_TASK_TYPES",
35
42
  ]
@@ -0,0 +1,164 @@
1
+ """
2
+ Listwise loss functions for ranking and contrastive training.
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class SampledSoftmaxLoss(nn.Module):
13
+ """
14
+ Softmax over one positive and multiple sampled negatives.
15
+ """
16
+
17
+ def __init__(self, reduction: str = "mean"):
18
+ super().__init__()
19
+ self.reduction = reduction
20
+
21
+ def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
22
+ pos_logits = pos_logits.unsqueeze(1)
23
+ all_logits = torch.cat([pos_logits, neg_logits], dim=1)
24
+ targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
25
+ loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
26
+ return loss
27
+
28
+
29
+ class InfoNCELoss(nn.Module):
30
+ """
31
+ InfoNCE loss for contrastive learning with one positive and many negatives.
32
+ """
33
+
34
+ def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
35
+ super().__init__()
36
+ self.temperature = temperature
37
+ self.reduction = reduction
38
+
39
+ def forward(
40
+ self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor
41
+ ) -> torch.Tensor:
42
+ pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature
43
+ pos_sim = pos_sim.unsqueeze(1)
44
+ query_expanded = query.unsqueeze(1)
45
+ neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature
46
+ logits = torch.cat([pos_sim, neg_sim], dim=1)
47
+ labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
48
+ loss = F.cross_entropy(logits, labels, reduction=self.reduction)
49
+ return loss
50
+
51
+
52
+ class ListNetLoss(nn.Module):
53
+ """
54
+ ListNet loss using top-1 probability distribution.
55
+ Reference: Cao et al. (ICML 2007)
56
+ """
57
+
58
+ def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
59
+ super().__init__()
60
+ self.temperature = temperature
61
+ self.reduction = reduction
62
+
63
+ def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
64
+ pred_probs = F.softmax(scores / self.temperature, dim=1)
65
+ true_probs = F.softmax(labels / self.temperature, dim=1)
66
+ loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
67
+
68
+ if self.reduction == "mean":
69
+ return loss.mean()
70
+ if self.reduction == "sum":
71
+ return loss.sum()
72
+ return loss
73
+
74
+
75
+ class ListMLELoss(nn.Module):
76
+ """
77
+ ListMLE (Maximum Likelihood Estimation) loss.
78
+ Reference: Xia et al. (ICML 2008)
79
+ """
80
+
81
+ def __init__(self, reduction: str = "mean"):
82
+ super().__init__()
83
+ self.reduction = reduction
84
+
85
+ def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
86
+ sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
87
+ batch_size, list_size = scores.shape
88
+ batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
89
+ sorted_scores = scores[batch_indices, sorted_indices]
90
+
91
+ loss = torch.tensor(0.0, device=scores.device)
92
+ for i in range(list_size):
93
+ remaining_scores = sorted_scores[:, i:]
94
+ log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
95
+ loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
96
+
97
+ if self.reduction == "mean":
98
+ return loss / batch_size
99
+ if self.reduction == "sum":
100
+ return loss
101
+ return loss / batch_size
102
+
103
+
104
+ class ApproxNDCGLoss(nn.Module):
105
+ """
106
+ Approximate NDCG loss for learning to rank.
107
+ Reference: Qin et al. (2010)
108
+ """
109
+
110
+ def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
111
+ super().__init__()
112
+ self.temperature = temperature
113
+ self.reduction = reduction
114
+
115
+ def _ideal_dcg(self, labels: torch.Tensor, k: Optional[int]) -> torch.Tensor:
116
+ # labels: [B, L]
117
+ sorted_labels, _ = torch.sort(labels, dim=1, descending=True)
118
+ if k is not None:
119
+ sorted_labels = sorted_labels[:, :k]
120
+
121
+ gains = torch.pow(2.0, sorted_labels) - 1.0 # [B, K]
122
+ positions = torch.arange(
123
+ 1, gains.size(1) + 1, device=gains.device, dtype=torch.float32
124
+ ) # [K]
125
+ discounts = 1.0 / torch.log2(positions + 1.0) # [K]
126
+ ideal_dcg = torch.sum(gains * discounts, dim=1) # [B]
127
+ return ideal_dcg
128
+
129
+ def forward(
130
+ self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
131
+ ) -> torch.Tensor:
132
+ """
133
+ scores: [B, L]
134
+ labels: [B, L]
135
+ """
136
+ batch_size, list_size = scores.shape
137
+ device = scores.device
138
+
139
+ # diff[b, i, j] = (s_j - s_i) / T
140
+ scores_i = scores.unsqueeze(2) # [B, L, 1]
141
+ scores_j = scores.unsqueeze(1) # [B, 1, L]
142
+ diff = (scores_j - scores_i) / self.temperature # [B, L, L]
143
+
144
+ P_ji = torch.sigmoid(diff) # [B, L, L]
145
+ eye = torch.eye(list_size, device=device).unsqueeze(0) # [1, L, L]
146
+ P_ji = P_ji * (1.0 - eye)
147
+
148
+ exp_rank = 1.0 + P_ji.sum(dim=-1) # [B, L]
149
+
150
+ discounts = 1.0 / torch.log2(exp_rank + 1.0) # [B, L]
151
+
152
+ gains = torch.pow(2.0, labels) - 1.0 # [B, L]
153
+ approx_dcg = torch.sum(gains * discounts, dim=1) # [B]
154
+
155
+ ideal_dcg = self._ideal_dcg(labels, k) # [B]
156
+
157
+ ndcg = approx_dcg / (ideal_dcg + 1e-10) # [B]
158
+ loss = 1.0 - ndcg
159
+
160
+ if self.reduction == "mean":
161
+ return loss.mean()
162
+ if self.reduction == "sum":
163
+ return loss.sum()
164
+ return loss
@@ -1,136 +1,163 @@
1
1
  """
2
- Loss utilities for NextRec
3
-
4
- Date: create on 09/11/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
2
+ Loss utilities for NextRec.
7
3
  """
8
- import torch
9
- import torch.nn as nn
4
+
10
5
  from typing import Literal
11
6
 
12
- from nextrec.loss.match_losses import (
13
- BPRLoss,
14
- HingeLoss,
15
- TripletLoss,
7
+ import torch.nn as nn
8
+
9
+ from nextrec.loss.listwise import (
10
+ ApproxNDCGLoss,
11
+ InfoNCELoss,
12
+ ListMLELoss,
13
+ ListNetLoss,
16
14
  SampledSoftmaxLoss,
17
- CosineContrastiveLoss,
18
- InfoNCELoss
15
+ )
16
+ from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
17
+ from nextrec.loss.pointwise import (
18
+ ClassBalancedFocalLoss,
19
+ CosineContrastiveLoss,
20
+ FocalLoss,
21
+ WeightedBCELoss,
19
22
  )
20
23
 
21
24
  # Valid task types for validation
22
- VALID_TASK_TYPES = ['binary', 'multiclass', 'regression', 'multivariate_regression', 'match', 'ranking', 'multitask', 'multilabel']
25
+ VALID_TASK_TYPES = [
26
+ "binary",
27
+ "multiclass",
28
+ "regression",
29
+ "multivariate_regression",
30
+ "match",
31
+ "ranking",
32
+ "multitask",
33
+ "multilabel",
34
+ ]
23
35
 
24
36
 
25
37
  def get_loss_fn(
26
38
  task_type: str = "binary",
27
39
  training_mode: str | None = None,
28
40
  loss: str | nn.Module | None = None,
29
- **loss_kwargs
41
+ **loss_kwargs,
30
42
  ) -> nn.Module:
31
43
  """
32
44
  Get loss function based on task type and training mode.
33
-
34
- Examples:
35
- # Ranking task (binary classification)
36
- >>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
37
-
38
- # Match task with pointwise training
39
- >>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
40
-
41
- # Match task with pairwise training
42
- >>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
43
-
44
- # Match task with listwise training
45
- >>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
46
45
  """
47
46
 
48
47
  if isinstance(loss, nn.Module):
49
48
  return loss
50
49
 
50
+ # Common mappings
51
51
  if task_type == "match":
52
- if training_mode == "pointwise":
53
- # Pointwise training uses binary cross entropy
54
- if loss is None or loss == "bce" or loss == "binary_crossentropy":
55
- return nn.BCELoss(**loss_kwargs)
56
- elif loss == "cosine_contrastive":
57
- return CosineContrastiveLoss(**loss_kwargs)
58
- elif isinstance(loss, str):
59
- raise ValueError(f"Unsupported pointwise loss: {loss}")
60
-
61
- elif training_mode == "pairwise":
62
- if loss is None or loss == "bpr":
63
- return BPRLoss(**loss_kwargs)
64
- elif loss == "hinge":
65
- return HingeLoss(**loss_kwargs)
66
- elif loss == "triplet":
67
- return TripletLoss(**loss_kwargs)
68
- elif isinstance(loss, str):
69
- raise ValueError(f"Unsupported pairwise loss: {loss}")
70
-
71
- elif training_mode == "listwise":
72
- if loss is None or loss == "sampled_softmax" or loss == "softmax":
73
- return SampledSoftmaxLoss(**loss_kwargs)
74
- elif loss == "infonce":
75
- return InfoNCELoss(**loss_kwargs)
76
- elif loss == "crossentropy" or loss == "ce":
77
- return nn.CrossEntropyLoss(**loss_kwargs)
78
- elif isinstance(loss, str):
79
- raise ValueError(f"Unsupported listwise loss: {loss}")
80
-
81
- else:
82
- raise ValueError(f"Unknown training_mode: {training_mode}")
83
-
84
- elif task_type in ["ranking", "multitask", "binary"]:
85
- if loss is None or loss == "bce" or loss == "binary_crossentropy":
86
- return nn.BCELoss(**loss_kwargs)
87
- elif loss == "mse":
88
- return nn.MSELoss(**loss_kwargs)
89
- elif loss == "mae":
90
- return nn.L1Loss(**loss_kwargs)
91
- elif loss == "crossentropy" or loss == "ce":
92
- return nn.CrossEntropyLoss(**loss_kwargs)
93
- elif isinstance(loss, str):
94
- raise ValueError(f"Unsupported loss function: {loss}")
52
+ return _get_match_loss(training_mode, loss, **loss_kwargs)
95
53
 
96
- elif task_type == "multiclass":
97
- if loss is None or loss == "crossentropy" or loss == "ce":
98
- return nn.CrossEntropyLoss(**loss_kwargs)
99
- elif isinstance(loss, str):
100
- raise ValueError(f"Unsupported multiclass loss: {loss}")
101
-
102
- elif task_type == "regression":
54
+ if task_type in ["ranking", "multitask", "binary", "multilabel"]:
55
+ return _get_classification_loss(loss, **loss_kwargs)
56
+
57
+ if task_type == "multiclass":
58
+ return _get_multiclass_loss(loss, **loss_kwargs)
59
+
60
+ if task_type == "regression":
103
61
  if loss is None or loss == "mse":
104
62
  return nn.MSELoss(**loss_kwargs)
105
- elif loss == "mae":
63
+ if loss == "mae":
106
64
  return nn.L1Loss(**loss_kwargs)
107
- elif isinstance(loss, str):
65
+ if isinstance(loss, str):
108
66
  raise ValueError(f"Unsupported regression loss: {loss}")
109
-
110
- else:
111
- raise ValueError(f"Unsupported task_type: {task_type}")
112
-
113
- return loss
114
-
115
-
116
- def validate_training_mode(
117
- training_mode: str,
118
- support_training_modes: list[str],
119
- model_name: str = "Model"
120
- ) -> None:
67
+
68
+ raise ValueError(f"Unsupported task_type: {task_type}")
69
+
70
+
71
+ def _get_match_loss(training_mode: str | None, loss: str | None, **loss_kwargs) -> nn.Module:
72
+ if training_mode == "pointwise":
73
+ if loss is None or loss in {"bce", "binary_crossentropy"}:
74
+ return nn.BCELoss(**loss_kwargs)
75
+ if loss == "weighted_bce":
76
+ return WeightedBCELoss(**loss_kwargs)
77
+ if loss == "focal":
78
+ return FocalLoss(**loss_kwargs)
79
+ if loss == "class_balanced_focal":
80
+ return _build_cb_focal(loss_kwargs)
81
+ if loss == "cosine_contrastive":
82
+ return CosineContrastiveLoss(**loss_kwargs)
83
+ if isinstance(loss, str):
84
+ raise ValueError(f"Unsupported pointwise loss: {loss}")
85
+
86
+ if training_mode == "pairwise":
87
+ if loss is None or loss == "bpr":
88
+ return BPRLoss(**loss_kwargs)
89
+ if loss == "hinge":
90
+ return HingeLoss(**loss_kwargs)
91
+ if loss == "triplet":
92
+ return TripletLoss(**loss_kwargs)
93
+ if isinstance(loss, str):
94
+ raise ValueError(f"Unsupported pairwise loss: {loss}")
95
+
96
+ if training_mode == "listwise":
97
+ if loss is None or loss in {"sampled_softmax", "softmax"}:
98
+ return SampledSoftmaxLoss(**loss_kwargs)
99
+ if loss == "infonce":
100
+ return InfoNCELoss(**loss_kwargs)
101
+ if loss == "listnet":
102
+ return ListNetLoss(**loss_kwargs)
103
+ if loss == "listmle":
104
+ return ListMLELoss(**loss_kwargs)
105
+ if loss == "approx_ndcg":
106
+ return ApproxNDCGLoss(**loss_kwargs)
107
+ if loss in {"crossentropy", "ce"}:
108
+ return nn.CrossEntropyLoss(**loss_kwargs)
109
+ if isinstance(loss, str):
110
+ raise ValueError(f"Unsupported listwise loss: {loss}")
111
+
112
+ raise ValueError(f"Unknown training_mode: {training_mode}")
113
+
114
+
115
+ def _get_classification_loss(loss: str | None, **loss_kwargs) -> nn.Module:
116
+ if loss is None or loss in {"bce", "binary_crossentropy"}:
117
+ return nn.BCELoss(**loss_kwargs)
118
+ if loss == "weighted_bce":
119
+ return WeightedBCELoss(**loss_kwargs)
120
+ if loss == "focal":
121
+ return FocalLoss(**loss_kwargs)
122
+ if loss == "class_balanced_focal":
123
+ return _build_cb_focal(loss_kwargs)
124
+ if loss == "mse":
125
+ return nn.MSELoss(**loss_kwargs)
126
+ if loss == "mae":
127
+ return nn.L1Loss(**loss_kwargs)
128
+ if loss in {"crossentropy", "ce"}:
129
+ return nn.CrossEntropyLoss(**loss_kwargs)
130
+ if isinstance(loss, str):
131
+ raise ValueError(f"Unsupported loss function: {loss}")
132
+ raise ValueError("Loss must be specified for classification task.")
133
+
134
+
135
+ def _get_multiclass_loss(loss: str | None, **loss_kwargs) -> nn.Module:
136
+ if loss is None or loss in {"crossentropy", "ce"}:
137
+ return nn.CrossEntropyLoss(**loss_kwargs)
138
+ if loss == "focal":
139
+ return FocalLoss(**loss_kwargs)
140
+ if loss == "class_balanced_focal":
141
+ return _build_cb_focal(loss_kwargs)
142
+ if isinstance(loss, str):
143
+ raise ValueError(f"Unsupported multiclass loss: {loss}")
144
+ raise ValueError("Loss must be specified for multiclass task.")
145
+
146
+
147
+ def _build_cb_focal(loss_kwargs: dict) -> ClassBalancedFocalLoss:
148
+ if "class_counts" not in loss_kwargs:
149
+ raise ValueError("class_balanced_focal requires `class_counts` argument.")
150
+ return ClassBalancedFocalLoss(**loss_kwargs)
151
+
152
+
153
+ def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
121
154
  """
122
- Validate that the requested training mode is supported by the model.
123
-
124
- Args:
125
- training_mode: Requested training mode
126
- support_training_modes: List of supported training modes
127
- model_name: Name of the model (for error messages)
128
-
129
- Raises:
130
- ValueError: If training mode is not supported
155
+ Resolve per-task loss kwargs from a dict or list of dicts.
131
156
  """
132
- if training_mode not in support_training_modes:
133
- raise ValueError(
134
- f"{model_name} does not support training_mode='{training_mode}'. "
135
- f"Supported modes: {support_training_modes}"
136
- )
157
+ if loss_params is None:
158
+ return {}
159
+ if isinstance(loss_params, list):
160
+ if index < len(loss_params) and loss_params[index] is not None:
161
+ return loss_params[index]
162
+ return {}
163
+ return loss_params
@@ -0,0 +1,105 @@
1
+ """
2
+ Pairwise loss functions for learning-to-rank and matching tasks.
3
+ """
4
+
5
+ from typing import Literal
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class BPRLoss(nn.Module):
13
+ """
14
+ Bayesian Personalized Ranking loss with support for multiple negatives.
15
+ """
16
+
17
+ def __init__(self, reduction: str = "mean"):
18
+ super().__init__()
19
+ self.reduction = reduction
20
+
21
+ def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
22
+ if neg_score.dim() == 2:
23
+ pos_score = pos_score.unsqueeze(1)
24
+ diff = pos_score - neg_score
25
+ else:
26
+ diff = pos_score - neg_score
27
+
28
+ loss = -torch.log(torch.sigmoid(diff) + 1e-8)
29
+ if self.reduction == "mean":
30
+ return loss.mean()
31
+ if self.reduction == "sum":
32
+ return loss.sum()
33
+ return loss
34
+
35
+
36
+ class HingeLoss(nn.Module):
37
+ """
38
+ Hinge loss for pairwise ranking.
39
+ """
40
+
41
+ def __init__(self, margin: float = 1.0, reduction: str = "mean"):
42
+ super().__init__()
43
+ self.margin = margin
44
+ self.reduction = reduction
45
+
46
+ def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
47
+ if neg_score.dim() == 2:
48
+ pos_score = pos_score.unsqueeze(1)
49
+
50
+ diff = pos_score - neg_score
51
+ loss = torch.clamp(self.margin - diff, min=0)
52
+
53
+ if self.reduction == "mean":
54
+ return loss.mean()
55
+ if self.reduction == "sum":
56
+ return loss.sum()
57
+ return loss
58
+
59
+
60
+ class TripletLoss(nn.Module):
61
+ """
62
+ Triplet margin loss with cosine or euclidean distance.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ margin: float = 1.0,
68
+ reduction: str = "mean",
69
+ distance: Literal["euclidean", "cosine"] = "euclidean",
70
+ ):
71
+ super().__init__()
72
+ self.margin = margin
73
+ self.reduction = reduction
74
+ self.distance = distance
75
+
76
+ def forward(
77
+ self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
78
+ ) -> torch.Tensor:
79
+ if self.distance == "euclidean":
80
+ pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
81
+ if negative.dim() == 3:
82
+ anchor_expanded = anchor.unsqueeze(1)
83
+ neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1)
84
+ else:
85
+ neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
86
+ if neg_dist.dim() == 2:
87
+ pos_dist = pos_dist.unsqueeze(1)
88
+ elif self.distance == "cosine":
89
+ pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
90
+ if negative.dim() == 3:
91
+ anchor_expanded = anchor.unsqueeze(1)
92
+ neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
93
+ else:
94
+ neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
95
+ if neg_dist.dim() == 2:
96
+ pos_dist = pos_dist.unsqueeze(1)
97
+ else:
98
+ raise ValueError(f"Unsupported distance: {self.distance}")
99
+
100
+ loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
101
+ if self.reduction == "mean":
102
+ return loss.mean()
103
+ if self.reduction == "sum":
104
+ return loss.sum()
105
+ return loss