torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,188 +1,193 @@
1
- """
2
- Date: create on 07/06/2022
3
- References:
4
- paper: Controllable Multi-Interest Framework for Recommendation
5
- url: https://arxiv.org/pdf/2005.09347.pdf
6
- code: https://github.com/ShiningCosmos/pytorch_ComiRec/blob/main/ComiRec.py
7
- Authors: Kai Wang, 306178200@qq.com
8
- """
9
-
10
- import torch
11
-
12
- from ...basic.layers import MLP, EmbeddingLayer, MultiInterestSA, CapsuleNetwork
13
- from torch import nn
14
- import torch.nn.functional as F
15
-
16
-
17
- class ComirecSA(torch.nn.Module):
18
- """The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
19
- It's a ComirecSA match model trained by global softmax loss on list-wise samples.
20
- Note in origin paper, it's without item dnn tower and train item embedding directly.
21
-
22
- Args:
23
- user_features (list[Feature Class]): training by the user tower module.
24
- history_features (list[Feature Class]): training history
25
- item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
26
- neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
27
- temperature (float): temperature factor for similarity score, default to 1.0.
28
- interest_num (int): interest num
29
- """
30
-
31
- def __init__(self, user_features, history_features, item_features, neg_item_feature, temperature=1.0, interest_num=4):
32
- super().__init__()
33
- self.user_features = user_features
34
- self.item_features = item_features
35
- self.history_features = history_features
36
- self.neg_item_feature = neg_item_feature
37
- self.temperature = temperature
38
- self.interest_num = interest_num
39
- self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
40
-
41
- self.embedding = EmbeddingLayer(user_features + item_features + history_features)
42
- self.multi_interest_sa = MultiInterestSA(embedding_dim=self.history_features[0].embed_dim, interest_num=self.interest_num)
43
- self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
44
- self.mode = None
45
-
46
- def forward(self, x):
47
- user_embedding = self.user_tower(x)
48
- item_embedding = self.item_tower(x)
49
- if self.mode == "user":
50
- return user_embedding
51
- if self.mode == "item":
52
- return item_embedding
53
-
54
- pos_item_embedding = item_embedding[:,0,:]
55
- dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
56
- k_index = torch.argmax(dot_res, dim=1)
57
- best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
58
- for k in range(user_embedding.shape[0]):
59
- best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
60
- best_interest_emb = best_interest_emb.unsqueeze(1)
61
-
62
- y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
63
-
64
- return y
65
-
66
- def user_tower(self, x):
67
- if self.mode == "item":
68
- return None
69
- input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) #[batch_size, num_features*deep_dims]
70
- input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
71
-
72
- history_emb = self.embedding(x, self.history_features).squeeze(1)
73
- mask = self.gen_mask(x)
74
- mask = mask.unsqueeze(-1).float()
75
- multi_interest_emb = self.multi_interest_sa(history_emb,mask)
76
-
77
- input_user = torch.cat([input_user,multi_interest_emb],dim=-1)
78
-
79
- # user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, interest_num, embed_dim]
80
- user_embedding = torch.matmul(input_user,self.convert_user_weight)
81
- user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
82
- if self.mode == "user":
83
- return user_embedding #inference embedding mode -> [batch_size, interest_num, embed_dim]
84
- return user_embedding
85
-
86
- def item_tower(self, x):
87
- if self.mode == "user":
88
- return None
89
- pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
90
- pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
91
- if self.mode == "item": #inference embedding mode
92
- return pos_embedding.squeeze(1) #[batch_size, embed_dim]
93
- neg_embeddings = self.embedding(x, self.neg_item_feature,
94
- squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
95
- neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
96
- return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
97
-
98
- def gen_mask(self, x):
99
- his_list = x[self.history_features[0].name]
100
- mask = (his_list > 0).long()
101
- return mask
102
-
103
- class ComirecDR(torch.nn.Module):
104
- """The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
105
- It's a ComirecDR match model trained by global softmax loss on list-wise samples.
106
- Note in origin paper, it's without item dnn tower and train item embedding directly.
107
-
108
- Args:
109
- user_features (list[Feature Class]): training by the user tower module.
110
- history_features (list[Feature Class]): training history
111
- item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
112
- neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
113
- max_length (int): max sequence length of input item sequence
114
- temperature (float): temperature factor for similarity score, default to 1.0.
115
- interest_num (int): interest num
116
- """
117
-
118
- def __init__(self, user_features, history_features, item_features, neg_item_feature, max_length, temperature=1.0, interest_num=4):
119
- super().__init__()
120
- self.user_features = user_features
121
- self.item_features = item_features
122
- self.history_features = history_features
123
- self.neg_item_feature = neg_item_feature
124
- self.temperature = temperature
125
- self.interest_num = interest_num
126
- self.max_length = max_length
127
- self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
128
-
129
- self.embedding = EmbeddingLayer(user_features + item_features + history_features)
130
- self.capsule = CapsuleNetwork(self.history_features[0].embed_dim,self.max_length,bilinear_type=2,interest_num=self.interest_num)
131
- self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
132
- self.mode = None
133
-
134
- def forward(self, x):
135
- user_embedding = self.user_tower(x)
136
- item_embedding = self.item_tower(x)
137
- if self.mode == "user":
138
- return user_embedding
139
- if self.mode == "item":
140
- return item_embedding
141
-
142
- pos_item_embedding = item_embedding[:,0,:]
143
- dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
144
- k_index = torch.argmax(dot_res, dim=1)
145
- best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
146
- for k in range(user_embedding.shape[0]):
147
- best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
148
- best_interest_emb = best_interest_emb.unsqueeze(1)
149
-
150
- y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
151
-
152
- return y
153
-
154
- def user_tower(self, x):
155
- if self.mode == "item":
156
- return None
157
- input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) #[batch_size, num_features*deep_dims]
158
- input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
159
-
160
- history_emb = self.embedding(x, self.history_features).squeeze(1)
161
- mask = self.gen_mask(x)
162
- multi_interest_emb = self.capsule(history_emb,mask)
163
-
164
- input_user = torch.cat([input_user,multi_interest_emb],dim=-1)
165
-
166
- # user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, interest_num, embed_dim]
167
- user_embedding = torch.matmul(input_user,self.convert_user_weight)
168
- user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
169
- if self.mode == "user":
170
- return user_embedding #inference embedding mode -> [batch_size, interest_num, embed_dim]
171
- return user_embedding
172
-
173
- def item_tower(self, x):
174
- if self.mode == "user":
175
- return None
176
- pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
177
- pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
178
- if self.mode == "item": #inference embedding mode
179
- return pos_embedding.squeeze(1) #[batch_size, embed_dim]
180
- neg_embeddings = self.embedding(x, self.neg_item_feature,
181
- squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
182
- neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
183
- return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
184
-
185
- def gen_mask(self, x):
186
- his_list = x[self.history_features[0].name]
187
- mask = (his_list > 0).long()
188
- return mask
1
+ """
2
+ Date: create on 07/06/2022
3
+ References:
4
+ paper: Controllable Multi-Interest Framework for Recommendation
5
+ url: https://arxiv.org/pdf/2005.09347.pdf
6
+ code: https://github.com/ShiningCosmos/pytorch_ComiRec/blob/main/ComiRec.py
7
+ Authors: Kai Wang, 306178200@qq.com
8
+ """
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from ...basic.layers import MLP, CapsuleNetwork, EmbeddingLayer, MultiInterestSA
15
+
16
+
17
+ class ComirecSA(torch.nn.Module):
18
+ """The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
19
+ It's a ComirecSA match model trained by global softmax loss on list-wise samples.
20
+ Note in origin paper, it's without item dnn tower and train item embedding directly.
21
+
22
+ Args:
23
+ user_features (list[Feature Class]): training by the user tower module.
24
+ history_features (list[Feature Class]): training history
25
+ item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
26
+ neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
27
+ temperature (float): temperature factor for similarity score, default to 1.0.
28
+ interest_num (int): interest num
29
+ """
30
+
31
+ def __init__(self, user_features, history_features, item_features, neg_item_feature, temperature=1.0, interest_num=4):
32
+ super().__init__()
33
+ self.user_features = user_features
34
+ self.item_features = item_features
35
+ self.history_features = history_features
36
+ self.neg_item_feature = neg_item_feature
37
+ self.temperature = temperature
38
+ self.interest_num = interest_num
39
+ self.user_dims = sum([fea.embed_dim for fea in user_features + history_features])
40
+
41
+ self.embedding = EmbeddingLayer(user_features + item_features + history_features)
42
+ self.multi_interest_sa = MultiInterestSA(embedding_dim=self.history_features[0].embed_dim, interest_num=self.interest_num)
43
+ self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
44
+ self.mode = None
45
+
46
+ def forward(self, x):
47
+ user_embedding = self.user_tower(x)
48
+ item_embedding = self.item_tower(x)
49
+ if self.mode == "user":
50
+ return user_embedding
51
+ if self.mode == "item":
52
+ return item_embedding
53
+
54
+ pos_item_embedding = item_embedding[:, 0, :]
55
+ dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
56
+ k_index = torch.argmax(dot_res, dim=1)
57
+ best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
58
+ for k in range(user_embedding.shape[0]):
59
+ best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
60
+ best_interest_emb = best_interest_emb.unsqueeze(1)
61
+
62
+ y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
63
+
64
+ return y
65
+
66
+ def user_tower(self, x):
67
+ if self.mode == "item":
68
+ return None
69
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) # [batch_size, num_features*deep_dims]
70
+ input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
71
+
72
+ history_emb = self.embedding(x, self.history_features).squeeze(1)
73
+ mask = self.gen_mask(x)
74
+ mask = mask.unsqueeze(-1).float()
75
+ multi_interest_emb = self.multi_interest_sa(history_emb, mask)
76
+
77
+ input_user = torch.cat([input_user, multi_interest_emb], dim=-1)
78
+
79
+ # user_embedding = self.user_mlp(input_user).unsqueeze(1)
80
+ # #[batch_size, interest_num, embed_dim]
81
+ user_embedding = torch.matmul(input_user, self.convert_user_weight)
82
+ user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
83
+ if self.mode == "user":
84
+ # inference embedding mode -> [batch_size, interest_num, embed_dim]
85
+ return user_embedding
86
+ return user_embedding
87
+
88
+ def item_tower(self, x):
89
+ if self.mode == "user":
90
+ return None
91
+ pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) # [batch_size, 1, embed_dim]
92
+ pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
93
+ if self.mode == "item": # inference embedding mode
94
+ return pos_embedding.squeeze(1) # [batch_size, embed_dim]
95
+ neg_embeddings = self.embedding(x, self.neg_item_feature, squeeze_dim=False).squeeze(1) # [batch_size, n_neg_items, embed_dim]
96
+ neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
97
+ # [batch_size, 1+n_neg_items, embed_dim]
98
+ return torch.cat((pos_embedding, neg_embeddings), dim=1)
99
+
100
+ def gen_mask(self, x):
101
+ his_list = x[self.history_features[0].name]
102
+ mask = (his_list > 0).long()
103
+ return mask
104
+
105
+
106
+ class ComirecDR(torch.nn.Module):
107
+ """The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
108
+ It's a ComirecDR match model trained by global softmax loss on list-wise samples.
109
+ Note in origin paper, it's without item dnn tower and train item embedding directly.
110
+
111
+ Args:
112
+ user_features (list[Feature Class]): training by the user tower module.
113
+ history_features (list[Feature Class]): training history
114
+ item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
115
+ neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
116
+ max_length (int): max sequence length of input item sequence
117
+ temperature (float): temperature factor for similarity score, default to 1.0.
118
+ interest_num (int): interest num
119
+ """
120
+
121
+ def __init__(self, user_features, history_features, item_features, neg_item_feature, max_length, temperature=1.0, interest_num=4):
122
+ super().__init__()
123
+ self.user_features = user_features
124
+ self.item_features = item_features
125
+ self.history_features = history_features
126
+ self.neg_item_feature = neg_item_feature
127
+ self.temperature = temperature
128
+ self.interest_num = interest_num
129
+ self.max_length = max_length
130
+ self.user_dims = sum([fea.embed_dim for fea in user_features + history_features])
131
+
132
+ self.embedding = EmbeddingLayer(user_features + item_features + history_features)
133
+ self.capsule = CapsuleNetwork(self.history_features[0].embed_dim, self.max_length, bilinear_type=2, interest_num=self.interest_num)
134
+ self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
135
+ self.mode = None
136
+
137
+ def forward(self, x):
138
+ user_embedding = self.user_tower(x)
139
+ item_embedding = self.item_tower(x)
140
+ if self.mode == "user":
141
+ return user_embedding
142
+ if self.mode == "item":
143
+ return item_embedding
144
+
145
+ pos_item_embedding = item_embedding[:, 0, :]
146
+ dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
147
+ k_index = torch.argmax(dot_res, dim=1)
148
+ best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
149
+ for k in range(user_embedding.shape[0]):
150
+ best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
151
+ best_interest_emb = best_interest_emb.unsqueeze(1)
152
+
153
+ y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
154
+
155
+ return y
156
+
157
+ def user_tower(self, x):
158
+ if self.mode == "item":
159
+ return None
160
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) # [batch_size, num_features*deep_dims]
161
+ input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
162
+
163
+ history_emb = self.embedding(x, self.history_features).squeeze(1)
164
+ mask = self.gen_mask(x)
165
+ multi_interest_emb = self.capsule(history_emb, mask)
166
+
167
+ input_user = torch.cat([input_user, multi_interest_emb], dim=-1)
168
+
169
+ # user_embedding = self.user_mlp(input_user).unsqueeze(1)
170
+ # #[batch_size, interest_num, embed_dim]
171
+ user_embedding = torch.matmul(input_user, self.convert_user_weight)
172
+ user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
173
+ if self.mode == "user":
174
+ # inference embedding mode -> [batch_size, interest_num, embed_dim]
175
+ return user_embedding
176
+ return user_embedding
177
+
178
+ def item_tower(self, x):
179
+ if self.mode == "user":
180
+ return None
181
+ pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) # [batch_size, 1, embed_dim]
182
+ pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
183
+ if self.mode == "item": # inference embedding mode
184
+ return pos_embedding.squeeze(1) # [batch_size, embed_dim]
185
+ neg_embeddings = self.embedding(x, self.neg_item_feature, squeeze_dim=False).squeeze(1) # [batch_size, n_neg_items, embed_dim]
186
+ neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
187
+ # [batch_size, 1+n_neg_items, embed_dim]
188
+ return torch.cat((pos_embedding, neg_embeddings), dim=1)
189
+
190
+ def gen_mask(self, x):
191
+ his_list = x[self.history_features[0].name]
192
+ mask = (his_list > 0).long()
193
+ return mask
@@ -1,66 +1,72 @@
1
- """
2
- Date: create on 12/05/2022, update on 20/05/2022
3
- References:
4
- paper: (CIKM'2013) Learning Deep Structured Semantic Models for Web Search using Clickthrough Data
5
- url: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf
6
- code: https://github.com/bbruceyuan/DeepMatch-Torch/blob/main/deepmatch_torch/models/dssm.py
7
- Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
8
- """
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- from ...basic.layers import MLP, EmbeddingLayer
13
-
14
-
15
- class DSSM(torch.nn.Module):
16
- """Deep Structured Semantic Model
17
-
18
- Args:
19
- user_features (list[Feature Class]): training by the user tower module.
20
- item_features (list[Feature Class]): training by the item tower module.
21
- temperature (float): temperature factor for similarity score, default to 1.0.
22
- user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
23
- item_params (dict): the params of the Item Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
24
- """
25
-
26
- def __init__(self, user_features, item_features, user_params, item_params, temperature=1.0):
27
- super().__init__()
28
- self.user_features = user_features
29
- self.item_features = item_features
30
- self.temperature = temperature
31
- self.user_dims = sum([fea.embed_dim for fea in user_features])
32
- self.item_dims = sum([fea.embed_dim for fea in item_features])
33
-
34
- self.embedding = EmbeddingLayer(user_features + item_features)
35
- self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
36
- self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
37
- self.mode = None
38
-
39
- def forward(self, x):
40
- user_embedding = self.user_tower(x)
41
- item_embedding = self.item_tower(x)
42
- if self.mode == "user":
43
- return user_embedding
44
- if self.mode == "item":
45
- return item_embedding
46
-
47
- # calculate cosine score
48
- y = torch.mul(user_embedding, item_embedding).sum(dim=1)
49
- # y = y / self.temperature
50
- return torch.sigmoid(y)
51
-
52
- def user_tower(self, x):
53
- if self.mode == "item":
54
- return None
55
- input_user = self.embedding(x, self.user_features, squeeze_dim=True) #[batch_size, num_features*deep_dims]
56
- user_embedding = self.user_mlp(input_user) #[batch_size, user_params["dims"][-1]]
57
- user_embedding = F.normalize(user_embedding, p=2, dim=1) # L2 normalize
58
- return user_embedding
59
-
60
- def item_tower(self, x):
61
- if self.mode == "user":
62
- return None
63
- input_item = self.embedding(x, self.item_features, squeeze_dim=True) #[batch_size, num_features*embed_dim]
64
- item_embedding = self.item_mlp(input_item) #[batch_size, item_params["dims"][-1]]
65
- item_embedding = F.normalize(item_embedding, p=2, dim=1)
66
- return item_embedding
1
+ """
2
+ Date: create on 12/05/2022, update on 20/05/2022
3
+ References:
4
+ paper: (CIKM'2013) Learning Deep Structured Semantic Models for Web Search using Clickthrough Data
5
+ url: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf
6
+ code: https://github.com/bbruceyuan/DeepMatch-Torch/blob/main/deepmatch_torch/models/dssm.py
7
+ Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
8
+ """
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from ...basic.layers import MLP, EmbeddingLayer
14
+
15
+
16
+ class DSSM(torch.nn.Module):
17
+ """Deep Structured Semantic Model
18
+
19
+ Args:
20
+ user_features (list[Feature Class]): training by the user tower module.
21
+ item_features (list[Feature Class]): training by the item tower module.
22
+ temperature (float): temperature factor for similarity score, default to 1.0.
23
+ user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
24
+ item_params (dict): the params of the Item Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
25
+ """
26
+
27
+ def __init__(self, user_features, item_features, user_params, item_params, temperature=1.0):
28
+ super().__init__()
29
+ self.user_features = user_features
30
+ self.item_features = item_features
31
+ self.temperature = temperature
32
+ self.user_dims = sum([fea.embed_dim for fea in user_features])
33
+ self.item_dims = sum([fea.embed_dim for fea in item_features])
34
+
35
+ self.embedding = EmbeddingLayer(user_features + item_features)
36
+ self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
37
+ self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
38
+ self.mode = None
39
+
40
+ def forward(self, x):
41
+ user_embedding = self.user_tower(x)
42
+ item_embedding = self.item_tower(x)
43
+ if self.mode == "user":
44
+ return user_embedding
45
+ if self.mode == "item":
46
+ return item_embedding
47
+
48
+
49
+ # calculate cosine score
50
+ y = torch.mul(user_embedding, item_embedding).sum(dim=1)
51
+ # y = y / self.temperature
52
+ return torch.sigmoid(y)
53
+
54
+ def user_tower(self, x):
55
+ if self.mode == "item":
56
+ return None
57
+ # [batch_size, num_features*deep_dims]
58
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True)
59
+ # [batch_size, user_params["dims"][-1]]
60
+ user_embedding = self.user_mlp(input_user)
61
+ user_embedding = F.normalize(user_embedding, p=2, dim=1) # L2 normalize
62
+ return user_embedding
63
+
64
+ def item_tower(self, x):
65
+ if self.mode == "user":
66
+ return None
67
+ # [batch_size, num_features*embed_dim]
68
+ input_item = self.embedding(x, self.item_features, squeeze_dim=True)
69
+ # [batch_size, item_params["dims"][-1]]
70
+ item_embedding = self.item_mlp(input_item)
71
+ item_embedding = F.normalize(item_embedding, p=2, dim=1)
72
+ return item_embedding