torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -1,151 +1,148 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Date: created on 03/07/2022
|
|
3
|
-
References:
|
|
4
|
-
paper: Sparse-Interest Network for Sequential Recommendation
|
|
5
|
-
url: https://arxiv.org/abs/2102.09267
|
|
6
|
-
code: https://github.com/Qiaoyut/SINE/blob/master/model.py
|
|
7
|
-
Authors: Bo Kang, klinux@live.com
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
import torch.nn.functional as F
|
|
12
|
-
from torch import einsum
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
36
|
-
self.
|
|
37
|
-
self.
|
|
38
|
-
self.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
self.
|
|
44
|
-
torch.nn.init.normal_(self.
|
|
45
|
-
self.
|
|
46
|
-
torch.nn.init.normal_(self.
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
self.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
self.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
self.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
#
|
|
74
|
-
#
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
##
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
v_u
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
his_list = x[self.history_features[0].name]
|
|
150
|
-
mask = (his_list > 0).long()
|
|
151
|
-
return mask
|
|
1
|
+
"""
|
|
2
|
+
Date: created on 03/07/2022
|
|
3
|
+
References:
|
|
4
|
+
paper: Sparse-Interest Network for Sequential Recommendation
|
|
5
|
+
url: https://arxiv.org/abs/2102.09267
|
|
6
|
+
code: https://github.com/Qiaoyut/SINE/blob/master/model.py
|
|
7
|
+
Authors: Bo Kang, klinux@live.com
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from torch import einsum
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SINE(torch.nn.Module):
|
|
16
|
+
"""The match model was proposed in `Sparse-Interest Network for Sequential Recommendation` paper.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
history_features (list[str]): training history feature names, this is for indexing the historical sequences from input dictionary
|
|
20
|
+
item_features (list[str]): item feature names, this is for indexing the items from input dictionary
|
|
21
|
+
neg_item_features (list[str]): neg item feature names, this for indexing negative items from input dictionary
|
|
22
|
+
num_items (int): number of items in the data
|
|
23
|
+
embedding_dim (int): dimensionality of the embeddings
|
|
24
|
+
hidden_dim (int): dimensionality of the hidden layer in self attention modules
|
|
25
|
+
num_concept (int): number of concept, also called conceptual prototypes
|
|
26
|
+
num_intention (int): number of (user) specific intentions out of the concepts
|
|
27
|
+
seq_max_len (int): max sequence length of input item sequence
|
|
28
|
+
num_heads (int): number of attention heads in self attention modules, default to 1
|
|
29
|
+
temperature (float): temperature factor in the similarity measure, default to 1.0
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, history_features, item_features, neg_item_features, num_items, embedding_dim, hidden_dim, num_concept, num_intention, seq_max_len, num_heads=1, temperature=1.0):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.item_features = item_features
|
|
35
|
+
self.history_features = history_features
|
|
36
|
+
self.neg_item_features = neg_item_features
|
|
37
|
+
self.temperature = temperature
|
|
38
|
+
self.num_concept = num_concept
|
|
39
|
+
self.num_intention = num_intention
|
|
40
|
+
self.seq_max_len = seq_max_len
|
|
41
|
+
|
|
42
|
+
std = 1e-4
|
|
43
|
+
self.item_embedding = torch.nn.Embedding(num_items, embedding_dim)
|
|
44
|
+
torch.nn.init.normal_(self.item_embedding.weight, 0, std)
|
|
45
|
+
self.concept_embedding = torch.nn.Embedding(num_concept, embedding_dim)
|
|
46
|
+
torch.nn.init.normal_(self.concept_embedding.weight, 0, std)
|
|
47
|
+
self.position_embedding = torch.nn.Embedding(seq_max_len, embedding_dim)
|
|
48
|
+
torch.nn.init.normal_(self.position_embedding.weight, 0, std)
|
|
49
|
+
|
|
50
|
+
self.w_1 = torch.nn.Parameter(torch.rand(embedding_dim, hidden_dim), requires_grad=True)
|
|
51
|
+
self.w_2 = torch.nn.Parameter(torch.rand(hidden_dim, num_heads), requires_grad=True)
|
|
52
|
+
|
|
53
|
+
self.w_3 = torch.nn.Parameter(torch.rand(embedding_dim, embedding_dim), requires_grad=True)
|
|
54
|
+
|
|
55
|
+
self.w_k1 = torch.nn.Parameter(torch.rand(embedding_dim, hidden_dim), requires_grad=True)
|
|
56
|
+
self.w_k2 = torch.nn.Parameter(torch.rand(hidden_dim, num_intention), requires_grad=True)
|
|
57
|
+
|
|
58
|
+
self.w_4 = torch.nn.Parameter(torch.rand(embedding_dim, hidden_dim), requires_grad=True)
|
|
59
|
+
self.w_5 = torch.nn.Parameter(torch.rand(hidden_dim, num_heads), requires_grad=True)
|
|
60
|
+
|
|
61
|
+
self.mode = None
|
|
62
|
+
|
|
63
|
+
def forward(self, x):
|
|
64
|
+
user_embedding = self.user_tower(x)
|
|
65
|
+
item_embedding = self.item_tower(x)
|
|
66
|
+
if self.mode == "user":
|
|
67
|
+
return user_embedding
|
|
68
|
+
if self.mode == "item":
|
|
69
|
+
return item_embedding
|
|
70
|
+
|
|
71
|
+
y = torch.mul(user_embedding, item_embedding).sum(dim=-1)
|
|
72
|
+
|
|
73
|
+
# # compute covariance regularizer
|
|
74
|
+
# M = torch.cov(self.concept_embedding.weight, correction=0)
|
|
75
|
+
# l_c = (torch.norm(M, p='fro')**2 - torch.norm(torch.diag(M), p='fro')**2)/2
|
|
76
|
+
|
|
77
|
+
return y
|
|
78
|
+
|
|
79
|
+
def user_tower(self, x):
|
|
80
|
+
if self.mode == "item":
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# sparse interests extraction
|
|
85
|
+
# # user specific historical item embedding X^u
|
|
86
|
+
hist_item = x[self.history_features[0]]
|
|
87
|
+
x_u = self.item_embedding(hist_item) + \
|
|
88
|
+
self.position_embedding.weight.unsqueeze(0)
|
|
89
|
+
x_u_mask = (x[self.history_features[0]] > 0).long()
|
|
90
|
+
|
|
91
|
+
# # user specific conceptual prototypes C^u
|
|
92
|
+
# ## attention a
|
|
93
|
+
h_1 = einsum('bse, ed -> bsd', x_u, self.w_1).tanh()
|
|
94
|
+
a_hist = F.softmax(einsum('bsd, dh -> bsh', h_1, self.w_2) + -1.e9 * (1 - x_u_mask.unsqueeze(-1).float()), dim=1)
|
|
95
|
+
|
|
96
|
+
# ## virtual concept vector z_u
|
|
97
|
+
z_u = einsum("bse, bsh -> be", x_u, a_hist)
|
|
98
|
+
|
|
99
|
+
# ## similarity between user's concept vector and entire conceptual prototypes s^u
|
|
100
|
+
s_u = einsum("be, te -> bt", z_u, self.concept_embedding.weight)
|
|
101
|
+
s_u_top_k = torch.topk(s_u, self.num_intention)
|
|
102
|
+
|
|
103
|
+
# ## final C^u
|
|
104
|
+
c_u = einsum("bk, bke -> bke", torch.sigmoid(s_u_top_k.values), self.concept_embedding(s_u_top_k.indices))
|
|
105
|
+
|
|
106
|
+
# # user intention assignment P_{k|t}
|
|
107
|
+
p_u = F.softmax(einsum("bse, bke -> bks", F.normalize(x_u @ self.w_3, dim=-1), F.normalize(c_u, p=2, dim=-1)), dim=1)
|
|
108
|
+
|
|
109
|
+
# # attention weighing P_{t|k}
|
|
110
|
+
h_2 = einsum('bse, ed -> bsd', x_u, self.w_k1).tanh()
|
|
111
|
+
a_concept_k = F.softmax(einsum('bsd, dk -> bsk', h_2, self.w_k2) + -1.e9 * (1 - x_u_mask.unsqueeze(-1).float()), dim=1)
|
|
112
|
+
|
|
113
|
+
# # multiple interests encoding \phi_\theta^k(x^u)
|
|
114
|
+
phi_u = einsum("bks, bse -> bke", p_u * a_concept_k.permute(0, 2, 1), x_u)
|
|
115
|
+
|
|
116
|
+
# adaptive interest aggregation
|
|
117
|
+
# # intention aware input behavior \hat{X^u}
|
|
118
|
+
x_u_hat = einsum('bks, bke -> bse', p_u, c_u)
|
|
119
|
+
|
|
120
|
+
# # user's next intention C^u_{apt}
|
|
121
|
+
h_3 = einsum('bse, ed -> bsd', x_u_hat, self.w_4).tanh()
|
|
122
|
+
c_u_apt = F.normalize(einsum("bs, bse -> be", F.softmax(einsum('bsd, dh -> bsh', h_3, self.w_5).reshape(-1, self.seq_max_len) + -1.e9 * (1 - x_u_mask.float()), dim=1), x_u_hat), -1)
|
|
123
|
+
|
|
124
|
+
# # aggregation weights e_k^u
|
|
125
|
+
e_u = F.softmax(einsum('be, bke -> bk', c_u_apt, phi_u) / self.temperature, dim=1)
|
|
126
|
+
|
|
127
|
+
# final user representation v^u
|
|
128
|
+
v_u = einsum('bk, bke -> be', e_u, phi_u)
|
|
129
|
+
|
|
130
|
+
if self.mode == "user":
|
|
131
|
+
return v_u
|
|
132
|
+
return v_u.unsqueeze(1)
|
|
133
|
+
|
|
134
|
+
def item_tower(self, x):
|
|
135
|
+
if self.mode == "user":
|
|
136
|
+
return None
|
|
137
|
+
pos_embedding = self.item_embedding(x[self.item_features[0]]).unsqueeze(1)
|
|
138
|
+
if self.mode == "item": # inference embedding mode
|
|
139
|
+
return pos_embedding.squeeze(1) # [batch_size, embed_dim]
|
|
140
|
+
neg_embeddings = self.item_embedding(x[self.neg_item_features[0]]).squeeze(1) # [batch_size, n_neg_items, embed_dim]
|
|
141
|
+
|
|
142
|
+
# [batch_size, 1+n_neg_items, embed_dim]
|
|
143
|
+
return torch.cat((pos_embedding, neg_embeddings), dim=1)
|
|
144
|
+
|
|
145
|
+
def gen_mask(self, x):
|
|
146
|
+
his_list = x[self.history_features[0].name]
|
|
147
|
+
mask = (his_list > 0).long()
|
|
148
|
+
return mask
|
|
@@ -1,83 +1,81 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Date: created on 17/09/2022
|
|
3
|
-
References:
|
|
4
|
-
paper: STAMP: Short-Term Attention/Memory Priority Model for Session-based Recommendation
|
|
5
|
-
url: https://dl.acm.org/doi/10.1145/3219819.3219950
|
|
6
|
-
official Tensorflow implementation: https://github.com/uestcnlp/STAMP
|
|
7
|
-
Authors: Bo Kang, klinux@live.com
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
import torch.nn
|
|
12
|
-
import torch.nn as
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class STAMP(nn.Module):
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
28
|
-
self.
|
|
29
|
-
self.
|
|
30
|
-
self.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
36
|
-
self.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def _init_parameter_weights(self, weight_std):
|
|
40
|
-
nn.init.normal_(self.w_0, std=weight_std)
|
|
41
|
-
nn.init.normal_(self.w_1_t, std=weight_std)
|
|
42
|
-
nn.init.normal_(self.w_2_t, std=weight_std)
|
|
43
|
-
nn.init.normal_(self.w_3_t, std=weight_std)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
module.
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
z
|
|
82
|
-
|
|
83
|
-
return z
|
|
1
|
+
"""
|
|
2
|
+
Date: created on 17/09/2022
|
|
3
|
+
References:
|
|
4
|
+
paper: STAMP: Short-Term Attention/Memory Priority Model for Session-based Recommendation
|
|
5
|
+
url: https://dl.acm.org/doi/10.1145/3219819.3219950
|
|
6
|
+
official Tensorflow implementation: https://github.com/uestcnlp/STAMP
|
|
7
|
+
Authors: Bo Kang, klinux@live.com
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class STAMP(nn.Module):
|
|
16
|
+
|
|
17
|
+
def __init__(self, item_history_feature, weight_std, emb_std):
|
|
18
|
+
super(STAMP, self).__init__()
|
|
19
|
+
|
|
20
|
+
# item embedding layer
|
|
21
|
+
self.item_history_feature = item_history_feature
|
|
22
|
+
n_items, item_emb_dim, = item_history_feature.vocab_size, item_history_feature.embed_dim
|
|
23
|
+
self.item_emb = nn.Embedding(n_items, item_emb_dim, padding_idx=0)
|
|
24
|
+
|
|
25
|
+
# weights and biases for attention computation
|
|
26
|
+
self.w_0 = nn.Parameter(torch.zeros(item_emb_dim, 1))
|
|
27
|
+
self.w_1_t = nn.Parameter(torch.zeros(item_emb_dim, item_emb_dim))
|
|
28
|
+
self.w_2_t = nn.Parameter(torch.zeros(item_emb_dim, item_emb_dim))
|
|
29
|
+
self.w_3_t = nn.Parameter(torch.zeros(item_emb_dim, item_emb_dim))
|
|
30
|
+
self.b_a = nn.Parameter(torch.zeros(item_emb_dim))
|
|
31
|
+
self._init_parameter_weights(weight_std)
|
|
32
|
+
|
|
33
|
+
# mlp layers
|
|
34
|
+
self.f_s = nn.Sequential(nn.Tanh(), nn.Linear(item_emb_dim, item_emb_dim))
|
|
35
|
+
self.f_t = nn.Sequential(nn.Tanh(), nn.Linear(item_emb_dim, item_emb_dim))
|
|
36
|
+
self.emb_std = emb_std
|
|
37
|
+
self.apply(self._init_module_weights)
|
|
38
|
+
|
|
39
|
+
def _init_parameter_weights(self, weight_std):
|
|
40
|
+
nn.init.normal_(self.w_0, std=weight_std)
|
|
41
|
+
nn.init.normal_(self.w_1_t, std=weight_std)
|
|
42
|
+
nn.init.normal_(self.w_2_t, std=weight_std)
|
|
43
|
+
nn.init.normal_(self.w_3_t, std=weight_std)
|
|
44
|
+
|
|
45
|
+
def _init_module_weights(self, module):
|
|
46
|
+
if isinstance(module, nn.Linear):
|
|
47
|
+
module.weight.data.normal_(std=self.emb_std)
|
|
48
|
+
if module.bias is not None:
|
|
49
|
+
module.bias.data.zero_()
|
|
50
|
+
elif isinstance(module, nn.Embedding):
|
|
51
|
+
module.weight.data.normal_(std=self.emb_std)
|
|
52
|
+
|
|
53
|
+
def forward(self, input_dict):
|
|
54
|
+
# Index the embeddings for the items in the session
|
|
55
|
+
input = input_dict[self.item_history_feature.name]
|
|
56
|
+
value_mask = (input != 0).unsqueeze(-1)
|
|
57
|
+
value_counts = value_mask.sum(dim=1, keepdim=True).squeeze(-1)
|
|
58
|
+
item_emb_batch = self.item_emb(input) * value_mask
|
|
59
|
+
|
|
60
|
+
# Index the embeddings of the latest clicked items
|
|
61
|
+
x_t = self.item_emb(torch.gather(input, 1, value_counts - 1))
|
|
62
|
+
|
|
63
|
+
# Eq. 2, user's general interest in the current session
|
|
64
|
+
m_s = ((item_emb_batch).sum(1) / value_counts).unsqueeze(1)
|
|
65
|
+
|
|
66
|
+
# Eq. 7, compute attention coefficient
|
|
67
|
+
a = F.normalize(torch.exp(torch.sigmoid(item_emb_batch @ self.w_1_t + x_t @ self.w_2_t + m_s @ self.w_3_t + self.b_a) @ self.w_0) * value_mask, p=1, dim=1)
|
|
68
|
+
|
|
69
|
+
# Eq. 8, compute user's attention-based interests
|
|
70
|
+
m_a = (a * item_emb_batch).sum(1) + m_s.squeeze(1)
|
|
71
|
+
|
|
72
|
+
# Eq. 3, compute the output state of the general interest
|
|
73
|
+
h_s = self.f_s(m_a)
|
|
74
|
+
|
|
75
|
+
# Eq. 9, compute the output state of the short-term interest
|
|
76
|
+
h_t = self.f_t(x_t).squeeze(1)
|
|
77
|
+
|
|
78
|
+
# Eq. 4, compute candidate scores
|
|
79
|
+
z = h_s * h_t @ self.item_emb.weight.T
|
|
80
|
+
|
|
81
|
+
return z
|
|
@@ -1,71 +1,75 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Date: create on 23/05/2022
|
|
3
|
-
References:
|
|
4
|
-
paper: (RecSys'2016) Deep Neural Networks for YouTube Recommendations
|
|
5
|
-
url: https://dl.acm.org/doi/10.1145/2959100.2959190
|
|
6
|
-
Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
import torch.nn.functional as F
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
self.
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
33
|
-
self.
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
36
|
-
self.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
if self.mode == "
|
|
67
|
-
return
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 23/05/2022
|
|
3
|
+
References:
|
|
4
|
+
paper: (RecSys'2016) Deep Neural Networks for YouTube Recommendations
|
|
5
|
+
url: https://dl.acm.org/doi/10.1145/2959100.2959190
|
|
6
|
+
Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
|
|
12
|
+
from ...basic.layers import MLP, EmbeddingLayer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class YoutubeDNN(torch.nn.Module):
|
|
16
|
+
"""The match model mentioned in `Deep Neural Networks for YouTube Recommendations` paper.
|
|
17
|
+
It's a DSSM match model trained by global softmax loss on list-wise samples.
|
|
18
|
+
Note in origin paper, it's without item dnn tower and train item embedding directly.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
user_features (list[Feature Class]): training by the user tower module.
|
|
22
|
+
item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
|
|
23
|
+
neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
|
|
24
|
+
user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
|
|
25
|
+
temperature (float): temperature factor for similarity score, default to 1.0.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, user_features, item_features, neg_item_feature, user_params, temperature=1.0):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.user_features = user_features
|
|
31
|
+
self.item_features = item_features
|
|
32
|
+
self.neg_item_feature = neg_item_feature
|
|
33
|
+
self.temperature = temperature
|
|
34
|
+
self.user_dims = sum([fea.embed_dim for fea in user_features])
|
|
35
|
+
self.embedding = EmbeddingLayer(user_features + item_features)
|
|
36
|
+
self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
|
|
37
|
+
self.mode = None
|
|
38
|
+
|
|
39
|
+
def forward(self, x):
|
|
40
|
+
user_embedding = self.user_tower(x)
|
|
41
|
+
item_embedding = self.item_tower(x)
|
|
42
|
+
if self.mode == "user":
|
|
43
|
+
return user_embedding
|
|
44
|
+
if self.mode == "item":
|
|
45
|
+
return item_embedding
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# calculate cosine score
|
|
49
|
+
y = torch.mul(user_embedding, item_embedding).sum(dim=2)
|
|
50
|
+
y = y / self.temperature
|
|
51
|
+
return y
|
|
52
|
+
|
|
53
|
+
def user_tower(self, x):
|
|
54
|
+
if self.mode == "item":
|
|
55
|
+
return None
|
|
56
|
+
# [batch_size, num_features*deep_dims]
|
|
57
|
+
input_user = self.embedding(x, self.user_features, squeeze_dim=True)
|
|
58
|
+
user_embedding = self.user_mlp(input_user).unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
59
|
+
user_embedding = F.normalize(user_embedding, p=2, dim=2)
|
|
60
|
+
if self.mode == "user":
|
|
61
|
+
# inference embedding mode -> [batch_size, embed_dim]
|
|
62
|
+
return user_embedding.squeeze(1)
|
|
63
|
+
return user_embedding
|
|
64
|
+
|
|
65
|
+
def item_tower(self, x):
|
|
66
|
+
if self.mode == "user":
|
|
67
|
+
return None
|
|
68
|
+
pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) # [batch_size, 1, embed_dim]
|
|
69
|
+
pos_embedding = F.normalize(pos_embedding, p=2, dim=2)
|
|
70
|
+
if self.mode == "item": # inference embedding mode
|
|
71
|
+
return pos_embedding.squeeze(1) # [batch_size, embed_dim]
|
|
72
|
+
neg_embeddings = self.embedding(x, self.neg_item_feature, squeeze_dim=False).squeeze(1) # [batch_size, n_neg_items, embed_dim]
|
|
73
|
+
neg_embeddings = F.normalize(neg_embeddings, p=2, dim=2)
|
|
74
|
+
# [batch_size, 1+n_neg_items, embed_dim]
|
|
75
|
+
return torch.cat((pos_embedding, neg_embeddings), dim=1)
|