nextrec 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -10
- nextrec/basic/callback.py +0 -1
- nextrec/basic/dataloader.py +127 -168
- nextrec/basic/features.py +27 -24
- nextrec/basic/layers.py +159 -328
- nextrec/basic/loggers.py +37 -50
- nextrec/basic/metrics.py +147 -255
- nextrec/basic/model.py +462 -817
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +12 -16
- nextrec/data/preprocessor.py +252 -276
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +22 -30
- nextrec/loss/match_losses.py +83 -116
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +61 -70
- nextrec/models/match/dssm_v2.py +51 -61
- nextrec/models/match/mind.py +71 -89
- nextrec/models/match/sdm.py +81 -93
- nextrec/models/match/youtube_dnn.py +53 -62
- nextrec/models/multi_task/esmm.py +43 -49
- nextrec/models/multi_task/mmoe.py +56 -65
- nextrec/models/multi_task/ple.py +65 -92
- nextrec/models/multi_task/share_bottom.py +42 -48
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +30 -39
- nextrec/models/ranking/autoint.py +57 -70
- nextrec/models/ranking/dcn.py +35 -43
- nextrec/models/ranking/deepfm.py +28 -34
- nextrec/models/ranking/dien.py +79 -115
- nextrec/models/ranking/din.py +60 -84
- nextrec/models/ranking/fibinet.py +35 -51
- nextrec/models/ranking/fm.py +26 -28
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +31 -30
- nextrec/models/ranking/widedeep.py +31 -36
- nextrec/models/ranking/xdeepfm.py +39 -46
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +15 -23
- nextrec/utils/optimizer.py +10 -14
- {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
- nextrec-0.1.7.dist-info/RECORD +51 -0
- nextrec-0.1.3.dist-info/RECORD +0 -51
- {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
- {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py
CHANGED
|
@@ -50,14 +50,7 @@ __all__ = [
|
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class PredictionLayer(nn.Module):
|
|
53
|
-
_CLASSIFICATION_TASKS = {
|
|
54
|
-
"classification",
|
|
55
|
-
"binary",
|
|
56
|
-
"ctr",
|
|
57
|
-
"ranking",
|
|
58
|
-
"match",
|
|
59
|
-
"matching",
|
|
60
|
-
}
|
|
53
|
+
_CLASSIFICATION_TASKS = {"classification", "binary", "ctr", "ranking", "match", "matching"}
|
|
61
54
|
_REGRESSION_TASKS = {"regression", "continuous"}
|
|
62
55
|
_MULTICLASS_TASKS = {"multiclass", "softmax"}
|
|
63
56
|
|
|
@@ -220,9 +213,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
220
213
|
elif feature.combiner == "concat":
|
|
221
214
|
pooling_layer = ConcatPooling()
|
|
222
215
|
else:
|
|
223
|
-
raise ValueError(
|
|
224
|
-
f"Unknown combiner for {feature.name}: {feature.combiner}"
|
|
225
|
-
)
|
|
216
|
+
raise ValueError(f"Unknown combiner for {feature.name}: {feature.combiner}")
|
|
226
217
|
|
|
227
218
|
feature_mask = InputMask()(x, feature, seq_input)
|
|
228
219
|
sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
|
|
@@ -254,9 +245,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
254
245
|
|
|
255
246
|
if target_dim is not None:
|
|
256
247
|
aligned_dense = [
|
|
257
|
-
emb.unsqueeze(1)
|
|
258
|
-
for emb in dense_embeds
|
|
259
|
-
if emb.shape[-1] == target_dim
|
|
248
|
+
emb.unsqueeze(1) for emb in dense_embeds if emb.shape[-1] == target_dim
|
|
260
249
|
]
|
|
261
250
|
output_embeddings.extend(aligned_dense)
|
|
262
251
|
|
|
@@ -268,9 +257,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
268
257
|
|
|
269
258
|
return torch.cat(output_embeddings, dim=1)
|
|
270
259
|
|
|
271
|
-
def _project_dense(
|
|
272
|
-
self, feature: DenseFeature, x: dict[str, torch.Tensor]
|
|
273
|
-
) -> torch.Tensor:
|
|
260
|
+
def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
274
261
|
if feature.name not in x:
|
|
275
262
|
raise KeyError(f"Dense feature '{feature.name}' is missing from input.")
|
|
276
263
|
|
|
@@ -293,7 +280,6 @@ class EmbeddingLayer(nn.Module):
|
|
|
293
280
|
def _compute_output_dim(self):
|
|
294
281
|
return
|
|
295
282
|
|
|
296
|
-
|
|
297
283
|
class InputMask(nn.Module):
|
|
298
284
|
"""Utility module to build sequence masks for pooling layers."""
|
|
299
285
|
|
|
@@ -303,9 +289,9 @@ class InputMask(nn.Module):
|
|
|
303
289
|
def forward(self, x, fea, seq_tensor=None):
|
|
304
290
|
values = seq_tensor if seq_tensor is not None else x[fea.name]
|
|
305
291
|
if fea.padding_idx is not None:
|
|
306
|
-
mask = values.long() != fea.padding_idx
|
|
292
|
+
mask = (values.long() != fea.padding_idx)
|
|
307
293
|
else:
|
|
308
|
-
mask = values.long() != 0
|
|
294
|
+
mask = (values.long() != 0)
|
|
309
295
|
if mask.dim() == 1:
|
|
310
296
|
mask = mask.unsqueeze(-1)
|
|
311
297
|
return mask.unsqueeze(1).float()
|
|
@@ -333,7 +319,7 @@ class ConcatPooling(nn.Module):
|
|
|
333
319
|
super().__init__()
|
|
334
320
|
|
|
335
321
|
def forward(self, x, mask=None):
|
|
336
|
-
return x.flatten(start_dim=1, end_dim=2)
|
|
322
|
+
return x.flatten(start_dim=1, end_dim=2)
|
|
337
323
|
|
|
338
324
|
|
|
339
325
|
class AveragePooling(nn.Module):
|
|
@@ -367,9 +353,7 @@ class SumPooling(nn.Module):
|
|
|
367
353
|
class MLP(nn.Module):
|
|
368
354
|
"""Stacked fully connected layers used in the deep component."""
|
|
369
355
|
|
|
370
|
-
def __init__(
|
|
371
|
-
self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"
|
|
372
|
-
):
|
|
356
|
+
def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
|
|
373
357
|
super().__init__()
|
|
374
358
|
if dims is None:
|
|
375
359
|
dims = []
|
|
@@ -396,7 +380,7 @@ class FM(nn.Module):
|
|
|
396
380
|
self.reduce_sum = reduce_sum
|
|
397
381
|
|
|
398
382
|
def forward(self, x):
|
|
399
|
-
square_of_sum = torch.sum(x, dim=1)
|
|
383
|
+
square_of_sum = torch.sum(x, dim=1)**2
|
|
400
384
|
sum_of_square = torch.sum(x**2, dim=1)
|
|
401
385
|
ix = square_of_sum - sum_of_square
|
|
402
386
|
if self.reduce_sum:
|
|
@@ -415,16 +399,7 @@ class CIN(nn.Module):
|
|
|
415
399
|
prev_dim, fc_input_dim = input_dim, 0
|
|
416
400
|
for i in range(self.num_layers):
|
|
417
401
|
cross_layer_size = cin_size[i]
|
|
418
|
-
self.conv_layers.append(
|
|
419
|
-
torch.nn.Conv1d(
|
|
420
|
-
input_dim * prev_dim,
|
|
421
|
-
cross_layer_size,
|
|
422
|
-
1,
|
|
423
|
-
stride=1,
|
|
424
|
-
dilation=1,
|
|
425
|
-
bias=True,
|
|
426
|
-
)
|
|
427
|
-
)
|
|
402
|
+
self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
|
|
428
403
|
if self.split_half and i != self.num_layers - 1:
|
|
429
404
|
cross_layer_size //= 2
|
|
430
405
|
prev_dim = cross_layer_size
|
|
@@ -446,7 +421,6 @@ class CIN(nn.Module):
|
|
|
446
421
|
xs.append(x)
|
|
447
422
|
return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
|
|
448
423
|
|
|
449
|
-
|
|
450
424
|
class CrossLayer(nn.Module):
|
|
451
425
|
"""Single cross layer used in DCN (Wang et al., 2017)."""
|
|
452
426
|
|
|
@@ -466,12 +440,8 @@ class CrossNetwork(nn.Module):
|
|
|
466
440
|
def __init__(self, input_dim, num_layers):
|
|
467
441
|
super().__init__()
|
|
468
442
|
self.num_layers = num_layers
|
|
469
|
-
self.w = torch.nn.ModuleList(
|
|
470
|
-
|
|
471
|
-
)
|
|
472
|
-
self.b = torch.nn.ParameterList(
|
|
473
|
-
[torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)]
|
|
474
|
-
)
|
|
443
|
+
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
|
|
444
|
+
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
475
445
|
|
|
476
446
|
def forward(self, x):
|
|
477
447
|
"""
|
|
@@ -483,30 +453,21 @@ class CrossNetwork(nn.Module):
|
|
|
483
453
|
x = x0 * xw + self.b[i] + x
|
|
484
454
|
return x
|
|
485
455
|
|
|
486
|
-
|
|
487
456
|
class CrossNetV2(nn.Module):
|
|
488
457
|
"""Vector-wise cross network proposed in DCN V2 (Wang et al., 2021)."""
|
|
489
|
-
|
|
490
458
|
def __init__(self, input_dim, num_layers):
|
|
491
459
|
super().__init__()
|
|
492
460
|
self.num_layers = num_layers
|
|
493
|
-
self.w = torch.nn.ModuleList(
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
for _ in range(num_layers)
|
|
497
|
-
]
|
|
498
|
-
)
|
|
499
|
-
self.b = torch.nn.ParameterList(
|
|
500
|
-
[torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)]
|
|
501
|
-
)
|
|
461
|
+
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
|
|
462
|
+
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
463
|
+
|
|
502
464
|
|
|
503
465
|
def forward(self, x):
|
|
504
466
|
x0 = x
|
|
505
467
|
for i in range(self.num_layers):
|
|
506
|
-
x =
|
|
468
|
+
x =x0*self.w[i](x) + self.b[i] + x
|
|
507
469
|
return x
|
|
508
470
|
|
|
509
|
-
|
|
510
471
|
class CrossNetMix(nn.Module):
|
|
511
472
|
"""Mixture of low-rank cross experts from DCN V2 (Wang et al., 2021)."""
|
|
512
473
|
|
|
@@ -516,46 +477,18 @@ class CrossNetMix(nn.Module):
|
|
|
516
477
|
self.num_experts = num_experts
|
|
517
478
|
|
|
518
479
|
# U: (input_dim, low_rank)
|
|
519
|
-
self.u_list = torch.nn.ParameterList(
|
|
520
|
-
|
|
521
|
-
nn.Parameter(
|
|
522
|
-
nn.init.xavier_normal_(
|
|
523
|
-
torch.empty(num_experts, input_dim, low_rank)
|
|
524
|
-
)
|
|
525
|
-
)
|
|
526
|
-
for i in range(self.num_layers)
|
|
527
|
-
]
|
|
528
|
-
)
|
|
480
|
+
self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
481
|
+
torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
529
482
|
# V: (input_dim, low_rank)
|
|
530
|
-
self.v_list = torch.nn.ParameterList(
|
|
531
|
-
|
|
532
|
-
nn.Parameter(
|
|
533
|
-
nn.init.xavier_normal_(
|
|
534
|
-
torch.empty(num_experts, input_dim, low_rank)
|
|
535
|
-
)
|
|
536
|
-
)
|
|
537
|
-
for i in range(self.num_layers)
|
|
538
|
-
]
|
|
539
|
-
)
|
|
483
|
+
self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
484
|
+
torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
540
485
|
# C: (low_rank, low_rank)
|
|
541
|
-
self.c_list = torch.nn.ParameterList(
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
]
|
|
548
|
-
)
|
|
549
|
-
self.gating = nn.ModuleList(
|
|
550
|
-
[nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)]
|
|
551
|
-
)
|
|
552
|
-
|
|
553
|
-
self.bias = torch.nn.ParameterList(
|
|
554
|
-
[
|
|
555
|
-
nn.Parameter(nn.init.zeros_(torch.empty(input_dim, 1)))
|
|
556
|
-
for i in range(self.num_layers)
|
|
557
|
-
]
|
|
558
|
-
)
|
|
486
|
+
self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
487
|
+
torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
|
|
488
|
+
self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
|
|
489
|
+
|
|
490
|
+
self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
|
|
491
|
+
torch.empty(input_dim, 1))) for i in range(self.num_layers)])
|
|
559
492
|
|
|
560
493
|
def forward(self, x):
|
|
561
494
|
x_0 = x.unsqueeze(2) # (bs, in_features, 1)
|
|
@@ -570,9 +503,7 @@ class CrossNetMix(nn.Module):
|
|
|
570
503
|
|
|
571
504
|
# (2) E(x_l)
|
|
572
505
|
# project the input x_l to $\mathbb{R}^{r}$
|
|
573
|
-
v_x = torch.matmul(
|
|
574
|
-
self.v_list[i][expert_id].t(), x_l
|
|
575
|
-
) # (bs, low_rank, 1)
|
|
506
|
+
v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
|
|
576
507
|
|
|
577
508
|
# nonlinear activation in low rank space
|
|
578
509
|
v_x = torch.tanh(v_x)
|
|
@@ -580,9 +511,7 @@ class CrossNetMix(nn.Module):
|
|
|
580
511
|
v_x = torch.tanh(v_x)
|
|
581
512
|
|
|
582
513
|
# project back to $\mathbb{R}^{d}$
|
|
583
|
-
uv_x = torch.matmul(
|
|
584
|
-
self.u_list[i][expert_id], v_x
|
|
585
|
-
) # (bs, in_features, 1)
|
|
514
|
+
uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
|
|
586
515
|
|
|
587
516
|
dot_ = uv_x + self.bias[i]
|
|
588
517
|
dot_ = x_0 * dot_ # Hadamard-product
|
|
@@ -590,78 +519,53 @@ class CrossNetMix(nn.Module):
|
|
|
590
519
|
output_of_experts.append(dot_.squeeze(2))
|
|
591
520
|
|
|
592
521
|
# (3) mixture of low-rank experts
|
|
593
|
-
output_of_experts = torch.stack(
|
|
594
|
-
|
|
595
|
-
) # (bs, in_features, num_experts)
|
|
596
|
-
gating_score_experts = torch.stack(
|
|
597
|
-
gating_score_experts, 1
|
|
598
|
-
) # (bs, num_experts, 1)
|
|
522
|
+
output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
|
|
523
|
+
gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
|
|
599
524
|
moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
|
|
600
525
|
x_l = moe_out + x_l # (bs, in_features, 1)
|
|
601
526
|
|
|
602
527
|
x_l = x_l.squeeze() # (bs, in_features)
|
|
603
528
|
return x_l
|
|
604
529
|
|
|
605
|
-
|
|
606
530
|
class SENETLayer(nn.Module):
|
|
607
531
|
"""Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
|
|
608
532
|
|
|
609
533
|
def __init__(self, num_fields, reduction_ratio=3):
|
|
610
534
|
super(SENETLayer, self).__init__()
|
|
611
|
-
reduced_size = max(1, int(num_fields
|
|
612
|
-
self.mlp = nn.Sequential(
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
nn.ReLU(),
|
|
617
|
-
)
|
|
618
|
-
|
|
535
|
+
reduced_size = max(1, int(num_fields/ reduction_ratio))
|
|
536
|
+
self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
|
|
537
|
+
nn.ReLU(),
|
|
538
|
+
nn.Linear(reduced_size, num_fields, bias=False),
|
|
539
|
+
nn.ReLU())
|
|
619
540
|
def forward(self, x):
|
|
620
541
|
z = torch.mean(x, dim=-1, out=None)
|
|
621
542
|
a = self.mlp(z)
|
|
622
|
-
v = x
|
|
543
|
+
v = x*a.unsqueeze(-1)
|
|
623
544
|
return v
|
|
624
545
|
|
|
625
|
-
|
|
626
546
|
class BiLinearInteractionLayer(nn.Module):
|
|
627
547
|
"""Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
|
|
628
548
|
|
|
629
|
-
def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
|
|
549
|
+
def __init__(self, input_dim, num_fields, bilinear_type = "field_interaction"):
|
|
630
550
|
super(BiLinearInteractionLayer, self).__init__()
|
|
631
551
|
self.bilinear_type = bilinear_type
|
|
632
552
|
if self.bilinear_type == "field_all":
|
|
633
553
|
self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
|
|
634
554
|
elif self.bilinear_type == "field_each":
|
|
635
|
-
self.bilinear_layer = nn.ModuleList(
|
|
636
|
-
[nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)]
|
|
637
|
-
)
|
|
555
|
+
self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
|
|
638
556
|
elif self.bilinear_type == "field_interaction":
|
|
639
|
-
self.bilinear_layer = nn.ModuleList(
|
|
640
|
-
[
|
|
641
|
-
nn.Linear(input_dim, input_dim, bias=False)
|
|
642
|
-
for i, j in combinations(range(num_fields), 2)
|
|
643
|
-
]
|
|
644
|
-
)
|
|
557
|
+
self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i,j in combinations(range(num_fields), 2)])
|
|
645
558
|
else:
|
|
646
559
|
raise NotImplementedError()
|
|
647
560
|
|
|
648
561
|
def forward(self, x):
|
|
649
562
|
feature_emb = torch.split(x, 1, dim=1)
|
|
650
563
|
if self.bilinear_type == "field_all":
|
|
651
|
-
bilinear_list = [
|
|
652
|
-
self.bilinear_layer(v_i) * v_j
|
|
653
|
-
for v_i, v_j in combinations(feature_emb, 2)
|
|
654
|
-
]
|
|
564
|
+
bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
|
|
655
565
|
elif self.bilinear_type == "field_each":
|
|
656
|
-
bilinear_list = [
|
|
657
|
-
self.bilinear_layer[i](feature_emb[i]) * feature_emb[j]
|
|
658
|
-
for i, j in combinations(range(len(feature_emb)), 2)
|
|
659
|
-
]
|
|
566
|
+
bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)]
|
|
660
567
|
elif self.bilinear_type == "field_interaction":
|
|
661
|
-
bilinear_list = [
|
|
662
|
-
self.bilinear_layer[i](v[0]) * v[1]
|
|
663
|
-
for i, v in enumerate(combinations(feature_emb, 2))
|
|
664
|
-
]
|
|
568
|
+
bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
|
|
665
569
|
return torch.cat(bilinear_list, dim=1)
|
|
666
570
|
|
|
667
571
|
|
|
@@ -674,23 +578,17 @@ class MultiInterestSA(nn.Module):
|
|
|
674
578
|
self.interest_num = interest_num
|
|
675
579
|
if hidden_dim == None:
|
|
676
580
|
self.hidden_dim = self.embedding_dim * 4
|
|
677
|
-
self.W1 = torch.nn.Parameter(
|
|
678
|
-
|
|
679
|
-
)
|
|
680
|
-
self.W2 = torch.nn.Parameter(
|
|
681
|
-
torch.rand(self.hidden_dim, self.interest_num), requires_grad=True
|
|
682
|
-
)
|
|
683
|
-
self.W3 = torch.nn.Parameter(
|
|
684
|
-
torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True
|
|
685
|
-
)
|
|
581
|
+
self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
|
|
582
|
+
self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
|
|
583
|
+
self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
|
|
686
584
|
|
|
687
585
|
def forward(self, seq_emb, mask=None):
|
|
688
|
-
H = torch.einsum(
|
|
586
|
+
H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
|
|
689
587
|
if mask != None:
|
|
690
|
-
A = torch.einsum(
|
|
588
|
+
A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
|
|
691
589
|
A = F.softmax(A, dim=1)
|
|
692
590
|
else:
|
|
693
|
-
A = F.softmax(torch.einsum(
|
|
591
|
+
A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
|
|
694
592
|
A = A.permute(0, 2, 1)
|
|
695
593
|
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
696
594
|
return multi_interest_emb
|
|
@@ -699,15 +597,7 @@ class MultiInterestSA(nn.Module):
|
|
|
699
597
|
class CapsuleNetwork(nn.Module):
|
|
700
598
|
"""Dynamic routing capsule network used in MIND (Li et al., 2019)."""
|
|
701
599
|
|
|
702
|
-
def __init__(
|
|
703
|
-
self,
|
|
704
|
-
embedding_dim,
|
|
705
|
-
seq_len,
|
|
706
|
-
bilinear_type=2,
|
|
707
|
-
interest_num=4,
|
|
708
|
-
routing_times=3,
|
|
709
|
-
relu_layer=False,
|
|
710
|
-
):
|
|
600
|
+
def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
|
|
711
601
|
super(CapsuleNetwork, self).__init__()
|
|
712
602
|
self.embedding_dim = embedding_dim # h
|
|
713
603
|
self.seq_len = seq_len # s
|
|
@@ -717,24 +607,13 @@ class CapsuleNetwork(nn.Module):
|
|
|
717
607
|
|
|
718
608
|
self.relu_layer = relu_layer
|
|
719
609
|
self.stop_grad = True
|
|
720
|
-
self.relu = nn.Sequential(
|
|
721
|
-
nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU()
|
|
722
|
-
)
|
|
610
|
+
self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
|
|
723
611
|
if self.bilinear_type == 0: # MIND
|
|
724
612
|
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
725
613
|
elif self.bilinear_type == 1:
|
|
726
|
-
self.linear = nn.Linear(
|
|
727
|
-
self.embedding_dim, self.embedding_dim * self.interest_num, bias=False
|
|
728
|
-
)
|
|
614
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
|
|
729
615
|
else:
|
|
730
|
-
self.w = nn.Parameter(
|
|
731
|
-
torch.Tensor(
|
|
732
|
-
1,
|
|
733
|
-
self.seq_len,
|
|
734
|
-
self.interest_num * self.embedding_dim,
|
|
735
|
-
self.embedding_dim,
|
|
736
|
-
)
|
|
737
|
-
)
|
|
616
|
+
self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
|
|
738
617
|
nn.init.xavier_uniform_(self.w)
|
|
739
618
|
|
|
740
619
|
def forward(self, item_eb, mask):
|
|
@@ -745,15 +624,11 @@ class CapsuleNetwork(nn.Module):
|
|
|
745
624
|
item_eb_hat = self.linear(item_eb)
|
|
746
625
|
else:
|
|
747
626
|
u = torch.unsqueeze(item_eb, dim=2)
|
|
748
|
-
item_eb_hat = torch.sum(self.w[:, :
|
|
627
|
+
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
749
628
|
|
|
750
|
-
item_eb_hat = torch.reshape(
|
|
751
|
-
item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim)
|
|
752
|
-
)
|
|
629
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
|
|
753
630
|
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
754
|
-
item_eb_hat = torch.reshape(
|
|
755
|
-
item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim)
|
|
756
|
-
)
|
|
631
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
|
|
757
632
|
|
|
758
633
|
if self.stop_grad:
|
|
759
634
|
item_eb_hat_iter = item_eb_hat.detach()
|
|
@@ -761,47 +636,34 @@ class CapsuleNetwork(nn.Module):
|
|
|
761
636
|
item_eb_hat_iter = item_eb_hat
|
|
762
637
|
|
|
763
638
|
if self.bilinear_type > 0:
|
|
764
|
-
capsule_weight = torch.zeros(
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
requires_grad=False,
|
|
770
|
-
)
|
|
639
|
+
capsule_weight = torch.zeros(item_eb_hat.shape[0],
|
|
640
|
+
self.interest_num,
|
|
641
|
+
self.seq_len,
|
|
642
|
+
device=item_eb.device,
|
|
643
|
+
requires_grad=False)
|
|
771
644
|
else:
|
|
772
|
-
capsule_weight = torch.randn(
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
requires_grad=False,
|
|
778
|
-
)
|
|
645
|
+
capsule_weight = torch.randn(item_eb_hat.shape[0],
|
|
646
|
+
self.interest_num,
|
|
647
|
+
self.seq_len,
|
|
648
|
+
device=item_eb.device,
|
|
649
|
+
requires_grad=False)
|
|
779
650
|
|
|
780
651
|
for i in range(self.routing_times): # 动态路由传播3次
|
|
781
652
|
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
782
653
|
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
783
654
|
|
|
784
655
|
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
785
|
-
capsule_softmax_weight = torch.where(
|
|
786
|
-
torch.eq(atten_mask, 0), paddings, capsule_softmax_weight
|
|
787
|
-
)
|
|
656
|
+
capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
|
|
788
657
|
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
789
658
|
|
|
790
659
|
if i < 2:
|
|
791
|
-
interest_capsule = torch.matmul(
|
|
792
|
-
capsule_softmax_weight, item_eb_hat_iter
|
|
793
|
-
)
|
|
660
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
|
|
794
661
|
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
795
662
|
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
796
663
|
interest_capsule = scalar_factor * interest_capsule
|
|
797
664
|
|
|
798
|
-
delta_weight = torch.matmul(
|
|
799
|
-
|
|
800
|
-
torch.transpose(interest_capsule, 2, 3).contiguous(),
|
|
801
|
-
)
|
|
802
|
-
delta_weight = torch.reshape(
|
|
803
|
-
delta_weight, (-1, self.interest_num, self.seq_len)
|
|
804
|
-
)
|
|
665
|
+
delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
|
|
666
|
+
delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
|
|
805
667
|
capsule_weight = capsule_weight + delta_weight
|
|
806
668
|
else:
|
|
807
669
|
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
@@ -809,9 +671,7 @@ class CapsuleNetwork(nn.Module):
|
|
|
809
671
|
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
810
672
|
interest_capsule = scalar_factor * interest_capsule
|
|
811
673
|
|
|
812
|
-
interest_capsule = torch.reshape(
|
|
813
|
-
interest_capsule, (-1, self.interest_num, self.embedding_dim)
|
|
814
|
-
)
|
|
674
|
+
interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
|
|
815
675
|
|
|
816
676
|
if self.relu_layer:
|
|
817
677
|
interest_capsule = self.relu(interest_capsule)
|
|
@@ -823,18 +683,18 @@ class FFM(nn.Module):
|
|
|
823
683
|
"""Field-aware Factorization Machine (Juan et al., 2016)."""
|
|
824
684
|
|
|
825
685
|
def __init__(self, num_fields, reduce_sum=True):
|
|
826
|
-
super().__init__()
|
|
686
|
+
super().__init__()
|
|
827
687
|
self.num_fields = num_fields
|
|
828
688
|
self.reduce_sum = reduce_sum
|
|
829
689
|
|
|
830
690
|
def forward(self, x):
|
|
831
691
|
# compute (non-redundant) second order field-aware feature crossings
|
|
832
692
|
crossed_embeddings = []
|
|
833
|
-
for i in range(self.num_fields
|
|
834
|
-
for j in range(i
|
|
835
|
-
crossed_embeddings.append(x[:, i, j, :] *
|
|
693
|
+
for i in range(self.num_fields-1):
|
|
694
|
+
for j in range(i+1, self.num_fields):
|
|
695
|
+
crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
|
|
836
696
|
crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
|
|
837
|
-
|
|
697
|
+
|
|
838
698
|
# if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
|
|
839
699
|
if self.reduce_sum:
|
|
840
700
|
crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
|
|
@@ -845,57 +705,49 @@ class CEN(nn.Module):
|
|
|
845
705
|
"""Field-attentive interaction network from FAT-DeepFFM (Wang et al., 2020)."""
|
|
846
706
|
|
|
847
707
|
def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
|
|
848
|
-
super().__init__()
|
|
849
|
-
|
|
708
|
+
super().__init__()
|
|
709
|
+
|
|
850
710
|
# convolution weight (Eq.7 FAT-DeepFFM)
|
|
851
|
-
self.u = torch.nn.Parameter(
|
|
852
|
-
torch.rand(num_field_crosses, embed_dim), requires_grad=True
|
|
853
|
-
)
|
|
711
|
+
self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
|
|
854
712
|
|
|
855
713
|
# two FC layers that computes the field attention
|
|
856
|
-
self.mlp_att = MLP(
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
activation="relu",
|
|
861
|
-
)
|
|
862
|
-
|
|
863
|
-
def forward(self, em):
|
|
714
|
+
self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses//reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def forward(self, em):
|
|
864
718
|
# compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
|
|
865
719
|
d = F.relu((self.u.squeeze(0) * em).sum(-1))
|
|
866
|
-
|
|
867
|
-
# compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
|
|
868
|
-
s = self.mlp_att(d)
|
|
720
|
+
|
|
721
|
+
# compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
|
|
722
|
+
s = self.mlp_att(d)
|
|
869
723
|
|
|
870
724
|
# rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
|
|
871
|
-
aem = s.unsqueeze(-1) * em
|
|
725
|
+
aem = s.unsqueeze(-1) * em
|
|
872
726
|
return aem.flatten(start_dim=1)
|
|
873
727
|
|
|
874
728
|
|
|
875
729
|
class MultiHeadSelfAttention(nn.Module):
|
|
876
730
|
"""Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
|
|
877
|
-
|
|
731
|
+
|
|
878
732
|
def __init__(self, embedding_dim, num_heads=2, dropout=0.0, use_residual=True):
|
|
879
733
|
super().__init__()
|
|
880
734
|
if embedding_dim % num_heads != 0:
|
|
881
|
-
raise ValueError(
|
|
882
|
-
|
|
883
|
-
)
|
|
884
|
-
|
|
735
|
+
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
|
|
736
|
+
|
|
885
737
|
self.embedding_dim = embedding_dim
|
|
886
738
|
self.num_heads = num_heads
|
|
887
739
|
self.head_dim = embedding_dim // num_heads
|
|
888
740
|
self.use_residual = use_residual
|
|
889
|
-
|
|
741
|
+
|
|
890
742
|
self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
891
743
|
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
892
744
|
self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
893
|
-
|
|
745
|
+
|
|
894
746
|
if self.use_residual:
|
|
895
747
|
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
896
|
-
|
|
748
|
+
|
|
897
749
|
self.dropout = nn.Dropout(dropout)
|
|
898
|
-
|
|
750
|
+
|
|
899
751
|
def forward(self, x):
|
|
900
752
|
"""
|
|
901
753
|
Args:
|
|
@@ -904,47 +756,37 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
904
756
|
output: [batch_size, num_fields, embedding_dim]
|
|
905
757
|
"""
|
|
906
758
|
batch_size, num_fields, _ = x.shape
|
|
907
|
-
|
|
759
|
+
|
|
908
760
|
# Linear projections
|
|
909
761
|
Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
|
|
910
762
|
K = self.W_K(x)
|
|
911
763
|
V = self.W_V(x)
|
|
912
|
-
|
|
764
|
+
|
|
913
765
|
# Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
|
|
914
|
-
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
915
|
-
|
|
916
|
-
)
|
|
917
|
-
|
|
918
|
-
1, 2
|
|
919
|
-
)
|
|
920
|
-
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
921
|
-
1, 2
|
|
922
|
-
)
|
|
923
|
-
|
|
766
|
+
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
767
|
+
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
768
|
+
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
769
|
+
|
|
924
770
|
# Attention scores
|
|
925
|
-
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
771
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
|
926
772
|
attention_weights = F.softmax(scores, dim=-1)
|
|
927
773
|
attention_weights = self.dropout(attention_weights)
|
|
928
|
-
|
|
774
|
+
|
|
929
775
|
# Apply attention to values
|
|
930
|
-
attention_output = torch.matmul(
|
|
931
|
-
|
|
932
|
-
) # [batch_size, num_heads, num_fields, head_dim]
|
|
933
|
-
|
|
776
|
+
attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
|
|
777
|
+
|
|
934
778
|
# Concatenate heads
|
|
935
779
|
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
936
|
-
attention_output = attention_output.view(
|
|
937
|
-
|
|
938
|
-
)
|
|
939
|
-
|
|
780
|
+
attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
|
|
781
|
+
|
|
940
782
|
# Residual connection
|
|
941
783
|
if self.use_residual:
|
|
942
784
|
output = attention_output + self.W_Res(x)
|
|
943
785
|
else:
|
|
944
786
|
output = attention_output
|
|
945
|
-
|
|
787
|
+
|
|
946
788
|
output = F.relu(output)
|
|
947
|
-
|
|
789
|
+
|
|
948
790
|
return output
|
|
949
791
|
|
|
950
792
|
|
|
@@ -953,31 +795,25 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
953
795
|
Attention pooling layer for DIN/DIEN
|
|
954
796
|
Computes attention weights between query (candidate item) and keys (user behavior sequence)
|
|
955
797
|
"""
|
|
956
|
-
|
|
957
|
-
def __init__(
|
|
958
|
-
self,
|
|
959
|
-
embedding_dim,
|
|
960
|
-
hidden_units=[80, 40],
|
|
961
|
-
activation="sigmoid",
|
|
962
|
-
use_softmax=True,
|
|
963
|
-
):
|
|
798
|
+
|
|
799
|
+
def __init__(self, embedding_dim, hidden_units=[80, 40], activation='sigmoid', use_softmax=True):
|
|
964
800
|
super().__init__()
|
|
965
801
|
self.embedding_dim = embedding_dim
|
|
966
802
|
self.use_softmax = use_softmax
|
|
967
|
-
|
|
803
|
+
|
|
968
804
|
# Build attention network
|
|
969
805
|
# Input: [query, key, query-key, query*key] -> 4 * embedding_dim
|
|
970
806
|
input_dim = 4 * embedding_dim
|
|
971
807
|
layers = []
|
|
972
|
-
|
|
808
|
+
|
|
973
809
|
for hidden_unit in hidden_units:
|
|
974
810
|
layers.append(nn.Linear(input_dim, hidden_unit))
|
|
975
811
|
layers.append(activation_layer(activation))
|
|
976
812
|
input_dim = hidden_unit
|
|
977
|
-
|
|
813
|
+
|
|
978
814
|
layers.append(nn.Linear(input_dim, 1))
|
|
979
815
|
self.attention_net = nn.Sequential(*layers)
|
|
980
|
-
|
|
816
|
+
|
|
981
817
|
def forward(self, query, keys, keys_length=None, mask=None):
|
|
982
818
|
"""
|
|
983
819
|
Args:
|
|
@@ -989,52 +825,48 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
989
825
|
output: [batch_size, embedding_dim] - attention pooled representation
|
|
990
826
|
"""
|
|
991
827
|
batch_size, seq_len, emb_dim = keys.shape
|
|
992
|
-
|
|
828
|
+
|
|
993
829
|
# Expand query to match sequence length: [batch_size, seq_len, embedding_dim]
|
|
994
830
|
query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
|
|
995
|
-
|
|
831
|
+
|
|
996
832
|
# Compute attention features: [query, key, query-key, query*key]
|
|
997
|
-
attention_input = torch.cat(
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
833
|
+
attention_input = torch.cat([
|
|
834
|
+
query_expanded,
|
|
835
|
+
keys,
|
|
836
|
+
query_expanded - keys,
|
|
837
|
+
query_expanded * keys
|
|
838
|
+
], dim=-1) # [batch_size, seq_len, 4*embedding_dim]
|
|
839
|
+
|
|
1001
840
|
# Compute attention scores
|
|
1002
|
-
attention_scores = self.attention_net(
|
|
1003
|
-
|
|
1004
|
-
) # [batch_size, seq_len, 1]
|
|
1005
|
-
|
|
841
|
+
attention_scores = self.attention_net(attention_input) # [batch_size, seq_len, 1]
|
|
842
|
+
|
|
1006
843
|
# Apply mask if provided
|
|
1007
844
|
if mask is not None:
|
|
1008
845
|
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
|
|
1009
|
-
|
|
846
|
+
|
|
1010
847
|
# Apply softmax to get attention weights
|
|
1011
848
|
if self.use_softmax:
|
|
1012
|
-
attention_weights = F.softmax(
|
|
1013
|
-
attention_scores, dim=1
|
|
1014
|
-
) # [batch_size, seq_len, 1]
|
|
849
|
+
attention_weights = F.softmax(attention_scores, dim=1) # [batch_size, seq_len, 1]
|
|
1015
850
|
else:
|
|
1016
851
|
attention_weights = attention_scores
|
|
1017
|
-
|
|
852
|
+
|
|
1018
853
|
# Weighted sum of keys
|
|
1019
|
-
output = torch.sum(
|
|
1020
|
-
|
|
1021
|
-
) # [batch_size, embedding_dim]
|
|
1022
|
-
|
|
854
|
+
output = torch.sum(attention_weights * keys, dim=1) # [batch_size, embedding_dim]
|
|
855
|
+
|
|
1023
856
|
return output
|
|
1024
857
|
|
|
1025
858
|
|
|
1026
859
|
class DynamicGRU(nn.Module):
|
|
1027
860
|
"""Dynamic GRU unit with auxiliary loss path from DIEN (Zhou et al., 2019)."""
|
|
1028
|
-
|
|
1029
861
|
"""
|
|
1030
862
|
GRU with dynamic routing for DIEN
|
|
1031
863
|
"""
|
|
1032
|
-
|
|
864
|
+
|
|
1033
865
|
def __init__(self, input_size, hidden_size, bias=True):
|
|
1034
866
|
super().__init__()
|
|
1035
867
|
self.input_size = input_size
|
|
1036
868
|
self.hidden_size = hidden_size
|
|
1037
|
-
|
|
869
|
+
|
|
1038
870
|
# GRU parameters
|
|
1039
871
|
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
|
|
1040
872
|
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
|
|
@@ -1042,16 +874,16 @@ class DynamicGRU(nn.Module):
|
|
|
1042
874
|
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
|
|
1043
875
|
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
|
|
1044
876
|
else:
|
|
1045
|
-
self.register_parameter(
|
|
1046
|
-
self.register_parameter(
|
|
1047
|
-
|
|
877
|
+
self.register_parameter('bias_ih', None)
|
|
878
|
+
self.register_parameter('bias_hh', None)
|
|
879
|
+
|
|
1048
880
|
self.reset_parameters()
|
|
1049
|
-
|
|
881
|
+
|
|
1050
882
|
def reset_parameters(self):
|
|
1051
883
|
std = 1.0 / (self.hidden_size) ** 0.5
|
|
1052
884
|
for weight in self.parameters():
|
|
1053
885
|
weight.data.uniform_(-std, std)
|
|
1054
|
-
|
|
886
|
+
|
|
1055
887
|
def forward(self, x, att_scores=None):
|
|
1056
888
|
"""
|
|
1057
889
|
Args:
|
|
@@ -1062,61 +894,60 @@ class DynamicGRU(nn.Module):
|
|
|
1062
894
|
hidden: [batch_size, hidden_size] - final hidden state
|
|
1063
895
|
"""
|
|
1064
896
|
batch_size, seq_len, _ = x.shape
|
|
1065
|
-
|
|
897
|
+
|
|
1066
898
|
# Initialize hidden state
|
|
1067
899
|
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
1068
|
-
|
|
900
|
+
|
|
1069
901
|
outputs = []
|
|
1070
902
|
for t in range(seq_len):
|
|
1071
903
|
x_t = x[:, t, :] # [batch_size, input_size]
|
|
1072
|
-
|
|
904
|
+
|
|
1073
905
|
# GRU computation
|
|
1074
906
|
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
1075
907
|
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
1076
908
|
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
1077
909
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
1078
|
-
|
|
910
|
+
|
|
1079
911
|
resetgate = torch.sigmoid(i_r + h_r)
|
|
1080
912
|
inputgate = torch.sigmoid(i_i + h_i)
|
|
1081
913
|
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
1082
914
|
h = newgate + inputgate * (h - newgate)
|
|
1083
|
-
|
|
915
|
+
|
|
1084
916
|
outputs.append(h.unsqueeze(1))
|
|
1085
|
-
|
|
917
|
+
|
|
1086
918
|
output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
|
|
1087
|
-
|
|
919
|
+
|
|
1088
920
|
return output, h
|
|
1089
921
|
|
|
1090
922
|
|
|
1091
923
|
class AUGRU(nn.Module):
|
|
1092
924
|
"""Attention-aware GRU update gate used in DIEN (Zhou et al., 2019)."""
|
|
1093
|
-
|
|
1094
925
|
"""
|
|
1095
926
|
Attention-based GRU for DIEN
|
|
1096
927
|
Uses attention scores to weight the update of hidden states
|
|
1097
928
|
"""
|
|
1098
|
-
|
|
929
|
+
|
|
1099
930
|
def __init__(self, input_size, hidden_size, bias=True):
|
|
1100
931
|
super().__init__()
|
|
1101
932
|
self.input_size = input_size
|
|
1102
933
|
self.hidden_size = hidden_size
|
|
1103
|
-
|
|
934
|
+
|
|
1104
935
|
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
|
|
1105
936
|
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
|
|
1106
937
|
if bias:
|
|
1107
938
|
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
|
|
1108
939
|
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
|
|
1109
940
|
else:
|
|
1110
|
-
self.register_parameter(
|
|
1111
|
-
self.register_parameter(
|
|
1112
|
-
|
|
941
|
+
self.register_parameter('bias_ih', None)
|
|
942
|
+
self.register_parameter('bias_hh', None)
|
|
943
|
+
|
|
1113
944
|
self.reset_parameters()
|
|
1114
|
-
|
|
945
|
+
|
|
1115
946
|
def reset_parameters(self):
|
|
1116
947
|
std = 1.0 / (self.hidden_size) ** 0.5
|
|
1117
948
|
for weight in self.parameters():
|
|
1118
949
|
weight.data.uniform_(-std, std)
|
|
1119
|
-
|
|
950
|
+
|
|
1120
951
|
def forward(self, x, att_scores):
|
|
1121
952
|
"""
|
|
1122
953
|
Args:
|
|
@@ -1127,28 +958,28 @@ class AUGRU(nn.Module):
|
|
|
1127
958
|
hidden: [batch_size, hidden_size] - final hidden state
|
|
1128
959
|
"""
|
|
1129
960
|
batch_size, seq_len, _ = x.shape
|
|
1130
|
-
|
|
961
|
+
|
|
1131
962
|
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
1132
|
-
|
|
963
|
+
|
|
1133
964
|
outputs = []
|
|
1134
965
|
for t in range(seq_len):
|
|
1135
966
|
x_t = x[:, t, :] # [batch_size, input_size]
|
|
1136
967
|
att_t = att_scores[:, t, :] # [batch_size, 1]
|
|
1137
|
-
|
|
968
|
+
|
|
1138
969
|
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
1139
970
|
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
1140
971
|
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
1141
972
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
1142
|
-
|
|
973
|
+
|
|
1143
974
|
resetgate = torch.sigmoid(i_r + h_r)
|
|
1144
975
|
inputgate = torch.sigmoid(i_i + h_i)
|
|
1145
976
|
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
1146
|
-
|
|
977
|
+
|
|
1147
978
|
# Use attention score to control update
|
|
1148
979
|
h = (1 - att_t) * h + att_t * newgate
|
|
1149
|
-
|
|
980
|
+
|
|
1150
981
|
outputs.append(h.unsqueeze(1))
|
|
1151
|
-
|
|
982
|
+
|
|
1152
983
|
output = torch.cat(outputs, dim=1)
|
|
1153
|
-
|
|
1154
|
-
return output, h
|
|
984
|
+
|
|
985
|
+
return output, h
|