torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,79 +1,77 @@
1
- """
2
- Date: create on 24/05/2022
3
- References:
4
- paper: (KDD'2020) Embedding-based Retrieval in Facebook Search
5
- url: https://arxiv.org/abs/2006.11632
6
- Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
7
- """
8
-
9
- import torch
10
- import torch.nn.functional as F
11
- from ...basic.layers import MLP, EmbeddingLayer
12
-
13
-
14
- class FaceBookDSSM(torch.nn.Module):
15
- """Embedding-based Retrieval in Facebook Search
16
- It's a DSSM match model trained by hinge loss on pair-wise samples.
17
-
18
- Args:
19
- user_features (list[Feature Class]): training by the user tower module.
20
- pos_item_features (list[Feature Class]): negative sample features, training by the item tower module.
21
- neg_item_features (list[Feature Class]): positive sample features, training by the item tower module.
22
- temperature (float): temperature factor for similarity score, default to 1.0.
23
- user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
24
- item_params (dict): the params of the Item Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
25
- """
26
-
27
- def __init__(self,
28
- user_features,
29
- pos_item_features,
30
- neg_item_features,
31
- user_params,
32
- item_params,
33
- temperature=1.0):
34
- super().__init__()
35
- self.user_features = user_features
36
- self.pos_item_features = pos_item_features
37
- self.neg_item_features = neg_item_features
38
- self.temperature = temperature
39
- self.user_dims = sum([fea.embed_dim for fea in user_features])
40
- self.item_dims = sum([fea.embed_dim for fea in pos_item_features])
41
-
42
- self.embedding = EmbeddingLayer(user_features + pos_item_features + neg_item_features)
43
- self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
44
- self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
45
- self.mode = None
46
-
47
- def forward(self, x):
48
- user_embedding = self.user_tower(x)
49
- pos_item_embedding, neg_item_embedding = self.item_tower(x)
50
- if self.mode == "user":
51
- return user_embedding
52
- if self.mode == "item":
53
- return pos_item_embedding
54
-
55
- # calculate cosine score
56
- pos_score = torch.mul(user_embedding, pos_item_embedding).sum(dim=1)
57
- neg_score = torch.mul(user_embedding, neg_item_embedding).sum(dim=1)
58
-
59
- return pos_score, neg_score
60
-
61
- def user_tower(self, x):
62
- if self.mode == "item":
63
- return None
64
- input_user = self.embedding(x, self.user_features, squeeze_dim=True) #[batch_size, num_features*deep_dims]
65
- user_embedding = self.user_mlp(input_user) #[batch_size, user_params["dims"][-1]]
66
- user_embedding = F.normalize(user_embedding, p=2, dim=1)
67
- return user_embedding
68
-
69
- def item_tower(self, x):
70
- if self.mode == "user":
71
- return None, None
72
- input_item_pos = self.embedding(x, self.pos_item_features, squeeze_dim=True)
73
- if self.mode == "item": #inference embedding mode, the zeros is just for placefolder
74
- return self.item_mlp(input_item_pos), None
75
- input_item_neg = self.embedding(x, self.neg_item_features, squeeze_dim=True)
76
- pos_embedding, neg_embedding = self.item_mlp(input_item_pos), self.item_mlp(input_item_neg)
77
- pos_embedding = F.normalize(pos_embedding, p=2, dim=1)
78
- neg_embedding = F.normalize(neg_embedding, p=2, dim=1)
79
- return pos_embedding, neg_embedding
1
+ """
2
+ Date: create on 24/05/2022
3
+ References:
4
+ paper: (KDD'2020) Embedding-based Retrieval in Facebook Search
5
+ url: https://arxiv.org/abs/2006.11632
6
+ Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
7
+ """
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from ...basic.layers import MLP, EmbeddingLayer
13
+
14
+
15
+ class FaceBookDSSM(torch.nn.Module):
16
+ """Embedding-based Retrieval in Facebook Search
17
+ It's a DSSM match model trained by hinge loss on pair-wise samples.
18
+
19
+ Args:
20
+ user_features (list[Feature Class]): training by the user tower module.
21
+ pos_item_features (list[Feature Class]): negative sample features, training by the item tower module.
22
+ neg_item_features (list[Feature Class]): positive sample features, training by the item tower module.
23
+ temperature (float): temperature factor for similarity score, default to 1.0.
24
+ user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
25
+ item_params (dict): the params of the Item Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
26
+ """
27
+
28
+ def __init__(self, user_features, pos_item_features, neg_item_features, user_params, item_params, temperature=1.0):
29
+ super().__init__()
30
+ self.user_features = user_features
31
+ self.pos_item_features = pos_item_features
32
+ self.neg_item_features = neg_item_features
33
+ self.temperature = temperature
34
+ self.user_dims = sum([fea.embed_dim for fea in user_features])
35
+ self.item_dims = sum([fea.embed_dim for fea in pos_item_features])
36
+
37
+ self.embedding = EmbeddingLayer(user_features + pos_item_features + neg_item_features)
38
+ self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
39
+ self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
40
+ self.mode = None
41
+
42
+ def forward(self, x):
43
+ user_embedding = self.user_tower(x)
44
+ pos_item_embedding, neg_item_embedding = self.item_tower(x)
45
+ if self.mode == "user":
46
+ return user_embedding
47
+ if self.mode == "item":
48
+ return pos_item_embedding
49
+
50
+
51
+ # calculate cosine score
52
+ pos_score = torch.mul(user_embedding, pos_item_embedding).sum(dim=1)
53
+ neg_score = torch.mul(user_embedding, neg_item_embedding).sum(dim=1)
54
+
55
+ return pos_score, neg_score
56
+
57
+ def user_tower(self, x):
58
+ if self.mode == "item":
59
+ return None
60
+ # [batch_size, num_features*deep_dims]
61
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True)
62
+ # [batch_size, user_params["dims"][-1]]
63
+ user_embedding = self.user_mlp(input_user)
64
+ user_embedding = F.normalize(user_embedding, p=2, dim=1)
65
+ return user_embedding
66
+
67
+ def item_tower(self, x):
68
+ if self.mode == "user":
69
+ return None, None
70
+ input_item_pos = self.embedding(x, self.pos_item_features, squeeze_dim=True)
71
+ if self.mode == "item": # inference embedding mode, the zeros is just for placefolder
72
+ return self.item_mlp(input_item_pos), None
73
+ input_item_neg = self.embedding(x, self.neg_item_features, squeeze_dim=True)
74
+ pos_embedding, neg_embedding = self.item_mlp(input_item_pos), self.item_mlp(input_item_neg)
75
+ pos_embedding = F.normalize(pos_embedding, p=2, dim=1)
76
+ neg_embedding = F.normalize(neg_embedding, p=2, dim=1)
77
+ return pos_embedding, neg_embedding
@@ -1,14 +1,15 @@
1
1
  """
2
2
  Date: create on 12/19/2024
3
- References:
3
+ References:
4
4
  url: https://zhuanlan.zhihu.com/p/358779957
5
5
  Authors: @1985312383
6
6
  """
