nextrec 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +9 -10
  4. nextrec/basic/callback.py +0 -1
  5. nextrec/basic/dataloader.py +127 -168
  6. nextrec/basic/features.py +27 -24
  7. nextrec/basic/layers.py +159 -328
  8. nextrec/basic/loggers.py +37 -50
  9. nextrec/basic/metrics.py +147 -255
  10. nextrec/basic/model.py +462 -817
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +12 -16
  13. nextrec/data/preprocessor.py +252 -276
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +22 -30
  16. nextrec/loss/match_losses.py +83 -116
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +61 -70
  19. nextrec/models/match/dssm_v2.py +51 -61
  20. nextrec/models/match/mind.py +71 -89
  21. nextrec/models/match/sdm.py +81 -93
  22. nextrec/models/match/youtube_dnn.py +53 -62
  23. nextrec/models/multi_task/esmm.py +43 -49
  24. nextrec/models/multi_task/mmoe.py +56 -65
  25. nextrec/models/multi_task/ple.py +65 -92
  26. nextrec/models/multi_task/share_bottom.py +42 -48
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +30 -39
  29. nextrec/models/ranking/autoint.py +57 -70
  30. nextrec/models/ranking/dcn.py +35 -43
  31. nextrec/models/ranking/deepfm.py +28 -34
  32. nextrec/models/ranking/dien.py +79 -115
  33. nextrec/models/ranking/din.py +60 -84
  34. nextrec/models/ranking/fibinet.py +35 -51
  35. nextrec/models/ranking/fm.py +26 -28
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +31 -30
  38. nextrec/models/ranking/widedeep.py +31 -36
  39. nextrec/models/ranking/xdeepfm.py +39 -46
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +15 -23
  43. nextrec/utils/optimizer.py +10 -14
  44. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
  45. nextrec-0.1.7.dist-info/RECORD +51 -0
  46. nextrec-0.1.4.dist-info/RECORD +0 -51
  47. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py CHANGED
