nextrec 0.1.11__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +1 -2
  3. nextrec/basic/callback.py +1 -2
  4. nextrec/basic/features.py +39 -8
  5. nextrec/basic/layers.py +3 -4
  6. nextrec/basic/loggers.py +15 -10
  7. nextrec/basic/metrics.py +1 -2
  8. nextrec/basic/model.py +160 -125
  9. nextrec/basic/session.py +150 -0
  10. nextrec/data/__init__.py +13 -2
  11. nextrec/data/data_utils.py +74 -22
  12. nextrec/data/dataloader.py +513 -0
  13. nextrec/data/preprocessor.py +494 -134
  14. nextrec/loss/__init__.py +31 -24
  15. nextrec/loss/listwise.py +164 -0
  16. nextrec/loss/loss_utils.py +133 -106
  17. nextrec/loss/pairwise.py +105 -0
  18. nextrec/loss/pointwise.py +198 -0
  19. nextrec/models/match/dssm.py +26 -17
  20. nextrec/models/match/dssm_v2.py +20 -2
  21. nextrec/models/match/mind.py +18 -3
  22. nextrec/models/match/sdm.py +17 -2
  23. nextrec/models/match/youtube_dnn.py +23 -10
  24. nextrec/models/multi_task/esmm.py +8 -8
  25. nextrec/models/multi_task/mmoe.py +8 -8
  26. nextrec/models/multi_task/ple.py +8 -8
  27. nextrec/models/multi_task/share_bottom.py +8 -8
  28. nextrec/models/ranking/__init__.py +8 -0
  29. nextrec/models/ranking/afm.py +5 -4
  30. nextrec/models/ranking/autoint.py +6 -4
  31. nextrec/models/ranking/dcn.py +6 -4
  32. nextrec/models/ranking/deepfm.py +5 -4
  33. nextrec/models/ranking/dien.py +6 -4
  34. nextrec/models/ranking/din.py +6 -4
  35. nextrec/models/ranking/fibinet.py +6 -4
  36. nextrec/models/ranking/fm.py +6 -4
  37. nextrec/models/ranking/masknet.py +6 -4
  38. nextrec/models/ranking/pnn.py +6 -4
  39. nextrec/models/ranking/widedeep.py +6 -4
  40. nextrec/models/ranking/xdeepfm.py +6 -4
  41. nextrec/utils/__init__.py +7 -11
  42. nextrec/utils/embedding.py +2 -4
  43. nextrec/utils/initializer.py +4 -5
  44. nextrec/utils/optimizer.py +7 -8
  45. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/METADATA +3 -3
  46. nextrec-0.2.2.dist-info/RECORD +53 -0
  47. nextrec/basic/dataloader.py +0 -447
  48. nextrec/loss/match_losses.py +0 -294
  49. nextrec/utils/common.py +0 -14
  50. nextrec-0.1.11.dist-info/RECORD +0 -51
  51. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/WHEEL +0 -0
  52. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,198 @@
