torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,42 +1,43 @@
1
- """
2
- Date: create on 22/04/2022
3
- References:
4
- paper: (IJCAI'2017) DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
5
- url: https://arxiv.org/abs/1703.04247
6
- Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
7
- """
8
-
9
- import torch
10
-
11
- from ...basic.layers import FM, MLP, LR, EmbeddingLayer
12
-
13
-
14
- class DeepFM(torch.nn.Module):
15
- """Deep Factorization Machine Model
16
-
17
- Args:
18
- deep_features (list): the list of `Feature Class`, training by the deep part module.
19
- fm_features (list): the list of `Feature Class`, training by the fm part module.
20
- mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
21
- """
22
-
23
- def __init__(self, deep_features, fm_features, mlp_params):
24
- super(DeepFM, self).__init__()
25
- self.deep_features = deep_features
26
- self.fm_features = fm_features
27
- self.deep_dims = sum([fea.embed_dim for fea in deep_features])
28
- self.fm_dims = sum([fea.embed_dim for fea in fm_features])
29
- self.linear = LR(self.fm_dims) # 1-odrder interaction
30
- self.fm = FM(reduce_sum=True) # 2-odrder interaction
31
- self.embedding = EmbeddingLayer(deep_features + fm_features)
32
- self.mlp = MLP(self.deep_dims, **mlp_params)
33
-
34
- def forward(self, x):
35
- input_deep = self.embedding(x, self.deep_features, squeeze_dim=True) #[batch_size, deep_dims]
36
- input_fm = self.embedding(x, self.fm_features, squeeze_dim=False) #[batch_size, num_fields, embed_dim]
37
-
38
- y_linear = self.linear(input_fm.flatten(start_dim=1))
39
- y_fm = self.fm(input_fm)
40
- y_deep = self.mlp(input_deep) #[batch_size, 1]
41
- y = y_linear + y_fm + y_deep
42
- return torch.sigmoid(y.squeeze(1))
1
+ """
2
+ Date: create on 22/04/2022
3
+ References:
4
+ paper: (IJCAI'2017) DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
5
+ url: https://arxiv.org/abs/1703.04247
6
+ Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
7
+ """
8
+
9
+ import torch
10
+
11
+ from ...basic.layers import FM, LR, MLP, EmbeddingLayer
12
+
13
+
14
+ class DeepFM(torch.nn.Module):
15
+ """Deep Factorization Machine Model
16
+
17
+ Args:
18
+ deep_features (list): the list of `Feature Class`, training by the deep part module.
19
+ fm_features (list): the list of `Feature Class`, training by the fm part module.
20
+ mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
21
+ """
22
+
23
+ def __init__(self, deep_features, fm_features, mlp_params):
24
+ super(DeepFM, self).__init__()
25
+ self.deep_features = deep_features
26
+ self.fm_features = fm_features
27
+ self.deep_dims = sum([fea.embed_dim for fea in deep_features])
28
+ self.fm_dims = sum([fea.embed_dim for fea in fm_features])
29
+ self.linear = LR(self.fm_dims) # 1-odrder interaction
30
+ self.fm = FM(reduce_sum=True) # 2-odrder interaction
31
+ self.embedding = EmbeddingLayer(deep_features + fm_features)
32
+ self.mlp = MLP(self.deep_dims, **mlp_params)
33
+
34
+ def forward(self, x):
35
+ input_deep = self.embedding(x, self.deep_features, squeeze_dim=True) # [batch_size, deep_dims]
36
+ # [batch_size, num_fields, embed_dim]
37
+ input_fm = self.embedding(x, self.fm_features, squeeze_dim=False)
38
+
39
+ y_linear = self.linear(input_fm.flatten(start_dim=1))
40
+ y_fm = self.fm(input_fm)
41
+ y_deep = self.mlp(input_deep) # [batch_size, 1]
42
+ y = y_linear + y_fm + y_deep
43
+ return torch.sigmoid(y.squeeze(1))
@@ -1,191 +1,191 @@
1
- """
2
- Date: create on 01/05/2024
3
- References:
4
- paper: (AAAI'2019) Deep Interest Evolution Network for Click-Through Rate Prediction
5
- url: https://arxiv.org/pdf/1809.03672
6
- Authors: Tao Fan, thisisevy@foxmail.com
7
- """
8
-
9
- import torch
10
- from torch import nn
11
- from torch.nn import Parameter, init
12
-
13
- from ...basic.layers import MLP, EmbeddingLayer
14
-
15
-
16
- class AUGRU(nn.Module):
17
-
18
- def __init__(self, embed_dim):
19
- super(AUGRU, self).__init__()
20
- self.embed_dim = embed_dim
21
- # 初始化AUGRU单元
22
- self.augru_cell = AUGRU_Cell(self.embed_dim)
23
-
24
- def forward(self, x, item):
25
- '''
26
- :param x: 输入的序列向量,维度为 [ batch_size, seq_lens, embed_dim ]
27
- :param item: 目标物品的向量
28
- :return: outs: 所有AUGRU单元输出的隐藏向量[ batch_size, seq_lens, embed_dim ]
29
- h: 最后一个AUGRU单元输出的隐藏向量[ batch_size, embed_dim ]
30
- '''
31
- outs = []
32
- h = None
33
- # 开始循环,x.shape[1]是序列的长度
34
- for i in range(x.shape[1]):
35
- if h == None:
36
- # 初始化第一层的输入h
37
- h = Parameter(torch.rand(x.shape[0], self.embed_dim).to(x.device))
38
- h = self.augru_cell(x[:, i], h, item)
39
- outs.append(torch.unsqueeze(h, dim=1))
40
- outs = torch.cat(outs, dim=1)
41
- return outs, h
42
-
43
-
44
- # AUGRU单元
45
- class AUGRU_Cell(nn.Module):
46
-
47
- def __init__(self, embed_dim):
48
- """
49
- :param embed_dim: 输入向量的维度
50
- """
51
- super(AUGRU_Cell, self).__init__()
52
-
53
- # 初始化更新门的模型参数
54
- self.Wu = Parameter(torch.rand(embed_dim, embed_dim))
55
- self.Uu = Parameter(torch.rand(embed_dim, embed_dim))
56
- self.bu = init.xavier_uniform_(Parameter(torch.empty(1, embed_dim)))
57
-
58
- # 初始化重置门的模型参数
59
- self.Wr = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
60
- self.Ur = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
61
- self.br = init.xavier_uniform_(Parameter(torch.empty(1, embed_dim)))
62
-
63
- # 初始化计算h~的模型参数
64
- self.Wh = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
65
- self.Uh = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
66
- self.bh = init.xavier_uniform_(Parameter(torch.empty(1, embed_dim)))
67
-
68
- # 初始化注意计算里的模型参数
69
- self.Wa = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
70
-
71
- # 注意力的计算
72
- def attention(self, x, item):
73
- '''
74
- :param x: 输入的序列中第t个向量 [ batch_size, embed_dim ]
75
- :param item: 目标物品的向量 [ batch_size, embed_dim ]
76
- :return: 注意力权重 [ batch_size, 1 ]
77
- '''
78
- hW = torch.matmul(x, self.Wa)
79
- hWi = torch.sum(hW * item, dim=1)
80
- hWi = torch.unsqueeze(hWi, 1)
81
- return torch.softmax(hWi, dim=1)
82
-
83
- def forward(self, x, h_1, item):
84
- '''
85
- :param x: 输入的序列中第t个物品向量 [ batch_size, embed_dim ]
86
- :param h_1: 上一个AUGRU单元输出的隐藏向量 [ batch_size, embed_dim ]
87
- :param item: 目标物品的向量 [ batch_size, embed_dim ]
88
- :return: h 当前层输出的隐藏向量 [ batch_size, embed_dim ]
89
- '''
90
- # [ batch_size, embed_dim ]
91
- u = torch.sigmoid(torch.matmul(x, self.Wu) + torch.matmul(h_1, self.Uu) + self.bu)
92
- # [ batch_size, embed_dim ]
93
- r = torch.sigmoid(torch.matmul(x, self.Wr) + torch.matmul(h_1, self.Ur) + self.br)
94
- # [ batch_size, embed_dim ]
95
- h_hat = torch.tanh(torch.matmul(x, self.Wh) + r * torch.matmul(h_1, self.Uh) + self.bh)
96
- # [ batch_size, 1 ]
97
- a = self.attention(x, item)
98
- # [ batch_size, embed_dim ]
99
- u_hat = a * u
100
- # [ batch_size, embed_dim ]
101
- h = (1 - u_hat) * h_1 + u_hat * h_hat
102
- # [ batch_size, embed_dim ]
103
- return h
104
-
105
-
106
- class DIEN(nn.Module):
107
- """Deep Interest Evolution Network
108
- Args:
109
- features (list): the list of `Feature Class`. training by MLP. It means the user profile features and context features in origin paper, exclude history and target features.
110
- history_features (list): the list of `Feature Class`,training by ActivationUnit. It means the user behaviour sequence features, eg.item id sequence, shop id sequence.
111
- target_features (list): the list of `Feature Class`, training by ActivationUnit. It means the target feature which will execute target-attention with history feature.
112
- mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
113
- history_labels (list): the list of history_features whether it is clicked history or not. It should be 0 or 1.
114
- alpha (float): the weighting of auxiliary loss.
115
- """
116
-
117
- def __init__(self, features, history_features, target_features, mlp_params, history_labels,
118
- alpha=0.2):
119
- super().__init__()
120
- self.alpha = alpha # 计算辅助损失函数时的权重
121
- self.features = features
122
- self.history_features = history_features
123
- self.target_features = target_features
124
- self.num_history_features = len(history_features)
125
- self.all_dims = sum([fea.embed_dim for fea in features + history_features + target_features])
126
- # self.GRU = nn.GRU(batch_first=True)
127
- self.embedding = EmbeddingLayer(features + history_features + target_features)
128
- self.interest_extractor_layers = nn.ModuleList(
129
- [nn.GRU(fea.embed_dim, fea.embed_dim, batch_first=True) for fea in self.history_features])
130
- self.interest_evolving_layers = nn.ModuleList(
131
- [AUGRU(fea.embed_dim) for fea in self.history_features])
132
-
133
- self.mlp = MLP(self.all_dims, activation="dice", **mlp_params)
134
- self.history_labels = torch.Tensor(history_labels)
135
- self.BCELoss = nn.BCELoss()
136
- # # 注意力计算中的线性层
137
- # self.attention_liner = nn.Linear(self.embed_dim, t)
138
- # # AFM公式中的h
139
- # self.h = init.xavier_uniform_(Parameter(torch.empty(t, 1)))
140
- # # AFM公式中的p
141
- # self.p = init.xavier_uniform_(Parameter(torch.empty(self.embed_dim, 1)))
142
-
143
- def auxiliary(self, outs, history_features, history_labels):
144
- '''
145
- :param history_features: 历史序列物品的向量 [ batch_size, len_seqs, dim ]
146
- :param outs: 兴趣抽取层GRU网络输出的outs [ batch_size, len_seqs, dim ]
147
- :param history_labels: 历史序列物品标注 [ batch_size, len_seqs, 1 ]
148
- :return: 辅助损失函数
149
- '''
150
- # [ batch_size * len_seqs, dim ]
151
- history_features = history_features.reshape(-1, history_features.shape[2])
152
- # [ batch_size * len_seqs, dim ]
153
- outs = outs.reshape(-1, outs.shape[2])
154
- # [ batch_size * len_seqs ]
155
- out = torch.sum(outs * history_features, dim=1)
156
- # [ batch_size * len_seqs, 1 ]
157
- out = torch.unsqueeze(torch.sigmoid(out), 1)
158
- # [ batch_size * len_seqs,1 ]
159
- history_labels = history_labels.reshape(-1, 1).float()
160
- return self.BCELoss(out, history_labels)
161
-
162
- def forward(self, x):
163
- embed_x_features = self.embedding(x, self.features) # (batch_size, num_features, emb_dim)
164
- embed_x_history = self.embedding(
165
- x, self.history_features) # (batch_size, num_history_features, seq_length, emb_dim)
166
- embed_x_target = self.embedding(x, self.target_features) # (batch_size, num_target_features, emb_dim)
167
-
168
- interest_extractor = []
169
- auxi_loss = 0
170
- for i in range(self.num_history_features):
171
- outs, _ = self.interest_extractor_layers[i](embed_x_history[:, i, :, :])
172
- # 利用GRU输出的outs得到辅助损失函数
173
- auxi_loss += self.auxiliary(outs, embed_x_history[:, i, :, :], self.history_labels)
174
- interest_extractor.append(outs.unsqueeze(1)) # (batch_size, 1, seq_length, emb_dim)
175
- interest_extractor = torch.cat(interest_extractor,
176
- dim=1) # (batch_size, num_history_features, seq_length, emb_dim)
177
- interest_evolving = []
178
- for i in range(self.num_history_features):
179
- _, h = self.interest_evolving_layers[i](interest_extractor[:, i, :, :], embed_x_target[:, i, :])
180
- interest_evolving.append(h.unsqueeze(1)) # (batch_size, 1, emb_dim)
181
- interest_evolving = torch.cat(interest_evolving, dim=1) # (batch_size, num_history_features, emb_dim)
182
-
183
- mlp_in = torch.cat([
184
- interest_evolving.flatten(start_dim=1),
185
- embed_x_target.flatten(start_dim=1),
186
- embed_x_features.flatten(start_dim=1)
187
- ],
188
- dim=1) # (batch_size, N)
189
- y = self.mlp(mlp_in)
190
-
191
- return torch.sigmoid(y.squeeze(1)), self.alpha * auxi_loss
1
+ """
2
+ Date: create on 01/05/2024
3
+ References:
4
+ paper: (AAAI'2019) Deep Interest Evolution Network for Click-Through Rate Prediction
5
+ url: https://arxiv.org/pdf/1809.03672
6
+ Authors: Tao Fan, thisisevy@foxmail.com
7
+ """
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import Parameter, init
12
+
13
+ from ...basic.layers import MLP, EmbeddingLayer
14
+
15
+
16
+ class AUGRU(nn.Module):
17
+
18
+ def __init__(self, embed_dim):
19
+ super(AUGRU, self).__init__()
20
+ self.embed_dim = embed_dim
21
+ # 初始化AUGRU单元
22
+ self.augru_cell = AUGRU_Cell(self.embed_dim)
23
+
24
+ def forward(self, x, item):
25
+ '''
26
+ :param x: 输入的序列向量,维度为 [ batch_size, seq_lens, embed_dim ]
27
+ :param item: 目标物品的向量
28
+ :return: outs: 所有AUGRU单元输出的隐藏向量[ batch_size, seq_lens, embed_dim ]
29
+ h: 最后一个AUGRU单元输出的隐藏向量[ batch_size, embed_dim ]
30
+ '''
31
+ outs = []
32
+ h = None
33
+ # 开始循环,x.shape[1]是序列的长度
34
+ for i in range(x.shape[1]):
35
+ if h is None:
36
+ # 初始化第一层的输入h
37
+ h = Parameter(torch.rand(x.shape[0], self.embed_dim).to(x.device))
38
+ h = self.augru_cell(x[:, i], h, item)
39
+ outs.append(torch.unsqueeze(h, dim=1))
40
+ outs = torch.cat(outs, dim=1)
41
+ return outs, h
42
+
43
+
44
+ # AUGRU单元
45
+ class AUGRU_Cell(nn.Module):
46
+
47
+ def __init__(self, embed_dim):
48
+ """
49
+ :param embed_dim: 输入向量的维度
50
+ """
51
+ super(AUGRU_Cell, self).__init__()
52
+
53
+ # 初始化更新门的模型参数
54
+ self.Wu = Parameter(torch.rand(embed_dim, embed_dim))
55
+ self.Uu = Parameter(torch.rand(embed_dim, embed_dim))
56
+ self.bu = init.xavier_uniform_(Parameter(torch.empty(1, embed_dim)))
57
+
58
+ # 初始化重置门的模型参数
59
+ self.Wr = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
60
+ self.Ur = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
61
+ self.br = init.xavier_uniform_(Parameter(torch.empty(1, embed_dim)))
62
+
63
+ # 初始化计算h~的模型参数
64
+ self.Wh = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
65
+ self.Uh = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
66
+ self.bh = init.xavier_uniform_(Parameter(torch.empty(1, embed_dim)))
67
+
68
+ # 初始化注意计算里的模型参数
69
+ self.Wa = init.xavier_uniform_(Parameter(torch.empty(embed_dim, embed_dim)))
70
+
71
+
72
+ # 注意力的计算
73
+
74
+ def attention(self, x, item):
75
+ '''
76
+ :param x: 输入的序列中第t个向量 [ batch_size, embed_dim ]
77
+ :param item: 目标物品的向量 [ batch_size, embed_dim ]
78
+ :return: 注意力权重 [ batch_size, 1 ]
79
+ '''
80
+ hW = torch.matmul(x, self.Wa)
81
+ hWi = torch.sum(hW * item, dim=1)
82
+ hWi = torch.unsqueeze(hWi, 1)
83
+ return torch.softmax(hWi, dim=1)
84
+
85
+ def forward(self, x, h_1, item):
86
+ '''
87
+ :param x: 输入的序列中第t个物品向量 [ batch_size, embed_dim ]
88
+ :param h_1: 上一个AUGRU单元输出的隐藏向量 [ batch_size, embed_dim ]
89
+ :param item: 目标物品的向量 [ batch_size, embed_dim ]
90
+ :return: h 当前层输出的隐藏向量 [ batch_size, embed_dim ]
91
+ '''
92
+ # [ batch_size, embed_dim ]
93
+ u = torch.sigmoid(torch.matmul(x, self.Wu) + torch.matmul(h_1, self.Uu) + self.bu)
94
+ # [ batch_size, embed_dim ]
95
+ r = torch.sigmoid(torch.matmul(x, self.Wr) + torch.matmul(h_1, self.Ur) + self.br)
96
+ # [ batch_size, embed_dim ]
97
+ h_hat = torch.tanh(torch.matmul(x, self.Wh) + r * torch.matmul(h_1, self.Uh) + self.bh)
98
+ # [ batch_size, 1 ]
99
+ a = self.attention(x, item)
100
+ # [ batch_size, embed_dim ]
101
+ u_hat = a * u
102
+ # [ batch_size, embed_dim ]
103
+ h = (1 - u_hat) * h_1 + u_hat * h_hat
104
+ # [ batch_size, embed_dim ]
105
+ return h
106
+
107
+
108
+ class DIEN(nn.Module):
109
+ """Deep Interest Evolution Network
110
+ Args:
111
+ features (list): the list of `Feature Class`. training by MLP. It means the user profile features and context features in origin paper, exclude history and target features.
112
+ history_features (list): the list of `Feature Class`,training by ActivationUnit. It means the user behaviour sequence features, eg.item id sequence, shop id sequence.
113
+ target_features (list): the list of `Feature Class`, training by ActivationUnit. It means the target feature which will execute target-attention with history feature.
114
+ mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
115
+ history_labels (list): the list of history_features whether it is clicked history or not. It should be 0 or 1.
116
+ alpha (float): the weighting of auxiliary loss.
117
+ """
118
+
119
+ def __init__(self, features, history_features, target_features, mlp_params, history_labels, alpha=0.2):
120
+ super().__init__()
121
+ self.alpha = alpha # 计算辅助损失函数时的权重
122
+ self.features = features
123
+ self.history_features = history_features
124
+ self.target_features = target_features
125
+ self.num_history_features = len(history_features)
126
+ self.all_dims = sum([fea.embed_dim for fea in features + history_features + target_features])
127
+ # self.GRU = nn.GRU(batch_first=True)
128
+ self.embedding = EmbeddingLayer(features + history_features + target_features)
129
+ self.interest_extractor_layers = nn.ModuleList([nn.GRU(fea.embed_dim, fea.embed_dim, batch_first=True) for fea in self.history_features])
130
+ self.interest_evolving_layers = nn.ModuleList([AUGRU(fea.embed_dim) for fea in self.history_features])
131
+
132
+ self.mlp = MLP(self.all_dims, activation="dice", **mlp_params)
133
+ self.history_labels = torch.Tensor(history_labels)
134
+ self.BCELoss = nn.BCELoss()
135
+
136
+
137
+ # # 注意力计算中的线性层
138
+ # self.attention_liner = nn.Linear(self.embed_dim, t)
139
+ # # AFM公式中的h
140
+ # self.h = init.xavier_uniform_(Parameter(torch.empty(t, 1)))
141
+ # # AFM公式中的p
142
+ # self.p = init.xavier_uniform_(Parameter(torch.empty(self.embed_dim, 1)))
143
+
144
+ def auxiliary(self, outs, history_features, history_labels):
145
+ '''
146
+ :param history_features: 历史序列物品的向量 [ batch_size, len_seqs, dim ]
147
+ :param outs: 兴趣抽取层GRU网络输出的outs [ batch_size, len_seqs, dim ]
148
+ :param history_labels: 历史序列物品标注 [ batch_size, len_seqs, 1 ]
149
+ :return: 辅助损失函数
150
+ '''
151
+ # [ batch_size * len_seqs, dim ]
152
+ history_features = history_features.reshape(-1, history_features.shape[2])
153
+ # [ batch_size * len_seqs, dim ]
154
+ outs = outs.reshape(-1, outs.shape[2])
155
+ # [ batch_size * len_seqs ]
156
+ out = torch.sum(outs * history_features, dim=1)
157
+ # [ batch_size * len_seqs, 1 ]
158
+ out = torch.unsqueeze(torch.sigmoid(out), 1)
159
+ # [ batch_size * len_seqs,1 ]
160
+ history_labels = history_labels.reshape(-1, 1).float()
161
+ return self.BCELoss(out, history_labels)
162
+
163
+ def forward(self, x):
164
+ # (batch_size, num_features, emb_dim)
165
+ embed_x_features = self.embedding(x, self.features)
166
+ # (batch_size, num_history_features, seq_length, emb_dim)
167
+ embed_x_history = self.embedding(x, self.history_features)
168
+ # (batch_size, num_target_features, emb_dim)
169
+ embed_x_target = self.embedding(x, self.target_features)
170
+
171
+ interest_extractor = []
172
+ auxi_loss = 0
173
+ for i in range(self.num_history_features):
174
+ outs, _ = self.interest_extractor_layers[i](embed_x_history[:, i, :, :])
175
+ # 利用GRU输出的outs得到辅助损失函数
176
+ auxi_loss += self.auxiliary(outs, embed_x_history[:, i, :, :], self.history_labels)
177
+ # (batch_size, 1, seq_length, emb_dim)
178
+ interest_extractor.append(outs.unsqueeze(1))
179
+ # (batch_size, num_history_features, seq_length, emb_dim)
180
+ interest_extractor = torch.cat(interest_extractor, dim=1)
181
+ interest_evolving = []
182
+ for i in range(self.num_history_features):
183
+ _, h = self.interest_evolving_layers[i](interest_extractor[:, i, :, :], embed_x_target[:, i, :])
184
+ interest_evolving.append(h.unsqueeze(1)) # (batch_size, 1, emb_dim)
185
+ # (batch_size, num_history_features, emb_dim)
186
+ interest_evolving = torch.cat(interest_evolving, dim=1)
187
+
188
+ mlp_in = torch.cat([interest_evolving.flatten(start_dim=1), embed_x_target.flatten(start_dim=1), embed_x_features.flatten(start_dim=1)], dim=1) # (batch_size, N)
189
+ y = self.mlp(mlp_in)
190
+
191
+ return torch.sigmoid(y.squeeze(1)), self.alpha * auxi_loss