@@ -50,14 +50,7 @@ __all__ = [
50
50
 
51
51
 
52
52
  class PredictionLayer(nn.Module):
53
- _CLASSIFICATION_TASKS = {
54
- "classification",
55
- "binary",
56
- "ctr",
57
- "ranking",
58
- "match",
59
- "matching",
60
- }
53
+ _CLASSIFICATION_TASKS = {"classification", "binary", "ctr", "ranking", "match", "matching"}
61
54
  _REGRESSION_TASKS = {"regression", "continuous"}
62
55
  _MULTICLASS_TASKS = {"multiclass", "softmax"}
63
56
 
@@ -220,9 +213,7 @@ class EmbeddingLayer(nn.Module):
220
213
  elif feature.combiner == "concat":
221
214
  pooling_layer = ConcatPooling()
222
215
  else:
223
- raise ValueError(
224
- f"Unknown combiner for {feature.name}: {feature.combiner}"
225
- )
216
+ raise ValueError(f"Unknown combiner for {feature.name}: {feature.combiner}")
226
217
 
227
218
  feature_mask = InputMask()(x, feature, seq_input)
228
219
  sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
@@ -254,9 +245,7 @@ class EmbeddingLayer(nn.Module):
254
245
 
255
246
  if target_dim is not None:
256
247
  aligned_dense = [
257
- emb.unsqueeze(1)
258
- for emb in dense_embeds
259
- if emb.shape[-1] == target_dim
248
+ emb.unsqueeze(1) for emb in dense_embeds if emb.shape[-1] == target_dim
260
249
  ]
261
250
  output_embeddings.extend(aligned_dense)
262
251
 
@@ -268,9 +257,7 @@ class EmbeddingLayer(nn.Module):
268
257
 
269
258
  return torch.cat(output_embeddings, dim=1)
270
259
 
271
- def _project_dense(
272
- self, feature: DenseFeature, x: dict[str, torch.Tensor]
273
- ) -> torch.Tensor:
260
+ def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
274
261
  if feature.name not in x:
275
262
  raise KeyError(f"Dense feature '{feature.name}' is missing from input.")
276
263
 
@@ -293,7 +280,6 @@ class EmbeddingLayer(nn.Module):
293
280
  def _compute_output_dim(self):
294
281
  return
295
282
 
296
-
297
283
  class InputMask(nn.Module):
298
284
  """Utility module to build sequence masks for pooling layers."""
299
285
 
@@ -303,9 +289,9 @@ class InputMask(nn.Module):
303
289
  def forward(self, x, fea, seq_tensor=None):
304
290
  values = seq_tensor if seq_tensor is not None else x[fea.name]
305
291
  if fea.padding_idx is not None:
306
- mask = values.long() != fea.padding_idx
292
+ mask = (values.long() != fea.padding_idx)
307
293
  else:
308
- mask = values.long() != 0
294
+ mask = (values.long() != 0)
309
295
  if mask.dim() == 1:
310
296
  mask = mask.unsqueeze(-1)
311
297
  return mask.unsqueeze(1).float()
@@ -333,7 +319,7 @@ class ConcatPooling(nn.Module):
333
319
  super().__init__()
334
320
 
335
321
  def forward(self, x, mask=None):
336
- return x.flatten(start_dim=1, end_dim=2)
322
+ return x.flatten(start_dim=1, end_dim=2)
337
323
 
338
324
 
339
325
  class AveragePooling(nn.Module):
@@ -367,9 +353,7 @@ class SumPooling(nn.Module):
367
353
  class MLP(nn.Module):
368
354
  """Stacked fully connected layers used in the deep component."""
369
355
 
370
- def __init__(
371
- self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"
372
- ):
356
+ def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
373
357
  super().__init__()
374
358
  if dims is None:
375
359
  dims = []
@@ -396,7 +380,7 @@ class FM(nn.Module):
396
380
  self.reduce_sum = reduce_sum
397
381
 
398
382
  def forward(self, x):
399
- square_of_sum = torch.sum(x, dim=1) ** 2
383
+ square_of_sum = torch.sum(x, dim=1)**2
400
384
  sum_of_square = torch.sum(x**2, dim=1)
401
385
  ix = square_of_sum - sum_of_square
402
386
  if self.reduce_sum:
@@ -415,16 +399,7 @@ class CIN(nn.Module):
415
399
  prev_dim, fc_input_dim = input_dim, 0
416
400
  for i in range(self.num_layers):
417
401
  cross_layer_size = cin_size[i]
418
- self.conv_layers.append(
419
- torch.nn.Conv1d(
420
- input_dim * prev_dim,
421
- cross_layer_size,
422
- 1,
423
- stride=1,
424
- dilation=1,
425
- bias=True,
426
- )
427
- )
402
+ self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
428
403
  if self.split_half and i != self.num_layers - 1:
429
404
  cross_layer_size //= 2
430
405
  prev_dim = cross_layer_size
@@ -446,7 +421,6 @@ class CIN(nn.Module):
446
421
  xs.append(x)
447
422
  return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
448
423
 
449
-
450
424
  class CrossLayer(nn.Module):
451
425
  """Single cross layer used in DCN (Wang et al., 2017)."""
452
426
 
@@ -466,12 +440,8 @@ class CrossNetwork(nn.Module):
466
440
  def __init__(self, input_dim, num_layers):
467
441
  super().__init__()
468
442
  self.num_layers = num_layers
469
- self.w = torch.nn.ModuleList(
470
- [torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)]
471
- )
472
- self.b = torch.nn.ParameterList(
473
- [torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)]
474
- )
443
+ self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
444
+ self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
475
445
 
476
446
  def forward(self, x):
