torch-rechub 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/activation.py +54 -52
- torch_rechub/basic/callback.py +32 -32
- torch_rechub/basic/features.py +94 -57
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +720 -240
- torch_rechub/basic/loss_func.py +34 -0
- torch_rechub/basic/metaoptimizer.py +72 -0
- torch_rechub/basic/metric.py +250 -0
- torch_rechub/models/matching/__init__.py +11 -0
- torch_rechub/models/matching/comirec.py +188 -0
- torch_rechub/models/matching/dssm.py +66 -0
- torch_rechub/models/matching/dssm_facebook.py +79 -0
- torch_rechub/models/matching/dssm_senet.py +75 -0
- torch_rechub/models/matching/gru4rec.py +87 -0
- torch_rechub/models/matching/mind.py +101 -0
- torch_rechub/models/matching/narm.py +76 -0
- torch_rechub/models/matching/sasrec.py +140 -0
- torch_rechub/models/matching/sine.py +151 -0
- torch_rechub/models/matching/stamp.py +83 -0
- torch_rechub/models/matching/youtube_dnn.py +71 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -4
- torch_rechub/models/multi_task/aitm.py +84 -0
- torch_rechub/models/multi_task/esmm.py +55 -45
- torch_rechub/models/multi_task/mmoe.py +58 -52
- torch_rechub/models/multi_task/ple.py +130 -104
- torch_rechub/models/multi_task/shared_bottom.py +45 -44
- torch_rechub/models/ranking/__init__.py +11 -3
- torch_rechub/models/ranking/afm.py +63 -0
- torch_rechub/models/ranking/bst.py +63 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +69 -0
- torch_rechub/models/ranking/deepffm.py +123 -0
- torch_rechub/models/ranking/deepfm.py +41 -41
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +91 -81
- torch_rechub/models/ranking/edcn.py +117 -0
- torch_rechub/models/ranking/fibinet.py +50 -0
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +2 -1
- torch_rechub/trainers/{trainer.py → ctr_trainer.py} +128 -111
- torch_rechub/trainers/match_trainer.py +170 -0
- torch_rechub/trainers/mtl_trainer.py +206 -144
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +360 -0
- torch_rechub/utils/match.py +274 -0
- torch_rechub/utils/mtl.py +126 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +177 -0
- torch_rechub-0.0.3.dist-info/RECORD +55 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/WHEEL +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class HingeLoss(torch.nn.Module):
|
|
6
|
+
"""Hinge Loss for pairwise learning.
|
|
7
|
+
reference: https://github.com/ustcml/RecStudio/blob/main/recstudio/model/loss_func.py
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, margin=2, num_items=None):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.margin = margin
|
|
14
|
+
self.n_items = num_items
|
|
15
|
+
|
|
16
|
+
def forward(self, pos_score, neg_score):
|
|
17
|
+
loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
|
|
18
|
+
if self.n_items is not None:
|
|
19
|
+
impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
|
|
20
|
+
rank = torch.mean(impostors, -1) * self.n_items
|
|
21
|
+
return torch.mean(loss * torch.log(rank + 1))
|
|
22
|
+
else:
|
|
23
|
+
return torch.mean(loss)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BPRLoss(torch.nn.Module):
|
|
27
|
+
|
|
28
|
+
def __init__(self):
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
def forward(self, pos_score, neg_score):
|
|
32
|
+
loss = torch.mean(-(pos_score - neg_score).sigmoid().log(), dim=-1)
|
|
33
|
+
return loss
|
|
34
|
+
#loss = -torch.mean(F.logsigmoid(pos_score - torch.max(neg_score, dim=-1))) need v1.10
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""The metaoptimizer module, it provides a class MetaBalance
|
|
2
|
+
MetaBalance is used to scale the gradient and balance the gradient of each task
|
|
3
|
+
Authors: Qida Dong, dongjidan@126.com
|
|
4
|
+
"""
|
|
5
|
+
import torch
|
|
6
|
+
from torch.optim.optimizer import Optimizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MetaBalance(Optimizer):
|
|
10
|
+
"""MetaBalance Optimizer
|
|
11
|
+
This method is used to scale the gradient and balance the gradient of each task
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
parameters (list): the parameters of model
|
|
15
|
+
relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
|
|
16
|
+
beta (float, optional): the coefficient of moving average (default: 0.9)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, parameters, relax_factor=0.7, beta=0.9):
|
|
20
|
+
|
|
21
|
+
if relax_factor < 0. or relax_factor >= 1.:
|
|
22
|
+
raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
|
|
23
|
+
if beta < 0. or beta >= 1.:
|
|
24
|
+
raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
|
|
25
|
+
rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
|
|
26
|
+
super(MetaBalance, self).__init__(parameters, rel_beta_dict)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad()
|
|
29
|
+
def step(self, losses):
|
|
30
|
+
"""_summary_
|
|
31
|
+
Args:
|
|
32
|
+
losses (_type_): _description_
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
RuntimeError: _description_
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
for idx, loss in enumerate(losses):
|
|
39
|
+
loss.backward(retain_graph=True)
|
|
40
|
+
for group in self.param_groups:
|
|
41
|
+
for gp in group['params']:
|
|
42
|
+
if gp.grad is None:
|
|
43
|
+
# print('breaking')
|
|
44
|
+
break
|
|
45
|
+
if gp.grad.is_sparse:
|
|
46
|
+
raise RuntimeError('MetaBalance does not support sparse gradients')
|
|
47
|
+
# store the result of moving average
|
|
48
|
+
state = self.state[gp]
|
|
49
|
+
if len(state) == 0:
|
|
50
|
+
for i in range(len(losses)):
|
|
51
|
+
if i == 0:
|
|
52
|
+
gp.norms = [0]
|
|
53
|
+
else:
|
|
54
|
+
gp.norms.append(0)
|
|
55
|
+
# calculate the moving average
|
|
56
|
+
beta = group['beta']
|
|
57
|
+
gp.norms[idx] = gp.norms[idx] * beta + (1 - beta) * torch.norm(gp.grad)
|
|
58
|
+
# scale the auxiliary gradient
|
|
59
|
+
relax_factor = group['relax_factor']
|
|
60
|
+
gp.grad = gp.grad * gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. - relax_factor)
|
|
61
|
+
# store the gradient of each auxiliary task in state
|
|
62
|
+
if idx == 0:
|
|
63
|
+
state['sum_gradient'] = torch.zeros_like(gp.data)
|
|
64
|
+
state['sum_gradient'] += gp.grad
|
|
65
|
+
else:
|
|
66
|
+
state['sum_gradient'] += gp.grad
|
|
67
|
+
|
|
68
|
+
if gp.grad is not None:
|
|
69
|
+
gp.grad.detach_()
|
|
70
|
+
gp.grad.zero_()
|
|
71
|
+
if idx == len(losses) - 1:
|
|
72
|
+
gp.grad = state['sum_gradient']
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""The metric module, it is used to provide some metrics for recommenders.
|
|
2
|
+
Available function:
|
|
3
|
+
- auc_score: compute AUC
|
|
4
|
+
- gauc_score: compute GAUC
|
|
5
|
+
- log_loss: compute LogLoss
|
|
6
|
+
- topk_metrics: compute topk metrics contains 'ndcg', 'mrr', 'recall', 'hit'
|
|
7
|
+
Authors: Qida Dong, dongjidan@126.com
|
|
8
|
+
"""
|
|
9
|
+
from sklearn.metrics import roc_auc_score
|
|
10
|
+
import numpy as np
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def auc_score(y_true, y_pred):
|
|
15
|
+
|
|
16
|
+
return roc_auc_score(y_true, y_pred)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_user_pred(y_true, y_pred, users):
|
|
20
|
+
"""divide the result into different group by user id
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
y_true (array): all true labels of the data
|
|
24
|
+
y_pred (array): the predicted score
|
|
25
|
+
users (array): user id
|
|
26
|
+
|
|
27
|
+
Return:
|
|
28
|
+
user_pred (dict): {userid: values}, key is user id and value is the labels and scores of each user
|
|
29
|
+
"""
|
|
30
|
+
user_pred = {}
|
|
31
|
+
for i, u in enumerate(users):
|
|
32
|
+
if u not in user_pred:
|
|
33
|
+
user_pred[u] = {'y_true': [y_true[i]], 'y_pred': [y_pred[i]]}
|
|
34
|
+
else:
|
|
35
|
+
user_pred[u]['y_true'].append(y_true[i])
|
|
36
|
+
user_pred[u]['y_pred'].append(y_pred[i])
|
|
37
|
+
|
|
38
|
+
return user_pred
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def gauc_score(y_true, y_pred, users, weights=None):
|
|
42
|
+
"""compute GAUC
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
y_true (array): dim(N, ), all true labels of the data
|
|
46
|
+
y_pred (array): dim(N, ), the predicted score
|
|
47
|
+
users (array): dim(N, ), user id
|
|
48
|
+
weight (dict): {userid: weight_value}, it contains weights for each group.
|
|
49
|
+
if it is None, the weight is equal to the number
|
|
50
|
+
of times the user is recommended
|
|
51
|
+
Return:
|
|
52
|
+
score: float, GAUC
|
|
53
|
+
"""
|
|
54
|
+
assert len(y_true) == len(y_pred) and len(y_true) == len(users)
|
|
55
|
+
|
|
56
|
+
user_pred = get_user_pred(y_true, y_pred, users)
|
|
57
|
+
score = 0
|
|
58
|
+
num = 0
|
|
59
|
+
for u in user_pred.keys():
|
|
60
|
+
auc = auc_score(user_pred[u]['y_true'], user_pred[u]['y_pred'])
|
|
61
|
+
if weights is None:
|
|
62
|
+
user_weight = len(user_pred[u]['y_true'])
|
|
63
|
+
else:
|
|
64
|
+
user_weight = weights[u]
|
|
65
|
+
auc *= user_weight
|
|
66
|
+
num += user_weight
|
|
67
|
+
score += auc
|
|
68
|
+
return score / num
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def ndcg_score(y_true, y_pred, topKs=None):
|
|
73
|
+
if topKs is None:
|
|
74
|
+
topKs = [5]
|
|
75
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
76
|
+
return result['NDCG']
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def hit_score(y_true, y_pred, topKs=None):
|
|
81
|
+
if topKs is None:
|
|
82
|
+
topKs = [5]
|
|
83
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
84
|
+
return result['Hit']
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def mrr_score(y_true, y_pred, topKs=None):
|
|
88
|
+
if topKs is None:
|
|
89
|
+
topKs = [5]
|
|
90
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
91
|
+
return result['MRR']
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def recall_score(y_true, y_pred, topKs=None):
|
|
95
|
+
if topKs is None:
|
|
96
|
+
topKs = [5]
|
|
97
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
98
|
+
return result['Recall']
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def precision_score(y_true, y_pred, topKs=None):
|
|
102
|
+
if topKs is None:
|
|
103
|
+
topKs = [5]
|
|
104
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
105
|
+
return result['Precision']
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def topk_metrics(y_true, y_pred, topKs=None):
|
|
109
|
+
"""choice topk metrics and compute it
|
|
110
|
+
the metrics contains 'ndcg', 'mrr', 'recall', 'precision' and 'hit'
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
y_true (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items the user interacted
|
|
114
|
+
y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
|
|
115
|
+
topKs (list or tuple): if you want to get top5 and top10, topKs=(5, 10)
|
|
116
|
+
|
|
117
|
+
Return:
|
|
118
|
+
results (dict): {metric_name: metric_values}, it contains five metrics, 'ndcg', 'recall', 'mrr', 'hit', 'precision'
|
|
119
|
+
|
|
120
|
+
"""
|
|
121
|
+
if topKs is None:
|
|
122
|
+
topKs = [5]
|
|
123
|
+
assert len(y_true) == len(y_pred)
|
|
124
|
+
|
|
125
|
+
if not isinstance(topKs, (tuple, list)):
|
|
126
|
+
raise ValueError('topKs wrong, it should be tuple or list')
|
|
127
|
+
|
|
128
|
+
pred_array = []
|
|
129
|
+
true_array = []
|
|
130
|
+
for u in y_true.keys():
|
|
131
|
+
pred_array.append(y_pred[u])
|
|
132
|
+
true_array.append(y_true[u])
|
|
133
|
+
ndcg_result = []
|
|
134
|
+
mrr_result = []
|
|
135
|
+
hit_result = []
|
|
136
|
+
precision_result = []
|
|
137
|
+
recall_result = []
|
|
138
|
+
for idx in range(len(topKs)):
|
|
139
|
+
ndcgs = 0
|
|
140
|
+
mrrs = 0
|
|
141
|
+
hits = 0
|
|
142
|
+
precisions = 0
|
|
143
|
+
recalls = 0
|
|
144
|
+
gts = 0
|
|
145
|
+
for i in range(len(true_array)):
|
|
146
|
+
if len(true_array[i]) != 0:
|
|
147
|
+
mrr_tmp = 0
|
|
148
|
+
mrr_flag = True
|
|
149
|
+
hit_tmp = 0
|
|
150
|
+
dcg_tmp = 0
|
|
151
|
+
idcg_tmp = 0
|
|
152
|
+
for j in range(topKs[idx]):
|
|
153
|
+
if pred_array[i][j] in true_array[i]:
|
|
154
|
+
hit_tmp += 1.
|
|
155
|
+
if mrr_flag:
|
|
156
|
+
mrr_flag = False
|
|
157
|
+
mrr_tmp = 1. / (1 + j)
|
|
158
|
+
dcg_tmp += 1. / (np.log2(j + 2))
|
|
159
|
+
if j < len(true_array[i]):
|
|
160
|
+
idcg_tmp += 1. / (np.log2(j + 2))
|
|
161
|
+
gts += len(true_array[i])
|
|
162
|
+
hits += hit_tmp
|
|
163
|
+
mrrs += mrr_tmp
|
|
164
|
+
recalls += hit_tmp / len(true_array[i])
|
|
165
|
+
precisions += hit_tmp / topKs[idx]
|
|
166
|
+
if idcg_tmp != 0:
|
|
167
|
+
ndcgs += dcg_tmp / idcg_tmp
|
|
168
|
+
hit_result.append(round(hits / gts, 4))
|
|
169
|
+
mrr_result.append(round(mrrs / len(pred_array), 4))
|
|
170
|
+
recall_result.append(round(recalls / len(pred_array), 4))
|
|
171
|
+
precision_result.append(round(precisions / len(pred_array), 4))
|
|
172
|
+
ndcg_result.append(round(ndcgs / len(pred_array), 4))
|
|
173
|
+
|
|
174
|
+
results = defaultdict(list)
|
|
175
|
+
for idx in range(len(topKs)):
|
|
176
|
+
|
|
177
|
+
output = f'NDCG@{topKs[idx]}: {ndcg_result[idx]}'
|
|
178
|
+
results['NDCG'].append(output)
|
|
179
|
+
|
|
180
|
+
output = f'MRR@{topKs[idx]}: {mrr_result[idx]}'
|
|
181
|
+
results['MRR'].append(output)
|
|
182
|
+
|
|
183
|
+
output = f'Recall@{topKs[idx]}: {recall_result[idx]}'
|
|
184
|
+
results['Recall'].append(output)
|
|
185
|
+
|
|
186
|
+
output = f'Hit@{topKs[idx]}: {hit_result[idx]}'
|
|
187
|
+
results['Hit'].append(output)
|
|
188
|
+
|
|
189
|
+
output = f'Precision@{topKs[idx]}: {precision_result[idx]}'
|
|
190
|
+
results['Precision'].append(output)
|
|
191
|
+
return results
|
|
192
|
+
|
|
193
|
+
def log_loss(y_true, y_pred):
|
|
194
|
+
score = y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
|
|
195
|
+
return -score.sum() / len(y_true)
|
|
196
|
+
|
|
197
|
+
def Coverage(y_pred, all_items, topKs=None):
|
|
198
|
+
"""compute the coverage
|
|
199
|
+
This method measures the diversity of the recommended items
|
|
200
|
+
and the ability to explore the long-tailed items
|
|
201
|
+
Arg:
|
|
202
|
+
y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
|
|
203
|
+
all_items (set): all unique items
|
|
204
|
+
Return:
|
|
205
|
+
result (list[float]): the list of coverage scores
|
|
206
|
+
"""
|
|
207
|
+
if topKs is None:
|
|
208
|
+
topKs = [5]
|
|
209
|
+
result = []
|
|
210
|
+
for k in topKs:
|
|
211
|
+
rec_items = set([])
|
|
212
|
+
for u in y_pred.keys():
|
|
213
|
+
tmp_items = set(y_pred[u][:k])
|
|
214
|
+
rec_items = rec_items | tmp_items
|
|
215
|
+
score = len(rec_items) * 1. / len(all_items)
|
|
216
|
+
score = round(score, 4)
|
|
217
|
+
result.append(f'Coverage@{k}: {score}')
|
|
218
|
+
return result
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# print(Coverage({'0':[0,1,2],'1':[1,3,4]}, {0,1,2,3,4,5}, [2,3]))
|
|
222
|
+
|
|
223
|
+
# pred = np.array([ 0.3, 0.2, 0.5, 0.9, 0.7, 0.31, 0.8, 0.1, 0.4, 0.6])
|
|
224
|
+
# label = np.array([ 1, 0, 0, 1, 0, 0, 1, 0, 0, 1])
|
|
225
|
+
# users_id = np.array([ 2, 1, 0, 2, 1, 0, 0, 2, 1, 1])
|
|
226
|
+
|
|
227
|
+
# print('auc: ', auc_score(label, pred))
|
|
228
|
+
# print('gauc: ', gauc_score(label, pred, users_id))
|
|
229
|
+
# print('log_loss: ', log_loss(label, pred))
|
|
230
|
+
|
|
231
|
+
# for mt in ['ndcg', 'mrr', 'recall', 'hit','s']:
|
|
232
|
+
# tm = topk_metrics(y_true, y_pred, users_id, 3, metric_type=mt)
|
|
233
|
+
# print(f'{mt}: {tm}')
|
|
234
|
+
# y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
|
|
235
|
+
# y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
|
|
236
|
+
# out = topk_metrics(y_true, y_pred, topKs=(1,2))
|
|
237
|
+
# ndcgs = ndcg_score(y_true,y_pred, topKs=(1,2))
|
|
238
|
+
# print(out)
|
|
239
|
+
# print(ndcgs)
|
|
240
|
+
|
|
241
|
+
# ground_truth, match_res = np.load("C:\\Users\\dongj\\Desktop/res.npy", allow_pickle=True)
|
|
242
|
+
# print(len(ground_truth),len(match_res))
|
|
243
|
+
# out = topk_metrics(y_true=ground_truth, y_pred=match_res, topKs=[50])
|
|
244
|
+
# print(out)
|
|
245
|
+
|
|
246
|
+
if __name__ == "__main__":
|
|
247
|
+
y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
|
|
248
|
+
y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
|
|
249
|
+
out = topk_metrics(y_true, y_pred, topKs=(1,2))
|
|
250
|
+
print(out)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .dssm import DSSM
|
|
2
|
+
from .youtube_dnn import YoutubeDNN
|
|
3
|
+
from .youtube_sbc import YoutubeSBC
|
|
4
|
+
from .dssm_facebook import FaceBookDSSM
|
|
5
|
+
from .gru4rec import GRU4Rec
|
|
6
|
+
from .comirec import ComirecSA, ComirecDR
|
|
7
|
+
from .mind import MIND
|
|
8
|
+
from .narm import NARM
|
|
9
|
+
from .stamp import STAMP
|
|
10
|
+
from .sasrec import SASRec
|
|
11
|
+
from .sine import SINE
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 07/06/2022
|
|
3
|
+
References:
|
|
4
|
+
paper: Controllable Multi-Interest Framework for Recommendation
|
|
5
|
+
url: https://arxiv.org/pdf/2005.09347.pdf
|
|
6
|
+
code: https://github.com/ShiningCosmos/pytorch_ComiRec/blob/main/ComiRec.py
|
|
7
|
+
Authors: Kai Wang, 306178200@qq.com
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from ...basic.layers import MLP, EmbeddingLayer, MultiInterestSA, CapsuleNetwork
|
|
13
|
+
from torch import nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ComirecSA(torch.nn.Module):
|
|
18
|
+
"""The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
|
|
19
|
+
It's a ComirecSA match model trained by global softmax loss on list-wise samples.
|
|
20
|
+
Note in origin paper, it's without item dnn tower and train item embedding directly.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
user_features (list[Feature Class]): training by the user tower module.
|
|
24
|
+
history_features (list[Feature Class]): training history
|
|
25
|
+
item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
|
|
26
|
+
neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
|
|
27
|
+
temperature (float): temperature factor for similarity score, default to 1.0.
|
|
28
|
+
interest_num (int): interest num
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, user_features, history_features, item_features, neg_item_feature, temperature=1.0, interest_num=4):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.user_features = user_features
|
|
34
|
+
self.item_features = item_features
|
|
35
|
+
self.history_features = history_features
|
|
36
|
+
self.neg_item_feature = neg_item_feature
|
|
37
|
+
self.temperature = temperature
|
|
38
|
+
self.interest_num = interest_num
|
|
39
|
+
self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
|
|
40
|
+
|
|
41
|
+
self.embedding = EmbeddingLayer(user_features + item_features + history_features)
|
|
42
|
+
self.multi_interest_sa = MultiInterestSA(embedding_dim=self.history_features[0].embed_dim, interest_num=self.interest_num)
|
|
43
|
+
self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
|
|
44
|
+
self.mode = None
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
user_embedding = self.user_tower(x)
|
|
48
|
+
item_embedding = self.item_tower(x)
|
|
49
|
+
if self.mode == "user":
|
|
50
|
+
return user_embedding
|
|
51
|
+
if self.mode == "item":
|
|
52
|
+
return item_embedding
|
|
53
|
+
|
|
54
|
+
pos_item_embedding = item_embedding[:,0,:]
|
|
55
|
+
dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
|
|
56
|
+
k_index = torch.argmax(dot_res, dim=1)
|
|
57
|
+
best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
|
|
58
|
+
for k in range(user_embedding.shape[0]):
|
|
59
|
+
best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
|
|
60
|
+
best_interest_emb = best_interest_emb.unsqueeze(1)
|
|
61
|
+
|
|
62
|
+
y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
|
|
63
|
+
|
|
64
|
+
return y
|
|
65
|
+
|
|
66
|
+
def user_tower(self, x):
|
|
67
|
+
if self.mode == "item":
|
|
68
|
+
return None
|
|
69
|
+
input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) #[batch_size, num_features*deep_dims]
|
|
70
|
+
input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
|
|
71
|
+
|
|
72
|
+
history_emb = self.embedding(x, self.history_features).squeeze(1)
|
|
73
|
+
mask = self.gen_mask(x)
|
|
74
|
+
mask = mask.unsqueeze(-1).float()
|
|
75
|
+
multi_interest_emb = self.multi_interest_sa(history_emb,mask)
|
|
76
|
+
|
|
77
|
+
input_user = torch.cat([input_user,multi_interest_emb],dim=-1)
|
|
78
|
+
|
|
79
|
+
# user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, interest_num, embed_dim]
|
|
80
|
+
user_embedding = torch.matmul(input_user,self.convert_user_weight)
|
|
81
|
+
user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
|
|
82
|
+
if self.mode == "user":
|
|
83
|
+
return user_embedding #inference embedding mode -> [batch_size, interest_num, embed_dim]
|
|
84
|
+
return user_embedding
|
|
85
|
+
|
|
86
|
+
def item_tower(self, x):
|
|
87
|
+
if self.mode == "user":
|
|
88
|
+
return None
|
|
89
|
+
pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
|
|
90
|
+
pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
|
|
91
|
+
if self.mode == "item": #inference embedding mode
|
|
92
|
+
return pos_embedding.squeeze(1) #[batch_size, embed_dim]
|
|
93
|
+
neg_embeddings = self.embedding(x, self.neg_item_feature,
|
|
94
|
+
squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
|
|
95
|
+
neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
|
|
96
|
+
return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
|
|
97
|
+
|
|
98
|
+
def gen_mask(self, x):
|
|
99
|
+
his_list = x[self.history_features[0].name]
|
|
100
|
+
mask = (his_list > 0).long()
|
|
101
|
+
return mask
|
|
102
|
+
|
|
103
|
+
class ComirecDR(torch.nn.Module):
|
|
104
|
+
"""The match model mentioned in `Controllable Multi-Interest Framework for Recommendation` paper.
|
|
105
|
+
It's a ComirecDR match model trained by global softmax loss on list-wise samples.
|
|
106
|
+
Note in origin paper, it's without item dnn tower and train item embedding directly.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
user_features (list[Feature Class]): training by the user tower module.
|
|
110
|
+
history_features (list[Feature Class]): training history
|
|
111
|
+
item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
|
|
112
|
+
neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
|
|
113
|
+
max_length (int): max sequence length of input item sequence
|
|
114
|
+
temperature (float): temperature factor for similarity score, default to 1.0.
|
|
115
|
+
interest_num (int): interest num
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(self, user_features, history_features, item_features, neg_item_feature, max_length, temperature=1.0, interest_num=4):
|
|
119
|
+
super().__init__()
|
|
120
|
+
self.user_features = user_features
|
|
121
|
+
self.item_features = item_features
|
|
122
|
+
self.history_features = history_features
|
|
123
|
+
self.neg_item_feature = neg_item_feature
|
|
124
|
+
self.temperature = temperature
|
|
125
|
+
self.interest_num = interest_num
|
|
126
|
+
self.max_length = max_length
|
|
127
|
+
self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])
|
|
128
|
+
|
|
129
|
+
self.embedding = EmbeddingLayer(user_features + item_features + history_features)
|
|
130
|
+
self.capsule = CapsuleNetwork(self.history_features[0].embed_dim,self.max_length,bilinear_type=2,interest_num=self.interest_num)
|
|
131
|
+
self.convert_user_weight = nn.Parameter(torch.rand(self.user_dims, self.history_features[0].embed_dim), requires_grad=True)
|
|
132
|
+
self.mode = None
|
|
133
|
+
|
|
134
|
+
def forward(self, x):
|
|
135
|
+
user_embedding = self.user_tower(x)
|
|
136
|
+
item_embedding = self.item_tower(x)
|
|
137
|
+
if self.mode == "user":
|
|
138
|
+
return user_embedding
|
|
139
|
+
if self.mode == "item":
|
|
140
|
+
return item_embedding
|
|
141
|
+
|
|
142
|
+
pos_item_embedding = item_embedding[:,0,:]
|
|
143
|
+
dot_res = torch.bmm(user_embedding, pos_item_embedding.squeeze(1).unsqueeze(-1))
|
|
144
|
+
k_index = torch.argmax(dot_res, dim=1)
|
|
145
|
+
best_interest_emb = torch.rand(user_embedding.shape[0], user_embedding.shape[2]).to(user_embedding.device)
|
|
146
|
+
for k in range(user_embedding.shape[0]):
|
|
147
|
+
best_interest_emb[k, :] = user_embedding[k, k_index[k], :]
|
|
148
|
+
best_interest_emb = best_interest_emb.unsqueeze(1)
|
|
149
|
+
|
|
150
|
+
y = torch.mul(best_interest_emb, item_embedding).sum(dim=1)
|
|
151
|
+
|
|
152
|
+
return y
|
|
153
|
+
|
|
154
|
+
def user_tower(self, x):
|
|
155
|
+
if self.mode == "item":
|
|
156
|
+
return None
|
|
157
|
+
input_user = self.embedding(x, self.user_features, squeeze_dim=True).unsqueeze(1) #[batch_size, num_features*deep_dims]
|
|
158
|
+
input_user = input_user.expand([input_user.shape[0], self.interest_num, input_user.shape[-1]])
|
|
159
|
+
|
|
160
|
+
history_emb = self.embedding(x, self.history_features).squeeze(1)
|
|
161
|
+
mask = self.gen_mask(x)
|
|
162
|
+
multi_interest_emb = self.capsule(history_emb,mask)
|
|
163
|
+
|
|
164
|
+
input_user = torch.cat([input_user,multi_interest_emb],dim=-1)
|
|
165
|
+
|
|
166
|
+
# user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, interest_num, embed_dim]
|
|
167
|
+
user_embedding = torch.matmul(input_user,self.convert_user_weight)
|
|
168
|
+
user_embedding = F.normalize(user_embedding, p=2, dim=-1) # L2 normalize
|
|
169
|
+
if self.mode == "user":
|
|
170
|
+
return user_embedding #inference embedding mode -> [batch_size, interest_num, embed_dim]
|
|
171
|
+
return user_embedding
|
|
172
|
+
|
|
173
|
+
def item_tower(self, x):
|
|
174
|
+
if self.mode == "user":
|
|
175
|
+
return None
|
|
176
|
+
pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
|
|
177
|
+
pos_embedding = F.normalize(pos_embedding, p=2, dim=-1) # L2 normalize
|
|
178
|
+
if self.mode == "item": #inference embedding mode
|
|
179
|
+
return pos_embedding.squeeze(1) #[batch_size, embed_dim]
|
|
180
|
+
neg_embeddings = self.embedding(x, self.neg_item_feature,
|
|
181
|
+
squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
|
|
182
|
+
neg_embeddings = F.normalize(neg_embeddings, p=2, dim=-1) # L2 normalize
|
|
183
|
+
return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
|
|
184
|
+
|
|
185
|
+
def gen_mask(self, x):
|
|
186
|
+
his_list = x[self.history_features[0].name]
|
|
187
|
+
mask = (his_list > 0).long()
|
|
188
|
+
return mask
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 12/05/2022, update on 20/05/2022
|
|
3
|
+
References:
|
|
4
|
+
paper: (CIKM'2013) Learning Deep Structured Semantic Models for Web Search using Clickthrough Data
|
|
5
|
+
url: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf
|
|
6
|
+
code: https://github.com/bbruceyuan/DeepMatch-Torch/blob/main/deepmatch_torch/models/dssm.py
|
|
7
|
+
Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from ...basic.layers import MLP, EmbeddingLayer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DSSM(torch.nn.Module):
|
|
16
|
+
"""Deep Structured Semantic Model
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
user_features (list[Feature Class]): training by the user tower module.
|
|
20
|
+
item_features (list[Feature Class]): training by the item tower module.
|
|
21
|
+
temperature (float): temperature factor for similarity score, default to 1.0.
|
|
22
|
+
user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
|
|
23
|
+
item_params (dict): the params of the Item Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, user_features, item_features, user_params, item_params, temperature=1.0):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.user_features = user_features
|
|
29
|
+
self.item_features = item_features
|
|
30
|
+
self.temperature = temperature
|
|
31
|
+
self.user_dims = sum([fea.embed_dim for fea in user_features])
|
|
32
|
+
self.item_dims = sum([fea.embed_dim for fea in item_features])
|
|
33
|
+
|
|
34
|
+
self.embedding = EmbeddingLayer(user_features + item_features)
|
|
35
|
+
self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
|
|
36
|
+
self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
|
|
37
|
+
self.mode = None
|
|
38
|
+
|
|
39
|
+
def forward(self, x):
|
|
40
|
+
user_embedding = self.user_tower(x)
|
|
41
|
+
item_embedding = self.item_tower(x)
|
|
42
|
+
if self.mode == "user":
|
|
43
|
+
return user_embedding
|
|
44
|
+
if self.mode == "item":
|
|
45
|
+
return item_embedding
|
|
46
|
+
|
|
47
|
+
# calculate cosine score
|
|
48
|
+
y = torch.mul(user_embedding, item_embedding).sum(dim=1)
|
|
49
|
+
# y = y / self.temperature
|
|
50
|
+
return torch.sigmoid(y)
|
|
51
|
+
|
|
52
|
+
def user_tower(self, x):
|
|
53
|
+
if self.mode == "item":
|
|
54
|
+
return None
|
|
55
|
+
input_user = self.embedding(x, self.user_features, squeeze_dim=True) #[batch_size, num_features*deep_dims]
|
|
56
|
+
user_embedding = self.user_mlp(input_user) #[batch_size, user_params["dims"][-1]]
|
|
57
|
+
user_embedding = F.normalize(user_embedding, p=2, dim=1) # L2 normalize
|
|
58
|
+
return user_embedding
|
|
59
|
+
|
|
60
|
+
def item_tower(self, x):
|
|
61
|
+
if self.mode == "user":
|
|
62
|
+
return None
|
|
63
|
+
input_item = self.embedding(x, self.item_features, squeeze_dim=True) #[batch_size, num_features*embed_dim]
|
|
64
|
+
item_embedding = self.item_mlp(input_item) #[batch_size, item_params["dims"][-1]]
|
|
65
|
+
item_embedding = F.normalize(item_embedding, p=2, dim=1)
|
|
66
|
+
return item_embedding
|