nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
  45. nextrec-0.1.2.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py CHANGED
@@ -50,7 +50,14 @@ __all__ = [
50
50
 
51
51
 
52
52
  class PredictionLayer(nn.Module):
53
- _CLASSIFICATION_TASKS = {"classification", "binary", "ctr", "ranking", "match", "matching"}
53
+ _CLASSIFICATION_TASKS = {
54
+ "classification",
55
+ "binary",
56
+ "ctr",
57
+ "ranking",
58
+ "match",
59
+ "matching",
60
+ }
54
61
  _REGRESSION_TASKS = {"regression", "continuous"}
55
62
  _MULTICLASS_TASKS = {"multiclass", "softmax"}
56
63
 
@@ -213,7 +220,9 @@ class EmbeddingLayer(nn.Module):
213
220
  elif feature.combiner == "concat":
214
221
  pooling_layer = ConcatPooling()
215
222
  else:
216
- raise ValueError(f"Unknown combiner for {feature.name}: {feature.combiner}")
223
+ raise ValueError(
224
+ f"Unknown combiner for {feature.name}: {feature.combiner}"
225
+ )
217
226
 
218
227
  feature_mask = InputMask()(x, feature, seq_input)
219
228
  sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
@@ -245,7 +254,9 @@ class EmbeddingLayer(nn.Module):
245
254
 
246
255
  if target_dim is not None:
247
256
  aligned_dense = [
248
- emb.unsqueeze(1) for emb in dense_embeds if emb.shape[-1] == target_dim
257
+ emb.unsqueeze(1)
258
+ for emb in dense_embeds
259
+ if emb.shape[-1] == target_dim
249
260
  ]
250
261
  output_embeddings.extend(aligned_dense)
251
262
 
@@ -257,7 +268,9 @@ class EmbeddingLayer(nn.Module):
257
268
 
258
269
  return torch.cat(output_embeddings, dim=1)
259
270
 
260
- def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
271
+ def _project_dense(
272
+ self, feature: DenseFeature, x: dict[str, torch.Tensor]
273
+ ) -> torch.Tensor:
261
274
  if feature.name not in x:
262
275
  raise KeyError(f"Dense feature '{feature.name}' is missing from input.")
263
276
 
@@ -280,6 +293,7 @@ class EmbeddingLayer(nn.Module):
280
293
  def _compute_output_dim(self):
281
294
  return
282
295
 
296
+
283
297
  class InputMask(nn.Module):
284
298
  """Utility module to build sequence masks for pooling layers."""
285
299
 
@@ -289,9 +303,9 @@ class InputMask(nn.Module):
289
303
  def forward(self, x, fea, seq_tensor=None):
290
304
  values = seq_tensor if seq_tensor is not None else x[fea.name]
291
305
  if fea.padding_idx is not None:
292
- mask = (values.long() != fea.padding_idx)
306
+ mask = values.long() != fea.padding_idx
293
307
  else:
294
- mask = (values.long() != 0)
308
+ mask = values.long() != 0
295
309
  if mask.dim() == 1:
296
310
  mask = mask.unsqueeze(-1)
297
311
  return mask.unsqueeze(1).float()
@@ -319,7 +333,7 @@ class ConcatPooling(nn.Module):
319
333
  super().__init__()
320
334
 
321
335
  def forward(self, x, mask=None):
322
- return x.flatten(start_dim=1, end_dim=2)
336
+ return x.flatten(start_dim=1, end_dim=2)
323
337
 
324
338
 
325
339
  class AveragePooling(nn.Module):
@@ -353,7 +367,9 @@ class SumPooling(nn.Module):
353
367
  class MLP(nn.Module):
354
368
  """Stacked fully connected layers used in the deep component."""
355
369
 
356
- def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
370
+ def __init__(
371
+ self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"
372
+ ):
357
373
  super().__init__()
358
374
  if dims is None:
359
375
  dims = []
@@ -380,7 +396,7 @@ class FM(nn.Module):
380
396
  self.reduce_sum = reduce_sum
381
397
 
382
398
  def forward(self, x):
383
- square_of_sum = torch.sum(x, dim=1)**2
399
+ square_of_sum = torch.sum(x, dim=1) ** 2
384
400
  sum_of_square = torch.sum(x**2, dim=1)
385
401
  ix = square_of_sum - sum_of_square
386
402
  if self.reduce_sum:
@@ -399,7 +415,16 @@ class CIN(nn.Module):
399
415
  prev_dim, fc_input_dim = input_dim, 0
400
416
  for i in range(self.num_layers):
401
417
  cross_layer_size = cin_size[i]
402
- self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
418
+ self.conv_layers.append(
419
+ torch.nn.Conv1d(
420
+ input_dim * prev_dim,
421
+ cross_layer_size,
422
+ 1,
423
+ stride=1,
424
+ dilation=1,
425
+ bias=True,
426
+ )
427
+ )
403
428
  if self.split_half and i != self.num_layers - 1:
404
429
  cross_layer_size //= 2
405
430
  prev_dim = cross_layer_size
@@ -421,6 +446,7 @@ class CIN(nn.Module):
421
446
  xs.append(x)
422
447
  return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
423
448
 
449
+
424
450
  class CrossLayer(nn.Module):
425
451
  """Single cross layer used in DCN (Wang et al., 2017)."""
426
452
 
@@ -440,8 +466,12 @@ class CrossNetwork(nn.Module):
440
466
  def __init__(self, input_dim, num_layers):
441
467
  super().__init__()
442
468
  self.num_layers = num_layers
443
- self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
444
- self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
469
+ self.w = torch.nn.ModuleList(
470
+ [torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)]
471
+ )
472
+ self.b = torch.nn.ParameterList(
473
+ [torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)]
474
+ )
445
475
 
