nextrec 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/features.py +2 -1
  3. nextrec/basic/layers.py +2 -2
  4. nextrec/basic/model.py +82 -49
  5. nextrec/data/__init__.py +2 -4
  6. nextrec/data/dataloader.py +3 -3
  7. nextrec/data/preprocessor.py +2 -2
  8. nextrec/loss/__init__.py +31 -24
  9. nextrec/loss/listwise.py +162 -4
  10. nextrec/loss/loss_utils.py +133 -105
  11. nextrec/loss/pairwise.py +103 -4
  12. nextrec/loss/pointwise.py +196 -4
  13. nextrec/models/match/dssm.py +24 -15
  14. nextrec/models/match/dssm_v2.py +18 -0
  15. nextrec/models/match/mind.py +16 -1
  16. nextrec/models/match/sdm.py +15 -0
  17. nextrec/models/match/youtube_dnn.py +21 -8
  18. nextrec/models/multi_task/esmm.py +5 -5
  19. nextrec/models/multi_task/mmoe.py +5 -5
  20. nextrec/models/multi_task/ple.py +5 -5
  21. nextrec/models/multi_task/share_bottom.py +5 -5
  22. nextrec/models/ranking/__init__.py +8 -0
  23. nextrec/models/ranking/afm.py +3 -1
  24. nextrec/models/ranking/autoint.py +3 -1
  25. nextrec/models/ranking/dcn.py +3 -1
  26. nextrec/models/ranking/deepfm.py +3 -1
  27. nextrec/models/ranking/dien.py +3 -1
  28. nextrec/models/ranking/din.py +3 -1
  29. nextrec/models/ranking/fibinet.py +3 -1
  30. nextrec/models/ranking/fm.py +3 -1
  31. nextrec/models/ranking/masknet.py +3 -1
  32. nextrec/models/ranking/pnn.py +3 -1
  33. nextrec/models/ranking/widedeep.py +3 -1
  34. nextrec/models/ranking/xdeepfm.py +3 -1
  35. nextrec/utils/__init__.py +5 -5
  36. nextrec/utils/initializer.py +3 -3
  37. nextrec/utils/optimizer.py +6 -6
  38. {nextrec-0.2.1.dist-info → nextrec-0.2.3.dist-info}/METADATA +2 -2
  39. nextrec-0.2.3.dist-info/RECORD +53 -0
  40. nextrec/loss/match_losses.py +0 -293
  41. nextrec-0.2.1.dist-info/RECORD +0 -54
  42. {nextrec-0.2.1.dist-info → nextrec-0.2.3.dist-info}/WHEEL +0 -0
  43. {nextrec-0.2.1.dist-info → nextrec-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,135 +1,163 @@
1
1
  """
2
- Loss utilities for NextRec
3
-
4
- Date: create on 09/11/2025
5
- Author: Yang Zhou,zyaztec@gmail.com
2
+ Loss utilities for NextRec.
6
3
  """
7
- import torch
8
- import torch.nn as nn
4
+
9
5
  from typing import Literal
10
6
 
11
- from nextrec.loss.match_losses import (
12
- BPRLoss,
13
- HingeLoss,
14
- TripletLoss,
7
+ import torch.nn as nn
8
+
9
+ from nextrec.loss.listwise import (
10
+ ApproxNDCGLoss,
11
+ InfoNCELoss,
12
+ ListMLELoss,
13
+ ListNetLoss,
15
14
  SampledSoftmaxLoss,
16
- CosineContrastiveLoss,
17
- InfoNCELoss
15
+ )
16
+ from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
17
+ from nextrec.loss.pointwise import (
18
+ ClassBalancedFocalLoss,
19
+ CosineContrastiveLoss,
20
+ FocalLoss,
21
+ WeightedBCELoss,
18
22
  )
19
23
 
20
24
  # Valid task types for validation
21
- VALID_TASK_TYPES = ['binary', 'multiclass', 'regression', 'multivariate_regression', 'match', 'ranking', 'multitask', 'multilabel']
25
+ VALID_TASK_TYPES = [
26
+ "binary",
27
+ "multiclass",
28
+ "regression",
29
+ "multivariate_regression",
30
+ "match",
31
+ "ranking",
32
+ "multitask",
33
+ "multilabel",
34
+ ]
22
35
 
23
36
 
24
37
  def get_loss_fn(
25
38
  task_type: str = "binary",
26
39
  training_mode: str | None = None,
27
40
  loss: str | nn.Module | None = None,
28
- **loss_kwargs
41
+ **loss_kwargs,
29
42
  ) -> nn.Module:
30
43
  """
31
44
  Get loss function based on task type and training mode.
32
-
33
- Examples:
34
- # Ranking task (binary classification)
35
- >>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
36
-
37
- # Match task with pointwise training
38
- >>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
39
-
40
- # Match task with pairwise training
41
- >>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
42
-
43
- # Match task with listwise training
44
- >>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
45
45
  """
46
46
 
47
47
  if isinstance(loss, nn.Module):
48
48
  return loss
49
49
 
50
+ # Common mappings
50
51
  if task_type == "match":
51
- if training_mode == "pointwise":
52
- # Pointwise training uses binary cross entropy
53
- if loss is None or loss == "bce" or loss == "binary_crossentropy":
54
- return nn.BCELoss(**loss_kwargs)
55
- elif loss == "cosine_contrastive":
56
- return CosineContrastiveLoss(**loss_kwargs)
57
- elif isinstance(loss, str):
58
- raise ValueError(f"Unsupported pointwise loss: {loss}")
59
-
60
- elif training_mode == "pairwise":
61
- if loss is None or loss == "bpr":
62
- return BPRLoss(**loss_kwargs)
63
- elif loss == "hinge":
64
- return HingeLoss(**loss_kwargs)
65
- elif loss == "triplet":
66
- return TripletLoss(**loss_kwargs)
67
- elif isinstance(loss, str):
68
- raise ValueError(f"Unsupported pairwise loss: {loss}")
69
-
70
- elif training_mode == "listwise":
71
- if loss is None or loss == "sampled_softmax" or loss == "softmax":
72
- return SampledSoftmaxLoss(**loss_kwargs)
73
- elif loss == "infonce":
74
- return InfoNCELoss(**loss_kwargs)
75
- elif loss == "crossentropy" or loss == "ce":
76
- return nn.CrossEntropyLoss(**loss_kwargs)
77
- elif isinstance(loss, str):
78
- raise ValueError(f"Unsupported listwise loss: {loss}")
79
-
80
- else:
81
- raise ValueError(f"Unknown training_mode: {training_mode}")
82
-
83
- elif task_type in ["ranking", "multitask", "binary"]:
84
- if loss is None or loss == "bce" or loss == "binary_crossentropy":
85
- return nn.BCELoss(**loss_kwargs)
86
- elif loss == "mse":
87
- return nn.MSELoss(**loss_kwargs)
88
- elif loss == "mae":
89
- return nn.L1Loss(**loss_kwargs)
90
- elif loss == "crossentropy" or loss == "ce":
91
- return nn.CrossEntropyLoss(**loss_kwargs)
92
- elif isinstance(loss, str):
93
- raise ValueError(f"Unsupported loss function: {loss}")
52
+ return _get_match_loss(training_mode, loss, **loss_kwargs)
94
53
 
95
- elif task_type == "multiclass":
96
- if loss is None or loss == "crossentropy" or loss == "ce":
97
- return nn.CrossEntropyLoss(**loss_kwargs)
98
- elif isinstance(loss, str):
99
- raise ValueError(f"Unsupported multiclass loss: {loss}")
100
-
101
- elif task_type == "regression":
54
+ if task_type in ["ranking", "multitask", "binary", "multilabel"]:
55
+ return _get_classification_loss(loss, **loss_kwargs)
56
+
57
+ if task_type == "multiclass":
58
+ return _get_multiclass_loss(loss, **loss_kwargs)
59
+
60
+ if task_type == "regression":
102
61
  if loss is None or loss == "mse":
103
62
  return nn.MSELoss(**loss_kwargs)
104
- elif loss == "mae":
63
+ if loss == "mae":
105
64
  return nn.L1Loss(**loss_kwargs)
106
- elif isinstance(loss, str):
65
+ if isinstance(loss, str):
107
66
  raise ValueError(f"Unsupported regression loss: {loss}")
108
-
109
- else:
110
- raise ValueError(f"Unsupported task_type: {task_type}")
111
-
112
- return loss
113
-
114
-
115
- def validate_training_mode(
116
- training_mode: str,
117
- support_training_modes: list[str],
118
- model_name: str = "Model"
119
- ) -> None:
67
+
68
+ raise ValueError(f"Unsupported task_type: {task_type}")
69
+
70
+
71
+ def _get_match_loss(training_mode: str | None, loss: str | None, **loss_kwargs) -> nn.Module:
72
+ if training_mode == "pointwise":
73
+ if loss is None or loss in {"bce", "binary_crossentropy"}:
74
+ return nn.BCELoss(**loss_kwargs)
75
+ if loss == "weighted_bce":
76
+ return WeightedBCELoss(**loss_kwargs)
77
+ if loss == "focal":
78
+ return FocalLoss(**loss_kwargs)
79
+ if loss == "class_balanced_focal":
80
+ return _build_cb_focal(loss_kwargs)
81
+ if loss == "cosine_contrastive":
82
+ return CosineContrastiveLoss(**loss_kwargs)
83
+ if isinstance(loss, str):
84
+ raise ValueError(f"Unsupported pointwise loss: {loss}")
85
+
86
+ if training_mode == "pairwise":
87
+ if loss is None or loss == "bpr":
88
+ return BPRLoss(**loss_kwargs)
89
+ if loss == "hinge":
90
+ return HingeLoss(**loss_kwargs)
91
+ if loss == "triplet":
92
+ return TripletLoss(**loss_kwargs)
93
+ if isinstance(loss, str):
94
+ raise ValueError(f"Unsupported pairwise loss: {loss}")
95
+
96
+ if training_mode == "listwise":
97
+ if loss is None or loss in {"sampled_softmax", "softmax"}:
98
+ return SampledSoftmaxLoss(**loss_kwargs)
99
+ if loss == "infonce":
100
+ return InfoNCELoss(**loss_kwargs)
101
+ if loss == "listnet":
102
+ return ListNetLoss(**loss_kwargs)
103
+ if loss == "listmle":
104
+ return ListMLELoss(**loss_kwargs)
105
+ if loss == "approx_ndcg":
106
+ return ApproxNDCGLoss(**loss_kwargs)
107
+ if loss in {"crossentropy", "ce"}:
108
+ return nn.CrossEntropyLoss(**loss_kwargs)
109
+ if isinstance(loss, str):
110
+ raise ValueError(f"Unsupported listwise loss: {loss}")
111
+
112
+ raise ValueError(f"Unknown training_mode: {training_mode}")
113
+
114
+
115
+ def _get_classification_loss(loss: str | None, **loss_kwargs) -> nn.Module:
116
+ if loss is None or loss in {"bce", "binary_crossentropy"}:
117
+ return nn.BCELoss(**loss_kwargs)
118
+ if loss == "weighted_bce":
119
+ return WeightedBCELoss(**loss_kwargs)
120
+ if loss == "focal":
121
+ return FocalLoss(**loss_kwargs)
122
+ if loss == "class_balanced_focal":
123
+ return _build_cb_focal(loss_kwargs)
124
+ if loss == "mse":
125
+ return nn.MSELoss(**loss_kwargs)
126
+ if loss == "mae":
127
+ return nn.L1Loss(**loss_kwargs)
128
+ if loss in {"crossentropy", "ce"}:
129
+ return nn.CrossEntropyLoss(**loss_kwargs)
130
+ if isinstance(loss, str):
131
+ raise ValueError(f"Unsupported loss function: {loss}")
132
+ raise ValueError("Loss must be specified for classification task.")
133
+
134
+
135
+ def _get_multiclass_loss(loss: str | None, **loss_kwargs) -> nn.Module:
136
+ if loss is None or loss in {"crossentropy", "ce"}:
137
+ return nn.CrossEntropyLoss(**loss_kwargs)
138
+ if loss == "focal":
139
+ return FocalLoss(**loss_kwargs)
140
+ if loss == "class_balanced_focal":
141
+ return _build_cb_focal(loss_kwargs)
142
+ if isinstance(loss, str):
143
+ raise ValueError(f"Unsupported multiclass loss: {loss}")
144
+ raise ValueError("Loss must be specified for multiclass task.")
145
+
146
+
147
+ def _build_cb_focal(loss_kwargs: dict) -> ClassBalancedFocalLoss:
148
+ if "class_counts" not in loss_kwargs:
149
+ raise ValueError("class_balanced_focal requires `class_counts` argument.")
150
+ return ClassBalancedFocalLoss(**loss_kwargs)
151
+
152
+
153
+ def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
120
154
  """
121
- Validate that the requested training mode is supported by the model.
122
-
123
- Args:
124
- training_mode: Requested training mode
125
- support_training_modes: List of supported training modes
126
- model_name: Name of the model (for error messages)
127
-
128
- Raises:
129
- ValueError: If training mode is not supported
155
+ Resolve per-task loss kwargs from a dict or list of dicts.
130
156
  """
131
- if training_mode not in support_training_modes:
132
- raise ValueError(
133
- f"{model_name} does not support training_mode='{training_mode}'. "
134
- f"Supported modes: {support_training_modes}"
135
- )
157
+ if loss_params is None:
158
+ return {}
159
+ if isinstance(loss_params, list):
160
+ if index < len(loss_params) and loss_params[index] is not None:
161
+ return loss_params[index]
162
+ return {}
163
+ return loss_params
nextrec/loss/pairwise.py CHANGED
@@ -1,6 +1,105 @@
1
1
  """
2
- Loss functions for pairwise tasks
3
-
4
- Date: create on 22/11/2025
5
- Author: Yang Zhou,zyaztec@gmail.com
2
+ Pairwise loss functions for learning-to-rank and matching tasks.
6
3
  """
4
+
5
+ from typing import Literal
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class BPRLoss(nn.Module):
13
+ """
14
+ Bayesian Personalized Ranking loss with support for multiple negatives.
15
+ """
16
+
17
+ def __init__(self, reduction: str = "mean"):
18
+ super().__init__()
19
+ self.reduction = reduction
20
+
21
+ def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
22
+ if neg_score.dim() == 2:
23
+ pos_score = pos_score.unsqueeze(1)
24
+ diff = pos_score - neg_score
25
+ else:
26
+ diff = pos_score - neg_score
27
+
28
+ loss = -torch.log(torch.sigmoid(diff) + 1e-8)
29
+ if self.reduction == "mean":
30
+ return loss.mean()
31
+ if self.reduction == "sum":
32
+ return loss.sum()
33
+ return loss
34
+
35
+
36
+ class HingeLoss(nn.Module):
37
+ """
38
+ Hinge loss for pairwise ranking.
39
+ """
40
+
41
+ def __init__(self, margin: float = 1.0, reduction: str = "mean"):
42
+ super().__init__()
43
+ self.margin = margin
44
+ self.reduction = reduction
45
+
46
+ def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
47
+ if neg_score.dim() == 2:
48
+ pos_score = pos_score.unsqueeze(1)
49
+
50
+ diff = pos_score - neg_score
51
+ loss = torch.clamp(self.margin - diff, min=0)
52
+
53
+ if self.reduction == "mean":
54
+ return loss.mean()
55
+ if self.reduction == "sum":
56
+ return loss.sum()
57
+ return loss
58
+
59
+
60
+ class TripletLoss(nn.Module):
61
+ """
62
+ Triplet margin loss with cosine or euclidean distance.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ margin: float = 1.0,
68
+ reduction: str = "mean",
69
+ distance: Literal["euclidean", "cosine"] = "euclidean",
70
+ ):
71
+ super().__init__()
72
+ self.margin = margin
73
+ self.reduction = reduction
74
+ self.distance = distance
75
+
76
+ def forward(
77
+ self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
78
+ ) -> torch.Tensor:
79
+ if self.distance == "euclidean":
80
+ pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
81
+ if negative.dim() == 3:
82
+ anchor_expanded = anchor.unsqueeze(1)
83
+ neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1)
84
+ else:
85
+ neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
86
+ if neg_dist.dim() == 2:
87
+ pos_dist = pos_dist.unsqueeze(1)
88
+ elif self.distance == "cosine":
89
+ pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
90
+ if negative.dim() == 3:
91
+ anchor_expanded = anchor.unsqueeze(1)
92
+ neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
93
+ else:
94
+ neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
95
+ if neg_dist.dim() == 2:
96
+ pos_dist = pos_dist.unsqueeze(1)
97
+ else:
98
+ raise ValueError(f"Unsupported distance: {self.distance}")
99
+
100
+ loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
101
+ if self.reduction == "mean":
102
+ return loss.mean()
103
+ if self.reduction == "sum":
104
+ return loss.sum()
105
+ return loss
nextrec/loss/pointwise.py CHANGED
@@ -1,6 +1,198 @@
1
1
  """
2
- Loss functions for pointwise tasks
3
-
4
- Date: create on 22/11/2025
5
- Author: Yang Zhou,zyaztec@gmail.com
2
+ Pointwise loss functions, including imbalance-aware variants.
6
3
  """
4
+
5
+ from typing import Optional, Sequence
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class CosineContrastiveLoss(nn.Module):
13
+ """
14
+ Contrastive loss using cosine similarity for positive/negative pairs.
15
+ """
16
+
17
+ def __init__(self, margin: float = 0.5, reduction: str = "mean"):
18
+ super().__init__()
19
+ self.margin = margin
20
+ self.reduction = reduction
21
+
22
+ def forward(
23
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor
24
+ ) -> torch.Tensor:
25
+ labels = labels.float()
26
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
27
+ pos_loss = torch.clamp(self.margin - similarity, min=0) * labels
28
+ neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
29
+ loss = pos_loss + neg_loss
30
+
31
+ if self.reduction == "mean":
32
+ return loss.mean()
33
+ if self.reduction == "sum":
34
+ return loss.sum()
35
+ return loss
36
+
37
+
38
+ class WeightedBCELoss(nn.Module):
39
+ """
40
+ Binary cross entropy with controllable positive class weight.
41
+ Supports probability or logit inputs via `logits` flag.
42
+ If `auto_balance=True` and `pos_weight` is None, the positive weight is
43
+ computed from the batch as (#neg / #pos) for stable imbalance handling.
44
+ """
45
+ def __init__(
46
+ self,
47
+ pos_weight: float | torch.Tensor | None = None,
48
+ reduction: str = "mean",
49
+ logits: bool = False,
50
+ auto_balance: bool = False,
51
+ ):
52
+ super().__init__()
53
+ self.reduction = reduction
54
+ self.logits = logits
55
+ self.auto_balance = auto_balance
56
+
57
+ if pos_weight is not None:
58
+ self.register_buffer(
59
+ "pos_weight",
60
+ torch.as_tensor(pos_weight, dtype=torch.float32),
61
+ )
62
+ else:
63
+ self.pos_weight = None
64
+
65
+ def _resolve_pos_weight(self, labels: torch.Tensor) -> torch.Tensor:
66
+ if self.pos_weight is not None:
67
+ return self.pos_weight.to(device=labels.device)
68
+
69
+ if not self.auto_balance:
70
+ return torch.tensor(1.0, device=labels.device, dtype=labels.dtype)
71
+
72
+ labels_float = labels.float()
73
+ pos = torch.clamp(labels_float.sum(), min=1.0)
74
+ neg = torch.clamp(labels_float.numel() - labels_float.sum(), min=1.0)
75
+ return (neg / pos).to(device=labels.device, dtype=labels.dtype)
76
+
77
+ def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
78
+ labels = labels.float()
79
+ current_pos_weight = self._resolve_pos_weight(labels)
80
+ current_pos_weight = current_pos_weight.to(inputs.dtype)
81
+
82
+ if self.logits:
83
+ loss = F.binary_cross_entropy_with_logits(
84
+ inputs, labels, pos_weight=current_pos_weight, reduction="none"
85
+ )
86
+ else:
87
+ probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
88
+ base_loss = F.binary_cross_entropy(probs, labels, reduction="none")
89
+ loss = torch.where(labels == 1, base_loss * current_pos_weight, base_loss)
90
+
91
+ if self.reduction == "mean":
92
+ return loss.mean()
93
+ elif self.reduction == "sum":
94
+ return loss.sum()
95
+ else:
96
+ return loss
97
+
98
+
99
+ class FocalLoss(nn.Module):
100
+ """
101
+ Standard focal loss for binary or multi-class classification.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ gamma: float = 2.0,
107
+ alpha: Optional[float | Sequence[float] | torch.Tensor] = None,
108
+ reduction: str = "mean",
109
+ logits: bool = False,
110
+ ):
111
+ super().__init__()
112
+ self.gamma = gamma
113
+ self.reduction = reduction
114
+ self.logits = logits
115
+ self.alpha = alpha
116
+
117
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
118
+ # Multi-class case
119
+ if inputs.dim() > 1 and inputs.size(1) > 1:
120
+ log_probs = F.log_softmax(inputs, dim=1)
121
+ probs = log_probs.exp()
122
+ targets_one_hot = F.one_hot(targets.long(), num_classes=inputs.size(1)).float()
123
+
124
+ alpha = self._get_alpha(inputs)
125
+ alpha_factor = targets_one_hot * alpha
126
+ focal_weight = (1.0 - probs) ** self.gamma
127
+ loss = torch.sum(alpha_factor * focal_weight * (-log_probs), dim=1)
128
+ else:
129
+ targets = targets.float()
130
+ if self.logits:
131
+ ce_loss = F.binary_cross_entropy_with_logits(
132
+ inputs, targets, reduction="none"
133
+ )
134
+ probs = torch.sigmoid(inputs)
135
+ else:
136
+ ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
137
+ probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
138
+
139
+ p_t = probs * targets + (1 - probs) * (1 - targets)
140
+ alpha_factor = self._get_binary_alpha(targets, inputs.device)
141
+ focal_weight = (1.0 - p_t) ** self.gamma
142
+ loss = alpha_factor * focal_weight * ce_loss
143
+
144
+ if self.reduction == "mean":
145
+ return loss.mean()
146
+ if self.reduction == "sum":
147
+ return loss.sum()
148
+ return loss
149
+
150
+ def _get_alpha(self, inputs: torch.Tensor) -> torch.Tensor:
151
+ if self.alpha is None:
152
+ return torch.ones_like(inputs)
153
+ if isinstance(self.alpha, torch.Tensor):
154
+ return self.alpha.to(inputs.device)
155
+ alpha_tensor = torch.tensor(self.alpha, device=inputs.device, dtype=inputs.dtype)
156
+ return alpha_tensor
157
+
158
+ def _get_binary_alpha(self, targets: torch.Tensor, device: torch.device) -> torch.Tensor:
159
+ if self.alpha is None:
160
+ return torch.ones_like(targets)
161
+ if isinstance(self.alpha, (float, int)):
162
+ return torch.where(targets == 1, self.alpha, 1 - float(self.alpha)).to(device)
163
+ alpha_tensor = torch.tensor(self.alpha, device=device, dtype=targets.dtype)
164
+ return torch.where(targets == 1, alpha_tensor, 1 - alpha_tensor)
165
+
166
+
167
+ class ClassBalancedFocalLoss(nn.Module):
168
+ """
169
+ Focal loss weighted by effective number of samples per class.
170
+ Reference: "Class-Balanced Loss Based on Effective Number of Samples"
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ class_counts: Sequence[int] | torch.Tensor,
176
+ beta: float = 0.9999,
177
+ gamma: float = 2.0,
178
+ reduction: str = "mean",
179
+ ):
180
+ super().__init__()
181
+ self.gamma = gamma
182
+ self.reduction = reduction
183
+ class_counts = torch.as_tensor(class_counts, dtype=torch.float32)
184
+ effective_num = 1.0 - torch.pow(beta, class_counts)
185
+ weights = (1.0 - beta) / (effective_num + 1e-12)
186
+ weights = weights / weights.sum() * len(weights)
187
+ self.register_buffer("class_weights", weights)
188
+
189
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
190
+ focal = FocalLoss(
191
+ gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True
192
+ )
193
+ loss = focal(inputs, targets)
194
+ if self.reduction == "mean":
195
+ return loss.mean()
196
+ if self.reduction == "sum":
197
+ return loss.sum()
198
+ return loss
@@ -19,7 +19,8 @@ class DSSM(BaseMatchModel):
19
19
  """
20
20
  Deep Structured Semantic Model
21
21
 
22
- 双塔模型,分别对useritem特征编码为embedding,通过余弦相似度或点积计算匹配分数
22
+ Dual-tower model that encodes user and item features separately and
23
+ computes similarity via cosine or dot product.
23
24
  """
24
25
 
25
26
  @property
@@ -48,6 +49,12 @@ class DSSM(BaseMatchModel):
48
49
  embedding_l2_reg: float = 0.0,
49
50
  dense_l2_reg: float = 0.0,
50
51
  early_stop_patience: int = 20,
52
+ optimizer: str | torch.optim.Optimizer = "adam",
53
+ optimizer_params: dict | None = None,
54
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
55
+ scheduler_params: dict | None = None,
56
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
57
+ loss_params: dict | list[dict] | None = None,
51
58
  **kwargs):
52
59
 
53
60
  super(DSSM, self).__init__(
@@ -86,7 +93,7 @@ class DSSM(BaseMatchModel):
86
93
  if len(user_features) > 0:
87
94
  self.user_embedding = EmbeddingLayer(user_features)
88
95
 
89
- # 计算user tower输入维度
96
+ # Compute user tower input dimension
90
97
  user_input_dim = 0
91
98
  for feat in user_dense_features or []:
92
99
  user_input_dim += 1
@@ -117,7 +124,7 @@ class DSSM(BaseMatchModel):
117
124
  if len(item_features) > 0:
118
125
  self.item_embedding = EmbeddingLayer(item_features)
119
126
 
120
- # 计算item tower输入维度
127
+ # Compute item tower input dimension
121
128
  item_input_dim = 0
122
129
  for feat in item_dense_features or []:
123
130
  item_input_dim += 1
@@ -136,7 +143,6 @@ class DSSM(BaseMatchModel):
136
143
  activation=dnn_activation
137
144
  )
138
145
 
139
- # 注册正则化权重
140
146
  self._register_regularization_weights(
141
147
  embedding_attr='user_embedding',
142
148
  include_modules=['user_dnn']
@@ -146,28 +152,33 @@ class DSSM(BaseMatchModel):
146
152
  include_modules=['item_dnn']
147
153
  )
148
154
 
155
+ if optimizer_params is None:
156
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
157
+
149
158
  self.compile(
150
- optimizer="adam",
151
- optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
159
+ optimizer=optimizer,
160
+ optimizer_params=optimizer_params,
161
+ scheduler=scheduler,
162
+ scheduler_params=scheduler_params,
163
+ loss=loss,
164
+ loss_params=loss_params,
152
165
  )
153
166
 
154
167
  self.to(device)
155
168
 
156
169
  def user_tower(self, user_input: dict) -> torch.Tensor:
157
170
  """
158
- User tower: user特征编码为embedding
171
+ User tower encodes user features into embeddings.
159
172
 
160
173
  Args:
161
- user_input: user特征字典
174
+ user_input: user feature dict
162
175
 
163
176
  Returns:
164
177
  user_emb: [batch_size, embedding_dim]
165
178
  """
166
- # 获取user特征的embedding
167
179
  all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
168
180
  user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
169
181
 
170
- # 通过user DNN
171
182
  user_emb = self.user_dnn(user_emb)
172
183
 
173
184
  # L2 normalize for cosine similarity
@@ -178,19 +189,17 @@ class DSSM(BaseMatchModel):
178
189
 
179
190
  def item_tower(self, item_input: dict) -> torch.Tensor:
180
191
  """
181
- Item tower: item特征编码为embedding
192
+ Item tower encodes item features into embeddings.
182
193
 
183
194
  Args:
184
- item_input: item特征字典
195
+ item_input: item feature dict
185
196
 
186
197
  Returns:
187
- item_emb: [batch_size, embedding_dim] [batch_size, num_items, embedding_dim]
198
+ item_emb: [batch_size, embedding_dim] or [batch_size, num_items, embedding_dim]
188
199
  """
189
- # 获取item特征的embedding
190
200
  all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
191
201
  item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
192
202
 
193
- # 通过item DNN
194
203
  item_emb = self.item_dnn(item_emb)
195
204
 
196
205
  # L2 normalize for cosine similarity