477
447
  """
@@ -483,30 +453,21 @@ class CrossNetwork(nn.Module):
483
453
  x = x0 * xw + self.b[i] + x
484
454
  return x
485
455
 
486
-
487
456
  class CrossNetV2(nn.Module):
488
457
  """Vector-wise cross network proposed in DCN V2 (Wang et al., 2021)."""
489
-
490
458
  def __init__(self, input_dim, num_layers):
491
459
  super().__init__()
492
460
  self.num_layers = num_layers
493
- self.w = torch.nn.ModuleList(
494
- [
495
- torch.nn.Linear(input_dim, input_dim, bias=False)
496
- for _ in range(num_layers)
497
- ]
498
- )
499
- self.b = torch.nn.ParameterList(
500
- [torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)]
501
- )
461
+ self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
462
+ self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
463
+
502
464
 
503
465
  def forward(self, x):
504
466
  x0 = x
505
467
  for i in range(self.num_layers):
506
- x = x0 * self.w[i](x) + self.b[i] + x
468
+ x =x0*self.w[i](x) + self.b[i] + x
507
469
  return x
508
470
 
509
-
510
471
  class CrossNetMix(nn.Module):
511
472
  """Mixture of low-rank cross experts from DCN V2 (Wang et al., 2021)."""
512
473
 
@@ -516,46 +477,18 @@ class CrossNetMix(nn.Module):
516
477
  self.num_experts = num_experts
517
478
 
518
479
  # U: (input_dim, low_rank)
519
- self.u_list = torch.nn.ParameterList(
520
- [
521
- nn.Parameter(
522
- nn.init.xavier_normal_(
523
- torch.empty(num_experts, input_dim, low_rank)
524
- )
525
- )
526
- for i in range(self.num_layers)
527
- ]
528
- )
480
+ self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
481
+ torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
529
482
  # V: (input_dim, low_rank)
530
- self.v_list = torch.nn.ParameterList(
531
- [
532
- nn.Parameter(
533
- nn.init.xavier_normal_(
534
- torch.empty(num_experts, input_dim, low_rank)
535
- )
536
- )
537
- for i in range(self.num_layers)
538
- ]
539
- )
483
+ self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
484
+ torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
540
485
  # C: (low_rank, low_rank)
541
- self.c_list = torch.nn.ParameterList(
542
- [
543
- nn.Parameter(
544
- nn.init.xavier_normal_(torch.empty(num_experts, low_rank, low_rank))
545
- )
546
- for i in range(self.num_layers)
547
- ]
548
- )
549
- self.gating = nn.ModuleList(
550
- [nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)]
551
- )
552
-
553
- self.bias = torch.nn.ParameterList(
554
- [
555
- nn.Parameter(nn.init.zeros_(torch.empty(input_dim, 1)))
556
- for i in range(self.num_layers)
557
- ]
558
- )
486
+ self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
487
+ torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
488
+ self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
489
+
490
+ self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
491
+ torch.empty(input_dim, 1))) for i in range(self.num_layers)])
559
492
 
560
493
  def forward(self, x):
561
494
  x_0 = x.unsqueeze(2) # (bs, in_features, 1)
@@ -570,9 +503,7 @@ class CrossNetMix(nn.Module):
570
503
 
571
504
  # (2) E(x_l)
572
505
  # project the input x_l to $\mathbb{R}^{r}$
573
- v_x = torch.matmul(
574
- self.v_list[i][expert_id].t(), x_l
575
- ) # (bs, low_rank, 1)
506
+ v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
576
507
 
577
508
  # nonlinear activation in low rank space
578
509
  v_x = torch.tanh(v_x)
@@ -580,9 +511,7 @@ class CrossNetMix(nn.Module):
580
511
  v_x = torch.tanh(v_x)
581
512
 
582
513
  # project back to $\mathbb{R}^{d}$
583
- uv_x = torch.matmul(
584
- self.u_list[i][expert_id], v_x
585
- ) # (bs, in_features, 1)
514
+ uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
586
515
 
587
516
  dot_ = uv_x + self.bias[i]
588
517
  dot_ = x_0 * dot_ # Hadamard-product
@@ -590,78 +519,53 @@ class CrossNetMix(nn.Module):
590
519
  output_of_experts.append(dot_.squeeze(2))
591
520
 
592
521
  # (3) mixture of low-rank experts
593
- output_of_experts = torch.stack(
594
- output_of_experts, 2
595
- ) # (bs, in_features, num_experts)
596
- gating_score_experts = torch.stack(
597
- gating_score_experts, 1
598
- ) # (bs, num_experts, 1)
522
+ output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
523
+ gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
599
524
  moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
600
525
  x_l = moe_out + x_l # (bs, in_features, 1)
601
526
 
602
527
  x_l = x_l.squeeze() # (bs, in_features)
603
528
  return x_l
604
529
 
605
-
606
530
  class SENETLayer(nn.Module):
607
531
  """Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
608
532
 
609
533
  def __init__(self, num_fields, reduction_ratio=3):
610
534
  super(SENETLayer, self).__init__()
611
- reduced_size = max(1, int(num_fields / reduction_ratio))
612
- self.mlp = nn.Sequential(
613
- nn.Linear(num_fields, reduced_size, bias=False),
614
- nn.ReLU(),
615
- nn.Linear(reduced_size, num_fields, bias=False),
616
- nn.ReLU(),
617
- )
618
-
535
+ reduced_size = max(1, int(num_fields/ reduction_ratio))
536
+ self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
537
+ nn.ReLU(),
538
+ nn.Linear(reduced_size, num_fields, bias=False),
539
+ nn.ReLU())
619
540
  def forward(self, x):
620
541
  z = torch.mean(x, dim=-1, out=None)
621
542
  a = self.mlp(z)
