nextrec 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/layers.py +32 -8
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/metrics.py +2 -1
- nextrec/basic/model.py +3 -3
- nextrec/cli.py +41 -47
- nextrec/data/dataloader.py +1 -1
- nextrec/models/multi_task/esmm.py +23 -16
- nextrec/models/multi_task/mmoe.py +36 -17
- nextrec/models/multi_task/ple.py +18 -12
- nextrec/models/multi_task/poso.py +68 -37
- nextrec/models/multi_task/share_bottom.py +16 -2
- nextrec/models/ranking/afm.py +14 -14
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +61 -19
- nextrec/models/ranking/dcn_v2.py +224 -45
- nextrec/models/ranking/deepfm.py +14 -9
- nextrec/models/ranking/dien.py +215 -82
- nextrec/models/ranking/din.py +95 -57
- nextrec/models/ranking/fibinet.py +92 -30
- nextrec/models/ranking/fm.py +44 -8
- nextrec/models/ranking/masknet.py +7 -7
- nextrec/models/ranking/pnn.py +105 -38
- nextrec/models/ranking/widedeep.py +8 -4
- nextrec/models/ranking/xdeepfm.py +10 -5
- nextrec/utils/config.py +9 -3
- nextrec/utils/file.py +2 -1
- nextrec/utils/model.py +22 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/RECORD +33 -33
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/models/ranking/dien.py
CHANGED
|
@@ -1,11 +1,49 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Author:
|
|
4
|
-
|
|
3
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
4
|
+
Checkpoint: edit on 09/12/2025
|
|
5
5
|
Reference:
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
6
|
+
[1] Zhou G, Mou N, Fan Y, et al. Deep interest evolution network for click-through
|
|
7
|
+
rate prediction[C] // Proceedings of the AAAI conference on artificial intelligence.
|
|
8
|
+
2019, 33(01): 5941-5948. (https://arxiv.org/abs/1809.03672)
|
|
9
|
+
|
|
10
|
+
DIEN is a CTR prediction model that explicitly models how user interests evolve
|
|
11
|
+
over time. It introduces a two-stage pipeline:
|
|
12
|
+
(1) Interest Extraction: a GRU encodes raw behavior sequences into interest states
|
|
13
|
+
(2) Interest Evolution: an attention-aware GRU (AUGRU) updates interests by
|
|
14
|
+
focusing on behaviors most related to the target item
|
|
15
|
+
An auxiliary loss on next-click prediction guides the GRU to learn finer-grained
|
|
16
|
+
interest transitions and alleviates vanishing signals in long sequences.
|
|
17
|
+
|
|
18
|
+
Processing flow:
|
|
19
|
+
- Behavior embeddings -> DynamicGRU -> interest trajectory
|
|
20
|
+
- Target-aware attention scores highlight behaviors relevant to the candidate
|
|
21
|
+
- AUGRU modulates GRU updates with attention to emphasize impactful behaviors
|
|
22
|
+
- Final evolved interest, candidate embedding, and context features -> MLP -> CTR
|
|
23
|
+
|
|
24
|
+
Key advantages:
|
|
25
|
+
- Captures temporal evolution of user interests instead of a static summary
|
|
26
|
+
- Target-aware attention steers the evolution toward the candidate item
|
|
27
|
+
- AUGRU gates mitigate noise from irrelevant historical behaviors
|
|
28
|
+
- Auxiliary loss provides additional supervision for sequential dynamics
|
|
29
|
+
|
|
30
|
+
DIEN 是一个 CTR 预估模型,用于显式建模用户兴趣的时间演化。核心包含两阶段:
|
|
31
|
+
(1) 兴趣抽取:通过 GRU 将原始行为序列编码为兴趣状态轨迹
|
|
32
|
+
(2) 兴趣演化:利用目标感知的注意力门控 GRU(AUGRU),强调与候选目标相关的行为,
|
|
33
|
+
引导兴趣随时间更新
|
|
34
|
+
同时引入针对下一个行为点击的辅助损失,缓解长序列信号衰减并强化兴趣转移学习。
|
|
35
|
+
|
|
36
|
+
流程概览:
|
|
37
|
+
- 行为 embedding 输入 DynamicGRU,得到兴趣轨迹
|
|
38
|
+
- 目标相关的注意力得分突出关键行为
|
|
39
|
+
- AUGRU 用注意力调制更新,抑制无关历史噪声
|
|
40
|
+
- 最终演化兴趣 + 候选 embedding + 其他上下文特征,经 MLP 输出 CTR
|
|
41
|
+
|
|
42
|
+
主要优点:
|
|
43
|
+
- 建模兴趣随时间的演化,而非静态聚合
|
|
44
|
+
- 目标感知注意力将兴趣演化对齐到候选物品
|
|
45
|
+
- AUGRU 门控削弱无关行为的干扰
|
|
46
|
+
- 辅助损失为序列动态提供额外监督信号
|
|
9
47
|
"""
|
|
10
48
|
|
|
11
49
|
import torch
|
|
@@ -63,23 +101,28 @@ class AUGRU(nn.Module):
|
|
|
63
101
|
batch_size, seq_len, _ = x.shape
|
|
64
102
|
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
65
103
|
outputs = []
|
|
104
|
+
|
|
66
105
|
for t in range(seq_len):
|
|
67
|
-
x_t = x[:, t, :] # [
|
|
68
|
-
att_t = att_scores[:, t, :] # [
|
|
106
|
+
x_t = x[:, t, :] # [B, input_size]
|
|
107
|
+
att_t = att_scores[:, t, :] # [B, 1]
|
|
69
108
|
|
|
70
109
|
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
71
110
|
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
72
111
|
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
73
112
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
74
113
|
|
|
75
|
-
resetgate = torch.sigmoid(i_r + h_r)
|
|
76
|
-
|
|
77
|
-
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
78
|
-
|
|
79
|
-
|
|
114
|
+
resetgate = torch.sigmoid(i_r + h_r) # r_t
|
|
115
|
+
updategate = torch.sigmoid(i_i + h_i) # z_t
|
|
116
|
+
newgate = torch.tanh(i_n + resetgate * h_n) # n_t
|
|
117
|
+
|
|
118
|
+
# att_t: [B,1],broadcast to [B,H]
|
|
119
|
+
z_att = updategate * att_t
|
|
120
|
+
|
|
121
|
+
# h_t = (1 - z'_t) * h_{t-1} + z'_t * n_t
|
|
122
|
+
h = (1.0 - z_att) * h + z_att * newgate
|
|
80
123
|
outputs.append(h.unsqueeze(1))
|
|
81
|
-
output = torch.cat(outputs, dim=1)
|
|
82
124
|
|
|
125
|
+
output = torch.cat(outputs, dim=1) # [B, L, H]
|
|
83
126
|
return output, h
|
|
84
127
|
|
|
85
128
|
|
|
@@ -112,11 +155,10 @@ class DynamicGRU(nn.Module):
|
|
|
112
155
|
for weight in self.parameters():
|
|
113
156
|
weight.data.uniform_(-std, std)
|
|
114
157
|
|
|
115
|
-
def forward(self, x
|
|
158
|
+
def forward(self, x):
|
|
116
159
|
"""
|
|
117
160
|
Args:
|
|
118
161
|
x: [batch_size, seq_len, input_size]
|
|
119
|
-
att_scores: [batch_size, seq_len] - attention scores for auxiliary loss
|
|
120
162
|
Returns:
|
|
121
163
|
output: [batch_size, seq_len, hidden_size]
|
|
122
164
|
hidden: [batch_size, hidden_size] - final hidden state
|
|
@@ -137,14 +179,15 @@ class DynamicGRU(nn.Module):
|
|
|
137
179
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
138
180
|
|
|
139
181
|
resetgate = torch.sigmoid(i_r + h_r)
|
|
140
|
-
|
|
182
|
+
updategate = torch.sigmoid(i_i + h_i)
|
|
141
183
|
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
142
|
-
|
|
184
|
+
|
|
185
|
+
# h_t = (1 - z_t) * h_{t-1} + z_t * n_t
|
|
186
|
+
h = (1.0 - updategate) * h + updategate * newgate
|
|
143
187
|
|
|
144
188
|
outputs.append(h.unsqueeze(1))
|
|
145
189
|
|
|
146
190
|
output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
|
|
147
|
-
|
|
148
191
|
return output, h
|
|
149
192
|
|
|
150
193
|
|
|
@@ -159,18 +202,22 @@ class DIEN(BaseModel):
|
|
|
159
202
|
|
|
160
203
|
def __init__(
|
|
161
204
|
self,
|
|
162
|
-
dense_features: list[DenseFeature],
|
|
163
|
-
sparse_features: list[SparseFeature],
|
|
164
|
-
sequence_features: list[SequenceFeature],
|
|
165
|
-
|
|
205
|
+
dense_features: list[DenseFeature] | None = None,
|
|
206
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
207
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
208
|
+
behavior_feature_name: str | None = None,
|
|
209
|
+
candidate_feature_name: str | None = None,
|
|
210
|
+
neg_behavior_feature_name: str | None = None,
|
|
211
|
+
mlp_params: dict | None = None,
|
|
166
212
|
gru_hidden_size: int = 64,
|
|
167
|
-
attention_hidden_units: list[int] =
|
|
213
|
+
attention_hidden_units: list[int] | None = None,
|
|
168
214
|
attention_activation: str = "sigmoid",
|
|
169
215
|
use_negsampling: bool = False,
|
|
170
|
-
|
|
216
|
+
aux_loss_weight: float = 1.0,
|
|
217
|
+
target: list[str] | str | None = None,
|
|
171
218
|
task: str | list[str] | None = None,
|
|
172
219
|
optimizer: str = "adam",
|
|
173
|
-
optimizer_params: dict =
|
|
220
|
+
optimizer_params: dict | None = None,
|
|
174
221
|
loss: str | nn.Module | None = "bce",
|
|
175
222
|
loss_params: dict | list[dict] | None = None,
|
|
176
223
|
device: str = "cpu",
|
|
@@ -181,6 +228,15 @@ class DIEN(BaseModel):
|
|
|
181
228
|
**kwargs,
|
|
182
229
|
):
|
|
183
230
|
|
|
231
|
+
dense_features = dense_features or []
|
|
232
|
+
sparse_features = sparse_features or []
|
|
233
|
+
sequence_features = sequence_features or []
|
|
234
|
+
mlp_params = mlp_params or {}
|
|
235
|
+
attention_hidden_units = attention_hidden_units or [80, 40]
|
|
236
|
+
optimizer_params = optimizer_params or {}
|
|
237
|
+
if loss is None:
|
|
238
|
+
loss = "bce"
|
|
239
|
+
|
|
184
240
|
super(DIEN, self).__init__(
|
|
185
241
|
dense_features=dense_features,
|
|
186
242
|
sparse_features=sparse_features,
|
|
@@ -196,31 +252,44 @@ class DIEN(BaseModel):
|
|
|
196
252
|
)
|
|
197
253
|
|
|
198
254
|
self.loss = loss
|
|
199
|
-
if self.loss is None:
|
|
200
|
-
self.loss = "bce"
|
|
201
|
-
|
|
202
255
|
self.use_negsampling = use_negsampling
|
|
256
|
+
self.aux_loss_weight = float(aux_loss_weight)
|
|
257
|
+
self.auxiliary_cache = None
|
|
203
258
|
|
|
204
|
-
# Features classification
|
|
205
259
|
if len(sequence_features) == 0:
|
|
206
260
|
raise ValueError(
|
|
207
261
|
"DIEN requires at least one sequence feature for user behavior history"
|
|
208
262
|
)
|
|
209
263
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
264
|
+
if behavior_feature_name is None:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
"DIEN requires at least one sequence feature as behavior item feature"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
if candidate_feature_name is None:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
"DIEN requires at least one sparse_feature as candidate item feature"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
self.behavior_feature = [
|
|
275
|
+
f for f in sequence_features if f.name == behavior_feature_name
|
|
276
|
+
][0]
|
|
277
|
+
self.candidate_feature = [
|
|
278
|
+
f for f in sparse_features if f.name == candidate_feature_name
|
|
279
|
+
][0]
|
|
214
280
|
|
|
215
281
|
self.other_sparse_features = (
|
|
216
282
|
sparse_features[:-1] if self.candidate_feature else sparse_features
|
|
217
283
|
)
|
|
218
|
-
|
|
284
|
+
|
|
285
|
+
self.neg_behavior_feature = None
|
|
219
286
|
|
|
220
287
|
# Embedding layer
|
|
221
288
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
222
289
|
|
|
223
290
|
behavior_emb_dim = self.behavior_feature.embedding_dim
|
|
291
|
+
|
|
292
|
+
# projection candidate feature to match GRU hidden size if needed
|
|
224
293
|
self.candidate_proj = None
|
|
225
294
|
if (
|
|
226
295
|
self.candidate_feature is not None
|
|
@@ -230,17 +299,16 @@ class DIEN(BaseModel):
|
|
|
230
299
|
self.candidate_feature.embedding_dim, gru_hidden_size
|
|
231
300
|
)
|
|
232
301
|
|
|
233
|
-
#
|
|
302
|
+
# gru for interest extraction
|
|
234
303
|
self.interest_extractor = DynamicGRU(
|
|
235
304
|
input_size=behavior_emb_dim, hidden_size=gru_hidden_size
|
|
236
305
|
)
|
|
237
306
|
|
|
238
|
-
# Attention layer for computing attention scores
|
|
239
307
|
self.attention_layer = AttentionPoolingLayer(
|
|
240
308
|
embedding_dim=gru_hidden_size,
|
|
241
309
|
hidden_units=attention_hidden_units,
|
|
242
310
|
activation=attention_activation,
|
|
243
|
-
use_softmax=False,
|
|
311
|
+
use_softmax=False,
|
|
244
312
|
)
|
|
245
313
|
|
|
246
314
|
# Interest Evolution Layer (AUGRU)
|
|
@@ -248,7 +316,26 @@ class DIEN(BaseModel):
|
|
|
248
316
|
input_size=gru_hidden_size, hidden_size=gru_hidden_size
|
|
249
317
|
)
|
|
250
318
|
|
|
251
|
-
#
|
|
319
|
+
# build auxiliary loss net if provided neg sampling and neg_behavior_feature_name
|
|
320
|
+
# auxiliary loss uses the interest states to predict the next behavior in the sequence
|
|
321
|
+
# that's the second task of DIEN
|
|
322
|
+
if self.use_negsampling:
|
|
323
|
+
neg_candidates = [
|
|
324
|
+
f for f in sequence_features if f.name == neg_behavior_feature_name
|
|
325
|
+
]
|
|
326
|
+
if len(neg_candidates) == 0:
|
|
327
|
+
raise ValueError(
|
|
328
|
+
f"use_negsampling=True requires a negative sequence feature named '{neg_behavior_feature_name}'"
|
|
329
|
+
)
|
|
330
|
+
self.neg_behavior_feature = neg_candidates[0]
|
|
331
|
+
self.auxiliary_net = nn.Sequential(
|
|
332
|
+
nn.Linear(gru_hidden_size + behavior_emb_dim, gru_hidden_size),
|
|
333
|
+
nn.PReLU(),
|
|
334
|
+
nn.Linear(gru_hidden_size, 1),
|
|
335
|
+
)
|
|
336
|
+
else:
|
|
337
|
+
self.auxiliary_net = None
|
|
338
|
+
|
|
252
339
|
mlp_input_dim = 0
|
|
253
340
|
if self.candidate_feature:
|
|
254
341
|
mlp_input_dim += self.candidate_feature.embedding_dim
|
|
@@ -257,10 +344,10 @@ class DIEN(BaseModel):
|
|
|
257
344
|
mlp_input_dim += sum(
|
|
258
345
|
[getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
|
|
259
346
|
)
|
|
260
|
-
|
|
347
|
+
|
|
261
348
|
self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
|
|
262
349
|
self.prediction_layer = PredictionLayer(task_type=self.task)
|
|
263
|
-
|
|
350
|
+
|
|
264
351
|
self.register_regularization_weights(
|
|
265
352
|
embedding_attr="embedding",
|
|
266
353
|
include_modules=[
|
|
@@ -269,8 +356,10 @@ class DIEN(BaseModel):
|
|
|
269
356
|
"attention_layer",
|
|
270
357
|
"mlp",
|
|
271
358
|
"candidate_proj",
|
|
359
|
+
"auxiliary_net",
|
|
272
360
|
],
|
|
273
361
|
)
|
|
362
|
+
|
|
274
363
|
self.compile(
|
|
275
364
|
optimizer=optimizer,
|
|
276
365
|
optimizer_params=optimizer_params,
|
|
@@ -279,7 +368,7 @@ class DIEN(BaseModel):
|
|
|
279
368
|
)
|
|
280
369
|
|
|
281
370
|
def forward(self, x):
|
|
282
|
-
|
|
371
|
+
self.auxiliary_cache = None
|
|
283
372
|
if self.candidate_feature:
|
|
284
373
|
candidate_emb = self.embedding.embed_dict[
|
|
285
374
|
self.candidate_feature.embedding_name
|
|
@@ -289,87 +378,131 @@ class DIEN(BaseModel):
|
|
|
289
378
|
else:
|
|
290
379
|
raise ValueError("DIEN requires a candidate item feature")
|
|
291
380
|
|
|
292
|
-
# Get behavior sequence embedding
|
|
293
381
|
behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
|
|
294
382
|
behavior_emb = self.embedding.embed_dict[self.behavior_feature.embedding_name](
|
|
295
383
|
behavior_seq
|
|
296
384
|
) # [B, seq_len, emb_dim]
|
|
297
385
|
|
|
298
|
-
# Create mask for padding
|
|
299
386
|
if self.behavior_feature.padding_idx is not None:
|
|
300
|
-
mask = (
|
|
301
|
-
(behavior_seq != self.behavior_feature.padding_idx)
|
|
302
|
-
.unsqueeze(-1)
|
|
303
|
-
.float()
|
|
304
|
-
)
|
|
387
|
+
mask = (behavior_seq != self.behavior_feature.padding_idx).unsqueeze(-1)
|
|
305
388
|
else:
|
|
306
|
-
mask = (behavior_seq != 0).unsqueeze(-1)
|
|
389
|
+
mask = (behavior_seq != 0).unsqueeze(-1)
|
|
390
|
+
mask = mask.float() # [B, seq_len, 1]
|
|
307
391
|
|
|
308
|
-
# Step 1: Interest Extractor (GRU)
|
|
309
392
|
interest_states, _ = self.interest_extractor(
|
|
310
393
|
behavior_emb
|
|
311
394
|
) # [B, seq_len, hidden_size]
|
|
312
395
|
|
|
313
|
-
# Step 2: Compute attention scores for each time step
|
|
314
396
|
batch_size, seq_len, hidden_size = interest_states.shape
|
|
315
397
|
|
|
316
|
-
# Project candidate to hidden_size if necessary (defined in __init__)
|
|
317
398
|
if self.candidate_proj is not None:
|
|
318
399
|
candidate_for_attention = self.candidate_proj(candidate_emb)
|
|
319
400
|
else:
|
|
320
|
-
candidate_for_attention = candidate_emb
|
|
321
|
-
|
|
322
|
-
# Compute attention scores for AUGRU
|
|
323
|
-
attention_scores = []
|
|
401
|
+
candidate_for_attention = candidate_emb # [B, hidden_size]
|
|
402
|
+
att_scores_list = []
|
|
324
403
|
for t in range(seq_len):
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
) # [B, 1]
|
|
336
|
-
|
|
404
|
+
# [B, 4H]
|
|
405
|
+
concat_feat = torch.cat(
|
|
406
|
+
[
|
|
407
|
+
candidate_for_attention,
|
|
408
|
+
interest_states[:, t, :],
|
|
409
|
+
candidate_for_attention - interest_states[:, t, :],
|
|
410
|
+
candidate_for_attention * interest_states[:, t, :],
|
|
411
|
+
],
|
|
412
|
+
dim=-1,
|
|
413
|
+
)
|
|
414
|
+
score_t = self.attention_layer.attention_net(concat_feat) # [B, 1]
|
|
415
|
+
att_scores_list.append(score_t)
|
|
416
|
+
|
|
417
|
+
# [B, seq_len, 1]
|
|
418
|
+
att_scores = torch.cat(att_scores_list, dim=1)
|
|
419
|
+
|
|
420
|
+
scores_flat = att_scores.squeeze(-1) # [B, seq_len]
|
|
421
|
+
mask_flat = mask.squeeze(-1) # [B, seq_len]
|
|
337
422
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
) # [B, seq_len, 1]
|
|
341
|
-
attention_scores = torch.sigmoid(attention_scores) # Normalize to [0, 1]
|
|
423
|
+
scores_flat = scores_flat.masked_fill(mask_flat == 0, -1e9)
|
|
424
|
+
att_weights = torch.softmax(scores_flat, dim=1) # [B, seq_len]
|
|
425
|
+
att_weights = att_weights.unsqueeze(-1) # [B, seq_len, 1]
|
|
342
426
|
|
|
343
|
-
|
|
344
|
-
attention_scores = attention_scores * mask
|
|
427
|
+
att_weights = att_weights * mask
|
|
345
428
|
|
|
346
|
-
#
|
|
429
|
+
# 6. Interest Evolution(AUGRU)
|
|
347
430
|
final_states, final_interest = self.interest_evolution(
|
|
348
|
-
interest_states,
|
|
431
|
+
interest_states, att_weights
|
|
349
432
|
) # final_interest: [B, hidden_size]
|
|
350
433
|
|
|
351
|
-
|
|
434
|
+
if self.use_negsampling and self.training:
|
|
435
|
+
if self.neg_behavior_feature is None:
|
|
436
|
+
raise ValueError(
|
|
437
|
+
"Negative behavior feature is not configured while use_negsampling=True"
|
|
438
|
+
)
|
|
439
|
+
neg_seq = x[self.neg_behavior_feature.name].long()
|
|
440
|
+
neg_behavior_emb = self.embedding.embed_dict[
|
|
441
|
+
self.neg_behavior_feature.embedding_name
|
|
442
|
+
](neg_seq)
|
|
443
|
+
self.auxiliary_cache = {
|
|
444
|
+
"interest_states": interest_states,
|
|
445
|
+
"behavior_emb": behavior_emb,
|
|
446
|
+
"neg_behavior_emb": neg_behavior_emb,
|
|
447
|
+
"mask": mask,
|
|
448
|
+
}
|
|
449
|
+
|
|
352
450
|
other_embeddings = []
|
|
353
451
|
other_embeddings.append(candidate_emb)
|
|
354
452
|
other_embeddings.append(final_interest)
|
|
355
453
|
|
|
356
|
-
# Other sparse features
|
|
357
454
|
for feat in self.other_sparse_features:
|
|
358
455
|
feat_emb = self.embedding.embed_dict[feat.embedding_name](
|
|
359
456
|
x[feat.name].long()
|
|
360
457
|
)
|
|
361
458
|
other_embeddings.append(feat_emb)
|
|
362
459
|
|
|
363
|
-
|
|
364
|
-
for feat in self.dense_features_list:
|
|
460
|
+
for feat in self.dense_features:
|
|
365
461
|
val = x[feat.name].float()
|
|
366
462
|
if val.dim() == 1:
|
|
367
463
|
val = val.unsqueeze(1)
|
|
368
464
|
other_embeddings.append(val)
|
|
369
465
|
|
|
370
|
-
# Concatenate all features
|
|
371
466
|
concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
|
|
372
467
|
|
|
373
|
-
# MLP prediction
|
|
374
468
|
y = self.mlp(concat_input) # [B, 1]
|
|
375
469
|
return self.prediction_layer(y)
|
|
470
|
+
|
|
471
|
+
def compute_auxiliary_loss(self):
|
|
472
|
+
if not (self.training and self.use_negsampling and self.auxiliary_net):
|
|
473
|
+
return torch.tensor(0.0, device=self.device)
|
|
474
|
+
if self.auxiliary_cache is None:
|
|
475
|
+
return torch.tensor(0.0, device=self.device)
|
|
476
|
+
|
|
477
|
+
interest_states = self.auxiliary_cache["interest_states"]
|
|
478
|
+
behavior_emb = self.auxiliary_cache["behavior_emb"]
|
|
479
|
+
neg_behavior_emb = self.auxiliary_cache["neg_behavior_emb"]
|
|
480
|
+
mask = self.auxiliary_cache["mask"]
|
|
481
|
+
|
|
482
|
+
interest_states = interest_states[:, :-1, :]
|
|
483
|
+
pos_seq = behavior_emb[:, 1:, :]
|
|
484
|
+
neg_seq = neg_behavior_emb[:, 1:, :]
|
|
485
|
+
aux_mask = mask[:, 1:, :].squeeze(-1)
|
|
486
|
+
|
|
487
|
+
if aux_mask.sum() == 0:
|
|
488
|
+
return torch.tensor(0.0, device=self.device)
|
|
489
|
+
|
|
490
|
+
pos_input = torch.cat([interest_states, pos_seq], dim=-1)
|
|
491
|
+
neg_input = torch.cat([interest_states, neg_seq], dim=-1)
|
|
492
|
+
pos_logits = self.auxiliary_net(pos_input).squeeze(-1)
|
|
493
|
+
neg_logits = self.auxiliary_net(neg_input).squeeze(-1)
|
|
494
|
+
|
|
495
|
+
pos_loss = F.binary_cross_entropy_with_logits(
|
|
496
|
+
pos_logits, torch.ones_like(pos_logits), reduction="none"
|
|
497
|
+
)
|
|
498
|
+
neg_loss = F.binary_cross_entropy_with_logits(
|
|
499
|
+
neg_logits, torch.zeros_like(neg_logits), reduction="none"
|
|
500
|
+
)
|
|
501
|
+
aux_loss = (pos_loss + neg_loss) * aux_mask
|
|
502
|
+
aux_loss = aux_loss.sum() / torch.clamp(aux_mask.sum(), min=1.0)
|
|
503
|
+
return aux_loss
|
|
504
|
+
|
|
505
|
+
def compute_loss(self, y_pred, y_true):
|
|
506
|
+
main_loss = super().compute_loss(y_pred, y_true)
|
|
507
|
+
aux_loss = self.compute_auxiliary_loss()
|
|
508
|
+
return main_loss + self.aux_loss_weight * aux_loss
|