7
7
 
8
8
  import torch
9
9
  import torch.nn.functional as F
10
+
11
+ from ...basic.features import SequenceFeature, SparseFeature
10
12
  from ...basic.layers import MLP, EmbeddingLayer, SENETLayer
11
- from ...basic.features import SparseFeature, SequenceFeature
12
13
 
13
14
 
14
15
  class DSSM(torch.nn.Module):
@@ -33,8 +34,8 @@ class DSSM(torch.nn.Module):
33
34
  self.embedding = EmbeddingLayer(user_features + item_features)
34
35
  self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
35
36
  self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
36
- self.user_num_features = len([fea.embed_dim for fea in self.user_features if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature) and fea.shared_with == None])
37
- self.item_num_features = len([fea.embed_dim for fea in self.item_features if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature) and fea.shared_with == None])
37
+ self.user_num_features = len([fea.embed_dim for fea in self.user_features if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature) and fea.shared_with is None])
38
+ self.item_num_features = len([fea.embed_dim for fea in self.item_features if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature) and fea.shared_with is None])
38
39
  self.user_senet = SENETLayer(self.user_num_features)
39
40
  self.item_senet = SENETLayer(self.item_num_features)
40
41
  self.mode = None
@@ -47,7 +48,8 @@ class DSSM(torch.nn.Module):
47
48
  if self.mode == "item":
48
49
  return item_embedding
49
50
 
50
- # calculate cosine score
51
+
52
+ # calculate cosine score
51
53
  y = torch.mul(user_embedding, item_embedding).sum(dim=1)
52
54
  y = y / self.temperature
53
55
  return torch.sigmoid(y)
@@ -55,21 +57,31 @@ class DSSM(torch.nn.Module):
55
57
  def user_tower(self, x):
56
58
  if self.mode == "item":