622
- v = x * a.unsqueeze(-1)
543
+ v = x*a.unsqueeze(-1)
623
544
  return v
624
545
 
625
-
626
546
  class BiLinearInteractionLayer(nn.Module):
627
547
  """Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
628
548
 
629
- def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
549
+ def __init__(self, input_dim, num_fields, bilinear_type = "field_interaction"):
630
550
  super(BiLinearInteractionLayer, self).__init__()
631
551
  self.bilinear_type = bilinear_type
632
552
  if self.bilinear_type == "field_all":
633
553
  self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
634
554
  elif self.bilinear_type == "field_each":
635
- self.bilinear_layer = nn.ModuleList(
636
- [nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)]
637
- )
555
+ self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
638
556
  elif self.bilinear_type == "field_interaction":
639
- self.bilinear_layer = nn.ModuleList(
640
- [
641
- nn.Linear(input_dim, input_dim, bias=False)
642
- for i, j in combinations(range(num_fields), 2)
643
- ]
644
- )
557
+ self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i,j in combinations(range(num_fields), 2)])
645
558
  else:
646
559
  raise NotImplementedError()
647
560
 
648
561
  def forward(self, x):
649
562
  feature_emb = torch.split(x, 1, dim=1)
650
563
  if self.bilinear_type == "field_all":
651
- bilinear_list = [
652
- self.bilinear_layer(v_i) * v_j
653
- for v_i, v_j in combinations(feature_emb, 2)
654
- ]
564
+ bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
655
565
  elif self.bilinear_type == "field_each":
656
- bilinear_list = [
657
- self.bilinear_layer[i](feature_emb[i]) * feature_emb[j]
658
- for i, j in combinations(range(len(feature_emb)), 2)
659
- ]
566
+ bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)]
660
567
  elif self.bilinear_type == "field_interaction":
661
- bilinear_list = [
662
- self.bilinear_layer[i](v[0]) * v[1]
663
- for i, v in enumerate(combinations(feature_emb, 2))
664
- ]
568
+ bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
665
569
  return torch.cat(bilinear_list, dim=1)
666
570
 
667
571
 
@@ -674,23 +578,17 @@ class MultiInterestSA(nn.Module):
674
578
  self.interest_num = interest_num
675
579
  if hidden_dim == None:
676
580
  self.hidden_dim = self.embedding_dim * 4
677
- self.W1 = torch.nn.Parameter(
678
- torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True
679
- )
680
- self.W2 = torch.nn.Parameter(
681
- torch.rand(self.hidden_dim, self.interest_num), requires_grad=True
682
- )
683
- self.W3 = torch.nn.Parameter(
684
- torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True
685
- )
581
+ self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
582
+ self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
583
+ self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
686
584
 
687
585
  def forward(self, seq_emb, mask=None):
688
- H = torch.einsum("bse, ed -> bsd", seq_emb, self.W1).tanh()
586
+ H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
689
587
  if mask != None:
690
- A = torch.einsum("bsd, dk -> bsk", H, self.W2) + -1.0e9 * (1 - mask.float())
588
+ A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
691
589
  A = F.softmax(A, dim=1)
692
590
  else:
693
- A = F.softmax(torch.einsum("bsd, dk -> bsk", H, self.W2), dim=1)
591
+ A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
694
592
  A = A.permute(0, 2, 1)
695
593
  multi_interest_emb = torch.matmul(A, seq_emb)
696
594
  return multi_interest_emb
@@ -699,15 +597,7 @@ class MultiInterestSA(nn.Module):
699
597
  class CapsuleNetwork(nn.Module):
700
598
  """Dynamic routing capsule network used in MIND (Li et al., 2019)."""
701
599
 
702
- def __init__(
703
- self,
704
- embedding_dim,
705
- seq_len,
706
- bilinear_type=2,
707
- interest_num=4,
708
- routing_times=3,
709
- relu_layer=False,
710
- ):
600
+ def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
711
601
  super(CapsuleNetwork, self).__init__()
712
602
  self.embedding_dim = embedding_dim # h
713
603
  self.seq_len = seq_len # s
@@ -717,24 +607,13 @@ class CapsuleNetwork(nn.Module):
717
607
 
718
608
  self.relu_layer = relu_layer
719
609
  self.stop_grad = True
720
- self.relu = nn.Sequential(
721
- nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU()
722
- )
610
+ self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
723
611
  if self.bilinear_type == 0: # MIND
724
612
  self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
725
613
  elif self.bilinear_type == 1:
726
- self.linear = nn.Linear(
727
- self.embedding_dim, self.embedding_dim * self.interest_num, bias=False
728
- )
614
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
729
615
  else:
730
- self.w = nn.Parameter(
731
- torch.Tensor(
732
- 1,
733
- self.seq_len,
734
- self.interest_num * self.embedding_dim,
735
- self.embedding_dim,
736
- )
737
- )
616
+ self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
738
617
  nn.init.xavier_uniform_(self.w)
739
618
 
740
619
  def forward(self, item_eb, mask):
@@ -745,15 +624,11 @@ class CapsuleNetwork(nn.Module):
745
624
  item_eb_hat = self.linear(item_eb)
746
625
  else:
747
626
  u = torch.unsqueeze(item_eb, dim=2)
748
- item_eb_hat = torch.sum(self.w[:, : self.seq_len, :, :] * u, dim=3)
627
+ item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
749
628
 
750
- item_eb_hat = torch.reshape(
751
- item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim)
752
- )
629
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
753
630
  item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
754
- item_eb_hat = torch.reshape(
755
- item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim)
756
- )
631
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
757
632
 
758
633
  if self.stop_grad:
759
634
  item_eb_hat_iter = item_eb_hat.detach()
@@ -761,47 +636,34 @@ class CapsuleNetwork(nn.Module):
761
636
  item_eb_hat_iter = item_eb_hat
762
637
 
763
638
  if self.bilinear_type > 0:
764
- capsule_weight = torch.zeros(
765
- item_eb_hat.shape[0],
766
- self.interest_num,
767
- self.seq_len,
768
- device=item_eb.device,
769
- requires_grad=False,
770
- )
639
+ capsule_weight = torch.zeros(item_eb_hat.shape[0],
640
+ self.interest_num,
641
+ self.seq_len,
642
+ device=item_eb.device,
643
+ requires_grad=False)
771
644
  else:
772
- capsule_weight = torch.randn(
773
- item_eb_hat.shape[0],
774
- self.interest_num,
775
- self.seq_len,
776
- device=item_eb.device,
777
- requires_grad=False,
778
- )
645
+ capsule_weight = torch.randn(item_eb_hat.shape[0],
646
+ self.interest_num,
647
+ self.seq_len,
648
+ device=item_eb.device,
649
+ requires_grad=False)
779
650
 
780
651
  for i in range(self.routing_times): # 动态路由传播3次
781
652
  atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
782
653
  paddings = torch.zeros_like(atten_mask, dtype=torch.float)
783
654
 
784
655
  capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
785
- capsule_softmax_weight = torch.where(
786
- torch.eq(atten_mask, 0), paddings, capsule_softmax_weight
787
- )
656
+ capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
788
657
  capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
789
658
 
790
659
  if i < 2:
791
- interest_capsule = torch.matmul(
792
- capsule_softmax_weight, item_eb_hat_iter
793
- )
660
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
794
661
  cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
795
662
  scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
796
663
  interest_capsule = scalar_factor * interest_capsule
797
664
 
798
- delta_weight = torch.matmul(
799
- item_eb_hat_iter,
800
- torch.transpose(interest_capsule, 2, 3).contiguous(),
801
- )
802
- delta_weight = torch.reshape(
803
- delta_weight, (-1, self.interest_num, self.seq_len)
804
- )
665
+ delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
666
+ delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
805
667
  capsule_weight = capsule_weight + delta_weight
806
668
  else:
807
669
  interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
@@ -809,9 +671,7 @@ class CapsuleNetwork(nn.Module):
809
671
  scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
810
672
  interest_capsule = scalar_factor * interest_capsule
811
673
 
812
- interest_capsule = torch.reshape(
813
- interest_capsule, (-1, self.interest_num, self.embedding_dim)
814
- )
674
+ interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
815
675
 
816
676
  if self.relu_layer:
817
677
  interest_capsule = self.relu(interest_capsule)
@@ -823,18 +683,18 @@ class FFM(nn.Module):
823
683
  """Field-aware Factorization Machine (Juan et al., 2016)."""
824
684
 
825
685
  def __init__(self, num_fields, reduce_sum=True):
826
- super().__init__()
686
+ super().__init__()
827
687
  self.num_fields = num_fields
828
688
  self.reduce_sum = reduce_sum
829
689
 
830
690
  def forward(self, x):
831
691
  # compute (non-redundant) second order field-aware feature crossings
832
692
  crossed_embeddings = []
833
- for i in range(self.num_fields - 1):
834
- for j in range(i + 1, self.num_fields):
835
- crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
693
+ for i in range(self.num_fields-1):
694
+ for j in range(i+1, self.num_fields):
695
+ crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
836
696
  crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
837
-
697
+
838
698
  # if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
839
699
  if self.reduce_sum:
840
700
  crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
@@ -845,57 +705,49 @@ class CEN(nn.Module):
845
705
  """Field-attentive interaction network from FAT-DeepFFM (Wang et al., 2020)."""
846
706
 
847
707
  def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
848
- super().__init__()
849
-
708
+ super().__init__()
709
+
850
710
  # convolution weight (Eq.7 FAT-DeepFFM)
851
- self.u = torch.nn.Parameter(
852
- torch.rand(num_field_crosses, embed_dim), requires_grad=True
853
- )
711
+ self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
854
712
 
855
713
  # two FC layers that computes the field attention
856
- self.mlp_att = MLP(
857
- num_field_crosses,
858
- dims=[num_field_crosses // reduction_ratio, num_field_crosses],
859
- output_layer=False,
860
- activation="relu",
861
- )
862
-
863
- def forward(self, em):
714
+ self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses//reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
715
+
716
+
717
+ def forward(self, em):
864
718
  # compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
865
719
  d = F.relu((self.u.squeeze(0) * em).sum(-1))
866
-
867
- # compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
868
- s = self.mlp_att(d)
720
+
721
+ # compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
722
+ s = self.mlp_att(d)
869
723
 
870
724
  # rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
871
- aem = s.unsqueeze(-1) * em
725
+ aem = s.unsqueeze(-1) * em
872
726
  return aem.flatten(start_dim=1)
873
727
 
874
728
 
875
729
  class MultiHeadSelfAttention(nn.Module):
876
730
  """Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