446
476
  def forward(self, x):
447
477
  """
@@ -453,21 +483,30 @@ class CrossNetwork(nn.Module):
453
483
  x = x0 * xw + self.b[i] + x
454
484
  return x
455
485
 
486
+
456
487
  class CrossNetV2(nn.Module):
457
488
  """Vector-wise cross network proposed in DCN V2 (Wang et al., 2021)."""
489
+
458
490
  def __init__(self, input_dim, num_layers):
459
491
  super().__init__()
460
492
  self.num_layers = num_layers
461
- self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
462
- self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
463
-
493
+ self.w = torch.nn.ModuleList(
494
+ [
495
+ torch.nn.Linear(input_dim, input_dim, bias=False)
496
+ for _ in range(num_layers)
497
+ ]
498
+ )
499
+ self.b = torch.nn.ParameterList(
500
+ [torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)]
501
+ )
464
502
 
465
503
  def forward(self, x):
466
504
  x0 = x
467
505
  for i in range(self.num_layers):
468
- x =x0*self.w[i](x) + self.b[i] + x
506
+ x = x0 * self.w[i](x) + self.b[i] + x
469
507
  return x
470
508
 
509
+
471
510
  class CrossNetMix(nn.Module):
472
511
  """Mixture of low-rank cross experts from DCN V2 (Wang et al., 2021)."""
473
512
 
@@ -477,18 +516,46 @@ class CrossNetMix(nn.Module):
477
516
  self.num_experts = num_experts
478
517
 
479
518
  # U: (input_dim, low_rank)
480
- self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
481
- torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
519
+ self.u_list = torch.nn.ParameterList(
520
+ [
521
+ nn.Parameter(
522
+ nn.init.xavier_normal_(
523
+ torch.empty(num_experts, input_dim, low_rank)
524
+ )
525
+ )
526
+ for i in range(self.num_layers)
527
+ ]
528
+ )
482
529
  # V: (input_dim, low_rank)
483
- self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
484
- torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
530
+ self.v_list = torch.nn.ParameterList(
531
+ [
532
+ nn.Parameter(
533
+ nn.init.xavier_normal_(
534
+ torch.empty(num_experts, input_dim, low_rank)
535
+ )
536
+ )
537
+ for i in range(self.num_layers)
538
+ ]
539
+ )
485
540
  # C: (low_rank, low_rank)
486
- self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
487
- torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
488
- self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
489
-
490
- self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
491
- torch.empty(input_dim, 1))) for i in range(self.num_layers)])
541
+ self.c_list = torch.nn.ParameterList(
542
+ [
543
+ nn.Parameter(
544
+ nn.init.xavier_normal_(torch.empty(num_experts, low_rank, low_rank))
545
+ )
546
+ for i in range(self.num_layers)
547
+ ]
548
+ )
549
+ self.gating = nn.ModuleList(
550
+ [nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)]
551
+ )
552
+
553
+ self.bias = torch.nn.ParameterList(
554
+ [
555
+ nn.Parameter(nn.init.zeros_(torch.empty(input_dim, 1)))
556
+ for i in range(self.num_layers)
557
+ ]
558
+ )
492
559
 
493
560
  def forward(self, x):
494
561
  x_0 = x.unsqueeze(2) # (bs, in_features, 1)
@@ -503,7 +570,9 @@ class CrossNetMix(nn.Module):
503
570
 
504
571
  # (2) E(x_l)
505
572
  # project the input x_l to $\mathbb{R}^{r}$
506
- v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
573
+ v_x = torch.matmul(
574
+ self.v_list[i][expert_id].t(), x_l
575
+ ) # (bs, low_rank, 1)
507
576
 
508
577
  # nonlinear activation in low rank space
509
578
  v_x = torch.tanh(v_x)
@@ -511,7 +580,9 @@ class CrossNetMix(nn.Module):
511
580
  v_x = torch.tanh(v_x)
512
581
 
513
582
  # project back to $\mathbb{R}^{d}$
514
- uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
583
+ uv_x = torch.matmul(
584
+ self.u_list[i][expert_id], v_x
585
+ ) # (bs, in_features, 1)
515
586
 
516
587
  dot_ = uv_x + self.bias[i]
517
588
  dot_ = x_0 * dot_ # Hadamard-product
@@ -519,53 +590,78 @@ class CrossNetMix(nn.Module):
519
590
  output_of_experts.append(dot_.squeeze(2))
520
591
 
521
592
  # (3) mixture of low-rank experts
522
- output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
523
- gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
593
+ output_of_experts = torch.stack(
594
+ output_of_experts, 2
595
+ ) # (bs, in_features, num_experts)
596
+ gating_score_experts = torch.stack(
597
+ gating_score_experts, 1
598
+ ) # (bs, num_experts, 1)
524
599
  moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
525
600
  x_l = moe_out + x_l # (bs, in_features, 1)
526
601
 
527
602
  x_l = x_l.squeeze() # (bs, in_features)
528
603
  return x_l
529
604
 
605
+
530
606
  class SENETLayer(nn.Module):
531
607
  """Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