1
+ """
2
+ Pointwise loss functions, including imbalance-aware variants.
3
+ """
4
+
5
+ from typing import Optional, Sequence
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class CosineContrastiveLoss(nn.Module):
13
+ """
14
+ Contrastive loss using cosine similarity for positive/negative pairs.
15
+ """
16
+
17
+ def __init__(self, margin: float = 0.5, reduction: str = "mean"):
18
+ super().__init__()
19
+ self.margin = margin
20
+ self.reduction = reduction
21
+
22
+ def forward(
23
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor
24
+ ) -> torch.Tensor:
25
+ labels = labels.float()
26
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
27
+ pos_loss = torch.clamp(self.margin - similarity, min=0) * labels
28
+ neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
29
+ loss = pos_loss + neg_loss
30
+
31
+ if self.reduction == "mean":
32
+ return loss.mean()
33
+ if self.reduction == "sum":
34
+ return loss.sum()
35
+ return loss
36
+
37
+
38
+ class WeightedBCELoss(nn.Module):
39
+ """
40
+ Binary cross entropy with controllable positive class weight.
41
+ Supports probability or logit inputs via `logits` flag.
42
+ If `auto_balance=True` and `pos_weight` is None, the positive weight is
43
+ computed from the batch as (#neg / #pos) for stable imbalance handling.
44
+ """
45
+ def __init__(
46
+ self,
47
+ pos_weight: float | torch.Tensor | None = None,
48
+ reduction: str = "mean",
49
+ logits: bool = False,
50
+ auto_balance: bool = False,
51
+ ):
52
+ super().__init__()
53
+ self.reduction = reduction
54
+ self.logits = logits
55
+ self.auto_balance = auto_balance
56
+
57
+ if pos_weight is not None:
58
+ self.register_buffer(
59
+ "pos_weight",
60
+ torch.as_tensor(pos_weight, dtype=torch.float32),
61
+ )
62
+ else:
63
+ self.pos_weight = None
64
+
65
+ def _resolve_pos_weight(self, labels: torch.Tensor) -> torch.Tensor:
66
+ if self.pos_weight is not None:
67
+ return self.pos_weight.to(device=labels.device)
68
+
69
+ if not self.auto_balance:
70
+ return torch.tensor(1.0, device=labels.device, dtype=labels.dtype)
71
+
72
+ labels_float = labels.float()
73
+ pos = torch.clamp(labels_float.sum(), min=1.0)
74
+ neg = torch.clamp(labels_float.numel() - labels_float.sum(), min=1.0)
75
+ return (neg / pos).to(device=labels.device, dtype=labels.dtype)
76
+
77
+ def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
78
+ labels = labels.float()
79
+ current_pos_weight = self._resolve_pos_weight(labels)
80
+ current_pos_weight = current_pos_weight.to(inputs.dtype)
81
+
82
+ if self.logits:
83
+ loss = F.binary_cross_entropy_with_logits(
84
+ inputs, labels, pos_weight=current_pos_weight, reduction="none"
85
+ )
86
+ else:
87
+ probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
88
+ base_loss = F.binary_cross_entropy(probs, labels, reduction="none")
89
+ loss = torch.where(labels == 1, base_loss * current_pos_weight, base_loss)
90
+
91
+ if self.reduction == "mean":
92
+ return loss.mean()
93
+ elif self.reduction == "sum":
94
+ return loss.sum()
95
+ else:
96
+ return loss
97
+
98
+
99
+ class FocalLoss(nn.Module):
100
+ """
101
+ Standard focal loss for binary or multi-class classification.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ gamma: float = 2.0,
107
+ alpha: Optional[float | Sequence[float] | torch.Tensor] = None,
108
+ reduction: str = "mean",
109
+ logits: bool = False,
110
+ ):
111
+ super().__init__()
112
+ self.gamma = gamma
113
+ self.reduction = reduction
114
+ self.logits = logits
115
+ self.alpha = alpha
116
+
117
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
118
+ # Multi-class case
119
+ if inputs.dim() > 1 and inputs.size(1) > 1:
120
+ log_probs = F.log_softmax(inputs, dim=1)
121
+ probs = log_probs.exp()
122
+ targets_one_hot = F.one_hot(targets.long(), num_classes=inputs.size(1)).float()
123
+
124
+ alpha = self._get_alpha(inputs)
125
+ alpha_factor = targets_one_hot * alpha
126
+ focal_weight = (1.0 - probs) ** self.gamma
127
+ loss = torch.sum(alpha_factor * focal_weight * (-log_probs), dim=1)
128
+ else:
129
+ targets = targets.float()
130
+ if self.logits:
131
+ ce_loss = F.binary_cross_entropy_with_logits(
132
+ inputs, targets, reduction="none"
133
+ )
134
+ probs = torch.sigmoid(inputs)
135
+ else:
136
+ ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
137
+ probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
138
+
139
+ p_t = probs * targets + (1 - probs) * (1 - targets)
140
+ alpha_factor = self._get_binary_alpha(targets, inputs.device)
141
+ focal_weight = (1.0 - p_t) ** self.gamma
142
+ loss = alpha_factor * focal_weight * ce_loss
143
+
144
+ if self.reduction == "mean":
145
+ return loss.mean()
146
+ if self.reduction == "sum":
147
+ return loss.sum()
148
+ return loss
149
+
150
+ def _get_alpha(self, inputs: torch.Tensor) -> torch.Tensor:
151
+ if self.alpha is None:
152
+ return torch.ones_like(inputs)
153
+ if isinstance(self.alpha, torch.Tensor):
154
+ return self.alpha.to(inputs.device)
155
+ alpha_tensor = torch.tensor(self.alpha, device=inputs.device, dtype=inputs.dtype)
156
+ return alpha_tensor
157
+
158
+ def _get_binary_alpha(self, targets: torch.Tensor, device: torch.device) -> torch.Tensor:
159
+ if self.alpha is None:
160
+ return torch.ones_like(targets)
161
+ if isinstance(self.alpha, (float, int)):
162
+ return torch.where(targets == 1, self.alpha, 1 - float(self.alpha)).to(device)
163
+ alpha_tensor = torch.tensor(self.alpha, device=device, dtype=targets.dtype)
164
+ return torch.where(targets == 1, alpha_tensor, 1 - alpha_tensor)
165
+
166
+
167
+ class ClassBalancedFocalLoss(nn.Module):
168
+ """
169
+ Focal loss weighted by effective number of samples per class.
170
+ Reference: "Class-Balanced Loss Based on Effective Number of Samples"
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ class_counts: Sequence[int] | torch.Tensor,
176
+ beta: float = 0.9999,
177
+ gamma: float = 2.0,
178
+ reduction: str = "mean",
179
+ ):
180
+ super().__init__()
181
+ self.gamma = gamma
182
+ self.reduction = reduction
183
+ class_counts = torch.as_tensor(class_counts, dtype=torch.float32)
184
+ effective_num = 1.0 - torch.pow(beta, class_counts)
185
+ weights = (1.0 - beta) / (effective_num + 1e-12)
186
+ weights = weights / weights.sum() * len(weights)
187
+ self.register_buffer("class_weights", weights)
188
+
189
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
190
+ focal = FocalLoss(
191
+ gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True
192
+ )
193
+ loss = focal(inputs, targets)
194
+ if self.reduction == "mean":
195
+ return loss.mean()
196
+ if self.reduction == "sum":
197
+ return loss.sum()
198
+ return loss
@@ -19,7 +19,8 @@ class DSSM(BaseMatchModel):
19
19
  """
20
20
  Deep Structured Semantic Model
21
21
 
22
- 双塔模型,分别对useritem特征编码为embedding,通过余弦相似度或点积计算匹配分数
22
+ Dual-tower model that encodes user and item features separately and
23
+ computes similarity via cosine or dot product.
23
24
  """