877
-
731
+
878
732
  def __init__(self, embedding_dim, num_heads=2, dropout=0.0, use_residual=True):
879
733
  super().__init__()
880
734
  if embedding_dim % num_heads != 0:
881
- raise ValueError(
882
- f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
883
- )
884
-
735
+ raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
736
+
885
737
  self.embedding_dim = embedding_dim
886
738
  self.num_heads = num_heads
887
739
  self.head_dim = embedding_dim // num_heads
888
740
  self.use_residual = use_residual
889
-
741
+
890
742
  self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
891
743
  self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
892
744
  self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
893
-
745
+
894
746
  if self.use_residual:
895
747
  self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
896
-
748
+
897
749
  self.dropout = nn.Dropout(dropout)
898
-
750
+
899
751
  def forward(self, x):
900
752
  """
901
753
  Args:
@@ -904,47 +756,37 @@ class MultiHeadSelfAttention(nn.Module):
904
756
  output: [batch_size, num_fields, embedding_dim]
905
757
  """
906
758
  batch_size, num_fields, _ = x.shape
907
-
759
+
908
760
  # Linear projections
909
761
  Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
910
762
  K = self.W_K(x)
911
763
  V = self.W_V(x)
912
-
764
+
913
765
  # Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
914
- Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
915
- 1, 2
916
- )
917
- K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
918
- 1, 2
919
- )
920
- V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
921
- 1, 2
922
- )
923
-
766
+ Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
767
+ K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
768
+ V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
769
+
924
770
  # Attention scores