532
608
 
533
609
  def __init__(self, num_fields, reduction_ratio=3):
534
610
  super(SENETLayer, self).__init__()
535
- reduced_size = max(1, int(num_fields/ reduction_ratio))
536
- self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
537
- nn.ReLU(),
538
- nn.Linear(reduced_size, num_fields, bias=False),
539
- nn.ReLU())
611
+ reduced_size = max(1, int(num_fields / reduction_ratio))
612
+ self.mlp = nn.Sequential(
613
+ nn.Linear(num_fields, reduced_size, bias=False),
614
+ nn.ReLU(),
615
+ nn.Linear(reduced_size, num_fields, bias=False),
616
+ nn.ReLU(),
617
+ )
618
+
540
619
  def forward(self, x):
541
620
  z = torch.mean(x, dim=-1, out=None)
542
621
  a = self.mlp(z)
543
- v = x*a.unsqueeze(-1)
622
+ v = x * a.unsqueeze(-1)
544
623
  return v
545
624
 
625
+
546
626
  class BiLinearInteractionLayer(nn.Module):
547
627
  """Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
548
628
 
549
- def __init__(self, input_dim, num_fields, bilinear_type = "field_interaction"):
629
+ def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
550
630
  super(BiLinearInteractionLayer, self).__init__()
551
631
  self.bilinear_type = bilinear_type
552
632
  if self.bilinear_type == "field_all":
553
633
  self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
554
634
  elif self.bilinear_type == "field_each":
555
- self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
635
+ self.bilinear_layer = nn.ModuleList(
636
+ [nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)]
637
+ )
556
638
  elif self.bilinear_type == "field_interaction":
557
- self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i,j in combinations(range(num_fields), 2)])
639
+ self.bilinear_layer = nn.ModuleList(
640
+ [
641
+ nn.Linear(input_dim, input_dim, bias=False)
642
+ for i, j in combinations(range(num_fields), 2)
643
+ ]
644
+ )
558
645
  else:
559
646
  raise NotImplementedError()
560
647
 
561
648
  def forward(self, x):
562
649
  feature_emb = torch.split(x, 1, dim=1)
563
650
  if self.bilinear_type == "field_all":
564
- bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
651
+ bilinear_list = [
652
+ self.bilinear_layer(v_i) * v_j
653
+ for v_i, v_j in combinations(feature_emb, 2)
654
+ ]
565
655
  elif self.bilinear_type == "field_each":
566
- bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)]
656
+ bilinear_list = [
657
+ self.bilinear_layer[i](feature_emb[i]) * feature_emb[j]
658
+ for i, j in combinations(range(len(feature_emb)), 2)
659
+ ]
567
660
  elif self.bilinear_type == "field_interaction":
568
- bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
661
+ bilinear_list = [
662
+ self.bilinear_layer[i](v[0]) * v[1]
663
+ for i, v in enumerate(combinations(feature_emb, 2))
664
+ ]
569
665
  return torch.cat(bilinear_list, dim=1)
570
666
 
571
667
 
@@ -578,17 +674,23 @@ class MultiInterestSA(nn.Module):
578
674
  self.interest_num = interest_num
579
675
  if hidden_dim == None:
580
676
  self.hidden_dim = self.embedding_dim * 4
581
- self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
582
- self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
583
- self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
677
+ self.W1 = torch.nn.Parameter(
678
+ torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True
679
+ )
680
+ self.W2 = torch.nn.Parameter(
681
+ torch.rand(self.hidden_dim, self.interest_num), requires_grad=True
682
+ )
683
+ self.W3 = torch.nn.Parameter(
684
+ torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True
685
+ )
584
686
 
585
687
  def forward(self, seq_emb, mask=None):
586
- H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
688
+ H = torch.einsum("bse, ed -> bsd", seq_emb, self.W1).tanh()
587
689
  if mask != None:
588
- A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
690
+ A = torch.einsum("bsd, dk -> bsk", H, self.W2) + -1.0e9 * (1 - mask.float())
589
691
  A = F.softmax(A, dim=1)
590
692
  else:
591
- A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
693
+ A = F.softmax(torch.einsum("bsd, dk -> bsk", H, self.W2), dim=1)
592
694
  A = A.permute(0, 2, 1)
593
695
  multi_interest_emb = torch.matmul(A, seq_emb)
594
696
  return multi_interest_emb
@@ -597,7 +699,15 @@ class MultiInterestSA(nn.Module):
597
699
  class CapsuleNetwork(nn.Module):
598
700
  """Dynamic routing capsule network used in MIND (Li et al., 2019)."""
599
701
 
600
- def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
702
+ def __init__(
703
+ self,
704
+ embedding_dim,
705
+ seq_len,
706
+ bilinear_type=2,
707
+ interest_num=4,
708
+ routing_times=3,
709
+ relu_layer=False,
710
+ ):
601
711
  super(CapsuleNetwork, self).__init__()
602
712
  self.embedding_dim = embedding_dim # h
603
713
  self.seq_len = seq_len # s
@@ -607,13 +717,24 @@ class CapsuleNetwork(nn.Module):
607
717
 
608
718
  self.relu_layer = relu_layer
609
719
  self.stop_grad = True
610
- self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
720
+ self.relu = nn.Sequential(
721
+ nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU()
722
+ )
611
723
  if self.bilinear_type == 0: # MIND
612
724
  self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
613
725
  elif self.bilinear_type == 1:
614
- self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
726
+ self.linear = nn.Linear(
727
+ self.embedding_dim, self.embedding_dim * self.interest_num, bias=False
728
+ )
615
729
  else:
616
- self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
730
+ self.w = nn.Parameter(
731
+ torch.Tensor(
732
+ 1,
733
+ self.seq_len,
734
+ self.interest_num * self.embedding_dim,
735
+ self.embedding_dim,
736
+ )
737
+ )
617
738
  nn.init.xavier_uniform_(self.w)
618
739
 
619
740
  def forward(self, item_eb, mask):
@@ -624,11 +745,15 @@ class CapsuleNetwork(nn.Module):
624
745
  item_eb_hat = self.linear(item_eb)
625
746
  else:
626
747
  u = torch.unsqueeze(item_eb, dim=2)
627
- item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
748
+ item_eb_hat = torch.sum(self.w[:, : self.seq_len, :, :] * u, dim=3)
628
749
 
629
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
750
+ item_eb_hat = torch.reshape(
751
+ item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim)
752
+ )
630
753
  item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
631
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
754
+ item_eb_hat = torch.reshape(
755
+ item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim)
756
+ )
632
757
 
633
758
  if self.stop_grad:
634
759
  item_eb_hat_iter = item_eb_hat.detach()
@@ -636,34 +761,47 @@ class CapsuleNetwork(nn.Module):
636
761
  item_eb_hat_iter = item_eb_hat
637
762
 
638
763
  if self.bilinear_type > 0:
639
- capsule_weight = torch.zeros(item_eb_hat.shape[0],
640
- self.interest_num,
641
- self.seq_len,
642
- device=item_eb.device,
643
- requires_grad=False)
764
+ capsule_weight = torch.zeros(
765
+ item_eb_hat.shape[0],
766
+ self.interest_num,
767
+ self.seq_len,
768
+ device=item_eb.device,
769
+ requires_grad=False,
770
+ )
644
771
  else:
645
- capsule_weight = torch.randn(item_eb_hat.shape[0],
646
- self.interest_num,
647
- self.seq_len,
648
- device=item_eb.device,
649
- requires_grad=False)
772
+ capsule_weight = torch.randn(
773
+ item_eb_hat.shape[0],
774
+ self.interest_num,
775
+ self.seq_len,
776
+ device=item_eb.device,
777
+ requires_grad=False,
778
+ )
650
779
 
651
780
  for i in range(self.routing_times): # 动态路由传播3次
652
781
  atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
653
782
  paddings = torch.zeros_like(atten_mask, dtype=torch.float)
654
783
 
655
784
  capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
656
- capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
785
+ capsule_softmax_weight = torch.where(
786
+ torch.eq(atten_mask, 0), paddings, capsule_softmax_weight
787
+ )
657
788
  capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
658
789
 
659
790
  if i < 2:
660
- interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
791
+ interest_capsule = torch.matmul(
792
+ capsule_softmax_weight, item_eb_hat_iter
793
+ )
661
794
  cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
662
795
  scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
663
796
  interest_capsule = scalar_factor * interest_capsule
664
797
 
665
- delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
666
- delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
798
+ delta_weight = torch.matmul(
799
+ item_eb_hat_iter,
800
+ torch.transpose(interest_capsule, 2, 3).contiguous(),
801
+ )
802
+ delta_weight = torch.reshape(
803
+ delta_weight, (-1, self.interest_num, self.seq_len)
804
+ )
667
805
  capsule_weight = capsule_weight + delta_weight
668
806
  else:
669
807
  interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
@@ -671,7 +809,9 @@ class CapsuleNetwork(nn.Module):
671
809
  scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
672
810
  interest_capsule = scalar_factor * interest_capsule
673
811
 
674
- interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
812
+ interest_capsule = torch.reshape(
813
+ interest_capsule, (-1, self.interest_num, self.embedding_dim)
814
+ )
675
815
 
676
816
  if self.relu_layer:
677
817
  interest_capsule = self.relu(interest_capsule)
@@ -683,18 +823,18 @@ class FFM(nn.Module):
683
823
  """Field-aware Factorization Machine (Juan et al., 2016)."""
684
824
 
685
825
  def __init__(self, num_fields, reduce_sum=True):
686
- super().__init__()
826
+ super().__init__()
687
827
  self.num_fields = num_fields
688
828
  self.reduce_sum = reduce_sum
689
829
 
690
830
  def forward(self, x):
691
831
  # compute (non-redundant) second order field-aware feature crossings
692
832
  crossed_embeddings = []
693
- for i in range(self.num_fields-1):
694
- for j in range(i+1, self.num_fields):
695
- crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
833
+ for i in range(self.num_fields - 1):
834
+ for j in range(i + 1, self.num_fields):
835
+ crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
696
836
  crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
697
-
837
+
698
838
  # if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
699
839
  if self.reduce_sum:
700
840
  crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
@@ -705,49 +845,57 @@ class CEN(nn.Module):
705
845
  """Field-attentive interaction network from FAT-DeepFFM (Wang et al., 2020)."""
706
846
 
707
847
  def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
708
- super().__init__()
709
-
848
+ super().__init__()
849
+
710
850
  # convolution weight (Eq.7 FAT-DeepFFM)
711
- self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
851
+ self.u = torch.nn.Parameter(
852
+ torch.rand(num_field_crosses, embed_dim), requires_grad=True
853
+ )
712
854
 
713
855
  # two FC layers that computes the field attention
714
- self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses//reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
715
-
716
-
717
- def forward(self, em):
856
+ self.mlp_att = MLP(
857
+ num_field_crosses,
858
+ dims=[num_field_crosses // reduction_ratio, num_field_crosses],
859
+ output_layer=False,
860
+ activation="relu",
861
+ )
862
+
863
+ def forward(self, em):
718
864
  # compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
719
865
  d = F.relu((self.u.squeeze(0) * em).sum(-1))
720
-
721
- # compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
722
- s = self.mlp_att(d)
866
+
867
+ # compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
868
+ s = self.mlp_att(d)
723
869
 
724
870
  # rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
725
- aem = s.unsqueeze(-1) * em
871
+ aem = s.unsqueeze(-1) * em
726
872
  return aem.flatten(start_dim=1)
727
873
 
728
874
 
729
875
  class MultiHeadSelfAttention(nn.Module):
730
876
  """Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
