torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -1,76 +1,82 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Date: created on 06/09/2022
|
|
3
|
-
References:
|
|
4
|
-
paper: Neural Attentive Session-based Recommendation
|
|
5
|
-
url: http://arxiv.org/abs/1711.04725
|
|
6
|
-
official Theano implementation: https://github.com/lijingsdu/sessionRec_NARM
|
|
7
|
-
another Pytorch implementation: https://github.com/Wang-Shuo/Neural-Attentive-Session-Based-Recommendation-PyTorch
|
|
8
|
-
Authors: Bo Kang, klinux@live.com
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
import torch
|
|
12
|
-
import torch.nn as nn
|
|
13
|
-
import torch.nn.utils.rnn as rnn_utils
|
|
14
|
-
from torch import sigmoid
|
|
15
|
-
from torch.nn import GRU,
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class NARM(nn.Module):
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
self.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
h_t =
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
# Eq.
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
# Eq.
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
1
|
+
"""
|
|
2
|
+
Date: created on 06/09/2022
|
|
3
|
+
References:
|
|
4
|
+
paper: Neural Attentive Session-based Recommendation
|
|
5
|
+
url: http://arxiv.org/abs/1711.04725
|
|
6
|
+
official Theano implementation: https://github.com/lijingsdu/sessionRec_NARM
|
|
7
|
+
another Pytorch implementation: https://github.com/Wang-Shuo/Neural-Attentive-Session-Based-Recommendation-PyTorch
|
|
8
|
+
Authors: Bo Kang, klinux@live.com
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch.nn.utils.rnn as rnn_utils
|
|
14
|
+
from torch import sigmoid
|
|
15
|
+
from torch.nn import GRU, Dropout, Embedding, Parameter
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class NARM(nn.Module):
|
|
19
|
+
|
|
20
|
+
def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p):
|
|
21
|
+
super(NARM, self).__init__()
|
|
22
|
+
|
|
23
|
+
# item embedding layer
|
|
24
|
+
self.item_history_feature = item_history_feature
|
|
25
|
+
self.item_emb = Embedding(item_history_feature.vocab_size, item_history_feature.embed_dim, padding_idx=0)
|
|
26
|
+
|
|
27
|
+
# embedding dropout layer
|
|
28
|
+
self.emb_dropout = Dropout(emb_dropout_p)
|
|
29
|
+
|
|
30
|
+
# gru unit
|
|
31
|
+
self.gru = GRU(input_size=item_history_feature.embed_dim, hidden_size=hidden_dim)
|
|
32
|
+
|
|
33
|
+
# attention projection matrices
|
|
34
|
+
self.a_1, self.a_2 = Parameter(torch.randn(hidden_dim, hidden_dim)), Parameter(torch.randn(hidden_dim, hidden_dim))
|
|
35
|
+
|
|
36
|
+
# attention context vector
|
|
37
|
+
self.v = Parameter(torch.randn(hidden_dim, 1))
|
|
38
|
+
|
|
39
|
+
# session representation dropout layer
|
|
40
|
+
self.session_rep_dropout = Dropout(session_rep_dropout_p)
|
|
41
|
+
|
|
42
|
+
# bilinear projection matrix
|
|
43
|
+
self.b = Parameter(torch.randn(item_history_feature.embed_dim, hidden_dim * 2))
|
|
44
|
+
|
|
45
|
+
def forward(self, input_dict):
|
|
46
|
+
# Eq. 1-4, index item embeddings and pass through gru
|
|
47
|
+
# # Fetch the embeddings for items in the session
|
|
48
|
+
input = input_dict[self.item_history_feature.name]
|
|
49
|
+
value_mask = (input != 0)
|
|
50
|
+
value_counts = value_mask.sum(dim=1, keepdim=False).to("cpu").detach()
|
|
51
|
+
embs = rnn_utils.pack_padded_sequence(self.emb_dropout(self.item_emb(input)), value_counts, batch_first=True, enforce_sorted=False)
|
|
52
|
+
|
|
53
|
+
# # compute hidden states at each time step
|
|
54
|
+
h, h_t = self.gru(embs)
|
|
55
|
+
h_t = h_t.permute(1, 0, 2)
|
|
56
|
+
h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
|
|
57
|
+
|
|
58
|
+
# Eq. 5, set last hidden state of gru as the output of the global
|
|
59
|
+
# encoder
|
|
60
|
+
c_g = h_t.squeeze(1)
|
|
61
|
+
|
|
62
|
+
# Eq. 8, compute similarity between final hidden state and previous
|
|
63
|
+
# hidden states
|
|
64
|
+
q = sigmoid(h_t @ self.a_1.T + h @ self.a_2.T) @ self.v
|
|
65
|
+
|
|
66
|
+
# Eq. 7, compute attention
|
|
67
|
+
alpha = torch.exp(q) * value_mask.unsqueeze(-1)
|
|
68
|
+
alpha /= alpha.sum(dim=1, keepdim=True)
|
|
69
|
+
|
|
70
|
+
# Eq. 6, compute the output of the local encoder
|
|
71
|
+
c_l = (alpha * h).sum(1)
|
|
72
|
+
|
|
73
|
+
# Eq. 9, compute session representation by concatenating user
|
|
74
|
+
# sequential behavior (global) and main purpose in the current session
|
|
75
|
+
# (local)
|
|
76
|
+
c = self.session_rep_dropout(torch.hstack((c_g, c_l)))
|
|
77
|
+
|
|
78
|
+
# Eq. 10, compute bilinear similarity between current session and each
|
|
79
|
+
# candidate items
|
|
80
|
+
s = c @ self.b.T @ self.item_emb.weight.T
|
|
81
|
+
|
|
82
|
+
return s
|
|
@@ -1,140 +1,143 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Date: create on 2022/5/8, update on 2022/5/8
|
|
3
|
-
References:
|
|
4
|
-
paper: (ICDM'2018) Self-attentive sequential recommendation
|
|
5
|
-
url: https://arxiv.org/pdf/1808.09781.pdf
|
|
6
|
-
code: https://github.com/kang205/SASRec
|
|
7
|
-
Authors: Yuchen Wang, 615922749@qq.com
|
|
8
|
-
"""
|
|
9
|
-
import numpy as np
|
|
10
|
-
import torch
|
|
11
|
-
import torch.nn as nn
|
|
12
|
-
|
|
13
|
-
from torch_rechub.basic.features import DenseFeature,
|
|
14
|
-
from torch_rechub.basic.layers import
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class SASRec(torch.nn.Module):
|
|
18
|
-
"""SASRec: Self-Attentive Sequential Recommendation
|
|
19
|
-
Args:
|
|
20
|
-
features (list): the list of `Feature Class`. In sasrec, the features list needs to have three elements in order: user historical behavior sequence features, positive sample sequence, and negative sample sequence.
|
|
21
|
-
max_len: The length of the sequence feature.
|
|
22
|
-
num_blocks: The number of stacks of attention modules.
|
|
23
|
-
num_heads: The number of heads in MultiheadAttention.
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
self.
|
|
36
|
-
|
|
37
|
-
self.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
self.
|
|
41
|
-
|
|
42
|
-
self.
|
|
43
|
-
|
|
44
|
-
self.
|
|
45
|
-
|
|
46
|
-
self.
|
|
47
|
-
self.
|
|
48
|
-
self.
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
self.attention_layers.append(new_attn_layer)
|
|
58
|
-
|
|
59
|
-
new_fwd_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
|
|
60
|
-
self.forward_layernorms.append(new_fwd_layernorm)
|
|
61
|
-
|
|
62
|
-
new_fwd_layer = PointWiseFeedForward(self.embed_dim, dropout_rate)
|
|
63
|
-
self.forward_layers.append(new_fwd_layer)
|
|
64
|
-
|
|
65
|
-
def seq_forward(self, x, embed_x_feature):
|
|
66
|
-
x = x['seq']
|
|
67
|
-
|
|
68
|
-
embed_x_feature *= self.features[0].embed_dim
|
|
69
|
-
embed_x_feature = embed_x_feature.squeeze()
|
|
70
|
-
|
|
71
|
-
positions = np.tile(np.array(range(x.shape[1])), [x.shape[0], 1])
|
|
72
|
-
|
|
73
|
-
embed_x_feature += self.position_emb(torch.LongTensor(positions))
|
|
74
|
-
embed_x_feature = self.emb_dropout(embed_x_feature)
|
|
75
|
-
|
|
76
|
-
timeline_mask = torch.BoolTensor(x == 0)
|
|
77
|
-
embed_x_feature *= ~timeline_mask.unsqueeze(-1)
|
|
78
|
-
|
|
79
|
-
attention_mask = ~torch.tril(torch.ones((embed_x_feature.shape[1], embed_x_feature.shape[1]), dtype=torch.bool))
|
|
80
|
-
|
|
81
|
-
for i in range(len(self.attention_layers)):
|
|
82
|
-
embed_x_feature = torch.transpose(embed_x_feature, 0, 1)
|
|
83
|
-
Q = self.attention_layernorms[i](embed_x_feature)
|
|
84
|
-
mha_outputs, _ = self.attention_layers[i](Q, embed_x_feature, embed_x_feature,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
embed_x_feature =
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
embed_x_feature = self.
|
|
91
|
-
embed_x_feature
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
embedding = self.item_emb(x, self.features)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
self.
|
|
116
|
-
|
|
117
|
-
self.
|
|
118
|
-
self.
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 2022/5/8, update on 2022/5/8
|
|
3
|
+
References:
|
|
4
|
+
paper: (ICDM'2018) Self-attentive sequential recommendation
|
|
5
|
+
url: https://arxiv.org/pdf/1808.09781.pdf
|
|
6
|
+
code: https://github.com/kang205/SASRec
|
|
7
|
+
Authors: Yuchen Wang, 615922749@qq.com
|
|
8
|
+
"""
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
|
|
13
|
+
from torch_rechub.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
14
|
+
from torch_rechub.basic.layers import MLP, EmbeddingLayer
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SASRec(torch.nn.Module):
|
|
18
|
+
"""SASRec: Self-Attentive Sequential Recommendation
|
|
19
|
+
Args:
|
|
20
|
+
features (list): the list of `Feature Class`. In sasrec, the features list needs to have three elements in order: user historical behavior sequence features, positive sample sequence, and negative sample sequence.
|
|
21
|
+
max_len: The length of the sequence feature.
|
|
22
|
+
num_blocks: The number of stacks of attention modules.
|
|
23
|
+
num_heads: The number of heads in MultiheadAttention.
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
features,
|
|
30
|
+
max_len=50,
|
|
31
|
+
dropout_rate=0.5,
|
|
32
|
+
num_blocks=2,
|
|
33
|
+
num_heads=1,
|
|
34
|
+
):
|
|
35
|
+
super(SASRec, self).__init__()
|
|
36
|
+
|
|
37
|
+
self.features = features
|
|
38
|
+
|
|
39
|
+
self.item_num = self.features[0].vocab_size
|
|
40
|
+
self.embed_dim = self.features[0].embed_dim
|
|
41
|
+
|
|
42
|
+
self.item_emb = EmbeddingLayer(self.features)
|
|
43
|
+
self.position_emb = torch.nn.Embedding(max_len, self.embed_dim)
|
|
44
|
+
self.emb_dropout = torch.nn.Dropout(p=dropout_rate)
|
|
45
|
+
|
|
46
|
+
self.attention_layernorms = torch.nn.ModuleList()
|
|
47
|
+
self.attention_layers = torch.nn.ModuleList()
|
|
48
|
+
self.forward_layernorms = torch.nn.ModuleList()
|
|
49
|
+
self.forward_layers = torch.nn.ModuleList()
|
|
50
|
+
self.last_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
|
|
51
|
+
|
|
52
|
+
for _ in range(num_blocks):
|
|
53
|
+
new_attn_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
|
|
54
|
+
self.attention_layernorms.append(new_attn_layernorm)
|
|
55
|
+
|
|
56
|
+
new_attn_layer = torch.nn.MultiheadAttention(self.embed_dim, num_heads, dropout_rate)
|
|
57
|
+
self.attention_layers.append(new_attn_layer)
|
|
58
|
+
|
|
59
|
+
new_fwd_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
|
|
60
|
+
self.forward_layernorms.append(new_fwd_layernorm)
|
|
61
|
+
|
|
62
|
+
new_fwd_layer = PointWiseFeedForward(self.embed_dim, dropout_rate)
|
|
63
|
+
self.forward_layers.append(new_fwd_layer)
|
|
64
|
+
|
|
65
|
+
def seq_forward(self, x, embed_x_feature):
|
|
66
|
+
x = x['seq']
|
|
67
|
+
|
|
68
|
+
embed_x_feature *= self.features[0].embed_dim**0.5
|
|
69
|
+
embed_x_feature = embed_x_feature.squeeze() # (bacth_size, max_len, embed_dim)
|
|
70
|
+
|
|
71
|
+
positions = np.tile(np.array(range(x.shape[1])), [x.shape[0], 1])
|
|
72
|
+
|
|
73
|
+
embed_x_feature += self.position_emb(torch.LongTensor(positions))
|
|
74
|
+
embed_x_feature = self.emb_dropout(embed_x_feature)
|
|
75
|
+
|
|
76
|
+
timeline_mask = torch.BoolTensor(x == 0)
|
|
77
|
+
embed_x_feature *= ~timeline_mask.unsqueeze(-1)
|
|
78
|
+
|
|
79
|
+
attention_mask = ~torch.tril(torch.ones((embed_x_feature.shape[1], embed_x_feature.shape[1]), dtype=torch.bool))
|
|
80
|
+
|
|
81
|
+
for i in range(len(self.attention_layers)):
|
|
82
|
+
embed_x_feature = torch.transpose(embed_x_feature, 0, 1)
|
|
83
|
+
Q = self.attention_layernorms[i](embed_x_feature)
|
|
84
|
+
mha_outputs, _ = self.attention_layers[i](Q, embed_x_feature, embed_x_feature, attn_mask=attention_mask)
|
|
85
|
+
|
|
86
|
+
embed_x_feature = Q + mha_outputs
|
|
87
|
+
embed_x_feature = torch.transpose(embed_x_feature, 0, 1)
|
|
88
|
+
|
|
89
|
+
embed_x_feature = self.forward_layernorms[i](embed_x_feature)
|
|
90
|
+
embed_x_feature = self.forward_layers[i](embed_x_feature)
|
|
91
|
+
embed_x_feature *= ~timeline_mask.unsqueeze(-1)
|
|
92
|
+
|
|
93
|
+
seq_output = self.last_layernorm(embed_x_feature)
|
|
94
|
+
|
|
95
|
+
return seq_output
|
|
96
|
+
|
|
97
|
+
def forward(self, x):
|
|
98
|
+
# (batch_size, 3, max_len, embed_dim)
|
|
99
|
+
embedding = self.item_emb(x, self.features)
|
|
100
|
+
# (batch_size, max_len, embed_dim)
|
|
101
|
+
seq_embed, pos_embed, neg_embed = embedding[:, 0], embedding[:, 1], embedding[:, 2]
|
|
102
|
+
|
|
103
|
+
# (batch_size, max_len, embed_dim)
|
|
104
|
+
seq_output = self.seq_forward(x, seq_embed)
|
|
105
|
+
|
|
106
|
+
pos_logits = (seq_output * pos_embed).sum(dim=-1)
|
|
107
|
+
neg_logits = (seq_output * neg_embed).sum(dim=-1) # (batch_size, max_len)
|
|
108
|
+
|
|
109
|
+
return pos_logits, neg_logits
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class PointWiseFeedForward(torch.nn.Module):
|
|
113
|
+
|
|
114
|
+
def __init__(self, hidden_units, dropout_rate):
|
|
115
|
+
super(PointWiseFeedForward, self).__init__()
|
|
116
|
+
|
|
117
|
+
self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
|
|
118
|
+
self.dropout1 = torch.nn.Dropout(p=dropout_rate)
|
|
119
|
+
self.relu = torch.nn.ReLU()
|
|
120
|
+
self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
|
|
121
|
+
self.dropout2 = torch.nn.Dropout(p=dropout_rate)
|
|
122
|
+
|
|
123
|
+
def forward(self, inputs):
|
|
124
|
+
outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
|
|
125
|
+
outputs = outputs.transpose(-1, -2)
|
|
126
|
+
outputs += inputs
|
|
127
|
+
return outputs
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
if __name__ == '__main__':
|
|
131
|
+
seq = SequenceFeature('seq', vocab_size=17, embed_dim=7, pooling='concat')
|
|
132
|
+
pos = SequenceFeature('pos', vocab_size=17, embed_dim=7, pooling='concat', shared_with='seq')
|
|
133
|
+
neg = SequenceFeature('neg', vocab_size=17, embed_dim=7, pooling='concat', shared_with='seq')
|
|
134
|
+
|
|
135
|
+
seq = [seq, pos, neg]
|
|
136
|
+
|
|
137
|
+
hist_seq = torch.tensor([[1, 2, 3, 4], [2, 3, 7, 8]])
|
|
138
|
+
pos_seq = hist_seq
|
|
139
|
+
neg_seq = hist_seq
|
|
140
|
+
|
|
141
|
+
x = {'seq': hist_seq, 'pos': pos_seq, 'neg': neg_seq}
|
|
142
|
+
model = SASRec(features=seq)
|
|
143
|
+
print('out', model(x))
|