925
- scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
771
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
926
772
  attention_weights = F.softmax(scores, dim=-1)
927
773
  attention_weights = self.dropout(attention_weights)
928
-
774
+
929
775
  # Apply attention to values
930
- attention_output = torch.matmul(
931
- attention_weights, V
932
- ) # [batch_size, num_heads, num_fields, head_dim]
933
-
776
+ attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
777
+
934
778
  # Concatenate heads
935
779
  attention_output = attention_output.transpose(1, 2).contiguous()
936
- attention_output = attention_output.view(
937
- batch_size, num_fields, self.embedding_dim
938
- )
939
-
780
+ attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
781
+
940
782
  # Residual connection
941
783
  if self.use_residual:
942
784
  output = attention_output + self.W_Res(x)
943
785
  else:
944
786
  output = attention_output
945
-
787
+
946
788
  output = F.relu(output)
947
-
789
+
948
790
  return output
949
791
 
950
792
 
@@ -953,31 +795,25 @@ class AttentionPoolingLayer(nn.Module):
953
795
  Attention pooling layer for DIN/DIEN
954
796
  Computes attention weights between query (candidate item) and keys (user behavior sequence)
955
797
  """
956
-
957
- def __init__(
958
- self,
959
- embedding_dim,
960
- hidden_units=[80, 40],
961
- activation="sigmoid",
962
- use_softmax=True,
963
- ):
798
+
799
+ def __init__(self, embedding_dim, hidden_units=[80, 40], activation='sigmoid', use_softmax=True):
964
800
  super().__init__()
965
801
  self.embedding_dim = embedding_dim
966
802
  self.use_softmax = use_softmax
967
-
803
+
968
804
  # Build attention network
969
805
  # Input: [query, key, query-key, query*key] -> 4 * embedding_dim
970
806
  input_dim = 4 * embedding_dim
971
807
  layers = []
972
-
808
+
973
809
  for hidden_unit in hidden_units:
974
810
  layers.append(nn.Linear(input_dim, hidden_unit))
975
811
  layers.append(activation_layer(activation))
976
812
  input_dim = hidden_unit
977
-
813
+
978
814
  layers.append(nn.Linear(input_dim, 1))
979
815
  self.attention_net = nn.Sequential(*layers)
980
-
816
+
981
817
  def forward(self, query, keys, keys_length=None, mask=None):
982
818
  """