731
-
877
+
732
878
  def __init__(self, embedding_dim, num_heads=2, dropout=0.0, use_residual=True):
733
879
  super().__init__()
734
880
  if embedding_dim % num_heads != 0:
735
- raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
736
-
881
+ raise ValueError(
882
+ f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
883
+ )
884
+
737
885
  self.embedding_dim = embedding_dim
738
886
  self.num_heads = num_heads
739
887
  self.head_dim = embedding_dim // num_heads
740
888
  self.use_residual = use_residual
741
-
889
+
742
890
  self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
743
891
  self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
744
892
  self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
745
-
893
+
746
894
  if self.use_residual:
747
895
  self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
748
-
896
+
749
897
  self.dropout = nn.Dropout(dropout)
750
-
898
+
751
899
  def forward(self, x):
752
900
  """
753
901
  Args:
@@ -756,37 +904,47 @@ class MultiHeadSelfAttention(nn.Module):
756
904
  output: [batch_size, num_fields, embedding_dim]
757
905
  """
758
906
  batch_size, num_fields, _ = x.shape
759
-
907
+
760
908
  # Linear projections
761
909
  Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
762
910
  K = self.W_K(x)
763
911
  V = self.W_V(x)
764
-
912
+
765
913
  # Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
766
- Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
767
- K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
768
- V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
769
-
914
+ Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
915
+ 1, 2
916
+ )
917
+ K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
918
+ 1, 2
919
+ )
920
+ V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
921
+ 1, 2
922
+ )
923
+
770
924
  # Attention scores