24
25
 
25
26
  @property
@@ -48,7 +49,13 @@ class DSSM(BaseMatchModel):
48
49
  embedding_l2_reg: float = 0.0,
49
50
  dense_l2_reg: float = 0.0,
50
51
  early_stop_patience: int = 20,
51
- model_id: str = 'dssm'):
52
+ optimizer: str | torch.optim.Optimizer = "adam",
53
+ optimizer_params: dict | None = None,
54
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
55
+ scheduler_params: dict | None = None,
56
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
57
+ loss_params: dict | list[dict] | None = None,
58
+ **kwargs):
52
59
 
53
60
  super(DSSM, self).__init__(
54
61
  user_dense_features=user_dense_features,
@@ -67,7 +74,7 @@ class DSSM(BaseMatchModel):
67
74
  embedding_l2_reg=embedding_l2_reg,
68
75
  dense_l2_reg=dense_l2_reg,
69
76
  early_stop_patience=early_stop_patience,
70
- model_id=model_id
77
+ **kwargs
71
78
  )
72
79
 
73
80
  self.embedding_dim = embedding_dim
@@ -86,7 +93,7 @@ class DSSM(BaseMatchModel):
86
93
  if len(user_features) > 0:
87
94
  self.user_embedding = EmbeddingLayer(user_features)
88
95
 
89
- # 计算user tower输入维度
96
+ # Compute user tower input dimension
90
97
  user_input_dim = 0
91
98
  for feat in user_dense_features or []:
