nextrec 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nextrec/__init__.py +41 -0
  2. nextrec/__version__.py +1 -0
  3. nextrec/basic/__init__.py +0 -0
  4. nextrec/basic/activation.py +92 -0
  5. nextrec/basic/callback.py +35 -0
  6. nextrec/basic/dataloader.py +447 -0
  7. nextrec/basic/features.py +87 -0
  8. nextrec/basic/layers.py +985 -0
  9. nextrec/basic/loggers.py +124 -0
  10. nextrec/basic/metrics.py +557 -0
  11. nextrec/basic/model.py +1438 -0
  12. nextrec/data/__init__.py +27 -0
  13. nextrec/data/data_utils.py +132 -0
  14. nextrec/data/preprocessor.py +662 -0
  15. nextrec/loss/__init__.py +35 -0
  16. nextrec/loss/loss_utils.py +136 -0
  17. nextrec/loss/match_losses.py +294 -0
  18. nextrec/models/generative/hstu.py +0 -0
  19. nextrec/models/generative/tiger.py +0 -0
  20. nextrec/models/match/__init__.py +13 -0
  21. nextrec/models/match/dssm.py +200 -0
  22. nextrec/models/match/dssm_v2.py +162 -0
  23. nextrec/models/match/mind.py +210 -0
  24. nextrec/models/match/sdm.py +253 -0
  25. nextrec/models/match/youtube_dnn.py +172 -0
  26. nextrec/models/multi_task/esmm.py +129 -0
  27. nextrec/models/multi_task/mmoe.py +161 -0
  28. nextrec/models/multi_task/ple.py +260 -0
  29. nextrec/models/multi_task/share_bottom.py +126 -0
  30. nextrec/models/ranking/__init__.py +17 -0
  31. nextrec/models/ranking/afm.py +118 -0
  32. nextrec/models/ranking/autoint.py +140 -0
  33. nextrec/models/ranking/dcn.py +120 -0
  34. nextrec/models/ranking/deepfm.py +95 -0
  35. nextrec/models/ranking/dien.py +214 -0
  36. nextrec/models/ranking/din.py +181 -0
  37. nextrec/models/ranking/fibinet.py +130 -0
  38. nextrec/models/ranking/fm.py +87 -0
  39. nextrec/models/ranking/masknet.py +125 -0
  40. nextrec/models/ranking/pnn.py +128 -0
  41. nextrec/models/ranking/widedeep.py +105 -0
  42. nextrec/models/ranking/xdeepfm.py +117 -0
  43. nextrec/utils/__init__.py +18 -0
  44. nextrec/utils/common.py +14 -0
  45. nextrec/utils/embedding.py +19 -0
  46. nextrec/utils/initializer.py +47 -0
  47. nextrec/utils/optimizer.py +75 -0
  48. nextrec-0.1.1.dist-info/METADATA +302 -0
  49. nextrec-0.1.1.dist-info/RECORD +51 -0
  50. nextrec-0.1.1.dist-info/WHEEL +4 -0
  51. nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,136 @@