771
- scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
925
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
772
926
  attention_weights = F.softmax(scores, dim=-1)
773
927
  attention_weights = self.dropout(attention_weights)
774
-
928
+
775
929
  # Apply attention to values
776
- attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
777
-
930
+ attention_output = torch.matmul(
931
+ attention_weights, V
932
+ ) # [batch_size, num_heads, num_fields, head_dim]
933
+
778
934
  # Concatenate heads
779
935
  attention_output = attention_output.transpose(1, 2).contiguous()
780
- attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
781
-
936
+ attention_output = attention_output.view(
937
+ batch_size, num_fields, self.embedding_dim
938
+ )
939
+
782
940
  # Residual connection
783
941
  if self.use_residual:
784
942
  output = attention_output + self.W_Res(x)
785
943
  else:
786
944
  output = attention_output
787
-
945
+
788
946
  output = F.relu(output)
789
-
947
+
790
948
  return output
791
949
 
792
950
 
@@ -795,25 +953,31 @@ class AttentionPoolingLayer(nn.Module):
795
953
  Attention pooling layer for DIN/DIEN
796
954
  Computes attention weights between query (candidate item) and keys (user behavior sequence)
797
955
  """
798
-
799
- def __init__(self, embedding_dim, hidden_units=[80, 40], activation='sigmoid', use_softmax=True):
956
+
957
+ def __init__(
958
+ self,
959
+ embedding_dim,
960
+ hidden_units=[80, 40],
961
+ activation="sigmoid",
962
+ use_softmax=True,
963
+ ):
800
964
  super().__init__()
801
965
  self.embedding_dim = embedding_dim
802
966
  self.use_softmax = use_softmax
803
-
967
+
804
968
  # Build attention network
805
969
  # Input: [query, key, query-key, query*key] -> 4 * embedding_dim
806
970
  input_dim = 4 * embedding_dim
807
971
  layers = []
808
-
972
+
809
973
  for hidden_unit in hidden_units:
810
974
  layers.append(nn.Linear(input_dim, hidden_unit))
811
975
  layers.append(activation_layer(activation))
812
976
  input_dim = hidden_unit
813
-
977
+
814
978
  layers.append(nn.Linear(input_dim, 1))
815
979
  self.attention_net = nn.Sequential(*layers)
816
-
980
+
817
981
  def forward(self, query, keys, keys_length=None, mask=None):
818
982
  """