92
99
  user_input_dim += 1
@@ -117,7 +124,7 @@ class DSSM(BaseMatchModel):
117
124
  if len(item_features) > 0:
118
125
  self.item_embedding = EmbeddingLayer(item_features)
119
126
 
120
- # 计算item tower输入维度
127
+ # Compute item tower input dimension
121
128
  item_input_dim = 0
122
129
  for feat in item_dense_features or []:
123
130
  item_input_dim += 1
@@ -136,7 +143,6 @@ class DSSM(BaseMatchModel):
136
143
  activation=dnn_activation
137
144
  )
138
145
 
139
- # 注册正则化权重
140
146
  self._register_regularization_weights(
141
147
  embedding_attr='user_embedding',
142
148
  include_modules=['user_dnn']
@@ -146,28 +152,33 @@ class DSSM(BaseMatchModel):
146
152
  include_modules=['item_dnn']
147
153
  )
148
154
 
155
+ if optimizer_params is None:
156
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
157
+
149
158
  self.compile(
150
- optimizer="adam",
151
- optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
159
+ optimizer=optimizer,
160
+ optimizer_params=optimizer_params,
161
+ scheduler=scheduler,
162
+ scheduler_params=scheduler_params,
163
+ loss=loss,
164
+ loss_params=loss_params,
152
165
  )
153
166
 
154
167
  self.to(device)
155
168
 
156
169
  def user_tower(self, user_input: dict) -> torch.Tensor:
157
170
  """
158
- User tower: user特征编码为embedding
171
+ User tower encodes user features into embeddings.
159
172
 
160
173
  Args:
161
- user_input: user特征字典
174
+ user_input: user feature dict
162
175
 
163
176
  Returns:
164
177
  user_emb: [batch_size, embedding_dim]
165
178
  """
166
- # 获取user特征的embedding
167
179
  all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
168
180
  user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
169
181
 
170
- # 通过user DNN
171
182
  user_emb = self.user_dnn(user_emb)
172
183
 
173
184
  # L2 normalize for cosine similarity
@@ -178,19 +189,17 @@ class DSSM(BaseMatchModel):
178
189
 
179
190
  def item_tower(self, item_input: dict) -> torch.Tensor:
180
191
  """
181
- Item tower: item特征编码为embedding
192
+ Item tower encodes item features into embeddings.
182
193
 
183
194
  Args:
184
- item_input: item特征字典
195
+ item_input: item feature dict
185
196
 
186
197
  Returns:
187
- item_emb: [batch_size, embedding_dim] [batch_size, num_items, embedding_dim]
198
+ item_emb: [batch_size, embedding_dim] or [batch_size, num_items, embedding_dim]
188
199
  """
189
- # 获取item特征的embedding
190
200
  all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
191
201
  item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
192
202
 
193
- # 通过item DNN
194
203
  item_emb = self.item_dnn(item_emb)
195
204
 
196
205
  # L2 normalize for cosine similarity
@@ -44,7 +44,13 @@ class DSSM_v2(BaseMatchModel):
44
44
  embedding_l2_reg: float = 0.0,
45
45
  dense_l2_reg: float = 0.0,
46
46
  early_stop_patience: int = 20,
47
- model_id: str = 'dssm_v2'):
47
+ optimizer: str | torch.optim.Optimizer = "adam",
48
+ optimizer_params: dict | None = None,
49
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
50
+ scheduler_params: dict | None = None,
51
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
52
+ loss_params: dict | list[dict] | None = None,
53
+ **kwargs):
48
54
 
49
55
  super(DSSM_v2, self).__init__(
50
56
  user_dense_features=user_dense_features,
@@ -63,7 +69,7 @@ class DSSM_v2(BaseMatchModel):
63
69
  embedding_l2_reg=embedding_l2_reg,
64
70
  dense_l2_reg=dense_l2_reg,
65
71
  early_stop_patience=early_stop_patience,
66
- model_id=model_id
72
+ **kwargs
67
73
  )
68
74
 
69
75
  self.embedding_dim = embedding_dim
@@ -137,6 +143,18 @@ class DSSM_v2(BaseMatchModel):
137
143
  include_modules=['item_dnn']