57
59
  return None
58
- input_user = self.embedding(x, self.user_features, squeeze_dim=True) #[batch_size, num_features * embed_dim]
59
- input_user = input_user.view(input_user.size(0), self.user_num_features, -1) #[batch_size, num_features, embed_dim]
60
- input_user = self.user_senet(input_user) #[batch_size, num_features, embed_dim]
61
- input_user = input_user.view(input_user.size(0), -1) #[batch_size, num_features * embed_dim]
62
- user_embedding = self.user_mlp(input_user) #[batch_size, user_params["dims"][-1]]
60
+ # [batch_size, num_features * embed_dim]
61
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True)
62
+ # [batch_size, num_features, embed_dim]
63
+ input_user = input_user.view(input_user.size(0), self.user_num_features, -1)
64
+ # [batch_size, num_features, embed_dim]
65
+ input_user = self.user_senet(input_user)
66
+ # [batch_size, num_features * embed_dim]
67
+ input_user = input_user.view(input_user.size(0), -1)
68
+ # [batch_size, user_params["dims"][-1]]
69
+ user_embedding = self.user_mlp(input_user)
63
70
  user_embedding = F.normalize(user_embedding, p=2, dim=1) # L2 normalize
64
71
  return user_embedding
65
72
 
66
73
  def item_tower(self, x):
67
74
  if self.mode == "user":
68
75
  return None
69
- input_item = self.embedding(x, self.item_features, squeeze_dim=True) #[batch_size, num_features * embed_dim]
70
- input_item = input_item.view(input_item.size(0), self.item_num_features, -1) #[batch_size, num_features, embed_dim]
71
- input_item = self.item_senet(input_item) #[batch_size, num_features, embed_dim]
72
- input_item = input_item.view(input_item.size(0), -1) #[batch_size, num_features * embed_dim]
73
- item_embedding = self.item_mlp(input_item) #[batch_size, item_params["dims"][-1]]
76
+ # [batch_size, num_features * embed_dim]
77
+ input_item = self.embedding(x, self.item_features, squeeze_dim=True)
78
+ # [batch_size, num_features, embed_dim]
79
+ input_item = input_item.view(input_item.size(0), self.item_num_features, -1)
80
+ # [batch_size, num_features, embed_dim]
81
+ input_item = self.item_senet(input_item)
82
+ # [batch_size, num_features * embed_dim]
83
+ input_item = input_item.view(input_item.size(0), -1)
84
+ # [batch_size, item_params["dims"][-1]]
85
+ item_embedding = self.item_mlp(input_item)
74
86
  item_embedding = F.normalize(item_embedding, p=2, dim=1)
