nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -9
- nextrec/basic/callback.py +1 -0
- nextrec/basic/dataloader.py +168 -127
- nextrec/basic/features.py +24 -27
- nextrec/basic/layers.py +328 -159
- nextrec/basic/loggers.py +50 -37
- nextrec/basic/metrics.py +255 -147
- nextrec/basic/model.py +817 -462
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +16 -12
- nextrec/data/preprocessor.py +276 -252
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +30 -22
- nextrec/loss/match_losses.py +116 -83
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +70 -61
- nextrec/models/match/dssm_v2.py +61 -51
- nextrec/models/match/mind.py +89 -71
- nextrec/models/match/sdm.py +93 -81
- nextrec/models/match/youtube_dnn.py +62 -53
- nextrec/models/multi_task/esmm.py +49 -43
- nextrec/models/multi_task/mmoe.py +65 -56
- nextrec/models/multi_task/ple.py +92 -65
- nextrec/models/multi_task/share_bottom.py +48 -42
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +39 -30
- nextrec/models/ranking/autoint.py +70 -57
- nextrec/models/ranking/dcn.py +43 -35
- nextrec/models/ranking/deepfm.py +34 -28
- nextrec/models/ranking/dien.py +115 -79
- nextrec/models/ranking/din.py +84 -60
- nextrec/models/ranking/fibinet.py +51 -35
- nextrec/models/ranking/fm.py +28 -26
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +30 -31
- nextrec/models/ranking/widedeep.py +36 -31
- nextrec/models/ranking/xdeepfm.py +46 -39
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +23 -15
- nextrec/utils/optimizer.py +14 -10
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
- nextrec-0.1.2.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/RECORD +0 -51
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py
CHANGED
|
@@ -50,7 +50,14 @@ __all__ = [
|
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class PredictionLayer(nn.Module):
|
|
53
|
-
_CLASSIFICATION_TASKS = {
|
|
53
|
+
_CLASSIFICATION_TASKS = {
|
|
54
|
+
"classification",
|
|
55
|
+
"binary",
|
|
56
|
+
"ctr",
|
|
57
|
+
"ranking",
|
|
58
|
+
"match",
|
|
59
|
+
"matching",
|
|
60
|
+
}
|
|
54
61
|
_REGRESSION_TASKS = {"regression", "continuous"}
|
|
55
62
|
_MULTICLASS_TASKS = {"multiclass", "softmax"}
|
|
56
63
|
|
|
@@ -213,7 +220,9 @@ class EmbeddingLayer(nn.Module):
|
|
|
213
220
|
elif feature.combiner == "concat":
|
|
214
221
|
pooling_layer = ConcatPooling()
|
|
215
222
|
else:
|
|
216
|
-
raise ValueError(
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Unknown combiner for {feature.name}: {feature.combiner}"
|
|
225
|
+
)
|
|
217
226
|
|
|
218
227
|
feature_mask = InputMask()(x, feature, seq_input)
|
|
219
228
|
sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
|
|
@@ -245,7 +254,9 @@ class EmbeddingLayer(nn.Module):
|
|
|
245
254
|
|
|
246
255
|
if target_dim is not None:
|
|
247
256
|
aligned_dense = [
|
|
248
|
-
emb.unsqueeze(1)
|
|
257
|
+
emb.unsqueeze(1)
|
|
258
|
+
for emb in dense_embeds
|
|
259
|
+
if emb.shape[-1] == target_dim
|
|
249
260
|
]
|
|
250
261
|
output_embeddings.extend(aligned_dense)
|
|
251
262
|
|
|
@@ -257,7 +268,9 @@ class EmbeddingLayer(nn.Module):
|
|
|
257
268
|
|
|
258
269
|
return torch.cat(output_embeddings, dim=1)
|
|
259
270
|
|
|
260
|
-
def _project_dense(
|
|
271
|
+
def _project_dense(
|
|
272
|
+
self, feature: DenseFeature, x: dict[str, torch.Tensor]
|
|
273
|
+
) -> torch.Tensor:
|
|
261
274
|
if feature.name not in x:
|
|
262
275
|
raise KeyError(f"Dense feature '{feature.name}' is missing from input.")
|
|
263
276
|
|
|
@@ -280,6 +293,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
280
293
|
def _compute_output_dim(self):
|
|
281
294
|
return
|
|
282
295
|
|
|
296
|
+
|
|
283
297
|
class InputMask(nn.Module):
|
|
284
298
|
"""Utility module to build sequence masks for pooling layers."""
|
|
285
299
|
|
|
@@ -289,9 +303,9 @@ class InputMask(nn.Module):
|
|
|
289
303
|
def forward(self, x, fea, seq_tensor=None):
|
|
290
304
|
values = seq_tensor if seq_tensor is not None else x[fea.name]
|
|
291
305
|
if fea.padding_idx is not None:
|
|
292
|
-
mask =
|
|
306
|
+
mask = values.long() != fea.padding_idx
|
|
293
307
|
else:
|
|
294
|
-
mask =
|
|
308
|
+
mask = values.long() != 0
|
|
295
309
|
if mask.dim() == 1:
|
|
296
310
|
mask = mask.unsqueeze(-1)
|
|
297
311
|
return mask.unsqueeze(1).float()
|
|
@@ -319,7 +333,7 @@ class ConcatPooling(nn.Module):
|
|
|
319
333
|
super().__init__()
|
|
320
334
|
|
|
321
335
|
def forward(self, x, mask=None):
|
|
322
|
-
return x.flatten(start_dim=1, end_dim=2)
|
|
336
|
+
return x.flatten(start_dim=1, end_dim=2)
|
|
323
337
|
|
|
324
338
|
|
|
325
339
|
class AveragePooling(nn.Module):
|
|
@@ -353,7 +367,9 @@ class SumPooling(nn.Module):
|
|
|
353
367
|
class MLP(nn.Module):
|
|
354
368
|
"""Stacked fully connected layers used in the deep component."""
|
|
355
369
|
|
|
356
|
-
def __init__(
|
|
370
|
+
def __init__(
|
|
371
|
+
self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"
|
|
372
|
+
):
|
|
357
373
|
super().__init__()
|
|
358
374
|
if dims is None:
|
|
359
375
|
dims = []
|
|
@@ -380,7 +396,7 @@ class FM(nn.Module):
|
|
|
380
396
|
self.reduce_sum = reduce_sum
|
|
381
397
|
|
|
382
398
|
def forward(self, x):
|
|
383
|
-
square_of_sum = torch.sum(x, dim=1)**2
|
|
399
|
+
square_of_sum = torch.sum(x, dim=1) ** 2
|
|
384
400
|
sum_of_square = torch.sum(x**2, dim=1)
|
|
385
401
|
ix = square_of_sum - sum_of_square
|
|
386
402
|
if self.reduce_sum:
|
|
@@ -399,7 +415,16 @@ class CIN(nn.Module):
|
|
|
399
415
|
prev_dim, fc_input_dim = input_dim, 0
|
|
400
416
|
for i in range(self.num_layers):
|
|
401
417
|
cross_layer_size = cin_size[i]
|
|
402
|
-
self.conv_layers.append(
|
|
418
|
+
self.conv_layers.append(
|
|
419
|
+
torch.nn.Conv1d(
|
|
420
|
+
input_dim * prev_dim,
|
|
421
|
+
cross_layer_size,
|
|
422
|
+
1,
|
|
423
|
+
stride=1,
|
|
424
|
+
dilation=1,
|
|
425
|
+
bias=True,
|
|
426
|
+
)
|
|
427
|
+
)
|
|
403
428
|
if self.split_half and i != self.num_layers - 1:
|
|
404
429
|
cross_layer_size //= 2
|
|
405
430
|
prev_dim = cross_layer_size
|
|
@@ -421,6 +446,7 @@ class CIN(nn.Module):
|
|
|
421
446
|
xs.append(x)
|
|
422
447
|
return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
|
|
423
448
|
|
|
449
|
+
|
|
424
450
|
class CrossLayer(nn.Module):
|
|
425
451
|
"""Single cross layer used in DCN (Wang et al., 2017)."""
|
|
426
452
|
|
|
@@ -440,8 +466,12 @@ class CrossNetwork(nn.Module):
|
|
|
440
466
|
def __init__(self, input_dim, num_layers):
|
|
441
467
|
super().__init__()
|
|
442
468
|
self.num_layers = num_layers
|
|
443
|
-
self.w = torch.nn.ModuleList(
|
|
444
|
-
|
|
469
|
+
self.w = torch.nn.ModuleList(
|
|
470
|
+
[torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)]
|
|
471
|
+
)
|
|
472
|
+
self.b = torch.nn.ParameterList(
|
|
473
|
+
[torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)]
|
|
474
|
+
)
|
|
445
475
|
|
|
446
476
|
def forward(self, x):
|
|
447
477
|
"""
|
|
@@ -453,21 +483,30 @@ class CrossNetwork(nn.Module):
|
|
|
453
483
|
x = x0 * xw + self.b[i] + x
|
|
454
484
|
return x
|
|
455
485
|
|
|
486
|
+
|
|
456
487
|
class CrossNetV2(nn.Module):
|
|
457
488
|
"""Vector-wise cross network proposed in DCN V2 (Wang et al., 2021)."""
|
|
489
|
+
|
|
458
490
|
def __init__(self, input_dim, num_layers):
|
|
459
491
|
super().__init__()
|
|
460
492
|
self.num_layers = num_layers
|
|
461
|
-
self.w = torch.nn.ModuleList(
|
|
462
|
-
|
|
463
|
-
|
|
493
|
+
self.w = torch.nn.ModuleList(
|
|
494
|
+
[
|
|
495
|
+
torch.nn.Linear(input_dim, input_dim, bias=False)
|
|
496
|
+
for _ in range(num_layers)
|
|
497
|
+
]
|
|
498
|
+
)
|
|
499
|
+
self.b = torch.nn.ParameterList(
|
|
500
|
+
[torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)]
|
|
501
|
+
)
|
|
464
502
|
|
|
465
503
|
def forward(self, x):
|
|
466
504
|
x0 = x
|
|
467
505
|
for i in range(self.num_layers):
|
|
468
|
-
x =x0*self.w[i](x) + self.b[i] + x
|
|
506
|
+
x = x0 * self.w[i](x) + self.b[i] + x
|
|
469
507
|
return x
|
|
470
508
|
|
|
509
|
+
|
|
471
510
|
class CrossNetMix(nn.Module):
|
|
472
511
|
"""Mixture of low-rank cross experts from DCN V2 (Wang et al., 2021)."""
|
|
473
512
|
|
|
@@ -477,18 +516,46 @@ class CrossNetMix(nn.Module):
|
|
|
477
516
|
self.num_experts = num_experts
|
|
478
517
|
|
|
479
518
|
# U: (input_dim, low_rank)
|
|
480
|
-
self.u_list = torch.nn.ParameterList(
|
|
481
|
-
|
|
519
|
+
self.u_list = torch.nn.ParameterList(
|
|
520
|
+
[
|
|
521
|
+
nn.Parameter(
|
|
522
|
+
nn.init.xavier_normal_(
|
|
523
|
+
torch.empty(num_experts, input_dim, low_rank)
|
|
524
|
+
)
|
|
525
|
+
)
|
|
526
|
+
for i in range(self.num_layers)
|
|
527
|
+
]
|
|
528
|
+
)
|
|
482
529
|
# V: (input_dim, low_rank)
|
|
483
|
-
self.v_list = torch.nn.ParameterList(
|
|
484
|
-
|
|
530
|
+
self.v_list = torch.nn.ParameterList(
|
|
531
|
+
[
|
|
532
|
+
nn.Parameter(
|
|
533
|
+
nn.init.xavier_normal_(
|
|
534
|
+
torch.empty(num_experts, input_dim, low_rank)
|
|
535
|
+
)
|
|
536
|
+
)
|
|
537
|
+
for i in range(self.num_layers)
|
|
538
|
+
]
|
|
539
|
+
)
|
|
485
540
|
# C: (low_rank, low_rank)
|
|
486
|
-
self.c_list = torch.nn.ParameterList(
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
541
|
+
self.c_list = torch.nn.ParameterList(
|
|
542
|
+
[
|
|
543
|
+
nn.Parameter(
|
|
544
|
+
nn.init.xavier_normal_(torch.empty(num_experts, low_rank, low_rank))
|
|
545
|
+
)
|
|
546
|
+
for i in range(self.num_layers)
|
|
547
|
+
]
|
|
548
|
+
)
|
|
549
|
+
self.gating = nn.ModuleList(
|
|
550
|
+
[nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)]
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
self.bias = torch.nn.ParameterList(
|
|
554
|
+
[
|
|
555
|
+
nn.Parameter(nn.init.zeros_(torch.empty(input_dim, 1)))
|
|
556
|
+
for i in range(self.num_layers)
|
|
557
|
+
]
|
|
558
|
+
)
|
|
492
559
|
|
|
493
560
|
def forward(self, x):
|
|
494
561
|
x_0 = x.unsqueeze(2) # (bs, in_features, 1)
|
|
@@ -503,7 +570,9 @@ class CrossNetMix(nn.Module):
|
|
|
503
570
|
|
|
504
571
|
# (2) E(x_l)
|
|
505
572
|
# project the input x_l to $\mathbb{R}^{r}$
|
|
506
|
-
v_x = torch.matmul(
|
|
573
|
+
v_x = torch.matmul(
|
|
574
|
+
self.v_list[i][expert_id].t(), x_l
|
|
575
|
+
) # (bs, low_rank, 1)
|
|
507
576
|
|
|
508
577
|
# nonlinear activation in low rank space
|
|
509
578
|
v_x = torch.tanh(v_x)
|
|
@@ -511,7 +580,9 @@ class CrossNetMix(nn.Module):
|
|
|
511
580
|
v_x = torch.tanh(v_x)
|
|
512
581
|
|
|
513
582
|
# project back to $\mathbb{R}^{d}$
|
|
514
|
-
uv_x = torch.matmul(
|
|
583
|
+
uv_x = torch.matmul(
|
|
584
|
+
self.u_list[i][expert_id], v_x
|
|
585
|
+
) # (bs, in_features, 1)
|
|
515
586
|
|
|
516
587
|
dot_ = uv_x + self.bias[i]
|
|
517
588
|
dot_ = x_0 * dot_ # Hadamard-product
|
|
@@ -519,53 +590,78 @@ class CrossNetMix(nn.Module):
|
|
|
519
590
|
output_of_experts.append(dot_.squeeze(2))
|
|
520
591
|
|
|
521
592
|
# (3) mixture of low-rank experts
|
|
522
|
-
output_of_experts = torch.stack(
|
|
523
|
-
|
|
593
|
+
output_of_experts = torch.stack(
|
|
594
|
+
output_of_experts, 2
|
|
595
|
+
) # (bs, in_features, num_experts)
|
|
596
|
+
gating_score_experts = torch.stack(
|
|
597
|
+
gating_score_experts, 1
|
|
598
|
+
) # (bs, num_experts, 1)
|
|
524
599
|
moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
|
|
525
600
|
x_l = moe_out + x_l # (bs, in_features, 1)
|
|
526
601
|
|
|
527
602
|
x_l = x_l.squeeze() # (bs, in_features)
|
|
528
603
|
return x_l
|
|
529
604
|
|
|
605
|
+
|
|
530
606
|
class SENETLayer(nn.Module):
|
|
531
607
|
"""Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
|
|
532
608
|
|
|
533
609
|
def __init__(self, num_fields, reduction_ratio=3):
|
|
534
610
|
super(SENETLayer, self).__init__()
|
|
535
|
-
reduced_size = max(1, int(num_fields/ reduction_ratio))
|
|
536
|
-
self.mlp = nn.Sequential(
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
611
|
+
reduced_size = max(1, int(num_fields / reduction_ratio))
|
|
612
|
+
self.mlp = nn.Sequential(
|
|
613
|
+
nn.Linear(num_fields, reduced_size, bias=False),
|
|
614
|
+
nn.ReLU(),
|
|
615
|
+
nn.Linear(reduced_size, num_fields, bias=False),
|
|
616
|
+
nn.ReLU(),
|
|
617
|
+
)
|
|
618
|
+
|
|
540
619
|
def forward(self, x):
|
|
541
620
|
z = torch.mean(x, dim=-1, out=None)
|
|
542
621
|
a = self.mlp(z)
|
|
543
|
-
v = x*a.unsqueeze(-1)
|
|
622
|
+
v = x * a.unsqueeze(-1)
|
|
544
623
|
return v
|
|
545
624
|
|
|
625
|
+
|
|
546
626
|
class BiLinearInteractionLayer(nn.Module):
|
|
547
627
|
"""Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
|
|
548
628
|
|
|
549
|
-
def __init__(self, input_dim, num_fields, bilinear_type
|
|
629
|
+
def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
|
|
550
630
|
super(BiLinearInteractionLayer, self).__init__()
|
|
551
631
|
self.bilinear_type = bilinear_type
|
|
552
632
|
if self.bilinear_type == "field_all":
|
|
553
633
|
self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
|
|
554
634
|
elif self.bilinear_type == "field_each":
|
|
555
|
-
self.bilinear_layer = nn.ModuleList(
|
|
635
|
+
self.bilinear_layer = nn.ModuleList(
|
|
636
|
+
[nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)]
|
|
637
|
+
)
|
|
556
638
|
elif self.bilinear_type == "field_interaction":
|
|
557
|
-
self.bilinear_layer = nn.ModuleList(
|
|
639
|
+
self.bilinear_layer = nn.ModuleList(
|
|
640
|
+
[
|
|
641
|
+
nn.Linear(input_dim, input_dim, bias=False)
|
|
642
|
+
for i, j in combinations(range(num_fields), 2)
|
|
643
|
+
]
|
|
644
|
+
)
|
|
558
645
|
else:
|
|
559
646
|
raise NotImplementedError()
|
|
560
647
|
|
|
561
648
|
def forward(self, x):
|
|
562
649
|
feature_emb = torch.split(x, 1, dim=1)
|
|
563
650
|
if self.bilinear_type == "field_all":
|
|
564
|
-
bilinear_list = [
|
|
651
|
+
bilinear_list = [
|
|
652
|
+
self.bilinear_layer(v_i) * v_j
|
|
653
|
+
for v_i, v_j in combinations(feature_emb, 2)
|
|
654
|
+
]
|
|
565
655
|
elif self.bilinear_type == "field_each":
|
|
566
|
-
bilinear_list = [
|
|
656
|
+
bilinear_list = [
|
|
657
|
+
self.bilinear_layer[i](feature_emb[i]) * feature_emb[j]
|
|
658
|
+
for i, j in combinations(range(len(feature_emb)), 2)
|
|
659
|
+
]
|
|
567
660
|
elif self.bilinear_type == "field_interaction":
|
|
568
|
-
bilinear_list = [
|
|
661
|
+
bilinear_list = [
|
|
662
|
+
self.bilinear_layer[i](v[0]) * v[1]
|
|
663
|
+
for i, v in enumerate(combinations(feature_emb, 2))
|
|
664
|
+
]
|
|
569
665
|
return torch.cat(bilinear_list, dim=1)
|
|
570
666
|
|
|
571
667
|
|
|
@@ -578,17 +674,23 @@ class MultiInterestSA(nn.Module):
|
|
|
578
674
|
self.interest_num = interest_num
|
|
579
675
|
if hidden_dim == None:
|
|
580
676
|
self.hidden_dim = self.embedding_dim * 4
|
|
581
|
-
self.W1 = torch.nn.Parameter(
|
|
582
|
-
|
|
583
|
-
|
|
677
|
+
self.W1 = torch.nn.Parameter(
|
|
678
|
+
torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True
|
|
679
|
+
)
|
|
680
|
+
self.W2 = torch.nn.Parameter(
|
|
681
|
+
torch.rand(self.hidden_dim, self.interest_num), requires_grad=True
|
|
682
|
+
)
|
|
683
|
+
self.W3 = torch.nn.Parameter(
|
|
684
|
+
torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True
|
|
685
|
+
)
|
|
584
686
|
|
|
585
687
|
def forward(self, seq_emb, mask=None):
|
|
586
|
-
H = torch.einsum(
|
|
688
|
+
H = torch.einsum("bse, ed -> bsd", seq_emb, self.W1).tanh()
|
|
587
689
|
if mask != None:
|
|
588
|
-
A = torch.einsum(
|
|
690
|
+
A = torch.einsum("bsd, dk -> bsk", H, self.W2) + -1.0e9 * (1 - mask.float())
|
|
589
691
|
A = F.softmax(A, dim=1)
|
|
590
692
|
else:
|
|
591
|
-
A = F.softmax(torch.einsum(
|
|
693
|
+
A = F.softmax(torch.einsum("bsd, dk -> bsk", H, self.W2), dim=1)
|
|
592
694
|
A = A.permute(0, 2, 1)
|
|
593
695
|
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
594
696
|
return multi_interest_emb
|
|
@@ -597,7 +699,15 @@ class MultiInterestSA(nn.Module):
|
|
|
597
699
|
class CapsuleNetwork(nn.Module):
|
|
598
700
|
"""Dynamic routing capsule network used in MIND (Li et al., 2019)."""
|
|
599
701
|
|
|
600
|
-
def __init__(
|
|
702
|
+
def __init__(
|
|
703
|
+
self,
|
|
704
|
+
embedding_dim,
|
|
705
|
+
seq_len,
|
|
706
|
+
bilinear_type=2,
|
|
707
|
+
interest_num=4,
|
|
708
|
+
routing_times=3,
|
|
709
|
+
relu_layer=False,
|
|
710
|
+
):
|
|
601
711
|
super(CapsuleNetwork, self).__init__()
|
|
602
712
|
self.embedding_dim = embedding_dim # h
|
|
603
713
|
self.seq_len = seq_len # s
|
|
@@ -607,13 +717,24 @@ class CapsuleNetwork(nn.Module):
|
|
|
607
717
|
|
|
608
718
|
self.relu_layer = relu_layer
|
|
609
719
|
self.stop_grad = True
|
|
610
|
-
self.relu = nn.Sequential(
|
|
720
|
+
self.relu = nn.Sequential(
|
|
721
|
+
nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU()
|
|
722
|
+
)
|
|
611
723
|
if self.bilinear_type == 0: # MIND
|
|
612
724
|
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
613
725
|
elif self.bilinear_type == 1:
|
|
614
|
-
self.linear = nn.Linear(
|
|
726
|
+
self.linear = nn.Linear(
|
|
727
|
+
self.embedding_dim, self.embedding_dim * self.interest_num, bias=False
|
|
728
|
+
)
|
|
615
729
|
else:
|
|
616
|
-
self.w = nn.Parameter(
|
|
730
|
+
self.w = nn.Parameter(
|
|
731
|
+
torch.Tensor(
|
|
732
|
+
1,
|
|
733
|
+
self.seq_len,
|
|
734
|
+
self.interest_num * self.embedding_dim,
|
|
735
|
+
self.embedding_dim,
|
|
736
|
+
)
|
|
737
|
+
)
|
|
617
738
|
nn.init.xavier_uniform_(self.w)
|
|
618
739
|
|
|
619
740
|
def forward(self, item_eb, mask):
|
|
@@ -624,11 +745,15 @@ class CapsuleNetwork(nn.Module):
|
|
|
624
745
|
item_eb_hat = self.linear(item_eb)
|
|
625
746
|
else:
|
|
626
747
|
u = torch.unsqueeze(item_eb, dim=2)
|
|
627
|
-
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
748
|
+
item_eb_hat = torch.sum(self.w[:, : self.seq_len, :, :] * u, dim=3)
|
|
628
749
|
|
|
629
|
-
item_eb_hat = torch.reshape(
|
|
750
|
+
item_eb_hat = torch.reshape(
|
|
751
|
+
item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim)
|
|
752
|
+
)
|
|
630
753
|
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
631
|
-
item_eb_hat = torch.reshape(
|
|
754
|
+
item_eb_hat = torch.reshape(
|
|
755
|
+
item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim)
|
|
756
|
+
)
|
|
632
757
|
|
|
633
758
|
if self.stop_grad:
|
|
634
759
|
item_eb_hat_iter = item_eb_hat.detach()
|
|
@@ -636,34 +761,47 @@ class CapsuleNetwork(nn.Module):
|
|
|
636
761
|
item_eb_hat_iter = item_eb_hat
|
|
637
762
|
|
|
638
763
|
if self.bilinear_type > 0:
|
|
639
|
-
capsule_weight = torch.zeros(
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
764
|
+
capsule_weight = torch.zeros(
|
|
765
|
+
item_eb_hat.shape[0],
|
|
766
|
+
self.interest_num,
|
|
767
|
+
self.seq_len,
|
|
768
|
+
device=item_eb.device,
|
|
769
|
+
requires_grad=False,
|
|
770
|
+
)
|
|
644
771
|
else:
|
|
645
|
-
capsule_weight = torch.randn(
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
772
|
+
capsule_weight = torch.randn(
|
|
773
|
+
item_eb_hat.shape[0],
|
|
774
|
+
self.interest_num,
|
|
775
|
+
self.seq_len,
|
|
776
|
+
device=item_eb.device,
|
|
777
|
+
requires_grad=False,
|
|
778
|
+
)
|
|
650
779
|
|
|
651
780
|
for i in range(self.routing_times): # 动态路由传播3次
|
|
652
781
|
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
653
782
|
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
654
783
|
|
|
655
784
|
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
656
|
-
capsule_softmax_weight = torch.where(
|
|
785
|
+
capsule_softmax_weight = torch.where(
|
|
786
|
+
torch.eq(atten_mask, 0), paddings, capsule_softmax_weight
|
|
787
|
+
)
|
|
657
788
|
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
658
789
|
|
|
659
790
|
if i < 2:
|
|
660
|
-
interest_capsule = torch.matmul(
|
|
791
|
+
interest_capsule = torch.matmul(
|
|
792
|
+
capsule_softmax_weight, item_eb_hat_iter
|
|
793
|
+
)
|
|
661
794
|
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
662
795
|
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
663
796
|
interest_capsule = scalar_factor * interest_capsule
|
|
664
797
|
|
|
665
|
-
delta_weight = torch.matmul(
|
|
666
|
-
|
|
798
|
+
delta_weight = torch.matmul(
|
|
799
|
+
item_eb_hat_iter,
|
|
800
|
+
torch.transpose(interest_capsule, 2, 3).contiguous(),
|
|
801
|
+
)
|
|
802
|
+
delta_weight = torch.reshape(
|
|
803
|
+
delta_weight, (-1, self.interest_num, self.seq_len)
|
|
804
|
+
)
|
|
667
805
|
capsule_weight = capsule_weight + delta_weight
|
|
668
806
|
else:
|
|
669
807
|
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
@@ -671,7 +809,9 @@ class CapsuleNetwork(nn.Module):
|
|
|
671
809
|
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
672
810
|
interest_capsule = scalar_factor * interest_capsule
|
|
673
811
|
|
|
674
|
-
interest_capsule = torch.reshape(
|
|
812
|
+
interest_capsule = torch.reshape(
|
|
813
|
+
interest_capsule, (-1, self.interest_num, self.embedding_dim)
|
|
814
|
+
)
|
|
675
815
|
|
|
676
816
|
if self.relu_layer:
|
|
677
817
|
interest_capsule = self.relu(interest_capsule)
|
|
@@ -683,18 +823,18 @@ class FFM(nn.Module):
|
|
|
683
823
|
"""Field-aware Factorization Machine (Juan et al., 2016)."""
|
|
684
824
|
|
|
685
825
|
def __init__(self, num_fields, reduce_sum=True):
|
|
686
|
-
super().__init__()
|
|
826
|
+
super().__init__()
|
|
687
827
|
self.num_fields = num_fields
|
|
688
828
|
self.reduce_sum = reduce_sum
|
|
689
829
|
|
|
690
830
|
def forward(self, x):
|
|
691
831
|
# compute (non-redundant) second order field-aware feature crossings
|
|
692
832
|
crossed_embeddings = []
|
|
693
|
-
for i in range(self.num_fields-1):
|
|
694
|
-
for j in range(i+1, self.num_fields):
|
|
695
|
-
crossed_embeddings.append(x[:, i, j, :] *
|
|
833
|
+
for i in range(self.num_fields - 1):
|
|
834
|
+
for j in range(i + 1, self.num_fields):
|
|
835
|
+
crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
|
|
696
836
|
crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
|
|
697
|
-
|
|
837
|
+
|
|
698
838
|
# if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
|
|
699
839
|
if self.reduce_sum:
|
|
700
840
|
crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
|
|
@@ -705,49 +845,57 @@ class CEN(nn.Module):
|
|
|
705
845
|
"""Field-attentive interaction network from FAT-DeepFFM (Wang et al., 2020)."""
|
|
706
846
|
|
|
707
847
|
def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
|
|
708
|
-
super().__init__()
|
|
709
|
-
|
|
848
|
+
super().__init__()
|
|
849
|
+
|
|
710
850
|
# convolution weight (Eq.7 FAT-DeepFFM)
|
|
711
|
-
self.u = torch.nn.Parameter(
|
|
851
|
+
self.u = torch.nn.Parameter(
|
|
852
|
+
torch.rand(num_field_crosses, embed_dim), requires_grad=True
|
|
853
|
+
)
|
|
712
854
|
|
|
713
855
|
# two FC layers that computes the field attention
|
|
714
|
-
self.mlp_att = MLP(
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
856
|
+
self.mlp_att = MLP(
|
|
857
|
+
num_field_crosses,
|
|
858
|
+
dims=[num_field_crosses // reduction_ratio, num_field_crosses],
|
|
859
|
+
output_layer=False,
|
|
860
|
+
activation="relu",
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
def forward(self, em):
|
|
718
864
|
# compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
|
|
719
865
|
d = F.relu((self.u.squeeze(0) * em).sum(-1))
|
|
720
|
-
|
|
721
|
-
# compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
|
|
722
|
-
s = self.mlp_att(d)
|
|
866
|
+
|
|
867
|
+
# compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
|
|
868
|
+
s = self.mlp_att(d)
|
|
723
869
|
|
|
724
870
|
# rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
|
|
725
|
-
aem = s.unsqueeze(-1) * em
|
|
871
|
+
aem = s.unsqueeze(-1) * em
|
|
726
872
|
return aem.flatten(start_dim=1)
|
|
727
873
|
|
|
728
874
|
|
|
729
875
|
class MultiHeadSelfAttention(nn.Module):
|
|
730
876
|
"""Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
|
|
731
|
-
|
|
877
|
+
|
|
732
878
|
def __init__(self, embedding_dim, num_heads=2, dropout=0.0, use_residual=True):
|
|
733
879
|
super().__init__()
|
|
734
880
|
if embedding_dim % num_heads != 0:
|
|
735
|
-
raise ValueError(
|
|
736
|
-
|
|
881
|
+
raise ValueError(
|
|
882
|
+
f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
|
|
883
|
+
)
|
|
884
|
+
|
|
737
885
|
self.embedding_dim = embedding_dim
|
|
738
886
|
self.num_heads = num_heads
|
|
739
887
|
self.head_dim = embedding_dim // num_heads
|
|
740
888
|
self.use_residual = use_residual
|
|
741
|
-
|
|
889
|
+
|
|
742
890
|
self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
743
891
|
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
744
892
|
self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
745
|
-
|
|
893
|
+
|
|
746
894
|
if self.use_residual:
|
|
747
895
|
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
748
|
-
|
|
896
|
+
|
|
749
897
|
self.dropout = nn.Dropout(dropout)
|
|
750
|
-
|
|
898
|
+
|
|
751
899
|
def forward(self, x):
|
|
752
900
|
"""
|
|
753
901
|
Args:
|
|
@@ -756,37 +904,47 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
756
904
|
output: [batch_size, num_fields, embedding_dim]
|
|
757
905
|
"""
|
|
758
906
|
batch_size, num_fields, _ = x.shape
|
|
759
|
-
|
|
907
|
+
|
|
760
908
|
# Linear projections
|
|
761
909
|
Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
|
|
762
910
|
K = self.W_K(x)
|
|
763
911
|
V = self.W_V(x)
|
|
764
|
-
|
|
912
|
+
|
|
765
913
|
# Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
|
|
766
|
-
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
914
|
+
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
915
|
+
1, 2
|
|
916
|
+
)
|
|
917
|
+
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
918
|
+
1, 2
|
|
919
|
+
)
|
|
920
|
+
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
921
|
+
1, 2
|
|
922
|
+
)
|
|
923
|
+
|
|
770
924
|
# Attention scores
|
|
771
|
-
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim
|
|
925
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
772
926
|
attention_weights = F.softmax(scores, dim=-1)
|
|
773
927
|
attention_weights = self.dropout(attention_weights)
|
|
774
|
-
|
|
928
|
+
|
|
775
929
|
# Apply attention to values
|
|
776
|
-
attention_output = torch.matmul(
|
|
777
|
-
|
|
930
|
+
attention_output = torch.matmul(
|
|
931
|
+
attention_weights, V
|
|
932
|
+
) # [batch_size, num_heads, num_fields, head_dim]
|
|
933
|
+
|
|
778
934
|
# Concatenate heads
|
|
779
935
|
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
780
|
-
attention_output = attention_output.view(
|
|
781
|
-
|
|
936
|
+
attention_output = attention_output.view(
|
|
937
|
+
batch_size, num_fields, self.embedding_dim
|
|
938
|
+
)
|
|
939
|
+
|
|
782
940
|
# Residual connection
|
|
783
941
|
if self.use_residual:
|
|
784
942
|
output = attention_output + self.W_Res(x)
|
|
785
943
|
else:
|
|
786
944
|
output = attention_output
|
|
787
|
-
|
|
945
|
+
|
|
788
946
|
output = F.relu(output)
|
|
789
|
-
|
|
947
|
+
|
|
790
948
|
return output
|
|
791
949
|
|
|
792
950
|
|
|
@@ -795,25 +953,31 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
795
953
|
Attention pooling layer for DIN/DIEN
|
|
796
954
|
Computes attention weights between query (candidate item) and keys (user behavior sequence)
|
|
797
955
|
"""
|
|
798
|
-
|
|
799
|
-
def __init__(
|
|
956
|
+
|
|
957
|
+
def __init__(
|
|
958
|
+
self,
|
|
959
|
+
embedding_dim,
|
|
960
|
+
hidden_units=[80, 40],
|
|
961
|
+
activation="sigmoid",
|
|
962
|
+
use_softmax=True,
|
|
963
|
+
):
|
|
800
964
|
super().__init__()
|
|
801
965
|
self.embedding_dim = embedding_dim
|
|
802
966
|
self.use_softmax = use_softmax
|
|
803
|
-
|
|
967
|
+
|
|
804
968
|
# Build attention network
|
|
805
969
|
# Input: [query, key, query-key, query*key] -> 4 * embedding_dim
|
|
806
970
|
input_dim = 4 * embedding_dim
|
|
807
971
|
layers = []
|
|
808
|
-
|
|
972
|
+
|
|
809
973
|
for hidden_unit in hidden_units:
|
|
810
974
|
layers.append(nn.Linear(input_dim, hidden_unit))
|
|
811
975
|
layers.append(activation_layer(activation))
|
|
812
976
|
input_dim = hidden_unit
|
|
813
|
-
|
|
977
|
+
|
|
814
978
|
layers.append(nn.Linear(input_dim, 1))
|
|
815
979
|
self.attention_net = nn.Sequential(*layers)
|
|
816
|
-
|
|
980
|
+
|
|
817
981
|
def forward(self, query, keys, keys_length=None, mask=None):
|
|
818
982
|
"""
|
|
819
983
|
Args:
|
|
@@ -825,48 +989,52 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
825
989
|
output: [batch_size, embedding_dim] - attention pooled representation
|
|
826
990
|
"""
|
|
827
991
|
batch_size, seq_len, emb_dim = keys.shape
|
|
828
|
-
|
|
992
|
+
|
|
829
993
|
# Expand query to match sequence length: [batch_size, seq_len, embedding_dim]
|
|
830
994
|
query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
|
|
831
|
-
|
|
995
|
+
|
|
832
996
|
# Compute attention features: [query, key, query-key, query*key]
|
|
833
|
-
attention_input = torch.cat(
|
|
834
|
-
query_expanded,
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
query_expanded * keys
|
|
838
|
-
], dim=-1) # [batch_size, seq_len, 4*embedding_dim]
|
|
839
|
-
|
|
997
|
+
attention_input = torch.cat(
|
|
998
|
+
[query_expanded, keys, query_expanded - keys, query_expanded * keys], dim=-1
|
|
999
|
+
) # [batch_size, seq_len, 4*embedding_dim]
|
|
1000
|
+
|
|
840
1001
|
# Compute attention scores
|
|
841
|
-
attention_scores = self.attention_net(
|
|
842
|
-
|
|
1002
|
+
attention_scores = self.attention_net(
|
|
1003
|
+
attention_input
|
|
1004
|
+
) # [batch_size, seq_len, 1]
|
|
1005
|
+
|
|
843
1006
|
# Apply mask if provided
|
|
844
1007
|
if mask is not None:
|
|
845
1008
|
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
|
|
846
|
-
|
|
1009
|
+
|
|
847
1010
|
# Apply softmax to get attention weights
|
|
848
1011
|
if self.use_softmax:
|
|
849
|
-
attention_weights = F.softmax(
|
|
1012
|
+
attention_weights = F.softmax(
|
|
1013
|
+
attention_scores, dim=1
|
|
1014
|
+
) # [batch_size, seq_len, 1]
|
|
850
1015
|
else:
|
|
851
1016
|
attention_weights = attention_scores
|
|
852
|
-
|
|
1017
|
+
|
|
853
1018
|
# Weighted sum of keys
|
|
854
|
-
output = torch.sum(
|
|
855
|
-
|
|
1019
|
+
output = torch.sum(
|
|
1020
|
+
attention_weights * keys, dim=1
|
|
1021
|
+
) # [batch_size, embedding_dim]
|
|
1022
|
+
|
|
856
1023
|
return output
|
|
857
1024
|
|
|
858
1025
|
|
|
859
1026
|
class DynamicGRU(nn.Module):
|
|
860
1027
|
"""Dynamic GRU unit with auxiliary loss path from DIEN (Zhou et al., 2019)."""
|
|
1028
|
+
|
|
861
1029
|
"""
|
|
862
1030
|
GRU with dynamic routing for DIEN
|
|
863
1031
|
"""
|
|
864
|
-
|
|
1032
|
+
|
|
865
1033
|
def __init__(self, input_size, hidden_size, bias=True):
|
|
866
1034
|
super().__init__()
|
|
867
1035
|
self.input_size = input_size
|
|
868
1036
|
self.hidden_size = hidden_size
|
|
869
|
-
|
|
1037
|
+
|
|
870
1038
|
# GRU parameters
|
|
871
1039
|
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
|
|
872
1040
|
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
|
|
@@ -874,16 +1042,16 @@ class DynamicGRU(nn.Module):
|
|
|
874
1042
|
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
|
|
875
1043
|
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
|
|
876
1044
|
else:
|
|
877
|
-
self.register_parameter(
|
|
878
|
-
self.register_parameter(
|
|
879
|
-
|
|
1045
|
+
self.register_parameter("bias_ih", None)
|
|
1046
|
+
self.register_parameter("bias_hh", None)
|
|
1047
|
+
|
|
880
1048
|
self.reset_parameters()
|
|
881
|
-
|
|
1049
|
+
|
|
882
1050
|
def reset_parameters(self):
|
|
883
1051
|
std = 1.0 / (self.hidden_size) ** 0.5
|
|
884
1052
|
for weight in self.parameters():
|
|
885
1053
|
weight.data.uniform_(-std, std)
|
|
886
|
-
|
|
1054
|
+
|
|
887
1055
|
def forward(self, x, att_scores=None):
|
|
888
1056
|
"""
|
|
889
1057
|
Args:
|
|
@@ -894,60 +1062,61 @@ class DynamicGRU(nn.Module):
|
|
|
894
1062
|
hidden: [batch_size, hidden_size] - final hidden state
|
|
895
1063
|
"""
|
|
896
1064
|
batch_size, seq_len, _ = x.shape
|
|
897
|
-
|
|
1065
|
+
|
|
898
1066
|
# Initialize hidden state
|
|
899
1067
|
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
900
|
-
|
|
1068
|
+
|
|
901
1069
|
outputs = []
|
|
902
1070
|
for t in range(seq_len):
|
|
903
1071
|
x_t = x[:, t, :] # [batch_size, input_size]
|
|
904
|
-
|
|
1072
|
+
|
|
905
1073
|
# GRU computation
|
|
906
1074
|
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
907
1075
|
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
908
1076
|
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
909
1077
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
910
|
-
|
|
1078
|
+
|
|
911
1079
|
resetgate = torch.sigmoid(i_r + h_r)
|
|
912
1080
|
inputgate = torch.sigmoid(i_i + h_i)
|
|
913
1081
|
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
914
1082
|
h = newgate + inputgate * (h - newgate)
|
|
915
|
-
|
|
1083
|
+
|
|
916
1084
|
outputs.append(h.unsqueeze(1))
|
|
917
|
-
|
|
1085
|
+
|
|
918
1086
|
output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
|
|
919
|
-
|
|
1087
|
+
|
|
920
1088
|
return output, h
|
|
921
1089
|
|
|
922
1090
|
|
|
923
1091
|
class AUGRU(nn.Module):
|
|
924
1092
|
"""Attention-aware GRU update gate used in DIEN (Zhou et al., 2019)."""
|
|
1093
|
+
|
|
925
1094
|
"""
|
|
926
1095
|
Attention-based GRU for DIEN
|
|
927
1096
|
Uses attention scores to weight the update of hidden states
|
|
928
1097
|
"""
|
|
929
|
-
|
|
1098
|
+
|
|
930
1099
|
def __init__(self, input_size, hidden_size, bias=True):
|
|
931
1100
|
super().__init__()
|
|
932
1101
|
self.input_size = input_size
|
|
933
1102
|
self.hidden_size = hidden_size
|
|
934
|
-
|
|
1103
|
+
|
|
935
1104
|
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
|
|
936
1105
|
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
|
|
937
1106
|
if bias:
|
|
938
1107
|
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
|
|
939
1108
|
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
|
|
940
1109
|
else:
|
|
941
|
-
self.register_parameter(
|
|
942
|
-
self.register_parameter(
|
|
943
|
-
|
|
1110
|
+
self.register_parameter("bias_ih", None)
|
|
1111
|
+
self.register_parameter("bias_hh", None)
|
|
1112
|
+
|
|
944
1113
|
self.reset_parameters()
|
|
945
|
-
|
|
1114
|
+
|
|
946
1115
|
def reset_parameters(self):
|
|
947
1116
|
std = 1.0 / (self.hidden_size) ** 0.5
|
|
948
1117
|
for weight in self.parameters():
|
|
949
1118
|
weight.data.uniform_(-std, std)
|
|
950
|
-
|
|
1119
|
+
|
|
951
1120
|
def forward(self, x, att_scores):
|
|
952
1121
|
"""
|
|
953
1122
|
Args:
|
|
@@ -958,28 +1127,28 @@ class AUGRU(nn.Module):
|
|
|
958
1127
|
hidden: [batch_size, hidden_size] - final hidden state
|
|
959
1128
|
"""
|
|
960
1129
|
batch_size, seq_len, _ = x.shape
|
|
961
|
-
|
|
1130
|
+
|
|
962
1131
|
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
963
|
-
|
|
1132
|
+
|
|
964
1133
|
outputs = []
|
|
965
1134
|
for t in range(seq_len):
|
|
966
1135
|
x_t = x[:, t, :] # [batch_size, input_size]
|
|
967
1136
|
att_t = att_scores[:, t, :] # [batch_size, 1]
|
|
968
|
-
|
|
1137
|
+
|
|
969
1138
|
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
970
1139
|
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
971
1140
|
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
972
1141
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
973
|
-
|
|
1142
|
+
|
|
974
1143
|
resetgate = torch.sigmoid(i_r + h_r)
|
|
975
1144
|
inputgate = torch.sigmoid(i_i + h_i)
|
|
976
1145
|
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
977
|
-
|
|
1146
|
+
|
|
978
1147
|
# Use attention score to control update
|
|
979
1148
|
h = (1 - att_t) * h + att_t * newgate
|
|
980
|
-
|
|
1149
|
+
|
|
981
1150
|
outputs.append(h.unsqueeze(1))
|
|
982
|
-
|
|
1151
|
+
|
|
983
1152
|
output = torch.cat(outputs, dim=1)
|
|
984
|
-
|
|
985
|
-
return output, h
|
|
1153
|
+
|
|
1154
|
+
return output, h
|