819
983
  Args:
@@ -825,48 +989,52 @@ class AttentionPoolingLayer(nn.Module):
825
989
  output: [batch_size, embedding_dim] - attention pooled representation
826
990
  """
827
991
  batch_size, seq_len, emb_dim = keys.shape
828
-
992
+
829
993
  # Expand query to match sequence length: [batch_size, seq_len, embedding_dim]
830
994
  query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
831
-
995
+
832
996
  # Compute attention features: [query, key, query-key, query*key]
833
- attention_input = torch.cat([
834
- query_expanded,
835
- keys,
836
- query_expanded - keys,
837
- query_expanded * keys
838
- ], dim=-1) # [batch_size, seq_len, 4*embedding_dim]
839
-
997
+ attention_input = torch.cat(
998
+ [query_expanded, keys, query_expanded - keys, query_expanded * keys], dim=-1
999
+ ) # [batch_size, seq_len, 4*embedding_dim]
1000
+
840
1001
  # Compute attention scores
841
- attention_scores = self.attention_net(attention_input) # [batch_size, seq_len, 1]
842
-
1002
+ attention_scores = self.attention_net(
1003
+ attention_input
1004
+ ) # [batch_size, seq_len, 1]
1005
+
843
1006
  # Apply mask if provided
844
1007
  if mask is not None:
845
1008
  attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
846
-
1009
+
847
1010
  # Apply softmax to get attention weights
848
1011
  if self.use_softmax:
849
- attention_weights = F.softmax(attention_scores, dim=1) # [batch_size, seq_len, 1]
1012
+ attention_weights = F.softmax(
1013
+ attention_scores, dim=1
1014
+ ) # [batch_size, seq_len, 1]
850
1015
  else:
851
1016
  attention_weights = attention_scores
852
-
1017
+
853
1018
  # Weighted sum of keys
854
- output = torch.sum(attention_weights * keys, dim=1) # [batch_size, embedding_dim]
855
-
1019
+ output = torch.sum(
1020
+ attention_weights * keys, dim=1
1021
+ ) # [batch_size, embedding_dim]
1022
+
856
1023
  return output
857
1024
 
858
1025
 
859
1026
  class DynamicGRU(nn.Module):
860
1027
  """Dynamic GRU unit with auxiliary loss path from DIEN (Zhou et al., 2019)."""
1028
+
861
1029
  """