1
+ """
2
+ Loss utilities for NextRec
3
+
4
+ Date: create on 09/11/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Literal
11
+
12
+ from nextrec.loss.match_losses import (
13
+ BPRLoss,
14
+ HingeLoss,
15
+ TripletLoss,
16
+ SampledSoftmaxLoss,
17
+ CosineContrastiveLoss,
18
+ InfoNCELoss
19
+ )
20
+
21
+ # Valid task types for validation
22
+ VALID_TASK_TYPES = ['binary', 'multiclass', 'regression', 'multivariate_regression', 'match', 'ranking', 'multitask', 'multilabel']
23
+
24
+
25
+ def get_loss_fn(
26
+ task_type: str = "binary",
27
+ training_mode: str | None = None,
28
+ loss: str | nn.Module | None = None,
29
+ **loss_kwargs
30
+ ) -> nn.Module:
31
+ """
32
+ Get loss function based on task type and training mode.
33
+
34
+ Examples:
35
+ # Ranking task (binary classification)
36
+ >>> loss_fn = get_loss_fn(task_type="binary", loss="bce")
37
+
38
+ # Match task with pointwise training
39
+ >>> loss_fn = get_loss_fn(task_type="match", training_mode="pointwise")
40
+
41
+ # Match task with pairwise training
42
+ >>> loss_fn = get_loss_fn(task_type="match", training_mode="pairwise", loss="bpr")
43
+
44
+ # Match task with listwise training
45
+ >>> loss_fn = get_loss_fn(task_type="match", training_mode="listwise", loss="sampled_softmax")
46
+ """
47
+
48
+ if isinstance(loss, nn.Module):
49
+ return loss
50
+
51
+ if task_type == "match":
52
+ if training_mode == "pointwise":
53
+ # Pointwise training uses binary cross entropy
54
+ if loss is None or loss == "bce" or loss == "binary_crossentropy":
55
+ return nn.BCELoss(**loss_kwargs)
56
+ elif loss == "cosine_contrastive":
57
+ return CosineContrastiveLoss(**loss_kwargs)
58
+ elif isinstance(loss, str):
59
+ raise ValueError(f"Unsupported pointwise loss: {loss}")
60
+
61
+ elif training_mode == "pairwise":
62
+ if loss is None or loss == "bpr":
63
+ return BPRLoss(**loss_kwargs)
64
+ elif loss == "hinge":
65
+ return HingeLoss(**loss_kwargs)
66
+ elif loss == "triplet":
67
+ return TripletLoss(**loss_kwargs)
68
+ elif isinstance(loss, str):
69
+ raise ValueError(f"Unsupported pairwise loss: {loss}")
70
+
71
+ elif training_mode == "listwise":
72
+ if loss is None or loss == "sampled_softmax" or loss == "softmax":
73
+ return SampledSoftmaxLoss(**loss_kwargs)
74
+ elif loss == "infonce":
75
+ return InfoNCELoss(**loss_kwargs)
76
+ elif loss == "crossentropy" or loss == "ce":
77
+ return nn.CrossEntropyLoss(**loss_kwargs)
78
+ elif isinstance(loss, str):
79
+ raise ValueError(f"Unsupported listwise loss: {loss}")
80
+
81
+ else:
82
+ raise ValueError(f"Unknown training_mode: {training_mode}")
83
+
84
+ elif task_type in ["ranking", "multitask", "binary"]:
85
+ if loss is None or loss == "bce" or loss == "binary_crossentropy":
86
+ return nn.BCELoss(**loss_kwargs)
87
+ elif loss == "mse":
88
+ return nn.MSELoss(**loss_kwargs)
89
+ elif loss == "mae":
90
+ return nn.L1Loss(**loss_kwargs)
91
+ elif loss == "crossentropy" or loss == "ce":
92
+ return nn.CrossEntropyLoss(**loss_kwargs)
93
+ elif isinstance(loss, str):
94
+ raise ValueError(f"Unsupported loss function: {loss}")
95
+
96
+ elif task_type == "multiclass":
97
+ if loss is None or loss == "crossentropy" or loss == "ce":
98
+ return nn.CrossEntropyLoss(**loss_kwargs)
99
+ elif isinstance(loss, str):
100
+ raise ValueError(f"Unsupported multiclass loss: {loss}")
101
+
102
+ elif task_type == "regression":
103
+ if loss is None or loss == "mse":
104
+ return nn.MSELoss(**loss_kwargs)
105
+ elif loss == "mae":
106
+ return nn.L1Loss(**loss_kwargs)
107
+ elif isinstance(loss, str):
108
+ raise ValueError(f"Unsupported regression loss: {loss}")
109
+
110
+ else:
111
+ raise ValueError(f"Unsupported task_type: {task_type}")
112
+
113
+ return loss
114
+
115
+
116
+ def validate_training_mode(
117
+ training_mode: str,
118
+ support_training_modes: list[str],
119
+ model_name: str = "Model"
120
+ ) -> None:
121
+ """
122
+ Validate that the requested training mode is supported by the model.
123
+
124
+ Args:
125
+ training_mode: Requested training mode
126
+ support_training_modes: List of supported training modes
127
+ model_name: Name of the model (for error messages)
128
+
129
+ Raises:
130
+ ValueError: If training mode is not supported
131
+ """
132
+ if training_mode not in support_training_modes:
133
+ raise ValueError(
134
+ f"{model_name} does not support training_mode='{training_mode}'. "
135
+ f"Supported modes: {support_training_modes}"
136
+ )
@@ -0,0 +1,294 @@
1
+ """
2
+ Loss functions for matching tasks
3
+
4
+ Date: create on 13/11/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Optional
13
+
14
+
15
+ class BPRLoss(nn.Module):
16
+ def __init__(self, reduction: str = 'mean'):
17
+ super(BPRLoss, self).__init__()
18
+ self.reduction = reduction
19
+
20
+ def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
21
+ if neg_score.dim() == 2:
22
+ pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
23
+ diff = pos_score - neg_score # [batch_size, num_neg]
24
+ loss = -torch.log(torch.sigmoid(diff) + 1e-8)
25
+ if self.reduction == 'mean':
26
+ return loss.mean()
27
+ elif self.reduction == 'sum':
28
+ return loss.sum()
29
+ else:
30
+ return loss
31
+ else:
32
+ diff = pos_score - neg_score
33
+ loss = -torch.log(torch.sigmoid(diff) + 1e-8)
34
+ if self.reduction == 'mean':
35
+ return loss.mean()
36
+ elif self.reduction == 'sum':
37
+ return loss.sum()
38
+ else:
39
+ return loss
40
+
41
+
42
+ class HingeLoss(nn.Module):
43
+ def __init__(self, margin: float = 1.0, reduction: str = 'mean'):
44
+ super(HingeLoss, self).__init__()
45
+ self.margin = margin
46
+ self.reduction = reduction
47
+
48
+ def forward(self, pos_score: torch.Tensor, neg_score: torch.Tensor) -> torch.Tensor:
49
+ if neg_score.dim() == 2:
50
+ pos_score = pos_score.unsqueeze(1) # [batch_size, 1]
51
+
52
+ diff = pos_score - neg_score
53
+ loss = torch.clamp(self.margin - diff, min=0)
54
+
55
+ if self.reduction == 'mean':
56
+ return loss.mean()
57
+ elif self.reduction == 'sum':
58
+ return loss.sum()
59
+ else:
60
+ return loss
61
+
62
+
63
+ class TripletLoss(nn.Module):
64
+ def __init__(self, margin: float = 1.0, reduction: str = 'mean', distance: str = 'euclidean'):
65
+ super(TripletLoss, self).__init__()
66
+ self.margin = margin
67
+ self.reduction = reduction
68
+ self.distance = distance
69
+
70
+ def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
71
+ if self.distance == 'euclidean':
72
+ pos_dist = torch.sum((anchor - positive) ** 2, dim=-1)
73
+
74
+ if negative.dim() == 3:
75
+ anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
76
+ neg_dist = torch.sum((anchor_expanded - negative) ** 2, dim=-1) # [batch_size, num_neg]
77
+ else:
78
+ neg_dist = torch.sum((anchor - negative) ** 2, dim=-1)
79
+
80
+ if neg_dist.dim() == 2:
81
+ pos_dist = pos_dist.unsqueeze(1) # [batch_size, 1]
82
+
83
+ elif self.distance == 'cosine':
84
+ pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
85
+
86
+ if negative.dim() == 3:
87
+ anchor_expanded = anchor.unsqueeze(1) # [batch_size, 1, dim]
88
+ neg_dist = 1 - F.cosine_similarity(anchor_expanded, negative, dim=-1)
89
+ else:
90
+ neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
91
+
92
+ if neg_dist.dim() == 2:
93
+ pos_dist = pos_dist.unsqueeze(1)
94
+ else:
95
+ raise ValueError(f"Unsupported distance: {self.distance}")
96
+
97
+ loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
98
+
99
+ if self.reduction == 'mean':
100
+ return loss.mean()
101
+ elif self.reduction == 'sum':
102
+ return loss.sum()
103
+ else:
104
+ return loss
105
+
106
+
107
+ class SampledSoftmaxLoss(nn.Module):
108
+ def __init__(self, reduction: str = 'mean'):
109
+ super(SampledSoftmaxLoss, self).__init__()
110
+ self.reduction = reduction
111
+
112
+ def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
113
+ pos_logits = pos_logits.unsqueeze(1) # [batch_size, 1]
114
+ all_logits = torch.cat([pos_logits, neg_logits], dim=1) # [batch_size, 1 + num_neg]
115
+ targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
116
+ loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
117
+
118
+ return loss
119
+
120
+
121
+ class CosineContrastiveLoss(nn.Module):
122
+ def __init__(self, margin: float = 0.5, reduction: str = 'mean'):
123
+ super(CosineContrastiveLoss, self).__init__()
124
+ self.margin = margin
125
+ self.reduction = reduction
126
+
127
+ def forward(self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
128
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
129
+ pos_loss = (1 - similarity) * labels
130
+
131
+ neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
132
+
133
+ loss = pos_loss + neg_loss
134
+
135
+ if self.reduction == 'mean':
136
+ return loss.mean()
137
+ elif self.reduction == 'sum':
138
+ return loss.sum()
139
+ else:
140
+ return loss
141
+
142
+
143
+ class InfoNCELoss(nn.Module):
144
+ def __init__(self, temperature: float = 0.07, reduction: str = 'mean'):
145
+ super(InfoNCELoss, self).__init__()
146
+ self.temperature = temperature
147
+ self.reduction = reduction
148
+
149
+ def forward(self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor) -> torch.Tensor:
150
+ pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature # [batch_size]
151
+ pos_sim = pos_sim.unsqueeze(1) # [batch_size, 1]
152
+ query_expanded = query.unsqueeze(1) # [batch_size, 1, dim]
153
+ neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature # [batch_size, num_neg]
154
+ logits = torch.cat([pos_sim, neg_sim], dim=1) # [batch_size, 1 + num_neg]
155
+ labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
156
+
157
+ loss = F.cross_entropy(logits, labels, reduction=self.reduction)
158
+
159
+ return loss
160
+
161
+
162
+ class ListNetLoss(nn.Module):
163
+ """
164
+ ListNet loss using top-1 probability distribution
165
+ Reference: Cao et al. Learning to Rank: From Pairwise Approach to Listwise Approach (ICML 2007)
166
+ """
167
+ def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
168
+ super(ListNetLoss, self).__init__()
169
+ self.temperature = temperature
170
+ self.reduction = reduction
171
+
172
+ def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
173
+ # Convert scores and labels to probability distributions
174
+ pred_probs = F.softmax(scores / self.temperature, dim=1)
175
+ true_probs = F.softmax(labels / self.temperature, dim=1)
176
+
177
+ # Cross entropy between two distributions
178
+ loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
179
+
180
+ if self.reduction == 'mean':
181
+ return loss.mean()
182
+ elif self.reduction == 'sum':
183
+ return loss.sum()
184
+ else:
185
+ return loss
186
+
187
+
188
+ class ListMLELoss(nn.Module):
189
+ """
190
+ ListMLE (Maximum Likelihood Estimation) loss
191
+ Reference: Xia et al. Listwise approach to learning to rank: theory and algorithm (ICML 2008)
192
+ """
193
+ def __init__(self, reduction: str = 'mean'):
194
+ super(ListMLELoss, self).__init__()
195
+ self.reduction = reduction
196
+
197
+ def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
198
+ # Sort by labels in descending order to get ground truth ranking
199
+ sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
200
+
201
+ # Reorder scores according to ground truth ranking
202
+ batch_size, list_size = scores.shape
203
+ batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
204
+ sorted_scores = scores[batch_indices, sorted_indices]
205
+
206
+ # Compute log likelihood
207
+ # For each position, compute log(exp(score_i) / sum(exp(score_j) for j >= i))
208
+ loss = torch.tensor(0.0, device=scores.device)
209
+ for i in range(list_size):
210
+ # Log-sum-exp trick for numerical stability
211
+ remaining_scores = sorted_scores[:, i:]
212
+ log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
213
+ loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
214
+
215
+ if self.reduction == 'mean':
216
+ return loss / batch_size
217
+ elif self.reduction == 'sum':
218
+ return loss
219
+ else:
220
+ return loss / batch_size
221
+
222
+
223
+ class ApproxNDCGLoss(nn.Module):
224
+ """
225
+ Approximate NDCG loss for learning to rank
226
+ Reference: Qin et al. A General Approximation Framework for Direct Optimization of
227
+ Information Retrieval Measures (Information Retrieval 2010)
228
+ """
229
+ def __init__(self, temperature: float = 1.0, reduction: str = 'mean'):
230
+ super(ApproxNDCGLoss, self).__init__()
231
+ self.temperature = temperature
232
+ self.reduction = reduction
233
+
234
+ def _dcg(self, relevance: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
235
+ if k is not None:
236
+ relevance = relevance[:, :k]
237
+
238
+ # DCG = sum(rel_i / log2(i + 2)) for i in range(list_size)
239
+ positions = torch.arange(1, relevance.size(1) + 1, device=relevance.device, dtype=torch.float32)
240
+ discounts = torch.log2(positions + 1.0)
241
+ dcg = torch.sum(relevance / discounts, dim=1)
242
+
243
+ return dcg
244
+
245
+ def forward(self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None) -> torch.Tensor:
246
+ """
247
+ Args:
248
+ scores: Predicted scores [batch_size, list_size]
249
+ labels: Ground truth relevance labels [batch_size, list_size]
250
+ k: Top-k items for NDCG@k (if None, use all items)
251
+
252
+ Returns:
253
+ Approximate NDCG loss (1 - NDCG)
254
+ """
255
+ batch_size = scores.size(0)
256
+
257
+ # Use differentiable sorting approximation with softmax
258
+ # Create pairwise comparison matrix
259
+ scores_expanded = scores.unsqueeze(2) # [batch_size, list_size, 1]
260
+ scores_tiled = scores.unsqueeze(1) # [batch_size, 1, list_size]
261
+
262
+ # Compute pairwise probabilities using sigmoid
263
+ pairwise_diff = (scores_expanded - scores_tiled) / self.temperature
264
+ pairwise_probs = torch.sigmoid(pairwise_diff) # [batch_size, list_size, list_size]
265
+
266
+ # Approximate ranking positions
267
+ # ranking_probs[i, j] ≈ probability that item i is ranked at position j
268
+ # We use softmax approximation for differentiable ranking
269
+ ranking_weights = F.softmax(scores / self.temperature, dim=1)
270
+
271
+ # Sort labels to get ideal DCG
272
+ ideal_labels, _ = torch.sort(labels, descending=True, dim=1)
273
+ ideal_dcg = self._dcg(ideal_labels, k)
274
+
275
+ # Compute approximate DCG using soft ranking
276
+ # Weight each item's relevance by its soft ranking position
277
+ positions = torch.arange(1, scores.size(1) + 1, device=scores.device, dtype=torch.float32)
278
+ discounts = 1.0 / torch.log2(positions + 1.0)
279
+
280
+ # Approximate DCG by weighting relevance with ranking probabilities
281
+ approx_dcg = torch.sum(labels * ranking_weights * discounts, dim=1)
282
+
283
+ # Normalize by ideal DCG to get NDCG
284
+ ndcg = approx_dcg / (ideal_dcg + 1e-10)
285
+
286
+ # Loss is 1 - NDCG (we want to maximize NDCG, so minimize 1 - NDCG)
287
+ loss = 1.0 - ndcg
288
+
289
+ if self.reduction == 'mean':
290
+ return loss.mean()
291
+ elif self.reduction == 'sum':
292
+ return loss.sum()
293
+ else:
294
+ return loss
File without changes
File without changes
@@ -0,0 +1,13 @@
1
+ from .dssm import DSSM
2
+ from .dssm_v2 import DSSM_v2
3
+ from .youtube_dnn import YoutubeDNN
4
+ from .mind import MIND
5
+ from .sdm import SDM
6
+
7
+ __all__ = [
8
+ 'DSSM',
9
+ 'DSSM_v2',
10
+ 'YoutubeDNN',
11
+ 'MIND',
12
+ 'SDM',
13
+ ]
@@ -0,0 +1,200 @@
1
+ """
2
+ Date: create on 09/11/2025
3
+ Author:
4
+ Yang Zhou,zyaztec@gmail.com
5
+ Reference:
6
+ [1] Huang P S, He X, Gao J, et al. Learning deep structured semantic models for web search using clickthrough data[C]
7
+ //Proceedings of the 22nd ACM international conference on Information & Knowledge Management. 2013: 2333-2338.
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing import Optional, Literal
12
+
13
+ from nextrec.basic.model import BaseMatchModel
14
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
15
+ from nextrec.basic.layers import MLP, EmbeddingLayer
16
+
17
+
18
+ class DSSM(BaseMatchModel):
19
+ """
20
+ Deep Structured Semantic Model
21
+
22
+ 双塔模型,分别对user和item特征编码为embedding,通过余弦相似度或点积计算匹配分数
23
+ """
24
+
25
+ @property
26
+ def model_name(self) -> str:
27
+ return "DSSM"
28
+
29
+ def __init__(self,
30
+ user_dense_features: list[DenseFeature] | None = None,
31
+ user_sparse_features: list[SparseFeature] | None = None,
32
+ user_sequence_features: list[SequenceFeature] | None = None,
33
+ item_dense_features: list[DenseFeature] | None = None,
34
+ item_sparse_features: list[SparseFeature] | None = None,
35
+ item_sequence_features: list[SequenceFeature] | None = None,
36
+ user_dnn_hidden_units: list[int] = [256, 128, 64],
37
+ item_dnn_hidden_units: list[int] = [256, 128, 64],
38
+ embedding_dim: int = 64,
39
+ dnn_activation: str = 'relu',
40
+ dnn_dropout: float = 0.0,
41
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
42
+ num_negative_samples: int = 4,
43
+ temperature: float = 1.0,
44
+ similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'cosine',
45
+ device: str = 'cpu',
46
+ embedding_l1_reg: float = 0.0,
47
+ dense_l1_reg: float = 0.0,
48
+ embedding_l2_reg: float = 0.0,
49
+ dense_l2_reg: float = 0.0,
50
+ early_stop_patience: int = 20,
51
+ model_id: str = 'dssm'):
52
+
53
+ super(DSSM, self).__init__(
54
+ user_dense_features=user_dense_features,
55
+ user_sparse_features=user_sparse_features,
56
+ user_sequence_features=user_sequence_features,
57
+ item_dense_features=item_dense_features,
58
+ item_sparse_features=item_sparse_features,
59
+ item_sequence_features=item_sequence_features,
60
+ training_mode=training_mode,
61
+ num_negative_samples=num_negative_samples,
62
+ temperature=temperature,
63
+ similarity_metric=similarity_metric,
64
+ device=device,
65
+ embedding_l1_reg=embedding_l1_reg,
66
+ dense_l1_reg=dense_l1_reg,
67
+ embedding_l2_reg=embedding_l2_reg,
68
+ dense_l2_reg=dense_l2_reg,
69
+ early_stop_patience=early_stop_patience,
70
+ model_id=model_id
71
+ )
72
+
73
+ self.embedding_dim = embedding_dim
74
+ self.user_dnn_hidden_units = user_dnn_hidden_units
75
+ self.item_dnn_hidden_units = item_dnn_hidden_units
76
+
77
+ # User tower embedding layer
78
+ user_features = []
79
+ if user_dense_features:
80
+ user_features.extend(user_dense_features)
81
+ if user_sparse_features:
82
+ user_features.extend(user_sparse_features)
83
+ if user_sequence_features:
84
+ user_features.extend(user_sequence_features)
85
+
86
+ if len(user_features) > 0:
87
+ self.user_embedding = EmbeddingLayer(user_features)
88
+
89
+ # 计算user tower输入维度
90
+ user_input_dim = 0
91
+ for feat in user_dense_features or []:
92
+ user_input_dim += 1
93
+ for feat in user_sparse_features or []:
94
+ user_input_dim += feat.embedding_dim
95
+ for feat in user_sequence_features or []:
96
+ user_input_dim += feat.embedding_dim
97
+
98
+ # User DNN
99
+ user_dnn_units = user_dnn_hidden_units + [embedding_dim]
100
+ self.user_dnn = MLP(
101
+ input_dim=user_input_dim,
102
+ dims=user_dnn_units,
103
+ output_layer=False,
104
+ dropout=dnn_dropout,
105
+ activation=dnn_activation
106
+ )
107
+
108
+ # Item tower embedding layer
109
+ item_features = []
110
+ if item_dense_features:
111
+ item_features.extend(item_dense_features)
112
+ if item_sparse_features:
113
+ item_features.extend(item_sparse_features)
114
+ if item_sequence_features:
115
+ item_features.extend(item_sequence_features)
116
+
117
+ if len(item_features) > 0:
118
+ self.item_embedding = EmbeddingLayer(item_features)
119
+
120
+ # 计算item tower输入维度
121
+ item_input_dim = 0
122
+ for feat in item_dense_features or []:
123
+ item_input_dim += 1
124
+ for feat in item_sparse_features or []:
125
+ item_input_dim += feat.embedding_dim
126
+ for feat in item_sequence_features or []:
127
+ item_input_dim += feat.embedding_dim
128
+
129
+ # Item DNN
130
+ item_dnn_units = item_dnn_hidden_units + [embedding_dim]
131
+ self.item_dnn = MLP(
132
+ input_dim=item_input_dim,
133
+ dims=item_dnn_units,
134
+ output_layer=False,
135
+ dropout=dnn_dropout,
136
+ activation=dnn_activation
137
+ )
138
+
139
+ # 注册正则化权重
140
+ self._register_regularization_weights(
141
+ embedding_attr='user_embedding',
142
+ include_modules=['user_dnn']
143
+ )
144
+ self._register_regularization_weights(
145
+ embedding_attr='item_embedding',
146
+ include_modules=['item_dnn']
147
+ )
148
+
149
+ self.compile(
150
+ optimizer="adam",
151
+ optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
152
+ )
153
+
154
+ self.to(device)
155
+
156
+ def user_tower(self, user_input: dict) -> torch.Tensor:
157
+ """
158
+ User tower: 将user特征编码为embedding
159
+
160
+ Args:
161
+ user_input: user特征字典
162
+
163
+ Returns:
164
+ user_emb: [batch_size, embedding_dim]
165
+ """
166
+ # 获取user特征的embedding
167
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
168
+ user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
169
+
170
+ # 通过user DNN
171
+ user_emb = self.user_dnn(user_emb)
172
+
173
+ # L2 normalize for cosine similarity
174
+ if self.similarity_metric == 'cosine':
175
+ user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
176
+
177
+ return user_emb
178
+
179
+ def item_tower(self, item_input: dict) -> torch.Tensor:
180
+ """
181
+ Item tower: 将item特征编码为embedding
182
+
183
+ Args:
184
+ item_input: item特征字典
185
+
186
+ Returns:
187
+ item_emb: [batch_size, embedding_dim] 或 [batch_size, num_items, embedding_dim]
188
+ """
189
+ # 获取item特征的embedding
190
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
191
+ item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
192
+
193
+ # 通过item DNN
194
+ item_emb = self.item_dnn(item_emb)
195
+
196
+ # L2 normalize for cosine similarity
197
+ if self.similarity_metric == 'cosine':
198
+ item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
199
+
200
+ return item_emb