nextrec 0.1.11__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +1 -2
- nextrec/basic/callback.py +1 -2
- nextrec/basic/features.py +39 -8
- nextrec/basic/layers.py +3 -4
- nextrec/basic/loggers.py +15 -10
- nextrec/basic/metrics.py +1 -2
- nextrec/basic/model.py +160 -125
- nextrec/basic/session.py +150 -0
- nextrec/data/__init__.py +13 -2
- nextrec/data/data_utils.py +74 -22
- nextrec/data/dataloader.py +513 -0
- nextrec/data/preprocessor.py +494 -134
- nextrec/loss/__init__.py +31 -24
- nextrec/loss/listwise.py +164 -0
- nextrec/loss/loss_utils.py +133 -106
- nextrec/loss/pairwise.py +105 -0
- nextrec/loss/pointwise.py +198 -0
- nextrec/models/match/dssm.py +26 -17
- nextrec/models/match/dssm_v2.py +20 -2
- nextrec/models/match/mind.py +18 -3
- nextrec/models/match/sdm.py +17 -2
- nextrec/models/match/youtube_dnn.py +23 -10
- nextrec/models/multi_task/esmm.py +8 -8
- nextrec/models/multi_task/mmoe.py +8 -8
- nextrec/models/multi_task/ple.py +8 -8
- nextrec/models/multi_task/share_bottom.py +8 -8
- nextrec/models/ranking/__init__.py +8 -0
- nextrec/models/ranking/afm.py +5 -4
- nextrec/models/ranking/autoint.py +6 -4
- nextrec/models/ranking/dcn.py +6 -4
- nextrec/models/ranking/deepfm.py +5 -4
- nextrec/models/ranking/dien.py +6 -4
- nextrec/models/ranking/din.py +6 -4
- nextrec/models/ranking/fibinet.py +6 -4
- nextrec/models/ranking/fm.py +6 -4
- nextrec/models/ranking/masknet.py +6 -4
- nextrec/models/ranking/pnn.py +6 -4
- nextrec/models/ranking/widedeep.py +6 -4
- nextrec/models/ranking/xdeepfm.py +6 -4
- nextrec/utils/__init__.py +7 -11
- nextrec/utils/embedding.py +2 -4
- nextrec/utils/initializer.py +4 -5
- nextrec/utils/optimizer.py +7 -8
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/METADATA +3 -3
- nextrec-0.2.2.dist-info/RECORD +53 -0
- nextrec/basic/dataloader.py +0 -447
- nextrec/loss/match_losses.py +0 -294
- nextrec/utils/common.py +0 -14
- nextrec-0.1.11.dist-info/RECORD +0 -51
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pointwise loss functions, including imbalance-aware variants.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional, Sequence
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CosineContrastiveLoss(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Contrastive loss using cosine similarity for positive/negative pairs.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, margin: float = 0.5, reduction: str = "mean"):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.margin = margin
|
|
20
|
+
self.reduction = reduction
|
|
21
|
+
|
|
22
|
+
def forward(
|
|
23
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor, labels: torch.Tensor
|
|
24
|
+
) -> torch.Tensor:
|
|
25
|
+
labels = labels.float()
|
|
26
|
+
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
27
|
+
pos_loss = torch.clamp(self.margin - similarity, min=0) * labels
|
|
28
|
+
neg_loss = torch.clamp(similarity - self.margin, min=0) * (1 - labels)
|
|
29
|
+
loss = pos_loss + neg_loss
|
|
30
|
+
|
|
31
|
+
if self.reduction == "mean":
|
|
32
|
+
return loss.mean()
|
|
33
|
+
if self.reduction == "sum":
|
|
34
|
+
return loss.sum()
|
|
35
|
+
return loss
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class WeightedBCELoss(nn.Module):
|
|
39
|
+
"""
|
|
40
|
+
Binary cross entropy with controllable positive class weight.
|
|
41
|
+
Supports probability or logit inputs via `logits` flag.
|
|
42
|
+
If `auto_balance=True` and `pos_weight` is None, the positive weight is
|
|
43
|
+
computed from the batch as (#neg / #pos) for stable imbalance handling.
|
|
44
|
+
"""
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
pos_weight: float | torch.Tensor | None = None,
|
|
48
|
+
reduction: str = "mean",
|
|
49
|
+
logits: bool = False,
|
|
50
|
+
auto_balance: bool = False,
|
|
51
|
+
):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.reduction = reduction
|
|
54
|
+
self.logits = logits
|
|
55
|
+
self.auto_balance = auto_balance
|
|
56
|
+
|
|
57
|
+
if pos_weight is not None:
|
|
58
|
+
self.register_buffer(
|
|
59
|
+
"pos_weight",
|
|
60
|
+
torch.as_tensor(pos_weight, dtype=torch.float32),
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
self.pos_weight = None
|
|
64
|
+
|
|
65
|
+
def _resolve_pos_weight(self, labels: torch.Tensor) -> torch.Tensor:
|
|
66
|
+
if self.pos_weight is not None:
|
|
67
|
+
return self.pos_weight.to(device=labels.device)
|
|
68
|
+
|
|
69
|
+
if not self.auto_balance:
|
|
70
|
+
return torch.tensor(1.0, device=labels.device, dtype=labels.dtype)
|
|
71
|
+
|
|
72
|
+
labels_float = labels.float()
|
|
73
|
+
pos = torch.clamp(labels_float.sum(), min=1.0)
|
|
74
|
+
neg = torch.clamp(labels_float.numel() - labels_float.sum(), min=1.0)
|
|
75
|
+
return (neg / pos).to(device=labels.device, dtype=labels.dtype)
|
|
76
|
+
|
|
77
|
+
def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
labels = labels.float()
|
|
79
|
+
current_pos_weight = self._resolve_pos_weight(labels)
|
|
80
|
+
current_pos_weight = current_pos_weight.to(inputs.dtype)
|
|
81
|
+
|
|
82
|
+
if self.logits:
|
|
83
|
+
loss = F.binary_cross_entropy_with_logits(
|
|
84
|
+
inputs, labels, pos_weight=current_pos_weight, reduction="none"
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
|
|
88
|
+
base_loss = F.binary_cross_entropy(probs, labels, reduction="none")
|
|
89
|
+
loss = torch.where(labels == 1, base_loss * current_pos_weight, base_loss)
|
|
90
|
+
|
|
91
|
+
if self.reduction == "mean":
|
|
92
|
+
return loss.mean()
|
|
93
|
+
elif self.reduction == "sum":
|
|
94
|
+
return loss.sum()
|
|
95
|
+
else:
|
|
96
|
+
return loss
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class FocalLoss(nn.Module):
|
|
100
|
+
"""
|
|
101
|
+
Standard focal loss for binary or multi-class classification.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
gamma: float = 2.0,
|
|
107
|
+
alpha: Optional[float | Sequence[float] | torch.Tensor] = None,
|
|
108
|
+
reduction: str = "mean",
|
|
109
|
+
logits: bool = False,
|
|
110
|
+
):
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.gamma = gamma
|
|
113
|
+
self.reduction = reduction
|
|
114
|
+
self.logits = logits
|
|
115
|
+
self.alpha = alpha
|
|
116
|
+
|
|
117
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
118
|
+
# Multi-class case
|
|
119
|
+
if inputs.dim() > 1 and inputs.size(1) > 1:
|
|
120
|
+
log_probs = F.log_softmax(inputs, dim=1)
|
|
121
|
+
probs = log_probs.exp()
|
|
122
|
+
targets_one_hot = F.one_hot(targets.long(), num_classes=inputs.size(1)).float()
|
|
123
|
+
|
|
124
|
+
alpha = self._get_alpha(inputs)
|
|
125
|
+
alpha_factor = targets_one_hot * alpha
|
|
126
|
+
focal_weight = (1.0 - probs) ** self.gamma
|
|
127
|
+
loss = torch.sum(alpha_factor * focal_weight * (-log_probs), dim=1)
|
|
128
|
+
else:
|
|
129
|
+
targets = targets.float()
|
|
130
|
+
if self.logits:
|
|
131
|
+
ce_loss = F.binary_cross_entropy_with_logits(
|
|
132
|
+
inputs, targets, reduction="none"
|
|
133
|
+
)
|
|
134
|
+
probs = torch.sigmoid(inputs)
|
|
135
|
+
else:
|
|
136
|
+
ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
|
|
137
|
+
probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
|
|
138
|
+
|
|
139
|
+
p_t = probs * targets + (1 - probs) * (1 - targets)
|
|
140
|
+
alpha_factor = self._get_binary_alpha(targets, inputs.device)
|
|
141
|
+
focal_weight = (1.0 - p_t) ** self.gamma
|
|
142
|
+
loss = alpha_factor * focal_weight * ce_loss
|
|
143
|
+
|
|
144
|
+
if self.reduction == "mean":
|
|
145
|
+
return loss.mean()
|
|
146
|
+
if self.reduction == "sum":
|
|
147
|
+
return loss.sum()
|
|
148
|
+
return loss
|
|
149
|
+
|
|
150
|
+
def _get_alpha(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
151
|
+
if self.alpha is None:
|
|
152
|
+
return torch.ones_like(inputs)
|
|
153
|
+
if isinstance(self.alpha, torch.Tensor):
|
|
154
|
+
return self.alpha.to(inputs.device)
|
|
155
|
+
alpha_tensor = torch.tensor(self.alpha, device=inputs.device, dtype=inputs.dtype)
|
|
156
|
+
return alpha_tensor
|
|
157
|
+
|
|
158
|
+
def _get_binary_alpha(self, targets: torch.Tensor, device: torch.device) -> torch.Tensor:
|
|
159
|
+
if self.alpha is None:
|
|
160
|
+
return torch.ones_like(targets)
|
|
161
|
+
if isinstance(self.alpha, (float, int)):
|
|
162
|
+
return torch.where(targets == 1, self.alpha, 1 - float(self.alpha)).to(device)
|
|
163
|
+
alpha_tensor = torch.tensor(self.alpha, device=device, dtype=targets.dtype)
|
|
164
|
+
return torch.where(targets == 1, alpha_tensor, 1 - alpha_tensor)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class ClassBalancedFocalLoss(nn.Module):
|
|
168
|
+
"""
|
|
169
|
+
Focal loss weighted by effective number of samples per class.
|
|
170
|
+
Reference: "Class-Balanced Loss Based on Effective Number of Samples"
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
class_counts: Sequence[int] | torch.Tensor,
|
|
176
|
+
beta: float = 0.9999,
|
|
177
|
+
gamma: float = 2.0,
|
|
178
|
+
reduction: str = "mean",
|
|
179
|
+
):
|
|
180
|
+
super().__init__()
|
|
181
|
+
self.gamma = gamma
|
|
182
|
+
self.reduction = reduction
|
|
183
|
+
class_counts = torch.as_tensor(class_counts, dtype=torch.float32)
|
|
184
|
+
effective_num = 1.0 - torch.pow(beta, class_counts)
|
|
185
|
+
weights = (1.0 - beta) / (effective_num + 1e-12)
|
|
186
|
+
weights = weights / weights.sum() * len(weights)
|
|
187
|
+
self.register_buffer("class_weights", weights)
|
|
188
|
+
|
|
189
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
190
|
+
focal = FocalLoss(
|
|
191
|
+
gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True
|
|
192
|
+
)
|
|
193
|
+
loss = focal(inputs, targets)
|
|
194
|
+
if self.reduction == "mean":
|
|
195
|
+
return loss.mean()
|
|
196
|
+
if self.reduction == "sum":
|
|
197
|
+
return loss.sum()
|
|
198
|
+
return loss
|
nextrec/models/match/dssm.py
CHANGED
|
@@ -19,7 +19,8 @@ class DSSM(BaseMatchModel):
|
|
|
19
19
|
"""
|
|
20
20
|
Deep Structured Semantic Model
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
Dual-tower model that encodes user and item features separately and
|
|
23
|
+
computes similarity via cosine or dot product.
|
|
23
24
|
"""
|
|
24
25
|
|
|
25
26
|
@property
|
|
@@ -48,7 +49,13 @@ class DSSM(BaseMatchModel):
|
|
|
48
49
|
embedding_l2_reg: float = 0.0,
|
|
49
50
|
dense_l2_reg: float = 0.0,
|
|
50
51
|
early_stop_patience: int = 20,
|
|
51
|
-
|
|
52
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
53
|
+
optimizer_params: dict | None = None,
|
|
54
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
55
|
+
scheduler_params: dict | None = None,
|
|
56
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
57
|
+
loss_params: dict | list[dict] | None = None,
|
|
58
|
+
**kwargs):
|
|
52
59
|
|
|
53
60
|
super(DSSM, self).__init__(
|
|
54
61
|
user_dense_features=user_dense_features,
|
|
@@ -67,7 +74,7 @@ class DSSM(BaseMatchModel):
|
|
|
67
74
|
embedding_l2_reg=embedding_l2_reg,
|
|
68
75
|
dense_l2_reg=dense_l2_reg,
|
|
69
76
|
early_stop_patience=early_stop_patience,
|
|
70
|
-
|
|
77
|
+
**kwargs
|
|
71
78
|
)
|
|
72
79
|
|
|
73
80
|
self.embedding_dim = embedding_dim
|
|
@@ -86,7 +93,7 @@ class DSSM(BaseMatchModel):
|
|
|
86
93
|
if len(user_features) > 0:
|
|
87
94
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
88
95
|
|
|
89
|
-
#
|
|
96
|
+
# Compute user tower input dimension
|
|
90
97
|
user_input_dim = 0
|
|
91
98
|
for feat in user_dense_features or []:
|
|
92
99
|
user_input_dim += 1
|
|
@@ -117,7 +124,7 @@ class DSSM(BaseMatchModel):
|
|
|
117
124
|
if len(item_features) > 0:
|
|
118
125
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
119
126
|
|
|
120
|
-
#
|
|
127
|
+
# Compute item tower input dimension
|
|
121
128
|
item_input_dim = 0
|
|
122
129
|
for feat in item_dense_features or []:
|
|
123
130
|
item_input_dim += 1
|
|
@@ -136,7 +143,6 @@ class DSSM(BaseMatchModel):
|
|
|
136
143
|
activation=dnn_activation
|
|
137
144
|
)
|
|
138
145
|
|
|
139
|
-
# 注册正则化权重
|
|
140
146
|
self._register_regularization_weights(
|
|
141
147
|
embedding_attr='user_embedding',
|
|
142
148
|
include_modules=['user_dnn']
|
|
@@ -146,28 +152,33 @@ class DSSM(BaseMatchModel):
|
|
|
146
152
|
include_modules=['item_dnn']
|
|
147
153
|
)
|
|
148
154
|
|
|
155
|
+
if optimizer_params is None:
|
|
156
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
157
|
+
|
|
149
158
|
self.compile(
|
|
150
|
-
optimizer=
|
|
151
|
-
optimizer_params=
|
|
159
|
+
optimizer=optimizer,
|
|
160
|
+
optimizer_params=optimizer_params,
|
|
161
|
+
scheduler=scheduler,
|
|
162
|
+
scheduler_params=scheduler_params,
|
|
163
|
+
loss=loss,
|
|
164
|
+
loss_params=loss_params,
|
|
152
165
|
)
|
|
153
166
|
|
|
154
167
|
self.to(device)
|
|
155
168
|
|
|
156
169
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
157
170
|
"""
|
|
158
|
-
User tower
|
|
171
|
+
User tower encodes user features into embeddings.
|
|
159
172
|
|
|
160
173
|
Args:
|
|
161
|
-
user_input: user
|
|
174
|
+
user_input: user feature dict
|
|
162
175
|
|
|
163
176
|
Returns:
|
|
164
177
|
user_emb: [batch_size, embedding_dim]
|
|
165
178
|
"""
|
|
166
|
-
# 获取user特征的embedding
|
|
167
179
|
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
168
180
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
169
181
|
|
|
170
|
-
# 通过user DNN
|
|
171
182
|
user_emb = self.user_dnn(user_emb)
|
|
172
183
|
|
|
173
184
|
# L2 normalize for cosine similarity
|
|
@@ -178,19 +189,17 @@ class DSSM(BaseMatchModel):
|
|
|
178
189
|
|
|
179
190
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
180
191
|
"""
|
|
181
|
-
Item tower
|
|
192
|
+
Item tower encodes item features into embeddings.
|
|
182
193
|
|
|
183
194
|
Args:
|
|
184
|
-
item_input: item
|
|
195
|
+
item_input: item feature dict
|
|
185
196
|
|
|
186
197
|
Returns:
|
|
187
|
-
item_emb: [batch_size, embedding_dim]
|
|
198
|
+
item_emb: [batch_size, embedding_dim] or [batch_size, num_items, embedding_dim]
|
|
188
199
|
"""
|
|
189
|
-
# 获取item特征的embedding
|
|
190
200
|
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
191
201
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
192
202
|
|
|
193
|
-
# 通过item DNN
|
|
194
203
|
item_emb = self.item_dnn(item_emb)
|
|
195
204
|
|
|
196
205
|
# L2 normalize for cosine similarity
|
nextrec/models/match/dssm_v2.py
CHANGED
|
@@ -44,7 +44,13 @@ class DSSM_v2(BaseMatchModel):
|
|
|
44
44
|
embedding_l2_reg: float = 0.0,
|
|
45
45
|
dense_l2_reg: float = 0.0,
|
|
46
46
|
early_stop_patience: int = 20,
|
|
47
|
-
|
|
47
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
48
|
+
optimizer_params: dict | None = None,
|
|
49
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
50
|
+
scheduler_params: dict | None = None,
|
|
51
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
52
|
+
loss_params: dict | list[dict] | None = None,
|
|
53
|
+
**kwargs):
|
|
48
54
|
|
|
49
55
|
super(DSSM_v2, self).__init__(
|
|
50
56
|
user_dense_features=user_dense_features,
|
|
@@ -63,7 +69,7 @@ class DSSM_v2(BaseMatchModel):
|
|
|
63
69
|
embedding_l2_reg=embedding_l2_reg,
|
|
64
70
|
dense_l2_reg=dense_l2_reg,
|
|
65
71
|
early_stop_patience=early_stop_patience,
|
|
66
|
-
|
|
72
|
+
**kwargs
|
|
67
73
|
)
|
|
68
74
|
|
|
69
75
|
self.embedding_dim = embedding_dim
|
|
@@ -137,6 +143,18 @@ class DSSM_v2(BaseMatchModel):
|
|
|
137
143
|
include_modules=['item_dnn']
|
|
138
144
|
)
|
|
139
145
|
|
|
146
|
+
if optimizer_params is None:
|
|
147
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
148
|
+
|
|
149
|
+
self.compile(
|
|
150
|
+
optimizer=optimizer,
|
|
151
|
+
optimizer_params=optimizer_params,
|
|
152
|
+
scheduler=scheduler,
|
|
153
|
+
scheduler_params=scheduler_params,
|
|
154
|
+
loss=loss,
|
|
155
|
+
loss_params=loss_params,
|
|
156
|
+
)
|
|
157
|
+
|
|
140
158
|
self.to(device)
|
|
141
159
|
|
|
142
160
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
nextrec/models/match/mind.py
CHANGED
|
@@ -41,7 +41,7 @@ class MIND(BaseMatchModel):
|
|
|
41
41
|
item_dnn_hidden_units: list[int] = [256, 128],
|
|
42
42
|
dnn_activation: str = 'relu',
|
|
43
43
|
dnn_dropout: float = 0.0,
|
|
44
|
-
training_mode: Literal['pointwise', 'pairwise', 'listwise'] = '
|
|
44
|
+
training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
|
|
45
45
|
num_negative_samples: int = 100,
|
|
46
46
|
temperature: float = 1.0,
|
|
47
47
|
similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
|
|
@@ -51,7 +51,13 @@ class MIND(BaseMatchModel):
|
|
|
51
51
|
embedding_l2_reg: float = 0.0,
|
|
52
52
|
dense_l2_reg: float = 0.0,
|
|
53
53
|
early_stop_patience: int = 20,
|
|
54
|
-
|
|
54
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
55
|
+
optimizer_params: dict | None = None,
|
|
56
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
57
|
+
scheduler_params: dict | None = None,
|
|
58
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
59
|
+
loss_params: dict | list[dict] | None = None,
|
|
60
|
+
**kwargs):
|
|
55
61
|
|
|
56
62
|
super(MIND, self).__init__(
|
|
57
63
|
user_dense_features=user_dense_features,
|
|
@@ -70,7 +76,7 @@ class MIND(BaseMatchModel):
|
|
|
70
76
|
embedding_l2_reg=embedding_l2_reg,
|
|
71
77
|
dense_l2_reg=dense_l2_reg,
|
|
72
78
|
early_stop_patience=early_stop_patience,
|
|
73
|
-
|
|
79
|
+
**kwargs
|
|
74
80
|
)
|
|
75
81
|
|
|
76
82
|
self.embedding_dim = embedding_dim
|
|
@@ -152,6 +158,15 @@ class MIND(BaseMatchModel):
|
|
|
152
158
|
include_modules=['item_dnn'] if self.item_dnn else []
|
|
153
159
|
)
|
|
154
160
|
|
|
161
|
+
self.compile(
|
|
162
|
+
optimizer=optimizer,
|
|
163
|
+
optimizer_params=optimizer_params,
|
|
164
|
+
scheduler=scheduler,
|
|
165
|
+
scheduler_params=scheduler_params,
|
|
166
|
+
loss=loss,
|
|
167
|
+
loss_params=loss_params,
|
|
168
|
+
)
|
|
169
|
+
|
|
155
170
|
self.to(device)
|
|
156
171
|
|
|
157
172
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
nextrec/models/match/sdm.py
CHANGED
|
@@ -52,7 +52,13 @@ class SDM(BaseMatchModel):
|
|
|
52
52
|
embedding_l2_reg: float = 0.0,
|
|
53
53
|
dense_l2_reg: float = 0.0,
|
|
54
54
|
early_stop_patience: int = 20,
|
|
55
|
-
|
|
55
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
56
|
+
optimizer_params: dict | None = None,
|
|
57
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
58
|
+
scheduler_params: dict | None = None,
|
|
59
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
60
|
+
loss_params: dict | list[dict] | None = None,
|
|
61
|
+
**kwargs):
|
|
56
62
|
|
|
57
63
|
super(SDM, self).__init__(
|
|
58
64
|
user_dense_features=user_dense_features,
|
|
@@ -71,7 +77,7 @@ class SDM(BaseMatchModel):
|
|
|
71
77
|
embedding_l2_reg=embedding_l2_reg,
|
|
72
78
|
dense_l2_reg=dense_l2_reg,
|
|
73
79
|
early_stop_patience=early_stop_patience,
|
|
74
|
-
|
|
80
|
+
**kwargs
|
|
75
81
|
)
|
|
76
82
|
|
|
77
83
|
self.embedding_dim = embedding_dim
|
|
@@ -179,6 +185,15 @@ class SDM(BaseMatchModel):
|
|
|
179
185
|
include_modules=['item_dnn'] if self.item_dnn else []
|
|
180
186
|
)
|
|
181
187
|
|
|
188
|
+
self.compile(
|
|
189
|
+
optimizer=optimizer,
|
|
190
|
+
optimizer_params=optimizer_params,
|
|
191
|
+
scheduler=scheduler,
|
|
192
|
+
scheduler_params=scheduler_params,
|
|
193
|
+
loss=loss,
|
|
194
|
+
loss_params=loss_params,
|
|
195
|
+
)
|
|
196
|
+
|
|
182
197
|
self.to(device)
|
|
183
198
|
|
|
184
199
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
@@ -17,11 +17,10 @@ from nextrec.basic.layers import MLP, EmbeddingLayer, AveragePooling
|
|
|
17
17
|
|
|
18
18
|
class YoutubeDNN(BaseMatchModel):
|
|
19
19
|
"""
|
|
20
|
-
YouTube Deep Neural Network for Recommendations
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
训练:sampled softmax loss (listwise)
|
|
20
|
+
YouTube Deep Neural Network for Recommendations.
|
|
21
|
+
User tower: behavior sequence + user features -> user embedding.
|
|
22
|
+
Item tower: item features -> item embedding.
|
|
23
|
+
Training usually uses listwise / sampled softmax style objectives.
|
|
25
24
|
"""
|
|
26
25
|
|
|
27
26
|
@property
|
|
@@ -50,7 +49,13 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
50
49
|
embedding_l2_reg: float = 0.0,
|
|
51
50
|
dense_l2_reg: float = 0.0,
|
|
52
51
|
early_stop_patience: int = 20,
|
|
53
|
-
|
|
52
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
53
|
+
optimizer_params: dict | None = None,
|
|
54
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
55
|
+
scheduler_params: dict | None = None,
|
|
56
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
57
|
+
loss_params: dict | list[dict] | None = None,
|
|
58
|
+
**kwargs):
|
|
54
59
|
|
|
55
60
|
super(YoutubeDNN, self).__init__(
|
|
56
61
|
user_dense_features=user_dense_features,
|
|
@@ -69,7 +74,7 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
69
74
|
embedding_l2_reg=embedding_l2_reg,
|
|
70
75
|
dense_l2_reg=dense_l2_reg,
|
|
71
76
|
early_stop_patience=early_stop_patience,
|
|
72
|
-
|
|
77
|
+
**kwargs
|
|
73
78
|
)
|
|
74
79
|
|
|
75
80
|
self.embedding_dim = embedding_dim
|
|
@@ -94,7 +99,7 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
94
99
|
for feat in user_sparse_features or []:
|
|
95
100
|
user_input_dim += feat.embedding_dim
|
|
96
101
|
for feat in user_sequence_features or []:
|
|
97
|
-
#
|
|
102
|
+
# Sequence features are pooled before entering the DNN
|
|
98
103
|
user_input_dim += feat.embedding_dim
|
|
99
104
|
|
|
100
105
|
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
@@ -144,12 +149,20 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
144
149
|
include_modules=['item_dnn']
|
|
145
150
|
)
|
|
146
151
|
|
|
152
|
+
self.compile(
|
|
153
|
+
optimizer=optimizer,
|
|
154
|
+
optimizer_params=optimizer_params,
|
|
155
|
+
scheduler=scheduler,
|
|
156
|
+
scheduler_params=scheduler_params,
|
|
157
|
+
loss=loss,
|
|
158
|
+
loss_params=loss_params,
|
|
159
|
+
)
|
|
160
|
+
|
|
147
161
|
self.to(device)
|
|
148
162
|
|
|
149
163
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
150
164
|
"""
|
|
151
|
-
User tower
|
|
152
|
-
处理用户历史行为序列和其他用户特征
|
|
165
|
+
User tower to encode historical behavior sequences and user features.
|
|
153
166
|
"""
|
|
154
167
|
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
155
168
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Author:
|
|
4
|
-
|
|
5
|
-
Reference:
|
|
6
|
-
[1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
|
|
3
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
4
|
+
Reference: [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
|
|
7
5
|
"""
|
|
8
6
|
|
|
9
7
|
import torch
|
|
@@ -46,12 +44,13 @@ class ESMM(BaseModel):
|
|
|
46
44
|
optimizer: str = "adam",
|
|
47
45
|
optimizer_params: dict = {},
|
|
48
46
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
47
|
+
loss_params: dict | list[dict] | None = None,
|
|
49
48
|
device: str = 'cpu',
|
|
50
|
-
model_id: str = "baseline",
|
|
51
49
|
embedding_l1_reg=1e-6,
|
|
52
50
|
dense_l1_reg=1e-5,
|
|
53
51
|
embedding_l2_reg=1e-5,
|
|
54
|
-
dense_l2_reg=1e-4
|
|
52
|
+
dense_l2_reg=1e-4,
|
|
53
|
+
**kwargs):
|
|
55
54
|
|
|
56
55
|
# ESMM requires exactly 2 targets: ctr and ctcvr
|
|
57
56
|
if len(target) != 2:
|
|
@@ -69,7 +68,7 @@ class ESMM(BaseModel):
|
|
|
69
68
|
embedding_l2_reg=embedding_l2_reg,
|
|
70
69
|
dense_l2_reg=dense_l2_reg,
|
|
71
70
|
early_stop_patience=20,
|
|
72
|
-
|
|
71
|
+
**kwargs
|
|
73
72
|
)
|
|
74
73
|
|
|
75
74
|
self.loss = loss
|
|
@@ -106,7 +105,8 @@ class ESMM(BaseModel):
|
|
|
106
105
|
self.compile(
|
|
107
106
|
optimizer=optimizer,
|
|
108
107
|
optimizer_params=optimizer_params,
|
|
109
|
-
loss=loss
|
|
108
|
+
loss=loss,
|
|
109
|
+
loss_params=loss_params,
|
|
110
110
|
)
|
|
111
111
|
|
|
112
112
|
def forward(self, x):
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Author:
|
|
4
|
-
|
|
5
|
-
Reference:
|
|
6
|
-
[1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
|
|
3
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
4
|
+
Reference: [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
|
|
7
5
|
"""
|
|
8
6
|
|
|
9
7
|
import torch
|
|
@@ -44,12 +42,13 @@ class MMOE(BaseModel):
|
|
|
44
42
|
optimizer: str = "adam",
|
|
45
43
|
optimizer_params: dict = {},
|
|
46
44
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
45
|
+
loss_params: dict | list[dict] | None = None,
|
|
47
46
|
device: str = 'cpu',
|
|
48
|
-
model_id: str = "baseline",
|
|
49
47
|
embedding_l1_reg=1e-6,
|
|
50
48
|
dense_l1_reg=1e-5,
|
|
51
49
|
embedding_l2_reg=1e-5,
|
|
52
|
-
dense_l2_reg=1e-4
|
|
50
|
+
dense_l2_reg=1e-4,
|
|
51
|
+
**kwargs):
|
|
53
52
|
|
|
54
53
|
super(MMOE, self).__init__(
|
|
55
54
|
dense_features=dense_features,
|
|
@@ -63,7 +62,7 @@ class MMOE(BaseModel):
|
|
|
63
62
|
embedding_l2_reg=embedding_l2_reg,
|
|
64
63
|
dense_l2_reg=dense_l2_reg,
|
|
65
64
|
early_stop_patience=20,
|
|
66
|
-
|
|
65
|
+
**kwargs
|
|
67
66
|
)
|
|
68
67
|
|
|
69
68
|
self.loss = loss
|
|
@@ -128,7 +127,8 @@ class MMOE(BaseModel):
|
|
|
128
127
|
self.compile(
|
|
129
128
|
optimizer=optimizer,
|
|
130
129
|
optimizer_params=optimizer_params,
|
|
131
|
-
loss=loss
|
|
130
|
+
loss=loss,
|
|
131
|
+
loss_params=loss_params,
|
|
132
132
|
)
|
|
133
133
|
|
|
134
134
|
def forward(self, x):
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Author:
|
|
4
|
-
|
|
5
|
-
Reference:
|
|
6
|
-
[1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//RecSys. 2020: 269-278.
|
|
3
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
4
|
+
Reference: [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//RecSys. 2020: 269-278.
|
|
7
5
|
"""
|
|
8
6
|
|
|
9
7
|
import torch
|
|
@@ -47,12 +45,13 @@ class PLE(BaseModel):
|
|
|
47
45
|
optimizer: str = "adam",
|
|
48
46
|
optimizer_params: dict = {},
|
|
49
47
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
48
|
+
loss_params: dict | list[dict] | None = None,
|
|
50
49
|
device: str = 'cpu',
|
|
51
|
-
model_id: str = "baseline",
|
|
52
50
|
embedding_l1_reg=1e-6,
|
|
53
51
|
dense_l1_reg=1e-5,
|
|
54
52
|
embedding_l2_reg=1e-5,
|
|
55
|
-
dense_l2_reg=1e-4
|
|
53
|
+
dense_l2_reg=1e-4,
|
|
54
|
+
**kwargs):
|
|
56
55
|
|
|
57
56
|
super(PLE, self).__init__(
|
|
58
57
|
dense_features=dense_features,
|
|
@@ -66,7 +65,7 @@ class PLE(BaseModel):
|
|
|
66
65
|
embedding_l2_reg=embedding_l2_reg,
|
|
67
66
|
dense_l2_reg=dense_l2_reg,
|
|
68
67
|
early_stop_patience=20,
|
|
69
|
-
|
|
68
|
+
**kwargs
|
|
70
69
|
)
|
|
71
70
|
|
|
72
71
|
self.loss = loss
|
|
@@ -166,7 +165,8 @@ class PLE(BaseModel):
|
|
|
166
165
|
self.compile(
|
|
167
166
|
optimizer=optimizer,
|
|
168
167
|
optimizer_params=optimizer_params,
|
|
169
|
-
loss=loss
|
|
168
|
+
loss=loss,
|
|
169
|
+
loss_params=loss_params,
|
|
170
170
|
)
|
|
171
171
|
|
|
172
172
|
def forward(self, x):
|