862
1030
  GRU with dynamic routing for DIEN
863
1031
  """
864
-
1032
+
865
1033
  def __init__(self, input_size, hidden_size, bias=True):
866
1034
  super().__init__()
867
1035
  self.input_size = input_size
868
1036
  self.hidden_size = hidden_size
869
-
1037
+
870
1038
  # GRU parameters
871
1039
  self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
872
1040
  self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
@@ -874,16 +1042,16 @@ class DynamicGRU(nn.Module):
874
1042
  self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
875
1043
  self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
876
1044
  else:
877
- self.register_parameter('bias_ih', None)
878
- self.register_parameter('bias_hh', None)
879
-
1045
+ self.register_parameter("bias_ih", None)
1046
+ self.register_parameter("bias_hh", None)
1047
+
880
1048
  self.reset_parameters()
881
-
1049
+
882
1050
  def reset_parameters(self):
883
1051
  std = 1.0 / (self.hidden_size) ** 0.5
884
1052
  for weight in self.parameters():
885
1053
  weight.data.uniform_(-std, std)
886
-
1054
+
887
1055
  def forward(self, x, att_scores=None):
888
1056
  """
889
1057
  Args:
@@ -894,60 +1062,61 @@ class DynamicGRU(nn.Module):
894
1062
  hidden: [batch_size, hidden_size] - final hidden state
