torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -1,188 +1,193 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Date: create on 07/06/2022
|
|
3
|
-
References:
|
|
4
|
-
paper: Controllable Multi-Interest Framework for Recommendation
|
|
5
|
-
url: https://arxiv.org/pdf/2005.09347.pdf
|
|
6
|
-
code: https://github.com/ShiningCosmos/pytorch_ComiRec/blob/main/ComiRec.py
|
|
7
|
-
Authors: Kai Wang, 306178200@qq.com
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
|
|
12
|
-
from
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class ComirecSA(torch.nn.Module):
|
|
18
|
-
"""The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
|
|
19
|
-
It's a ComirecSA match model trained by global softmax loss on list-wise samples.
|
|
20
|
-
Note in origin paper, it's without item dnn tower and train item embedding directly.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
user_features (list[Feature Class]): training by the user tower module.
|
|
24
|
-
history_features (list[Feature Class]): training history
|
|
25
|
-
item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
|
|
26
|
-
neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
|
|
27
|
-
temperature (float): temperature factor for similarity score, default to 1.0.
|
|
28
|
-
interest_num (int): interest num
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
|
-
def __init__(self, user_features, history_features, item_features, neg_item_feature, temperature=1.0, interest_num=4):
|
|
32
|
-
super().__init__()
|
|
33
|
-
self.user_features = user_features
|
|
34
|
-
self.item_features = item_features
|
|
35
|
-
self.history_features = history_features
|
|
36
|
-
self.neg_item_feature = neg_item_feature
|
|
37
|
-
self.temperature = temperature
|
|
38
|
-
self.interest_num = interest_num
|
|
39
|
-
self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
|
|
40
|
-
|
|
41
|
-
self.embedding = EmbeddingLayer(user_features + item_features + history_features)
|
|
42
|
-
self.multi_interest_sa = MultiInterestSA(embedding_dim=self.history_features[0].embed_dim, interest_num=self.interest_num)
|
|
43
|
-
self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
|
|
44
|
-
self.mode = None
|
|
45
|
-
|
|
46
|
-
def forward(self, x):
|
|
47
|
-
user_embedding = self.user_tower(x)
|
|
48
|
-
item_embedding = self.item_tower(x)
|
|
49
|
-
if self.mode == "user":
|
|
50
|
-
return user_embedding
|
|
51
|
-
if self.mode == "item":
|
|
52
|
-
return item_embedding
|
|
53
|
-
|
|
54
|
-
pos_item_embedding = item_embedding[:,0
|
|
55
|
-
dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
|
|
56
|
-
k_index = torch.argmax(dot_res, dim=1)
|
|
57
|
-
best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
|
|
58
|
-
for k in range(user_embedding.shape[0]):
|
|
59
|
-
best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
|
|
60
|
-
best_interest_emb = best_interest_emb.unsqueeze(1)
|
|
61
|
-
|
|
62
|
-
y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
|
|
63
|
-
|
|
64
|
-
return y
|
|
65
|
-
|
|
66
|
-
def user_tower(self, x):
|
|
67
|
-
if self.mode == "item":
|
|
68
|
-
return None
|
|
69
|
-
input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) #[batch_size, num_features*deep_dims]
|
|
70
|
-
input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
|
|
71
|
-
|
|
72
|
-
history_emb = self.embedding(x, self.history_features).squeeze(1)
|
|
73
|
-
mask = self.gen_mask(x)
|
|
74
|
-
mask = mask.unsqueeze(-1).float()
|
|
75
|
-
multi_interest_emb = self.multi_interest_sa(history_emb,mask)
|
|
76
|
-
|
|
77
|
-
input_user = torch.cat([input_user,multi_interest_emb],dim=-1)
|
|
78
|
-
|
|
79
|
-
# user_embedding = self.user_mlp(input_user).unsqueeze(1)
|
|
80
|
-
|
|
81
|
-
user_embedding =
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
neg_embeddings =
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
self.
|
|
124
|
-
self.
|
|
125
|
-
self.
|
|
126
|
-
self.
|
|
127
|
-
self.
|
|
128
|
-
|
|
129
|
-
self.
|
|
130
|
-
self.
|
|
131
|
-
|
|
132
|
-
self.
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
best_interest_emb =
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
return
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
return
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 07/06/2022
|
|
3
|
+
References:
|
|
4
|
+
paper: Controllable Multi-Interest Framework for Recommendation
|
|
5
|
+
url: https://arxiv.org/pdf/2005.09347.pdf
|
|
6
|
+
code: https://github.com/ShiningCosmos/pytorch_ComiRec/blob/main/ComiRec.py
|
|
7
|
+
Authors: Kai Wang, 306178200@qq.com
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from torch import nn
|
|
13
|
+
|
|
14
|
+
from ...basic.layers import MLP, CapsuleNetwork, EmbeddingLayer, MultiInterestSA
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ComirecSA(torch.nn.Module):
|
|
18
|
+
"""The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
|
|
19
|
+
It's a ComirecSA match model trained by global softmax loss on list-wise samples.
|
|
20
|
+
Note in origin paper, it's without item dnn tower and train item embedding directly.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
user_features (list[Feature Class]): training by the user tower module.
|
|
24
|
+
history_features (list[Feature Class]): training history
|
|
25
|
+
item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
|
|
26
|
+
neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
|
|
27
|
+
temperature (float): temperature factor for similarity score, default to 1.0.
|
|
28
|
+
interest_num (int): interest num
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, user_features, history_features, item_features, neg_item_feature, temperature=1.0, interest_num=4):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.user_features = user_features
|
|
34
|
+
self.item_features = item_features
|
|
35
|
+
self.history_features = history_features
|
|
36
|
+
self.neg_item_feature = neg_item_feature
|
|
37
|
+
self.temperature = temperature
|
|
38
|
+
self.interest_num = interest_num
|
|
39
|
+
self.user_dims = sum([fea.embed_dim for fea in user_features + history_features])
|
|
40
|
+
|
|
41
|
+
self.embedding = EmbeddingLayer(user_features + item_features + history_features)
|
|
42
|
+
self.multi_interest_sa = MultiInterestSA(embedding_dim=self.history_features[0].embed_dim, interest_num=self.interest_num)
|
|
43
|
+
self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
|
|
44
|
+
self.mode = None
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
user_embedding = self.user_tower(x)
|
|
48
|
+
item_embedding = self.item_tower(x)
|
|
49
|
+
if self.mode == "user":
|
|
50
|
+
return user_embedding
|
|
51
|
+
if self.mode == "item":
|
|
52
|
+
return item_embedding
|
|
53
|
+
|
|
54
|
+
pos_item_embedding = item_embedding[:, 0, :]
|
|
55
|
+
dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
|
|
56
|
+
k_index = torch.argmax(dot_res, dim=1)
|
|
57
|
+
best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
|
|
58
|
+
for k in range(user_embedding.shape[0]):
|
|
59
|
+
best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
|
|
60
|
+
best_interest_emb = best_interest_emb.unsqueeze(1)
|
|
61
|
+
|
|
62
|
+
y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
|
|
63
|
+
|
|
64
|
+
return y
|
|
65
|
+
|
|
66
|
+
def user_tower(self, x):
|
|
67
|
+
if self.mode == "item":
|
|
68
|
+
return None
|
|
69
|
+
input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) # [batch_size, num_features*deep_dims]
|
|
70
|
+
input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
|
|
71
|
+
|
|
72
|
+
history_emb = self.embedding(x, self.history_features).squeeze(1)
|
|
73
|
+
mask = self.gen_mask(x)
|
|
74
|
+
mask = mask.unsqueeze(-1).float()
|
|
75
|
+
multi_interest_emb = self.multi_interest_sa(history_emb, mask)
|
|
76
|
+
|
|
77
|
+
input_user = torch.cat([input_user, multi_interest_emb], dim=-1)
|
|
78
|
+
|
|
79
|
+
# user_embedding = self.user_mlp(input_user).unsqueeze(1)
|
|
80
|
+
# #[batch_size, interest_num, embed_dim]
|
|
81
|
+
user_embedding = torch.matmul(input_user, self.convert_user_weight)
|
|
82
|
+
user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
|
|
83
|
+
if self.mode == "user":
|
|
84
|
+
# inference embedding mode -> [batch_size, interest_num, embed_dim]
|
|
85
|
+
return user_embedding
|
|
86
|
+
return user_embedding
|
|
87
|
+
|
|
88
|
+
def item_tower(self, x):
|
|
89
|
+
if self.mode == "user":
|
|
90
|
+
return None
|
|
91
|
+
pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) # [batch_size, 1, embed_dim]
|
|
92
|
+
pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
|
|
93
|
+
if self.mode == "item": # inference embedding mode
|
|
94
|
+
return pos_embedding.squeeze(1) # [batch_size, embed_dim]
|
|
95
|
+
neg_embeddings = self.embedding(x, self.neg_item_feature, squeeze_dim=False).squeeze(1) # [batch_size, n_neg_items, embed_dim]
|
|
96
|
+
neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
|
|
97
|
+
# [batch_size, 1+n_neg_items, embed_dim]
|
|
98
|
+
return torch.cat((pos_embedding, neg_embeddings), dim=1)
|
|
99
|
+
|
|
100
|
+
def gen_mask(self, x):
|
|
101
|
+
his_list = x[self.history_features[0].name]
|
|
102
|
+
mask = (his_list > 0).long()
|
|
103
|
+
return mask
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ComirecDR(torch.nn.Module):
|
|
107
|
+
"""The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
|
|
108
|
+
It's a ComirecDR match model trained by global softmax loss on list-wise samples.
|
|
109
|
+
Note in origin paper, it's without item dnn tower and train item embedding directly.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
user_features (list[Feature Class]): training by the user tower module.
|
|
113
|
+
history_features (list[Feature Class]): training history
|
|
114
|
+
item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
|
|
115
|
+
neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
|
|
116
|
+
max_length (int): max sequence length of input item sequence
|
|
117
|
+
temperature (float): temperature factor for similarity score, default to 1.0.
|
|
118
|
+
interest_num (int): interest num
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, user_features, history_features, item_features, neg_item_feature, max_length, temperature=1.0, interest_num=4):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.user_features = user_features
|
|
124
|
+
self.item_features = item_features
|
|
125
|
+
self.history_features = history_features
|
|
126
|
+
self.neg_item_feature = neg_item_feature
|
|
127
|
+
self.temperature = temperature
|
|
128
|
+
self.interest_num = interest_num
|
|
129
|
+
self.max_length = max_length
|
|
130
|
+
self.user_dims = sum([fea.embed_dim for fea in user_features + history_features])
|
|
131
|
+
|
|
132
|
+
self.embedding = EmbeddingLayer(user_features + item_features + history_features)
|
|
133
|
+
self.capsule = CapsuleNetwork(self.history_features[0].embed_dim, self.max_length, bilinear_type=2, interest_num=self.interest_num)
|
|
134
|
+
self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
|
|
135
|
+
self.mode = None
|
|
136
|
+
|
|
137
|
+
def forward(self, x):
|
|
138
|
+
user_embedding = self.user_tower(x)
|
|
139
|
+
item_embedding = self.item_tower(x)
|
|
140
|
+
if self.mode == "user":
|
|
141
|
+
return user_embedding
|
|
142
|
+
if self.mode == "item":
|
|
143
|
+
return item_embedding
|
|
144
|
+
|
|
145
|
+
pos_item_embedding = item_embedding[:, 0, :]
|
|
146
|
+
dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
|
|
147
|
+
k_index = torch.argmax(dot_res, dim=1)
|
|
148
|
+
best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
|
|
149
|
+
for k in range(user_embedding.shape[0]):
|
|
150
|
+
best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
|
|
151
|
+
best_interest_emb = best_interest_emb.unsqueeze(1)
|
|
152
|
+
|
|
153
|
+
y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
|
|
154
|
+
|
|
155
|
+
return y
|
|
156
|
+
|
|
157
|
+
def user_tower(self, x):
|
|
158
|
+
if self.mode == "item":
|
|
159
|
+
return None
|
|
160
|
+
input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) # [batch_size, num_features*deep_dims]
|
|
161
|
+
input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
|
|
162
|
+
|
|
163
|
+
history_emb = self.embedding(x, self.history_features).squeeze(1)
|
|
164
|
+
mask = self.gen_mask(x)
|
|
165
|
+
multi_interest_emb = self.capsule(history_emb, mask)
|
|
166
|
+
|
|
167
|
+
input_user = torch.cat([input_user, multi_interest_emb], dim=-1)
|
|
168
|
+
|
|
169
|
+
# user_embedding = self.user_mlp(input_user).unsqueeze(1)
|
|
170
|
+
# #[batch_size, interest_num, embed_dim]
|
|
171
|
+
user_embedding = torch.matmul(input_user, self.convert_user_weight)
|
|
172
|
+
user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
|
|
173
|
+
if self.mode == "user":
|
|
174
|
+
# inference embedding mode -> [batch_size, interest_num, embed_dim]
|
|
175
|
+
return user_embedding
|
|
176
|
+
return user_embedding
|
|
177
|
+
|
|
178
|
+
def item_tower(self, x):
|
|
179
|
+
if self.mode == "user":
|
|
180
|
+
return None
|
|
181
|
+
pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) # [batch_size, 1, embed_dim]
|
|
182
|
+
pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
|
|
183
|
+
if self.mode == "item": # inference embedding mode
|
|
184
|
+
return pos_embedding.squeeze(1) # [batch_size, embed_dim]
|
|
185
|
+
neg_embeddings = self.embedding(x, self.neg_item_feature, squeeze_dim=False).squeeze(1) # [batch_size, n_neg_items, embed_dim]
|
|
186
|
+
neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
|
|
187
|
+
# [batch_size, 1+n_neg_items, embed_dim]
|
|
188
|
+
return torch.cat((pos_embedding, neg_embeddings), dim=1)
|
|
189
|
+
|
|
190
|
+
def gen_mask(self, x):
|
|
191
|
+
his_list = x[self.history_features[0].name]
|
|
192
|
+
mask = (his_list > 0).long()
|
|
193
|
+
return mask
|
|
@@ -1,66 +1,72 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Date: create on 12/05/2022, update on 20/05/2022
|
|
3
|
-
References:
|
|
4
|
-
paper: (CIKM'2013) Learning Deep Structured Semantic Models for Web Search using Clickthrough Data
|
|
5
|
-
url: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf
|
|
6
|
-
code: https://github.com/bbruceyuan/DeepMatch-Torch/blob/main/deepmatch_torch/models/dssm.py
|
|
7
|
-
Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
import torch.nn.functional as F
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
self.
|
|
30
|
-
self.
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
self.
|
|
36
|
-
self.
|
|
37
|
-
self.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 12/05/2022, update on 20/05/2022
|
|
3
|
+
References:
|
|
4
|
+
paper: (CIKM'2013) Learning Deep Structured Semantic Models for Web Search using Clickthrough Data
|
|
5
|
+
url: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf
|
|
6
|
+
code: https://github.com/bbruceyuan/DeepMatch-Torch/blob/main/deepmatch_torch/models/dssm.py
|
|
7
|
+
Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
from ...basic.layers import MLP, EmbeddingLayer
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DSSM(torch.nn.Module):
|
|
17
|
+
"""Deep Structured Semantic Model
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
user_features (list[Feature Class]): training by the user tower module.
|
|
21
|
+
item_features (list[Feature Class]): training by the item tower module.
|
|
22
|
+
temperature (float): temperature factor for similarity score, default to 1.0.
|
|
23
|
+
user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
|
|
24
|
+
item_params (dict): the params of the Item Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, user_features, item_features, user_params, item_params, temperature=1.0):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.user_features = user_features
|
|
30
|
+
self.item_features = item_features
|
|
31
|
+
self.temperature = temperature
|
|
32
|
+
self.user_dims = sum([fea.embed_dim for fea in user_features])
|
|
33
|
+
self.item_dims = sum([fea.embed_dim for fea in item_features])
|
|
34
|
+
|
|
35
|
+
self.embedding = EmbeddingLayer(user_features + item_features)
|
|
36
|
+
self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
|
|
37
|
+
self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
|
|
38
|
+
self.mode = None
|
|
39
|
+
|
|
40
|
+
def forward(self, x):
|
|
41
|
+
user_embedding = self.user_tower(x)
|
|
42
|
+
item_embedding = self.item_tower(x)
|
|
43
|
+
if self.mode == "user":
|
|
44
|
+
return user_embedding
|
|
45
|
+
if self.mode == "item":
|
|
46
|
+
return item_embedding
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# calculate cosine score
|
|
50
|
+
y = torch.mul(user_embedding, item_embedding).sum(dim=1)
|
|
51
|
+
# y = y / self.temperature
|
|
52
|
+
return torch.sigmoid(y)
|
|
53
|
+
|
|
54
|
+
def user_tower(self, x):
|
|
55
|
+
if self.mode == "item":
|
|
56
|
+
return None
|
|
57
|
+
# [batch_size, num_features*deep_dims]
|
|
58
|
+
input_user = self.embedding(x, self.user_features, squeeze_dim=True)
|
|
59
|
+
# [batch_size, user_params["dims"][-1]]
|
|
60
|
+
user_embedding = self.user_mlp(input_user)
|
|
61
|
+
user_embedding = F.normalize(user_embedding, p=2, dim=1) # L2 normalize
|
|
62
|
+
return user_embedding
|
|
63
|
+
|
|
64
|
+
def item_tower(self, x):
|
|
65
|
+
if self.mode == "user":
|
|
66
|
+
return None
|
|
67
|
+
# [batch_size, num_features*embed_dim]
|
|
68
|
+
input_item = self.embedding(x, self.item_features, squeeze_dim=True)
|
|
69
|
+
# [batch_size, item_params["dims"][-1]]
|
|
70
|
+
item_embedding = self.item_mlp(input_item)
|
|
71
|
+
item_embedding = F.normalize(item_embedding, p=2, dim=1)
|
|
72
|
+
return item_embedding
|