983
819
  Args:
@@ -989,52 +825,48 @@ class AttentionPoolingLayer(nn.Module):
989
825
  output: [batch_size, embedding_dim] - attention pooled representation
990
826
  """
991
827
  batch_size, seq_len, emb_dim = keys.shape
992
-
828
+
993
829
  # Expand query to match sequence length: [batch_size, seq_len, embedding_dim]
994
830
  query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
995
-
831
+
996
832
  # Compute attention features: [query, key, query-key, query*key]
997
- attention_input = torch.cat(
998
- [query_expanded, keys, query_expanded - keys, query_expanded * keys], dim=-1
999
- ) # [batch_size, seq_len, 4*embedding_dim]
1000
-
833
+ attention_input = torch.cat([
834
+ query_expanded,
835
+ keys,
836
+ query_expanded - keys,
837
+ query_expanded * keys
838
+ ], dim=-1) # [batch_size, seq_len, 4*embedding_dim]
839
+
1001
840
  # Compute attention scores
1002
- attention_scores = self.attention_net(
1003
- attention_input
1004
- ) # [batch_size, seq_len, 1]
1005
-
841
+ attention_scores = self.attention_net(attention_input) # [batch_size, seq_len, 1]
842
+
1006
843
  # Apply mask if provided
1007
844
  if mask is not None:
1008
845
  attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
1009
-
846
+
1010
847
  # Apply softmax to get attention weights
1011
848
  if self.use_softmax:
1012
- attention_weights = F.softmax(
1013
- attention_scores, dim=1
1014
- ) # [batch_size, seq_len, 1]
849
+ attention_weights = F.softmax(attention_scores, dim=1) # [batch_size, seq_len, 1]
1015
850
  else:
1016
851
  attention_weights = attention_scores
1017
-
852
+
1018
853
  # Weighted sum of keys
1019
- output = torch.sum(
1020
- attention_weights * keys, dim=1
1021
- ) # [batch_size, embedding_dim]
1022
-
854
+ output = torch.sum(attention_weights * keys, dim=1) # [batch_size, embedding_dim]
855
+
1023
856
  return output
1024
857
 
1025
858
 
1026
859
  class DynamicGRU(nn.Module):
1027
860
  """Dynamic GRU unit with auxiliary loss path from DIEN (Zhou et al., 2019)."""
1028
-
1029
861
  """
1030
862
  GRU with dynamic routing for DIEN
1031
863
  """
1032
-
864
+
1033
865
  def __init__(self, input_size, hidden_size, bias=True):
1034
866
  super().__init__()
1035
867
  self.input_size = input_size
1036
868
  self.hidden_size = hidden_size
1037
-
869
+
1038
870
  # GRU parameters
1039
871
  self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
1040
872
  self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
@@ -1042,16 +874,16 @@ class DynamicGRU(nn.Module):
1042
874
  self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
1043
875
  self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
1044
876
  else:
1045
- self.register_parameter("bias_ih", None)
1046
- self.register_parameter("bias_hh", None)
1047
-
877
+ self.register_parameter('bias_ih', None)
878
+ self.register_parameter('bias_hh', None)
879
+
1048
880
  self.reset_parameters()
1049
-
881
+
1050
882
  def reset_parameters(self):
1051
883
  std = 1.0 / (self.hidden_size) ** 0.5
1052
884
  for weight in self.parameters():
1053
885
  weight.data.uniform_(-std, std)
1054
-
886
+
1055
887
  def forward(self, x, att_scores=None):
1056
888
  """