895
1063
  """
896
1064
  batch_size, seq_len, _ = x.shape
897
-
1065
+
898
1066
  # Initialize hidden state
899
1067
  h = torch.zeros(batch_size, self.hidden_size, device=x.device)
900
-
1068
+
901
1069
  outputs = []
902
1070
  for t in range(seq_len):
903
1071
  x_t = x[:, t, :] # [batch_size, input_size]
904
-
1072
+
905
1073
  # GRU computation
906
1074
  gi = F.linear(x_t, self.weight_ih, self.bias_ih)
907
1075
  gh = F.linear(h, self.weight_hh, self.bias_hh)
908
1076
  i_r, i_i, i_n = gi.chunk(3, 1)
909
1077
  h_r, h_i, h_n = gh.chunk(3, 1)
910
-
1078
+
911
1079
  resetgate = torch.sigmoid(i_r + h_r)
912
1080
  inputgate = torch.sigmoid(i_i + h_i)
913
1081
  newgate = torch.tanh(i_n + resetgate * h_n)
914
1082
  h = newgate + inputgate * (h - newgate)
915
-
1083
+
916
1084
  outputs.append(h.unsqueeze(1))
917
-
1085
+
918
1086
  output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
919
-
1087
+
920
1088
  return output, h
921
1089
 
922
1090
 
923
1091
  class AUGRU(nn.Module):
924
1092
  """Attention-aware GRU update gate used in DIEN (Zhou et al., 2019)."""
1093
+
925
1094
  """
926
1095
  Attention-based GRU for DIEN
927
1096
  Uses attention scores to weight the update of hidden states
928
1097
  """
929
-
1098
+
930
1099
  def __init__(self, input_size, hidden_size, bias=True):
931
1100
  super().__init__()
932
1101
  self.input_size = input_size
933
1102
  self.hidden_size = hidden_size
934
-
1103
+
935
1104
  self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
936
1105
  self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
937
1106
  if bias:
938
1107
  self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
939
1108
  self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
940
1109
  else:
941
- self.register_parameter('bias_ih', None)
942
- self.register_parameter('bias_hh', None)
943
-
1110
+ self.register_parameter("bias_ih", None)
1111
+ self.register_parameter("bias_hh", None)
1112
+
944
1113
  self.reset_parameters()
945
-
1114
+
946
1115
  def reset_parameters(self):
947
1116
  std = 1.0 / (self.hidden_size) ** 0.5
948
1117
  for weight in self.parameters():
949
1118
  weight.data.uniform_(-std, std)
950
-
1119
+
951
1120
  def forward(self, x, att_scores):
952
1121
  """
953
1122
  Args:
@@ -958,28 +1127,28 @@ class AUGRU(nn.Module):
958
1127
  hidden: [batch_size, hidden_size] - final hidden state
959
1128
  """
960
1129
  batch_size, seq_len, _ = x.shape
961
-
1130
+
962
1131
  h = torch.zeros(batch_size, self.hidden_size, device=x.device)
963
-
1132
+
964
1133
  outputs = []
965
1134
  for t in range(seq_len):
966
1135
  x_t = x[:, t, :] # [batch_size, input_size]
967
1136
  att_t = att_scores[:, t, :] # [batch_size, 1]
968
-
1137
+
969
1138
  gi = F.linear(x_t, self.weight_ih, self.bias_ih)
970
1139
  gh = F.linear(h, self.weight_hh, self.bias_hh)
971
1140
  i_r, i_i, i_n = gi.chunk(3, 1)
972
1141
  h_r, h_i, h_n = gh.chunk(3, 1)
973
-
1142
+
974
1143
  resetgate = torch.sigmoid(i_r + h_r)
975
1144
  inputgate = torch.sigmoid(i_i + h_i)
976
1145
  newgate = torch.tanh(i_n + resetgate * h_n)
977
-
1146
+
978
1147
  # Use attention score to control update
979
1148
  h = (1 - att_t) * h + att_t * newgate
980
-
1149
+
981
1150
  outputs.append(h.unsqueeze(1))
982
-
1151
+
983
1152
  output = torch.cat(outputs, dim=1)
984
-
985
- return output, h
1153
+
1154
+ return output, h