138
144
  )
139
145
 
146
+ if optimizer_params is None:
147
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
148
+
149
+ self.compile(
150
+ optimizer=optimizer,
151
+ optimizer_params=optimizer_params,
152
+ scheduler=scheduler,
153
+ scheduler_params=scheduler_params,
154
+ loss=loss,
155
+ loss_params=loss_params,
156
+ )
157
+
140
158
  self.to(device)
141
159
 
142
160
  def user_tower(self, user_input: dict) -> torch.Tensor:
@@ -41,7 +41,7 @@ class MIND(BaseMatchModel):
41
41
  item_dnn_hidden_units: list[int] = [256, 128],
42
42
  dnn_activation: str = 'relu',
43
43
  dnn_dropout: float = 0.0,
44
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'listwise',
44
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
45
45
  num_negative_samples: int = 100,
46
46
  temperature: float = 1.0,
47
47
  similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
@@ -51,7 +51,13 @@ class MIND(BaseMatchModel):
51
51
  embedding_l2_reg: float = 0.0,
52
52
  dense_l2_reg: float = 0.0,
53
53
  early_stop_patience: int = 20,
54
- model_id: str = 'mind'):
54
+ optimizer: str | torch.optim.Optimizer = "adam",
55
+ optimizer_params: dict | None = None,
56
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
57
+ scheduler_params: dict | None = None,
58
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
59
+ loss_params: dict | list[dict] | None = None,
60
+ **kwargs):
55
61
 
56
62
  super(MIND, self).__init__(
57
63
  user_dense_features=user_dense_features,
@@ -70,7 +76,7 @@ class MIND(BaseMatchModel):
70
76
  embedding_l2_reg=embedding_l2_reg,
71
77
  dense_l2_reg=dense_l2_reg,
72
78
  early_stop_patience=early_stop_patience,
73
- model_id=model_id
79
+ **kwargs
74
80
  )
75
81
 
76
82
  self.embedding_dim = embedding_dim
@@ -152,6 +158,15 @@ class MIND(BaseMatchModel):
152
158
  include_modules=['item_dnn'] if self.item_dnn else []
153
159
  )
154
160
 
161
+ self.compile(
162
+ optimizer=optimizer,
163
+ optimizer_params=optimizer_params,
164
+ scheduler=scheduler,
165
+ scheduler_params=scheduler_params,
166
+ loss=loss,
167
+ loss_params=loss_params,
168
+ )
169
+
155
170
  self.to(device)
156
171
 
157
172
  def user_tower(self, user_input: dict) -> torch.Tensor:
@@ -52,7 +52,13 @@ class SDM(BaseMatchModel):
52
52
  embedding_l2_reg: float = 0.0,
53
53
  dense_l2_reg: float = 0.0,
54
54
  early_stop_patience: int = 20,
55
- model_id: str = 'sdm'):
55
+ optimizer: str | torch.optim.Optimizer = "adam",
56
+ optimizer_params: dict | None = None,
57
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
58
+ scheduler_params: dict | None = None,
59
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
60
+ loss_params: dict | list[dict] | None = None,
61
+ **kwargs):
56
62
 
57
63
  super(SDM, self).__init__(
58
64
  user_dense_features=user_dense_features,
@@ -71,7 +77,7 @@ class SDM(BaseMatchModel):
71
77
  embedding_l2_reg=embedding_l2_reg,
72
78
  dense_l2_reg=dense_l2_reg,
73
79
  early_stop_patience=early_stop_patience,
74
- model_id=model_id
80
+ **kwargs
75
81
  )
76
82
 
77
83
  self.embedding_dim = embedding_dim
@@ -179,6 +185,15 @@ class SDM(BaseMatchModel):
179
185
  include_modules=['item_dnn'] if self.item_dnn else []
180
186
  )
181
187
 
188
+ self.compile(
189
+ optimizer=optimizer,
190
+ optimizer_params=optimizer_params,
191
+ scheduler=scheduler,
192
+ scheduler_params=scheduler_params,
193
+ loss=loss,
194
+ loss_params=loss_params,
195
+ )
196
+
182
197
  self.to(device)