1057
889
  Args:
@@ -1062,61 +894,60 @@ class DynamicGRU(nn.Module):
1062
894
  hidden: [batch_size, hidden_size] - final hidden state
1063
895
  """
1064
896
  batch_size, seq_len, _ = x.shape
1065
-
897
+
1066
898
  # Initialize hidden state
1067
899
  h = torch.zeros(batch_size, self.hidden_size, device=x.device)
1068
-
900
+
1069
901
  outputs = []
1070
902
  for t in range(seq_len):
1071
903
  x_t = x[:, t, :] # [batch_size, input_size]
1072
-
904
+
1073
905
  # GRU computation
1074
906
  gi = F.linear(x_t, self.weight_ih, self.bias_ih)
1075
907
  gh = F.linear(h, self.weight_hh, self.bias_hh)
1076
908
  i_r, i_i, i_n = gi.chunk(3, 1)
1077
909
  h_r, h_i, h_n = gh.chunk(3, 1)
1078
-
910
+
1079
911
  resetgate = torch.sigmoid(i_r + h_r)
1080
912
  inputgate = torch.sigmoid(i_i + h_i)
1081
913
  newgate = torch.tanh(i_n + resetgate * h_n)
1082
914
  h = newgate + inputgate * (h - newgate)
1083
-
915
+
1084
916
  outputs.append(h.unsqueeze(1))
1085
-
917
+
1086
918
  output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
1087
-
919
+
1088
920
  return output, h
1089
921
 
1090
922
 
1091
923
  class AUGRU(nn.Module):
1092
924
  """Attention-aware GRU update gate used in DIEN (Zhou et al., 2019)."""
1093
-
1094
925
  """
1095
926
  Attention-based GRU for DIEN
1096
927
  Uses attention scores to weight the update of hidden states
1097
928
  """
1098
-
929
+
1099
930
  def __init__(self, input_size, hidden_size, bias=True):
1100
931
  super().__init__()
1101
932
  self.input_size = input_size
1102
933
  self.hidden_size = hidden_size
1103
-
934
+
1104
935
  self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
1105
936
  self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
1106
937
  if bias:
1107
938
  self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
1108
939
  self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
1109
940
  else:
1110
- self.register_parameter("bias_ih", None)
1111
- self.register_parameter("bias_hh", None)
1112
-
941
+ self.register_parameter('bias_ih', None)
942
+ self.register_parameter('bias_hh', None)
943
+
1113
944
  self.reset_parameters()
1114
-
945
+
1115
946
  def reset_parameters(self):
1116
947
  std = 1.0 / (self.hidden_size) ** 0.5
1117
948
  for weight in self.parameters():
1118
949
  weight.data.uniform_(-std, std)
1119
-
950
+
1120
951
  def forward(self, x, att_scores):
1121
952
  """
1122
953
  Args:
@@ -1127,28 +958,28 @@ class AUGRU(nn.Module):
1127
958
  hidden: [batch_size, hidden_size] - final hidden state
1128
959
  """
1129
960
  batch_size, seq_len, _ = x.shape
1130
-
961
+
1131
962
  h = torch.zeros(batch_size, self.hidden_size, device=x.device)
1132
-
963
+
1133
964
  outputs = []
1134
965
  for t in range(seq_len):
1135
966
  x_t = x[:, t, :] # [batch_size, input_size]
1136
967
  att_t = att_scores[:, t, :] # [batch_size, 1]
1137
-
968
+
1138
969
  gi = F.linear(x_t, self.weight_ih, self.bias_ih)
1139
970
  gh = F.linear(h, self.weight_hh, self.bias_hh)
1140
971
  i_r, i_i, i_n = gi.chunk(3, 1)
1141
972
  h_r, h_i, h_n = gh.chunk(3, 1)
1142
-
973
+
1143
974
  resetgate = torch.sigmoid(i_r + h_r)
1144
975
  inputgate = torch.sigmoid(i_i + h_i)
1145
976
  newgate = torch.tanh(i_n + resetgate * h_n)
1146
-
977
+
1147
978
  # Use attention score to control update
1148
979
  h = (1 - att_t) * h + att_t * newgate
1149
-
980
+
1150
981
  outputs.append(h.unsqueeze(1))
1151
-
982
+
1152
983
  output = torch.cat(outputs, dim=1)
1153
-
1154
- return output, h
984
+
985
+ return output, h