75
- return item_embedding
87
+ return item_embedding
@@ -1,87 +1,85 @@
1
- """
2
- Date: create on 03/06/2022
3
- References:
4
- paper: SESSION-BASED RECOMMENDATIONS WITH RECURRENT NEURAL NETWORKS
5
- url: http://arxiv.org/abs/1511.06939
6
- Authors: Kai Wang, 306178200@qq.com
7
- """
8
-
9
- import torch
10
-
11
- from ...basic.layers import MLP, EmbeddingLayer
12
- from torch import nn
13
- import torch.nn.functional as F
14
-
15
-
16
- class GRU4Rec(torch.nn.Module):
17
- """The match model mentioned in `Deep Neural Networks for YouTube Recommendations` paper.
18
- It's a DSSM match model trained by global softmax loss on list-wise samples.
19
- Note in origin paper, it's without item dnn tower and train item embedding directly.
20
-
21
- Args:
22
- user_features (list[Feature Class]): training by the user tower module.
23
- history_features (list[Feature Class]): training history
24
- item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
25
- neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
26
- user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
27
- temperature (float): temperature factor for similarity score, default to 1.0.
28
- """
29
-
30
- def __init__(self, user_features, history_features, item_features, neg_item_feature, user_params, temperature=1.0):
31
- super().__init__()
32
- self.user_features = user_features
33
- self.item_features = item_features
34
- self.history_features = history_features
35
- self.neg_item_feature = neg_item_feature
36
- self.temperature = temperature
37
- self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
38
-
39
- self.embedding = EmbeddingLayer(user_features + item_features + history_features)
40
- self.gru = nn.GRU(input_size = history_features[0].embed_dim,
41
- hidden_size = history_features[0].embed_dim,
42
- num_layers = user_params.get('num_layers',2),
43
- batch_first = True,
44
- bias = False)
45
- self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
46
- self.mode = None
47
-
48
- def forward(self, x):
49
- user_embedding = self.user_tower(x)
50
- item_embedding = self.item_tower(x)
51
- if self.mode == "user":
52
- return user_embedding
53
- if self.mode == "item":
54
- return item_embedding
55
-
56
- y = torch.mul(user_embedding, item_embedding).sum(dim=1)
57
-
58
- return y
59
-
60
- def user_tower(self, x):
61
- if self.mode == "item":
62
- return None
63
- input_user = self.embedding(x, self.user_features, squeeze_dim=True) #[batch_size, num_features*deep_dims]
64
-
65
- history_emb = self.embedding(x, self.history_features).squeeze(1)
66
- _, history_emb = self.gru(history_emb)
67
- history_emb = history_emb[-1]
68
-
69
- input_user = torch.cat([input_user,history_emb],dim=-1)
70
-
71
- user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, 1, embed_dim]
72
- user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
73
- if self.mode == "user":
74
- return user_embedding.squeeze(1) #inference embedding mode -> [batch_size, embed_dim]
75
- return user_embedding
76
-
77
- def item_tower(self, x):
78
- if self.mode == "user":
79
- return None
80
- pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
81
- pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
82
- if self.mode == "item": #inference embedding mode
83
- return pos_embedding.squeeze(1) #[batch_size, embed_dim]
84
- neg_embeddings = self.embedding(x, self.neg_item_feature,
85
- squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
86
- neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
87
- return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
1
+ """
2
+ Date: create on 03/06/2022
3
+ References:
4
+ paper: SESSION-BASED RECOMMENDATIONS WITH RECURRENT NEURAL NETWORKS
5
+ url: http://arxiv.org/abs/1511.06939
6
+ Authors: Kai Wang, 306178200@qq.com
7
+ """
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from ...basic.layers import MLP, EmbeddingLayer
14
+
15
+
16
+ class GRU4Rec(torch.nn.Module):
17
+ """The match model mentioned in `Deep Neural Networks for YouTube Recommendations` paper.
18
+ It's a DSSM match model trained by global softmax loss on list-wise samples.
19
+ Note in origin paper, it's without item dnn tower and train item embedding directly.
20
+
21
+ Args:
22
+ user_features (list[Feature Class]): training by the user tower module.
23
+ history_features (list[Feature Class]): training history
24
+ item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
25
+ neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
26
+ user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
27
+ temperature (float): temperature factor for similarity score, default to 1.0.
28
+ """
29
+
30
+ def __init__(self, user_features, history_features, item_features, neg_item_feature, user_params, temperature=1.0):
31
+ super().__init__()
32
+ self.user_features = user_features
33
+ self.item_features = item_features
34
+ self.history_features = history_features
35
+ self.neg_item_feature = neg_item_feature
36
+ self.temperature = temperature
37
+ self.user_dims = sum([fea.embed_dim for fea in user_features + history_features])
38
+
39
+ self.embedding = EmbeddingLayer(user_features + item_features + history_features)
40
+ self.gru = nn.GRU(input_size=history_features[0].embed_dim, hidden_size=history_features[0].embed_dim, num_layers=user_params.get('num_layers', 2), batch_first=True, bias=False)
41
+ self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
42
+ self.mode = None
43
+
44
+ def forward(self, x):
45
+ user_embedding = self.user_tower(x)
46
+ item_embedding = self.item_tower(x)
47
+ if self.mode == "user":
48
+ return user_embedding
49
+ if self.mode == "item":
50
+ return item_embedding
51
+
52
+ y = torch.mul(user_embedding, item_embedding).sum(dim=1)
53
+
54
+ return y
55
+
56
+ def user_tower(self, x):
57
+ if self.mode == "item":
58
+ return None
59
+ # [batch_size, num_features*deep_dims]
60
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True)
61
+
62
+ history_emb = self.embedding(x, self.history_features).squeeze(1)
63
+ _, history_emb = self.gru(history_emb)
64
+ history_emb = history_emb[-1]
65
+
66
+ input_user = torch.cat([input_user, history_emb], dim=-1)
67
+
68
+ user_embedding = self.user_mlp(input_user).unsqueeze(1) # [batch_size, 1, embed_dim]
69
+ user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
70
+ if self.mode == "user":
71
+ # inference embedding mode -> [batch_size, embed_dim]
72
+ return user_embedding.squeeze(1)
73
+ return user_embedding
74
+
75
+ def item_tower(self, x):
76
+ if self.mode == "user":
77
+ return None
78
+ pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) # [batch_size, 1, embed_dim]
79
+ pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
80
+ if self.mode == "item": # inference embedding mode
81
+ return pos_embedding.squeeze(1) # [batch_size, embed_dim]
82
+ neg_embeddings = self.embedding(x, self.neg_item_feature, squeeze_dim=False).squeeze(1) # [batch_size, n_neg_items, embed_dim]
83
+ neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
84
+ # [batch_size, 1+n_neg_items, embed_dim]
85
+ return torch.cat((pos_embedding, neg_embeddings), dim=1)
@@ -1,101 +1,103 @@
1
- """
2
- Date: create on 08/06/2022
3
- References:
4
- paper: Multi-Interest Network with Dynamic Routing
5
- url: https://arxiv.org/pdf/1904.08030v1
6
- code: https://github.com/ShiningCosmos/pytorch_ComiRec/blob/main/MIND.py
7
- Authors: Kai Wang, 306178200@qq.com
8
- """
9
-
10
- import torch
11
-
12
- from ...basic.layers import MLP, EmbeddingLayer, MultiInterestSA, CapsuleNetwork
13
- from torch import nn
14
- import torch.nn.functional as F
15
-
16
-
17
- class MIND(torch.nn.Module):
18
- """The match model mentioned in `Multi-Interest Network with Dynamic Routing` paper.
19
- It's a ComirecDR match model trained by global softmax loss on list-wise samples.
20
- Note in origin paper, it's without item dnn tower and train item embedding directly.
21
-
22
- Args:
23
- user_features (list[Feature Class]): training by the user tower module.
24
- history_features (list[Feature Class]): training history
25
- item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
26
- neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
27
- max_length (int): max sequence length of input item sequence
28
- temperature (float): temperature factor for similarity score, default to 1.0.
29
- interest_num (int): interest num
30
- """
31
-
32
- def __init__(self, user_features, history_features, item_features, neg_item_feature, max_length, temperature=1.0, interest_num=4):
33
- super().__init__()
34
- self.user_features = user_features
35
- self.item_features = item_features
36
- self.history_features = history_features
37
- self.neg_item_feature = neg_item_feature
38
- self.temperature = temperature
39
- self.interest_num = interest_num
40
- self.max_length = max_length
41
- self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
42
-
43
- self.embedding = EmbeddingLayer(user_features + item_features + history_features)
44
- self.capsule = CapsuleNetwork(self.history_features[0].embed_dim,self.max_length,bilinear_type=0,interest_num=self.interest_num)
45
- self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
46
- self.mode = None
47
-
48
- def forward(self, x):
49
- user_embedding = self.user_tower(x)
50
- item_embedding = self.item_tower(x)
51
- if self.mode == "user":
52
- return user_embedding
53
- if self.mode == "item":
54
- return item_embedding
55
-
56
- pos_item_embedding = item_embedding[:,0,:]
57
- dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
58
- k_index = torch.argmax(dot_res, dim=1)
59
- best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
60
- for k in range(user_embedding.shape[0]):
61
- best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
62
- best_interest_emb = best_interest_emb.unsqueeze(1)
63
-
64
- y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
65
- return y
66
-
67
- def user_tower(self, x):
68
- if self.mode == "item":
69
- return None
70
- input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) #[batch_size, num_features*deep_dims]
71
- input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
72
-
73
- history_emb = self.embedding(x, self.history_features).squeeze(1)
74
- mask = self.gen_mask(x)
75
- multi_interest_emb = self.capsule(history_emb,mask)
76
-
77
- input_user = torch.cat([input_user,multi_interest_emb],dim=-1)
78
-
79
- # user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, interest_num, embed_dim]
80
- user_embedding = torch.matmul(input_user,self.convert_user_weight)
81
- user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
82
- if self.mode == "user":
83
- return user_embedding #inference embedding mode -> [batch_size, interest_num, embed_dim]
84
- return user_embedding
85
-
86
- def item_tower(self, x):
87
- if self.mode == "user":
88
- return None
89
- pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
90
- pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
91
- if self.mode == "item": #inference embedding mode
92
- return pos_embedding.squeeze(1) #[batch_size, embed_dim]
93
- neg_embeddings = self.embedding(x, self.neg_item_feature,
94
- squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
95
- neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
96
- return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
97
-
98
- def gen_mask(self, x):
99
- his_list = x[self.history_features[0].name]
100
- mask = (his_list > 0).long()
101
- return mask
1
+ """
2
+ Date: create on 08/06/2022
3
+ References:
4
+ paper: Multi-Interest Network with Dynamic Routing
5
+ url: https://arxiv.org/pdf/1904.08030v1
6
+ code: https://github.com/ShiningCosmos/pytorch_ComiRec/blob/main/MIND.py
7
+ Authors: Kai Wang, 306178200@qq.com
8
+ """
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from ...basic.layers import MLP, CapsuleNetwork, EmbeddingLayer, MultiInterestSA
15
+
16
+
17
+ class MIND(torch.nn.Module):
18
+ """The match model mentioned in `Multi-Interest Network with Dynamic Routing` paper.
19
+ It's a ComirecDR match model trained by global softmax loss on list-wise samples.
20
+ Note in origin paper, it's without item dnn tower and train item embedding directly.
21
+
22
+ Args:
23
+ user_features (list[Feature Class]): training by the user tower module.
24
+ history_features (list[Feature Class]): training history
25
+ item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
26
+ neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
27
+ max_length (int): max sequence length of input item sequence
28
+ temperature (float): temperature factor for similarity score, default to 1.0.
29
+ interest_num (int): interest num
30
+ """
31
+
32
+ def __init__(self, user_features, history_features, item_features, neg_item_feature, max_length, temperature=1.0, interest_num=4):
33
+ super().__init__()
34
+ self.user_features = user_features
35
+ self.item_features = item_features
36
+ self.history_features = history_features
37
+ self.neg_item_feature = neg_item_feature
38
+ self.temperature = temperature
39
+ self.interest_num = interest_num
40
+ self.max_length = max_length
41
+ self.user_dims = sum([fea.embed_dim for fea in user_features + history_features])
42
+
43
+ self.embedding = EmbeddingLayer(user_features + item_features + history_features)
44
+ self.capsule = CapsuleNetwork(self.history_features[0].embed_dim, self.max_length, bilinear_type=0, interest_num=self.interest_num)
45
+ self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
46
+ self.mode = None
47
+
48
+ def forward(self, x):
49
+ user_embedding = self.user_tower(x)
50
+ item_embedding = self.item_tower(x)
51
+ if self.mode == "user":
52
+ return user_embedding
53
+ if self.mode == "item":
54
+ return item_embedding
55
+
56
+ pos_item_embedding = item_embedding[:, 0, :]
57
+ dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
58
+ k_index = torch.argmax(dot_res, dim=1)
59
+ best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
60
+ for k in range(user_embedding.shape[0]):
61
+ best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
62
+ best_interest_emb = best_interest_emb.unsqueeze(1)
63
+
64
+ y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
65
+ return y
66
+
67
+ def user_tower(self, x):
68
+ if self.mode == "item":
69
+ return None
70
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) # [batch_size, num_features*deep_dims]
71
+ input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
72
+
73
+ history_emb = self.embedding(x, self.history_features).squeeze(1)
74
+ mask = self.gen_mask(x)
75
+ multi_interest_emb = self.capsule(history_emb, mask)
76
+
77
+ input_user = torch.cat([input_user, multi_interest_emb], dim=-1)
78
+
79
+ # user_embedding = self.user_mlp(input_user).unsqueeze(1)
80
+ # #[batch_size, interest_num, embed_dim]
81
+ user_embedding = torch.matmul(input_user, self.convert_user_weight)
82
+ user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
83
+ if self.mode == "user":
84
+ # inference embedding mode -> [batch_size, interest_num, embed_dim]
85
+ return user_embedding
86
+ return user_embedding
87
+
88
+ def item_tower(self, x):
89
+ if self.mode == "user":
90
+ return None
91
+ pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) # [batch_size, 1, embed_dim]
92
+ pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
93
+ if self.mode == "item": # inference embedding mode
94
+ return pos_embedding.squeeze(1) # [batch_size, embed_dim]
95
+ neg_embeddings = self.embedding(x, self.neg_item_feature, squeeze_dim=False).squeeze(1) # [batch_size, n_neg_items, embed_dim]
96
+ neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
97
+ # [batch_size, 1+n_neg_items, embed_dim]
98
+ return torch.cat((pos_embedding, neg_embeddings), dim=1)
99
+
100
+ def gen_mask(self, x):
101
+ his_list = x[self.history_features[0].name]
102
+ mask = (his_list > 0).long()
103
+ return mask