183
198
 
184
199
  def user_tower(self, user_input: dict) -> torch.Tensor:
@@ -17,11 +17,10 @@ from nextrec.basic.layers import MLP, EmbeddingLayer, AveragePooling
17
17
 
18
18
  class YoutubeDNN(BaseMatchModel):
19
19
  """
20
- YouTube Deep Neural Network for Recommendations
21
-
22
- 用户塔:历史行为序列 + 用户特征 -> 用户embedding
23
- 物品塔:物品特征 -> 物品embedding
24
- 训练:sampled softmax loss (listwise)
20
+ YouTube Deep Neural Network for Recommendations.
21
+ User tower: behavior sequence + user features -> user embedding.
22
+ Item tower: item features -> item embedding.
23
+ Training usually uses listwise / sampled softmax style objectives.
25
24
  """
26
25
 
27
26
  @property
@@ -50,7 +49,13 @@ class YoutubeDNN(BaseMatchModel):
50
49
  embedding_l2_reg: float = 0.0,
51
50
  dense_l2_reg: float = 0.0,
52
51
  early_stop_patience: int = 20,
53
- model_id: str = 'youtube_dnn'):
52
+ optimizer: str | torch.optim.Optimizer = "adam",
53
+ optimizer_params: dict | None = None,
54
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
55
+ scheduler_params: dict | None = None,
56
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
57
+ loss_params: dict | list[dict] | None = None,
58
+ **kwargs):
54
59
 
55
60
  super(YoutubeDNN, self).__init__(
56
61
  user_dense_features=user_dense_features,
@@ -69,7 +74,7 @@ class YoutubeDNN(BaseMatchModel):
69
74
  embedding_l2_reg=embedding_l2_reg,
70
75
  dense_l2_reg=dense_l2_reg,
71
76
  early_stop_patience=early_stop_patience,
72
- model_id=model_id
77
+ **kwargs
73
78
  )
74
79
 
75
80
  self.embedding_dim = embedding_dim
@@ -94,7 +99,7 @@ class YoutubeDNN(BaseMatchModel):
94
99
  for feat in user_sparse_features or []:
95
100
  user_input_dim += feat.embedding_dim
96
101
  for feat in user_sequence_features or []:
97
- # 序列特征通过平均池化聚合
102
+ # Sequence features are pooled before entering the DNN
98
103
  user_input_dim += feat.embedding_dim
99
104
 
100
105
  user_dnn_units = user_dnn_hidden_units + [embedding_dim]
@@ -144,12 +149,20 @@ class YoutubeDNN(BaseMatchModel):
144
149
  include_modules=['item_dnn']
145
150
  )
146
151
 
152
+ self.compile(
153
+ optimizer=optimizer,
154
+ optimizer_params=optimizer_params,
155
+ scheduler=scheduler,
156
+ scheduler_params=scheduler_params,
157
+ loss=loss,
158
+ loss_params=loss_params,
159
+ )
160
+
147
161
  self.to(device)
148
162
 
149
163
  def user_tower(self, user_input: dict) -> torch.Tensor:
150
164
  """
151
- User tower
152
- 处理用户历史行为序列和其他用户特征
165
+ User tower to encode historical behavior sequences and user features.
153
166
  """
154
167
  all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
155
168
  user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
