torch-rechub 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. torch_rechub/basic/activation.py +54 -52
  2. torch_rechub/basic/callback.py +32 -32
  3. torch_rechub/basic/features.py +94 -57
  4. torch_rechub/basic/initializers.py +92 -0
  5. torch_rechub/basic/layers.py +720 -240
  6. torch_rechub/basic/loss_func.py +34 -0
  7. torch_rechub/basic/metaoptimizer.py +72 -0
  8. torch_rechub/basic/metric.py +250 -0
  9. torch_rechub/models/matching/__init__.py +11 -0
  10. torch_rechub/models/matching/comirec.py +188 -0
  11. torch_rechub/models/matching/dssm.py +66 -0
  12. torch_rechub/models/matching/dssm_facebook.py +79 -0
  13. torch_rechub/models/matching/dssm_senet.py +75 -0
  14. torch_rechub/models/matching/gru4rec.py +87 -0
  15. torch_rechub/models/matching/mind.py +101 -0
  16. torch_rechub/models/matching/narm.py +76 -0
  17. torch_rechub/models/matching/sasrec.py +140 -0
  18. torch_rechub/models/matching/sine.py +151 -0
  19. torch_rechub/models/matching/stamp.py +83 -0
  20. torch_rechub/models/matching/youtube_dnn.py +71 -0
  21. torch_rechub/models/matching/youtube_sbc.py +98 -0
  22. torch_rechub/models/multi_task/__init__.py +5 -4
  23. torch_rechub/models/multi_task/aitm.py +84 -0
  24. torch_rechub/models/multi_task/esmm.py +55 -45
  25. torch_rechub/models/multi_task/mmoe.py +58 -52
  26. torch_rechub/models/multi_task/ple.py +130 -104
  27. torch_rechub/models/multi_task/shared_bottom.py +45 -44
  28. torch_rechub/models/ranking/__init__.py +11 -3
  29. torch_rechub/models/ranking/afm.py +63 -0
  30. torch_rechub/models/ranking/bst.py +63 -0
  31. torch_rechub/models/ranking/dcn.py +38 -0
  32. torch_rechub/models/ranking/dcn_v2.py +69 -0
  33. torch_rechub/models/ranking/deepffm.py +123 -0
  34. torch_rechub/models/ranking/deepfm.py +41 -41
  35. torch_rechub/models/ranking/dien.py +191 -0
  36. torch_rechub/models/ranking/din.py +91 -81
  37. torch_rechub/models/ranking/edcn.py +117 -0
  38. torch_rechub/models/ranking/fibinet.py +50 -0
  39. torch_rechub/models/ranking/widedeep.py +41 -41
  40. torch_rechub/trainers/__init__.py +2 -1
  41. torch_rechub/trainers/{trainer.py → ctr_trainer.py} +128 -111
  42. torch_rechub/trainers/match_trainer.py +170 -0
  43. torch_rechub/trainers/mtl_trainer.py +206 -144
  44. torch_rechub/utils/__init__.py +0 -0
  45. torch_rechub/utils/data.py +360 -0
  46. torch_rechub/utils/match.py +274 -0
  47. torch_rechub/utils/mtl.py +126 -0
  48. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/LICENSE +21 -21
  49. torch_rechub-0.0.3.dist-info/METADATA +177 -0
  50. torch_rechub-0.0.3.dist-info/RECORD +55 -0
  51. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/WHEEL +1 -1
  52. torch_rechub/basic/utils.py +0 -168
  53. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  54. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  55. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,34 @@
