torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,76 +1,82 @@
1
- """
2
- Date: created on 06/09/2022
3
- References:
4
- paper: Neural Attentive Session-based Recommendation
5
- url: http://arxiv.org/abs/1711.04725
6
- official Theano implementation: https://github.com/lijingsdu/sessionRec_NARM
7
- another Pytorch implementation: https://github.com/Wang-Shuo/Neural-Attentive-Session-Based-Recommendation-PyTorch
8
- Authors: Bo Kang, klinux@live.com
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.utils.rnn as rnn_utils
14
- from torch import sigmoid
15
- from torch.nn import GRU, Embedding, Dropout, Parameter
16
-
17
-
18
- class NARM(nn.Module):
19
- def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p):
20
- super(NARM, self).__init__()
21
-
22
- # item embedding layer
23
- self.item_history_feature = item_history_feature
24
- self.item_emb = Embedding(item_history_feature.vocab_size, item_history_feature.embed_dim, padding_idx=0)
25
-
26
- # embedding dropout layer
27
- self.emb_dropout = Dropout(emb_dropout_p)
28
-
29
- # gru unit
30
- self.gru = GRU(input_size=item_history_feature.embed_dim, hidden_size=hidden_dim)
31
-
32
- # attention projection matrices
33
- self.a_1, self.a_2 = Parameter(torch.randn(hidden_dim, hidden_dim)), Parameter(torch.randn(hidden_dim, hidden_dim))
34
-
35
- # attention context vector
36
- self.v = Parameter(torch.randn(hidden_dim, 1))
37
-
38
- # session representation dropout layer
39
- self.session_rep_dropout = Dropout(session_rep_dropout_p)
40
-
41
- # bilinear projection matrix
42
- self.b = Parameter(torch.randn(item_history_feature.embed_dim, hidden_dim * 2))
43
-
44
- def forward(self, input_dict):
45
- # Eq. 1-4, index item embeddings and pass through gru
46
- ## Fetch the embeddings for items in the session
47
- input = input_dict[self.item_history_feature.name]
48
- value_mask = (input != 0)
49
- value_counts = value_mask.sum(dim=1, keepdim=False).to("cpu").detach()
50
- embs = rnn_utils.pack_padded_sequence(self.emb_dropout(self.item_emb(input)), value_counts, batch_first=True, enforce_sorted=False)
51
-
52
- ## compute hidden states at each time step
53
- h, h_t = self.gru(embs)
54
- h_t = h_t.permute(1, 0, 2)
55
- h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
56
-
57
- # Eq. 5, set last hidden state of gru as the output of the global encoder
58
- c_g = h_t.squeeze(1)
59
-
60
- # Eq. 8, compute similarity between final hidden state and previous hidden states
61
- q = sigmoid(h_t @ self.a_1.T + h @ self.a_2.T) @ self.v
62
-
63
- # Eq. 7, compute attention
64
- alpha = torch.exp(q) * value_mask.unsqueeze(-1)
65
- alpha /= alpha.sum(dim=1, keepdim=True)
66
-
67
- # Eq. 6, compute the output of the local encoder
68
- c_l = (alpha * h).sum(1)
69
-
70
- # Eq. 9, compute session representation by concatenating user sequential behavior (global) and main purpose in the current session (local)
71
- c = self.session_rep_dropout(torch.hstack((c_g, c_l)))
72
-
73
- # Eq. 10, compute bilinear similarity between current session and each candidate items
74
- s = c @ self.b.T @ self.item_emb.weight.T
75
-
76
- return s
1
+ """
2
+ Date: created on 06/09/2022
3
+ References:
4
+ paper: Neural Attentive Session-based Recommendation
5
+ url: http://arxiv.org/abs/1711.04725
6
+ official Theano implementation: https://github.com/lijingsdu/sessionRec_NARM
7
+ another Pytorch implementation: https://github.com/Wang-Shuo/Neural-Attentive-Session-Based-Recommendation-PyTorch
8
+ Authors: Bo Kang, klinux@live.com
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.utils.rnn as rnn_utils
14
+ from torch import sigmoid
15
+ from torch.nn import GRU, Dropout, Embedding, Parameter
16
+
17
+
18
+ class NARM(nn.Module):
19
+
20
+ def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p):
21
+ super(NARM, self).__init__()
22
+
23
+ # item embedding layer
24
+ self.item_history_feature = item_history_feature
25
+ self.item_emb = Embedding(item_history_feature.vocab_size, item_history_feature.embed_dim, padding_idx=0)
26
+
27
+ # embedding dropout layer
28
+ self.emb_dropout = Dropout(emb_dropout_p)
29
+
30
+ # gru unit
31
+ self.gru = GRU(input_size=item_history_feature.embed_dim, hidden_size=hidden_dim)
32
+
33
+ # attention projection matrices
34
+ self.a_1, self.a_2 = Parameter(torch.randn(hidden_dim, hidden_dim)), Parameter(torch.randn(hidden_dim, hidden_dim))
35
+
36
+ # attention context vector
37
+ self.v = Parameter(torch.randn(hidden_dim, 1))
38
+
39
+ # session representation dropout layer
40
+ self.session_rep_dropout = Dropout(session_rep_dropout_p)
41
+
42
+ # bilinear projection matrix
43
+ self.b = Parameter(torch.randn(item_history_feature.embed_dim, hidden_dim * 2))
44
+
45
+ def forward(self, input_dict):
46
+ # Eq. 1-4, index item embeddings and pass through gru
47
+ # # Fetch the embeddings for items in the session
48
+ input = input_dict[self.item_history_feature.name]
49
+ value_mask = (input != 0)
50
+ value_counts = value_mask.sum(dim=1, keepdim=False).to("cpu").detach()
51
+ embs = rnn_utils.pack_padded_sequence(self.emb_dropout(self.item_emb(input)), value_counts, batch_first=True, enforce_sorted=False)
52
+
53
+ # # compute hidden states at each time step
54
+ h, h_t = self.gru(embs)
55
+ h_t = h_t.permute(1, 0, 2)
56
+ h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
57
+
58
+ # Eq. 5, set last hidden state of gru as the output of the global
59
+ # encoder
60
+ c_g = h_t.squeeze(1)
61
+
62
+ # Eq. 8, compute similarity between final hidden state and previous
63
+ # hidden states
64
+ q = sigmoid(h_t @ self.a_1.T + h @ self.a_2.T) @ self.v
65
+
66
+ # Eq. 7, compute attention
67
+ alpha = torch.exp(q) * value_mask.unsqueeze(-1)
68
+ alpha /= alpha.sum(dim=1, keepdim=True)
69
+
70
+ # Eq. 6, compute the output of the local encoder
71
+ c_l = (alpha * h).sum(1)
72
+
73
+ # Eq. 9, compute session representation by concatenating user
74
+ # sequential behavior (global) and main purpose in the current session
75
+ # (local)
76
+ c = self.session_rep_dropout(torch.hstack((c_g, c_l)))
77
+
78
+ # Eq. 10, compute bilinear similarity between current session and each
79
+ # candidate items
80
+ s = c @ self.b.T @ self.item_emb.weight.T
81
+
82
+ return s
@@ -1,140 +1,143 @@
1
- """
2
- Date: create on 2022/5/8, update on 2022/5/8
3
- References:
4
- paper: (ICDM'2018) Self-attentive sequential recommendation
5
- url: https://arxiv.org/pdf/1808.09781.pdf
6
- code: https://github.com/kang205/SASRec
7
- Authors: Yuchen Wang, 615922749@qq.com
8
- """
9
- import numpy as np
10
- import torch
11
- import torch.nn as nn
12
-
13
- from torch_rechub.basic.features import DenseFeature, SparseFeature, SequenceFeature
14
- from torch_rechub.basic.layers import EmbeddingLayer, MLP
15
-
16
-
17
- class SASRec(torch.nn.Module):
18
- """SASRec: Self-Attentive Sequential Recommendation
19
- Args:
20
- features (list): the list of `Feature Class`. In sasrec, the features list needs to have three elements in order: user historical behavior sequence features, positive sample sequence, and negative sample sequence.
21
- max_len: The length of the sequence feature.
22
- num_blocks: The number of stacks of attention modules.
23
- num_heads: The number of heads in MultiheadAttention.
24
-
25
- """
26
- def __init__(self,
27
- features,
28
- max_len=50,
29
- dropout_rate=0.5,
30
- num_blocks=2,
31
- num_heads=1,
32
- ):
33
- super(SASRec, self).__init__()
34
-
35
- self.features = features
36
-
37
- self.item_num = self.features[0].vocab_size
38
- self.embed_dim = self.features[0].embed_dim
39
-
40
- self.item_emb = EmbeddingLayer(self.features)
41
- self.position_emb = torch.nn.Embedding(max_len, self.embed_dim)
42
- self.emb_dropout = torch.nn.Dropout(p=dropout_rate)
43
-
44
- self.attention_layernorms = torch.nn.ModuleList()
45
- self.attention_layers = torch.nn.ModuleList()
46
- self.forward_layernorms = torch.nn.ModuleList()
47
- self.forward_layers = torch.nn.ModuleList()
48
- self.last_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
49
-
50
- for _ in range(num_blocks):
51
- new_attn_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
52
- self.attention_layernorms.append(new_attn_layernorm)
53
-
54
- new_attn_layer = torch.nn.MultiheadAttention(self.embed_dim,
55
- num_heads,
56
- dropout_rate)
57
- self.attention_layers.append(new_attn_layer)
58
-
59
- new_fwd_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
60
- self.forward_layernorms.append(new_fwd_layernorm)
61
-
62
- new_fwd_layer = PointWiseFeedForward(self.embed_dim, dropout_rate)
63
- self.forward_layers.append(new_fwd_layer)
64
-
65
- def seq_forward(self, x, embed_x_feature):
66
- x = x['seq']
67
-
68
- embed_x_feature *= self.features[0].embed_dim ** 0.5
69
- embed_x_feature = embed_x_feature.squeeze() # (bacth_size, max_len, embed_dim)
70
-
71
- positions = np.tile(np.array(range(x.shape[1])), [x.shape[0], 1])
72
-
73
- embed_x_feature += self.position_emb(torch.LongTensor(positions))
74
- embed_x_feature = self.emb_dropout(embed_x_feature)
75
-
76
- timeline_mask = torch.BoolTensor(x == 0)
77
- embed_x_feature *= ~timeline_mask.unsqueeze(-1)
78
-
79
- attention_mask = ~torch.tril(torch.ones((embed_x_feature.shape[1], embed_x_feature.shape[1]), dtype=torch.bool))
80
-
81
- for i in range(len(self.attention_layers)):
82
- embed_x_feature = torch.transpose(embed_x_feature, 0, 1)
83
- Q = self.attention_layernorms[i](embed_x_feature)
84
- mha_outputs, _ = self.attention_layers[i](Q, embed_x_feature, embed_x_feature,
85
- attn_mask=attention_mask)
86
-
87
- embed_x_feature = Q + mha_outputs
88
- embed_x_feature = torch.transpose(embed_x_feature, 0, 1)
89
-
90
- embed_x_feature = self.forward_layernorms[i](embed_x_feature)
91
- embed_x_feature = self.forward_layers[i](embed_x_feature)
92
- embed_x_feature *= ~timeline_mask.unsqueeze(-1)
93
-
94
- seq_output = self.last_layernorm(embed_x_feature)
95
-
96
- return seq_output
97
-
98
- def forward(self, x):
99
- embedding = self.item_emb(x, self.features) # (batch_size, 3, max_len, embed_dim)
100
- seq_embed, pos_embed, neg_embed = embedding[:, 0], embedding[:, 1], embedding[:, 2] # (batch_size, max_len, embed_dim)
101
-
102
- seq_output = self.seq_forward(x, seq_embed) # (batch_size, max_len, embed_dim)
103
-
104
- pos_logits = (seq_output * pos_embed).sum(dim=-1)
105
- neg_logits = (seq_output * neg_embed).sum(dim=-1) # (batch_size, max_len)
106
-
107
- return pos_logits, neg_logits
108
-
109
-
110
- class PointWiseFeedForward(torch.nn.Module):
111
- def __init__(self, hidden_units, dropout_rate):
112
- super(PointWiseFeedForward, self).__init__()
113
-
114
- self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
115
- self.dropout1 = torch.nn.Dropout(p=dropout_rate)
116
- self.relu = torch.nn.ReLU()
117
- self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
118
- self.dropout2 = torch.nn.Dropout(p=dropout_rate)
119
-
120
- def forward(self, inputs):
121
- outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
122
- outputs = outputs.transpose(-1, -2)
123
- outputs += inputs
124
- return outputs
125
-
126
-
127
- if __name__ == '__main__':
128
- seq = SequenceFeature('seq', vocab_size=17, embed_dim=7, pooling='concat')
129
- pos = SequenceFeature('pos', vocab_size=17, embed_dim=7, pooling='concat', shared_with='seq')
130
- neg = SequenceFeature('neg', vocab_size=17, embed_dim=7, pooling='concat', shared_with='seq')
131
-
132
- seq = [seq, pos, neg]
133
-
134
- hist_seq = torch.tensor([[1, 2, 3, 4], [2, 3, 7, 8]])
135
- pos_seq = hist_seq
136
- neg_seq = hist_seq
137
-
138
- x = {'seq': hist_seq, 'pos': pos_seq, 'neg': neg_seq}
139
- model = SASRec(features=seq)
140
- print('out', model(x))
1
+ """
2
+ Date: create on 2022/5/8, update on 2022/5/8
3
+ References:
4
+ paper: (ICDM'2018) Self-attentive sequential recommendation
5
+ url: https://arxiv.org/pdf/1808.09781.pdf
6
+ code: https://github.com/kang205/SASRec
7
+ Authors: Yuchen Wang, 615922749@qq.com
8
+ """
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from torch_rechub.basic.features import DenseFeature, SequenceFeature, SparseFeature
14
+ from torch_rechub.basic.layers import MLP, EmbeddingLayer
15
+
16
+
17
+ class SASRec(torch.nn.Module):
18
+ """SASRec: Self-Attentive Sequential Recommendation
19
+ Args:
20
+ features (list): the list of `Feature Class`. In sasrec, the features list needs to have three elements in order: user historical behavior sequence features, positive sample sequence, and negative sample sequence.
21
+ max_len: The length of the sequence feature.
22
+ num_blocks: The number of stacks of attention modules.
23
+ num_heads: The number of heads in MultiheadAttention.
24
+
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ features,
30
+ max_len=50,
31
+ dropout_rate=0.5,
32
+ num_blocks=2,
33
+ num_heads=1,
34
+ ):
35
+ super(SASRec, self).__init__()
36
+
37
+ self.features = features
38
+
39
+ self.item_num = self.features[0].vocab_size
40
+ self.embed_dim = self.features[0].embed_dim
41
+
42
+ self.item_emb = EmbeddingLayer(self.features)
43
+ self.position_emb = torch.nn.Embedding(max_len, self.embed_dim)
44
+ self.emb_dropout = torch.nn.Dropout(p=dropout_rate)
45
+
46
+ self.attention_layernorms = torch.nn.ModuleList()
47
+ self.attention_layers = torch.nn.ModuleList()
48
+ self.forward_layernorms = torch.nn.ModuleList()
49
+ self.forward_layers = torch.nn.ModuleList()
50
+ self.last_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
51
+
52
+ for _ in range(num_blocks):
53
+ new_attn_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
54
+ self.attention_layernorms.append(new_attn_layernorm)
55
+
56
+ new_attn_layer = torch.nn.MultiheadAttention(self.embed_dim, num_heads, dropout_rate)
57
+ self.attention_layers.append(new_attn_layer)
58
+
59
+ new_fwd_layernorm = torch.nn.LayerNorm(self.embed_dim, eps=1e-8)
60
+ self.forward_layernorms.append(new_fwd_layernorm)
61
+
62
+ new_fwd_layer = PointWiseFeedForward(self.embed_dim, dropout_rate)
63
+ self.forward_layers.append(new_fwd_layer)
64
+
65
+ def seq_forward(self, x, embed_x_feature):
66
+ x = x['seq']
67
+
68
+ embed_x_feature *= self.features[0].embed_dim**0.5
69
+ embed_x_feature = embed_x_feature.squeeze() # (bacth_size, max_len, embed_dim)
70
+
71
+ positions = np.tile(np.array(range(x.shape[1])), [x.shape[0], 1])
72
+
73
+ embed_x_feature += self.position_emb(torch.LongTensor(positions))
74
+ embed_x_feature = self.emb_dropout(embed_x_feature)
75
+
76
+ timeline_mask = torch.BoolTensor(x == 0)
77
+ embed_x_feature *= ~timeline_mask.unsqueeze(-1)
78
+
79
+ attention_mask = ~torch.tril(torch.ones((embed_x_feature.shape[1], embed_x_feature.shape[1]), dtype=torch.bool))
80
+
81
+ for i in range(len(self.attention_layers)):
82
+ embed_x_feature = torch.transpose(embed_x_feature, 0, 1)
83
+ Q = self.attention_layernorms[i](embed_x_feature)
84
+ mha_outputs, _ = self.attention_layers[i](Q, embed_x_feature, embed_x_feature, attn_mask=attention_mask)
85
+
86
+ embed_x_feature = Q + mha_outputs
87
+ embed_x_feature = torch.transpose(embed_x_feature, 0, 1)
88
+
89
+ embed_x_feature = self.forward_layernorms[i](embed_x_feature)
90
+ embed_x_feature = self.forward_layers[i](embed_x_feature)
91
+ embed_x_feature *= ~timeline_mask.unsqueeze(-1)
92
+
93
+ seq_output = self.last_layernorm(embed_x_feature)
94
+
95
+ return seq_output
96
+
97
+ def forward(self, x):
98
+ # (batch_size, 3, max_len, embed_dim)
99
+ embedding = self.item_emb(x, self.features)
100
+ # (batch_size, max_len, embed_dim)
101
+ seq_embed, pos_embed, neg_embed = embedding[:, 0], embedding[:, 1], embedding[:, 2]
102
+
103
+ # (batch_size, max_len, embed_dim)
104
+ seq_output = self.seq_forward(x, seq_embed)
105
+
106
+ pos_logits = (seq_output * pos_embed).sum(dim=-1)
107
+ neg_logits = (seq_output * neg_embed).sum(dim=-1) # (batch_size, max_len)
108
+
109
+ return pos_logits, neg_logits
110
+
111
+
112
+ class PointWiseFeedForward(torch.nn.Module):
113
+
114
+ def __init__(self, hidden_units, dropout_rate):
115
+ super(PointWiseFeedForward, self).__init__()
116
+
117
+ self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
118
+ self.dropout1 = torch.nn.Dropout(p=dropout_rate)
119
+ self.relu = torch.nn.ReLU()
120
+ self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
121
+ self.dropout2 = torch.nn.Dropout(p=dropout_rate)
122
+
123
+ def forward(self, inputs):
124
+ outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
125
+ outputs = outputs.transpose(-1, -2)
126
+ outputs += inputs
127
+ return outputs
128
+
129
+
130
+ if __name__ == '__main__':
131
+ seq = SequenceFeature('seq', vocab_size=17, embed_dim=7, pooling='concat')
132
+ pos = SequenceFeature('pos', vocab_size=17, embed_dim=7, pooling='concat', shared_with='seq')
133
+ neg = SequenceFeature('neg', vocab_size=17, embed_dim=7, pooling='concat', shared_with='seq')
134
+
135
+ seq = [seq, pos, neg]
136
+
137
+ hist_seq = torch.tensor([[1, 2, 3, 4], [2, 3, 7, 8]])
138
+ pos_seq = hist_seq
139
+ neg_seq = hist_seq
140
+
141
+ x = {'seq': hist_seq, 'pos': pos_seq, 'neg': neg_seq}
142
+ model = SASRec(features=seq)
143
+ print('out', model(x))