@@ -1,9 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
5
- Reference:
6
- [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
3
+ Author: Yang Zhou,zyaztec@gmail.com
4
+ Reference: [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
7
5
  """
8
6
 
9
7
  import torch
@@ -46,12 +44,13 @@ class ESMM(BaseModel):
46
44
  optimizer: str = "adam",
47
45
  optimizer_params: dict = {},
48
46
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
47
+ loss_params: dict | list[dict] | None = None,
49
48
  device: str = 'cpu',
50
- model_id: str = "baseline",
51
49
  embedding_l1_reg=1e-6,
52
50
  dense_l1_reg=1e-5,
53
51
  embedding_l2_reg=1e-5,
54
- dense_l2_reg=1e-4):
52
+ dense_l2_reg=1e-4,
53
+ **kwargs):
55
54
 
56
55
  # ESMM requires exactly 2 targets: ctr and ctcvr
57
56
  if len(target) != 2:
@@ -69,7 +68,7 @@ class ESMM(BaseModel):
69
68
  embedding_l2_reg=embedding_l2_reg,
70
69
  dense_l2_reg=dense_l2_reg,
71
70
  early_stop_patience=20,
72
- model_id=model_id
71
+ **kwargs
73
72
  )
74
73
 
75
74
  self.loss = loss
@@ -106,7 +105,8 @@ class ESMM(BaseModel):
106
105
  self.compile(
107
106
  optimizer=optimizer,
108
107
  optimizer_params=optimizer_params,
109
- loss=loss
108
+ loss=loss,
109
+ loss_params=loss_params,
110
110
  )
111
111
 
112
112
  def forward(self, x):
@@ -1,9 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
5
- Reference:
6
- [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
3
+ Author: Yang Zhou,zyaztec@gmail.com
4
+ Reference: [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
7
5
  """
8
6
 
9
7
  import torch
@@ -44,12 +42,13 @@ class MMOE(BaseModel):
44
42
  optimizer: str = "adam",
45
43
  optimizer_params: dict = {},
46
44
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
45
+ loss_params: dict | list[dict] | None = None,
47
46
  device: str = 'cpu',
48
- model_id: str = "baseline",
49
47
  embedding_l1_reg=1e-6,
50
48
  dense_l1_reg=1e-5,
51
49
  embedding_l2_reg=1e-5,
52
- dense_l2_reg=1e-4):
50
+ dense_l2_reg=1e-4,
51
+ **kwargs):
53
52
 
54
53
  super(MMOE, self).__init__(
55
54
  dense_features=dense_features,
@@ -63,7 +62,7 @@ class MMOE(BaseModel):
63
62
  embedding_l2_reg=embedding_l2_reg,
64
63
  dense_l2_reg=dense_l2_reg,
65
64
  early_stop_patience=20,
66
- model_id=model_id
65
+ **kwargs
67
66
  )
68
67
 
69
68
  self.loss = loss
@@ -128,7 +127,8 @@ class MMOE(BaseModel):
128
127
  self.compile(
129
128
  optimizer=optimizer,
130
129
  optimizer_params=optimizer_params,
131
- loss=loss
130
+ loss=loss,
131
+ loss_params=loss_params,
132
132
  )
133
133
 
134
134
  def forward(self, x):
@@ -1,9 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
5
- Reference:
6
- [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//RecSys. 2020: 269-278.
3
+ Author: Yang Zhou,zyaztec@gmail.com
4
+ Reference: [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//RecSys. 2020: 269-278.
7
5
  """
8
6
 
9
7
  import torch
@@ -47,12 +45,13 @@ class PLE(BaseModel):
47
45
  optimizer: str = "adam",
48
46
  optimizer_params: dict = {},
49
47
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
48
+ loss_params: dict | list[dict] | None = None,
50
49
  device: str = 'cpu',
51
- model_id: str = "baseline",
52
50
  embedding_l1_reg=1e-6,
53
51
  dense_l1_reg=1e-5,
54
52
  embedding_l2_reg=1e-5,
55
- dense_l2_reg=1e-4):
53
+ dense_l2_reg=1e-4,
54
+ **kwargs):
56
55
 
57
56
  super(PLE, self).__init__(
58
57
  dense_features=dense_features,
@@ -66,7 +65,7 @@ class PLE(BaseModel):
66
65
  embedding_l2_reg=embedding_l2_reg,
67
66
  dense_l2_reg=dense_l2_reg,
68
67
  early_stop_patience=20,
69
- model_id=model_id
68
+ **kwargs
70
69
  )
71
70
 
72
71
  self.loss = loss
@@ -166,7 +165,8 @@ class PLE(BaseModel):
166
165
  self.compile(
167
166
  optimizer=optimizer,
168
167
  optimizer_params=optimizer_params,
169
- loss=loss
168
+ loss=loss,
169
+ loss_params=loss_params,
170
170
  )
171
171
 
172
172
  def forward(self, x):