1
+ import torch
2
+ import torch.functional as F
3
+
4
+
5
+ class HingeLoss(torch.nn.Module):
6
+ """Hinge Loss for pairwise learning.
7
+ reference: https://github.com/ustcml/RecStudio/blob/main/recstudio/model/loss_func.py
8
+
9
+ """
10
+
11
+ def __init__(self, margin=2, num_items=None):
12
+ super().__init__()
13
+ self.margin = margin
14
+ self.n_items = num_items
15
+
16
+ def forward(self, pos_score, neg_score):
17
+ loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
18
+ if self.n_items is not None:
19
+ impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
20
+ rank = torch.mean(impostors, -1) * self.n_items
21
+ return torch.mean(loss * torch.log(rank + 1))
22
+ else:
23
+ return torch.mean(loss)
24
+
25
+
26
+ class BPRLoss(torch.nn.Module):
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+
31
+ def forward(self, pos_score, neg_score):
32
+ loss = torch.mean(-(pos_score - neg_score).sigmoid().log(), dim=-1)
33
+ return loss
34
+ #loss = -torch.mean(F.logsigmoid(pos_score - torch.max(neg_score, dim=-1))) need v1.10
@@ -0,0 +1,72 @@
1
+ """The metaoptimizer module, it provides a class MetaBalance
2
+ MetaBalance is used to scale the gradient and balance the gradient of each task
3
+ Authors: Qida Dong, dongjidan@126.com
4
+ """
5
+ import torch
6
+ from torch.optim.optimizer import Optimizer
7
+
8
+
9
+ class MetaBalance(Optimizer):
10
+ """MetaBalance Optimizer
11
+ This method is used to scale the gradient and balance the gradient of each task
12
+
13
+ Args:
14
+ parameters (list): the parameters of model
15
+ relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
16
+ beta (float, optional): the coefficient of moving average (default: 0.9)
17
+ """
18
+
19
+ def __init__(self, parameters, relax_factor=0.7, beta=0.9):
20
+
21
+ if relax_factor < 0. or relax_factor >= 1.:
22
+ raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
23
+ if beta < 0. or beta >= 1.:
24
+ raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
25
+ rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
26
+ super(MetaBalance, self).__init__(parameters, rel_beta_dict)
27
+
28
+ @torch.no_grad()
29
+ def step(self, losses):
30
+ """_summary_
31
+ Args:
32
+ losses (_type_): _description_
33
+
34
+ Raises:
35
+ RuntimeError: _description_
36
+ """
37
+
38
+ for idx, loss in enumerate(losses):
39
+ loss.backward(retain_graph=True)
40
+ for group in self.param_groups:
41
+ for gp in group['params']:
42
+ if gp.grad is None:
43
+ # print('breaking')
44
+ break
45
+ if gp.grad.is_sparse:
46
+ raise RuntimeError('MetaBalance does not support sparse gradients')
47
+ # store the result of moving average
48
+ state = self.state[gp]
49
+ if len(state) == 0:
50
+ for i in range(len(losses)):
51
+ if i == 0:
52
+ gp.norms = [0]
53
+ else:
54
+ gp.norms.append(0)
55
+ # calculate the moving average
56
+ beta = group['beta']
57
+ gp.norms[idx] = gp.norms[idx] * beta + (1 - beta) * torch.norm(gp.grad)
58
+ # scale the auxiliary gradient
59
+ relax_factor = group['relax_factor']
60
+ gp.grad = gp.grad * gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. - relax_factor)
61
+ # store the gradient of each auxiliary task in state
62
+ if idx == 0:
63
+ state['sum_gradient'] = torch.zeros_like(gp.data)
64
+ state['sum_gradient'] += gp.grad
65
+ else:
66
+ state['sum_gradient'] += gp.grad
67
+
68
+ if gp.grad is not None:
69
+ gp.grad.detach_()
70
+ gp.grad.zero_()
71
+ if idx == len(losses) - 1:
72
+ gp.grad = state['sum_gradient']
@@ -0,0 +1,250 @@
1
+ """The metric module, it is used to provide some metrics for recommenders.
2
+ Available function:
3
+ - auc_score: compute AUC
4
+ - gauc_score: compute GAUC
5
+ - log_loss: compute LogLoss
6
+ - topk_metrics: compute topk metrics contains 'ndcg', 'mrr', 'recall', 'hit'
7
+ Authors: Qida Dong, dongjidan@126.com
8
+ """
9
+ from sklearn.metrics import roc_auc_score
10
+ import numpy as np
11
+ from collections import defaultdict
12
+
13
+
14
+ def auc_score(y_true, y_pred):
15
+
16
+ return roc_auc_score(y_true, y_pred)
17
+
18
+
19
+ def get_user_pred(y_true, y_pred, users):
20
+ """divide the result into different group by user id
21
+
22
+ Args:
23
+ y_true (array): all true labels of the data
24
+ y_pred (array): the predicted score
25
+ users (array): user id
26
+
27
+ Return:
28
+ user_pred (dict): {userid: values}, key is user id and value is the labels and scores of each user
29
+ """
30
+ user_pred = {}
31
+ for i, u in enumerate(users):
32
+ if u not in user_pred:
33
+ user_pred[u] = {'y_true': [y_true[i]], 'y_pred': [y_pred[i]]}
34
+ else:
35
+ user_pred[u]['y_true'].append(y_true[i])
36
+ user_pred[u]['y_pred'].append(y_pred[i])
37
+
38
+ return user_pred
39
+
40
+
41
+ def gauc_score(y_true, y_pred, users, weights=None):
42
+ """compute GAUC
43
+
44
+ Args:
45
+ y_true (array): dim(N, ), all true labels of the data
46
+ y_pred (array): dim(N, ), the predicted score
47
+ users (array): dim(N, ), user id
48
+ weight (dict): {userid: weight_value}, it contains weights for each group.
49
+ if it is None, the weight is equal to the number
50
+ of times the user is recommended
51
+ Return:
52
+ score: float, GAUC
53
+ """
54
+ assert len(y_true) == len(y_pred) and len(y_true) == len(users)
55
+
56
+ user_pred = get_user_pred(y_true, y_pred, users)
57
+ score = 0
58
+ num = 0
59
+ for u in user_pred.keys():
60
+ auc = auc_score(user_pred[u]['y_true'], user_pred[u]['y_pred'])
61
+ if weights is None:
62
+ user_weight = len(user_pred[u]['y_true'])
63
+ else:
64
+ user_weight = weights[u]
65
+ auc *= user_weight
66
+ num += user_weight
67
+ score += auc
68
+ return score / num
69
+
70
+
71
+
72
+ def ndcg_score(y_true, y_pred, topKs=None):
73
+ if topKs is None:
74
+ topKs = [5]
75
+ result = topk_metrics(y_true, y_pred, topKs)
76
+ return result['NDCG']
77
+
78
+
79
+
80
+ def hit_score(y_true, y_pred, topKs=None):
81
+ if topKs is None:
82
+ topKs = [5]
83
+ result = topk_metrics(y_true, y_pred, topKs)
84
+ return result['Hit']
85
+
86
+
87
+ def mrr_score(y_true, y_pred, topKs=None):
88
+ if topKs is None:
89
+ topKs = [5]
90
+ result = topk_metrics(y_true, y_pred, topKs)
91
+ return result['MRR']
92
+
93
+
94
+ def recall_score(y_true, y_pred, topKs=None):
95
+ if topKs is None:
96
+ topKs = [5]
97
+ result = topk_metrics(y_true, y_pred, topKs)
98
+ return result['Recall']
99
+
100
+
101
+ def precision_score(y_true, y_pred, topKs=None):
102
+ if topKs is None:
103
+ topKs = [5]
104
+ result = topk_metrics(y_true, y_pred, topKs)
105
+ return result['Precision']
106
+
107
+
108
+ def topk_metrics(y_true, y_pred, topKs=None):
109
+ """choice topk metrics and compute it
110
+ the metrics contains 'ndcg', 'mrr', 'recall', 'precision' and 'hit'
111
+
112
+ Args:
113
+ y_true (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items the user interacted
114
+ y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
115
+ topKs (list or tuple): if you want to get top5 and top10, topKs=(5, 10)
116
+
117
+ Return:
118
+ results (dict): {metric_name: metric_values}, it contains five metrics, 'ndcg', 'recall', 'mrr', 'hit', 'precision'
119
+
120
+ """
121
+ if topKs is None:
122
+ topKs = [5]
123
+ assert len(y_true) == len(y_pred)
124
+
125
+ if not isinstance(topKs, (tuple, list)):
126
+ raise ValueError('topKs wrong, it should be tuple or list')
127
+
128
+ pred_array = []
129
+ true_array = []
130
+ for u in y_true.keys():
131
+ pred_array.append(y_pred[u])
132
+ true_array.append(y_true[u])
133
+ ndcg_result = []
134
+ mrr_result = []
135
+ hit_result = []
136
+ precision_result = []
137
+ recall_result = []
138
+ for idx in range(len(topKs)):
139
+ ndcgs = 0
140
+ mrrs = 0
141
+ hits = 0
142
+ precisions = 0
143
+ recalls = 0
144
+ gts = 0
145
+ for i in range(len(true_array)):
146
+ if len(true_array[i]) != 0:
147
+ mrr_tmp = 0
148
+ mrr_flag = True
149
+ hit_tmp = 0
150
+ dcg_tmp = 0
151
+ idcg_tmp = 0
152
+ for j in range(topKs[idx]):
153
+ if pred_array[i][j] in true_array[i]:
154
+ hit_tmp += 1.
155
+ if mrr_flag:
156
+ mrr_flag = False
157
+ mrr_tmp = 1. / (1 + j)
158
+ dcg_tmp += 1. / (np.log2(j + 2))
159
+ if j < len(true_array[i]):
160
+ idcg_tmp += 1. / (np.log2(j + 2))
161
+ gts += len(true_array[i])
162
+ hits += hit_tmp
163
+ mrrs += mrr_tmp
164
+ recalls += hit_tmp / len(true_array[i])
165
+ precisions += hit_tmp / topKs[idx]
166
+ if idcg_tmp != 0:
167
+ ndcgs += dcg_tmp / idcg_tmp
168
+ hit_result.append(round(hits / gts, 4))
169
+ mrr_result.append(round(mrrs / len(pred_array), 4))
170
+ recall_result.append(round(recalls / len(pred_array), 4))
171
+ precision_result.append(round(precisions / len(pred_array), 4))
172
+ ndcg_result.append(round(ndcgs / len(pred_array), 4))
173
+
174
+ results = defaultdict(list)
175
+ for idx in range(len(topKs)):
176
+
177
+ output = f'NDCG@{topKs[idx]}: {ndcg_result[idx]}'
178
+ results['NDCG'].append(output)
179
+
180
+ output = f'MRR@{topKs[idx]}: {mrr_result[idx]}'
181
+ results['MRR'].append(output)
182
+
183
+ output = f'Recall@{topKs[idx]}: {recall_result[idx]}'
184
+ results['Recall'].append(output)
185
+
186
+ output = f'Hit@{topKs[idx]}: {hit_result[idx]}'
187
+ results['Hit'].append(output)
188
+
189
+ output = f'Precision@{topKs[idx]}: {precision_result[idx]}'
190
+ results['Precision'].append(output)
191
+ return results
192
+
193
+ def log_loss(y_true, y_pred):
194
+ score = y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
195
+ return -score.sum() / len(y_true)
196
+
197
+ def Coverage(y_pred, all_items, topKs=None):
198
+ """compute the coverage
199
+ This method measures the diversity of the recommended items
200
+ and the ability to explore the long-tailed items
201
+ Arg:
202
+ y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
203
+ all_items (set): all unique items
204
+ Return:
205
+ result (list[float]): the list of coverage scores
206
+ """
207
+ if topKs is None:
208
+ topKs = [5]
209
+ result = []
210
+ for k in topKs:
211
+ rec_items = set([])
212
+ for u in y_pred.keys():
213
+ tmp_items = set(y_pred[u][:k])
214
+ rec_items = rec_items | tmp_items
215
+ score = len(rec_items) * 1. / len(all_items)
216
+ score = round(score, 4)
217
+ result.append(f'Coverage@{k}: {score}')
218
+ return result
219
+
220
+
221
+ # print(Coverage({'0':[0,1,2],'1':[1,3,4]}, {0,1,2,3,4,5}, [2,3]))
222
+
223
+ # pred = np.array([ 0.3, 0.2, 0.5, 0.9, 0.7, 0.31, 0.8, 0.1, 0.4, 0.6])
224
+ # label = np.array([ 1, 0, 0, 1, 0, 0, 1, 0, 0, 1])
225
+ # users_id = np.array([ 2, 1, 0, 2, 1, 0, 0, 2, 1, 1])
226
+
227
+ # print('auc: ', auc_score(label, pred))
228
+ # print('gauc: ', gauc_score(label, pred, users_id))
229
+ # print('log_loss: ', log_loss(label, pred))
230
+
231
+ # for mt in ['ndcg', 'mrr', 'recall', 'hit','s']:
232
+ # tm = topk_metrics(y_true, y_pred, users_id, 3, metric_type=mt)
233
+ # print(f'{mt}: {tm}')
234
+ # y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
235
+ # y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
236
+ # out = topk_metrics(y_true, y_pred, topKs=(1,2))
237
+ # ndcgs = ndcg_score(y_true,y_pred, topKs=(1,2))
238
+ # print(out)
239
+ # print(ndcgs)
240
+
241
+ # ground_truth, match_res = np.load("C:\\Users\\dongj\\Desktop/res.npy", allow_pickle=True)
242
+ # print(len(ground_truth),len(match_res))
243
+ # out = topk_metrics(y_true=ground_truth, y_pred=match_res, topKs=[50])
244
+ # print(out)
245
+
246
+ if __name__ == "__main__":
247
+ y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
248
+ y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
249
+ out = topk_metrics(y_true, y_pred, topKs=(1,2))
250
+ print(out)
@@ -0,0 +1,11 @@
1
+ from .dssm import DSSM
2
+ from .youtube_dnn import YoutubeDNN
3
+ from .youtube_sbc import YoutubeSBC
4
+ from .dssm_facebook import FaceBookDSSM
5
+ from .gru4rec import GRU4Rec
6
+ from .comirec import ComirecSA, ComirecDR
7
+ from .mind import MIND
8
+ from .narm import NARM
9
+ from .stamp import STAMP
10
+ from .sasrec import SASRec
11
+ from .sine import SINE
@@ -0,0 +1,188 @@
1
+ """
2
+ Date: create on 07/06/2022
3
+ References:
4
+ paper: Controllable Multi-Interest Framework for Recommendation
5
+ url: https://arxiv.org/pdf/2005.09347.pdf
6
+ code: https://github.com/ShiningCosmos/pytorch_ComiRec/blob/main/ComiRec.py
7
+ Authors: Kai Wang, 306178200@qq.com
8
+ """
9
+
10
+ import torch
11
+
12
+ from ...basic.layers import MLP, EmbeddingLayer, MultiInterestSA, CapsuleNetwork
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class ComirecSA(torch.nn.Module):
18
+ """The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
19
+ It's a ComirecSA match model trained by global softmax loss on list-wise samples.
20
+ Note in origin paper, it's without item dnn tower and train item embedding directly.
21
+
22
+ Args:
23
+ user_features (list[Feature Class]): training by the user tower module.
24
+ history_features (list[Feature Class]): training history
25
+ item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
26
+ neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
27
+ temperature (float): temperature factor for similarity score, default to 1.0.
28
+ interest_num (int): interest num
29
+ """
30
+
31
+ def __init__(self, user_features, history_features, item_features, neg_item_feature, temperature=1.0, interest_num=4):
32
+ super().__init__()
33
+ self.user_features = user_features
34
+ self.item_features = item_features
35
+ self.history_features = history_features
36
+ self.neg_item_feature = neg_item_feature
37
+ self.temperature = temperature
38
+ self.interest_num = interest_num
39
+ self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
40
+
41
+ self.embedding = EmbeddingLayer(user_features + item_features + history_features)
42
+ self.multi_interest_sa = MultiInterestSA(embedding_dim=self.history_features[0].embed_dim, interest_num=self.interest_num)
43
+ self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
44
+ self.mode = None
45
+
46
+ def forward(self, x):
47
+ user_embedding = self.user_tower(x)
48
+ item_embedding = self.item_tower(x)
49
+ if self.mode == "user":
50
+ return user_embedding
51
+ if self.mode == "item":
52
+ return item_embedding
53
+
54
+ pos_item_embedding = item_embedding[:,0,:]
55
+ dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
56
+ k_index = torch.argmax(dot_res, dim=1)
57
+ best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
58
+ for k in range(user_embedding.shape[0]):
59
+ best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
60
+ best_interest_emb = best_interest_emb.unsqueeze(1)
61
+
62
+ y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
63
+
64
+ return y
65
+
66
+ def user_tower(self, x):
67
+ if self.mode == "item":
68
+ return None
69
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) #[batch_size, num_features*deep_dims]
70
+ input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
71
+
72
+ history_emb = self.embedding(x, self.history_features).squeeze(1)
73
+ mask = self.gen_mask(x)
74
+ mask = mask.unsqueeze(-1).float()
75
+ multi_interest_emb = self.multi_interest_sa(history_emb,mask)
76
+
77
+ input_user = torch.cat([input_user,multi_interest_emb],dim=-1)
78
+
79
+ # user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, interest_num, embed_dim]
80
+ user_embedding = torch.matmul(input_user,self.convert_user_weight)
81
+ user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
82
+ if self.mode == "user":
83
+ return user_embedding #inference embedding mode -> [batch_size, interest_num, embed_dim]
84
+ return user_embedding
85
+
86
+ def item_tower(self, x):
87
+ if self.mode == "user":
88
+ return None
89
+ pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
90
+ pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
91
+ if self.mode == "item": #inference embedding mode
92
+ return pos_embedding.squeeze(1) #[batch_size, embed_dim]
93
+ neg_embeddings = self.embedding(x, self.neg_item_feature,
94
+ squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
95
+ neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
96
+ return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
97
+
98
+ def gen_mask(self, x):
99
+ his_list = x[self.history_features[0].name]
100
+ mask = (his_list > 0).long()
101
+ return mask
102
+
103
+ class ComirecDR(torch.nn.Module):
104
+ """The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
105
+ It's a ComirecDR match model trained by global softmax loss on list-wise samples.
106
+ Note in origin paper, it's without item dnn tower and train item embedding directly.
107
+
108
+ Args:
109
+ user_features (list[Feature Class]): training by the user tower module.
110
+ history_features (list[Feature Class]): training history
111
+ item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
112
+ neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
113
+ max_length (int): max sequence length of input item sequence
114
+ temperature (float): temperature factor for similarity score, default to 1.0.
115
+ interest_num (int): interest num
116
+ """
117
+
118
+ def __init__(self, user_features, history_features, item_features, neg_item_feature, max_length, temperature=1.0, interest_num=4):
119
+ super().__init__()
120
+ self.user_features = user_features
121
+ self.item_features = item_features
122
+ self.history_features = history_features
123
+ self.neg_item_feature = neg_item_feature
124
+ self.temperature = temperature
125
+ self.interest_num = interest_num
126
+ self.max_length = max_length
127
+ self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
128
+
129
+ self.embedding = EmbeddingLayer(user_features + item_features + history_features)
130
+ self.capsule = CapsuleNetwork(self.history_features[0].embed_dim,self.max_length,bilinear_type=2,interest_num=self.interest_num)
131
+ self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
132
+ self.mode = None
133
+
134
+ def forward(self, x):
135
+ user_embedding = self.user_tower(x)
136
+ item_embedding = self.item_tower(x)
137
+ if self.mode == "user":
138
+ return user_embedding
139
+ if self.mode == "item":
140
+ return item_embedding
141
+
142
+ pos_item_embedding = item_embedding[:,0,:]
143
+ dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
144
+ k_index = torch.argmax(dot_res, dim=1)
145
+ best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
146
+ for k in range(user_embedding.shape[0]):
147
+ best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
148
+ best_interest_emb = best_interest_emb.unsqueeze(1)
149
+
150
+ y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
151
+
152
+ return y
153
+
154
+ def user_tower(self, x):
155
+ if self.mode == "item":
156
+ return None
157
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) #[batch_size, num_features*deep_dims]
158
+ input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
159
+
160
+ history_emb = self.embedding(x, self.history_features).squeeze(1)
161
+ mask = self.gen_mask(x)
162
+ multi_interest_emb = self.capsule(history_emb,mask)
163
+
164
+ input_user = torch.cat([input_user,multi_interest_emb],dim=-1)
165
+
166
+ # user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, interest_num, embed_dim]
167
+ user_embedding = torch.matmul(input_user,self.convert_user_weight)
168
+ user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
169
+ if self.mode == "user":
170
+ return user_embedding #inference embedding mode -> [batch_size, interest_num, embed_dim]
171
+ return user_embedding
172
+
173
+ def item_tower(self, x):
174
+ if self.mode == "user":
175
+ return None
176
+ pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
177
+ pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
178
+ if self.mode == "item": #inference embedding mode
179
+ return pos_embedding.squeeze(1) #[batch_size, embed_dim]
180
+ neg_embeddings = self.embedding(x, self.neg_item_feature,
181
+ squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
182
+ neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
183
+ return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
184
+
185
+ def gen_mask(self, x):
186
+ his_list = x[self.history_features[0].name]
187
+ mask = (his_list > 0).long()
188
+ return mask
@@ -0,0 +1,66 @@
1
+ """
2
+ Date: create on 12/05/2022, update on 20/05/2022
3
+ References:
4
+ paper: (CIKM'2013) Learning Deep Structured Semantic Models for Web Search using Clickthrough Data
5
+ url: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf
6
+ code: https://github.com/bbruceyuan/DeepMatch-Torch/blob/main/deepmatch_torch/models/dssm.py
7
+ Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
8
+ """
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from ...basic.layers import MLP, EmbeddingLayer
13
+
14
+
15
+ class DSSM(torch.nn.Module):
16
+ """Deep Structured Semantic Model
17
+
18
+ Args:
19
+ user_features (list[Feature Class]): training by the user tower module.
20
+ item_features (list[Feature Class]): training by the item tower module.
21
+ temperature (float): temperature factor for similarity score, default to 1.0.
22
+ user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
23
+ item_params (dict): the params of the Item Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
24
+ """
25
+
26
+ def __init__(self, user_features, item_features, user_params, item_params, temperature=1.0):
27
+ super().__init__()
28
+ self.user_features = user_features
29
+ self.item_features = item_features
30
+ self.temperature = temperature
31
+ self.user_dims = sum([fea.embed_dim for fea in user_features])
32
+ self.item_dims = sum([fea.embed_dim for fea in item_features])
33
+
34
+ self.embedding = EmbeddingLayer(user_features + item_features)
35
+ self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
36
+ self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
37
+ self.mode = None
38
+
39
+ def forward(self, x):
40
+ user_embedding = self.user_tower(x)
41
+ item_embedding = self.item_tower(x)
42
+ if self.mode == "user":
43
+ return user_embedding
44
+ if self.mode == "item":
45
+ return item_embedding
46
+
47
+ # calculate cosine score
48
+ y = torch.mul(user_embedding, item_embedding).sum(dim=1)
49
+ # y = y / self.temperature
50
+ return torch.sigmoid(y)
51
+
52
+ def user_tower(self, x):
53
+ if self.mode == "item":
54
+ return None
55
+ input_user = self.embedding(x, self.user_features, squeeze_dim=True) #[batch_size, num_features*deep_dims]
56
+ user_embedding = self.user_mlp(input_user) #[batch_size, user_params["dims"][-1]]
57
+ user_embedding = F.normalize(user_embedding, p=2, dim=1) # L2 normalize
58
+ return user_embedding
59
+
60
+ def item_tower(self, x):
61
+ if self.mode == "user":
62
+ return None
63
+ input_item = self.embedding(x, self.item_features, squeeze_dim=True) #[batch_size, num_features*embed_dim]
64
+ item_embedding = self.item_mlp(input_item) #[batch_size, item_params["dims"][-1]]
65
+ item_embedding = F.normalize(item_embedding, p=2, dim=1)
66
+ return item_embedding