nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +220 -106
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1082 -400
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +272 -95
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +53 -37
- nextrec/models/multi_task/mmoe.py +64 -45
- nextrec/models/multi_task/ple.py +101 -48
- nextrec/models/multi_task/poso.py +113 -36
- nextrec/models/multi_task/share_bottom.py +48 -35
- nextrec/models/ranking/afm.py +72 -37
- nextrec/models/ranking/autoint.py +72 -55
- nextrec/models/ranking/dcn.py +55 -35
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +32 -22
- nextrec/models/ranking/dien.py +155 -99
- nextrec/models/ranking/din.py +85 -57
- nextrec/models/ranking/fibinet.py +52 -32
- nextrec/models/ranking/fm.py +29 -23
- nextrec/models/ranking/masknet.py +91 -29
- nextrec/models/ranking/pnn.py +31 -28
- nextrec/models/ranking/widedeep.py +34 -26
- nextrec/models/ranking/xdeepfm.py +60 -38
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/models/match/mind.py
CHANGED
|
@@ -6,6 +6,7 @@ Reference:
|
|
|
6
6
|
[1] Li C, Liu Z, Wu M, et al. Multi-interest network with dynamic routing for recommendation at Tmall[C]
|
|
7
7
|
//Proceedings of the 28th ACM international conference on information and knowledge management. 2019: 2615-2623.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
11
12
|
import torch.nn.functional as F
|
|
@@ -15,6 +16,7 @@ from nextrec.basic.model import BaseMatchModel
|
|
|
15
16
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
16
17
|
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
17
18
|
|
|
19
|
+
|
|
18
20
|
class MultiInterestSA(nn.Module):
|
|
19
21
|
"""Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
|
|
20
22
|
|
|
@@ -22,19 +24,25 @@ class MultiInterestSA(nn.Module):
|
|
|
22
24
|
super(MultiInterestSA, self).__init__()
|
|
23
25
|
self.embedding_dim = embedding_dim
|
|
24
26
|
self.interest_num = interest_num
|
|
25
|
-
if hidden_dim
|
|
27
|
+
if hidden_dim is None:
|
|
26
28
|
self.hidden_dim = self.embedding_dim * 4
|
|
27
|
-
self.W1 = torch.nn.Parameter(
|
|
28
|
-
|
|
29
|
-
|
|
29
|
+
self.W1 = torch.nn.Parameter(
|
|
30
|
+
torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True
|
|
31
|
+
)
|
|
32
|
+
self.W2 = torch.nn.Parameter(
|
|
33
|
+
torch.rand(self.hidden_dim, self.interest_num), requires_grad=True
|
|
34
|
+
)
|
|
35
|
+
self.W3 = torch.nn.Parameter(
|
|
36
|
+
torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True
|
|
37
|
+
)
|
|
30
38
|
|
|
31
39
|
def forward(self, seq_emb, mask=None):
|
|
32
|
-
H = torch.einsum(
|
|
33
|
-
if mask
|
|
34
|
-
A = torch.einsum(
|
|
40
|
+
H = torch.einsum("bse, ed -> bsd", seq_emb, self.W1).tanh()
|
|
41
|
+
if mask is not None:
|
|
42
|
+
A = torch.einsum("bsd, dk -> bsk", H, self.W2) + -1.0e9 * (1 - mask.float())
|
|
35
43
|
A = F.softmax(A, dim=1)
|
|
36
44
|
else:
|
|
37
|
-
A = F.softmax(torch.einsum(
|
|
45
|
+
A = F.softmax(torch.einsum("bsd, dk -> bsk", H, self.W2), dim=1)
|
|
38
46
|
A = A.permute(0, 2, 1)
|
|
39
47
|
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
40
48
|
return multi_interest_emb
|
|
@@ -43,7 +51,15 @@ class MultiInterestSA(nn.Module):
|
|
|
43
51
|
class CapsuleNetwork(nn.Module):
|
|
44
52
|
"""Dynamic routing capsule network used in MIND (Li et al., 2019)."""
|
|
45
53
|
|
|
46
|
-
def __init__(
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
embedding_dim,
|
|
57
|
+
seq_len,
|
|
58
|
+
bilinear_type=2,
|
|
59
|
+
interest_num=4,
|
|
60
|
+
routing_times=3,
|
|
61
|
+
relu_layer=False,
|
|
62
|
+
):
|
|
47
63
|
super(CapsuleNetwork, self).__init__()
|
|
48
64
|
self.embedding_dim = embedding_dim # h
|
|
49
65
|
self.seq_len = seq_len # s
|
|
@@ -53,13 +69,24 @@ class CapsuleNetwork(nn.Module):
|
|
|
53
69
|
|
|
54
70
|
self.relu_layer = relu_layer
|
|
55
71
|
self.stop_grad = True
|
|
56
|
-
self.relu = nn.Sequential(
|
|
72
|
+
self.relu = nn.Sequential(
|
|
73
|
+
nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU()
|
|
74
|
+
)
|
|
57
75
|
if self.bilinear_type == 0: # MIND
|
|
58
76
|
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
59
77
|
elif self.bilinear_type == 1:
|
|
60
|
-
self.linear = nn.Linear(
|
|
78
|
+
self.linear = nn.Linear(
|
|
79
|
+
self.embedding_dim, self.embedding_dim * self.interest_num, bias=False
|
|
80
|
+
)
|
|
61
81
|
else:
|
|
62
|
-
self.w = nn.Parameter(
|
|
82
|
+
self.w = nn.Parameter(
|
|
83
|
+
torch.Tensor(
|
|
84
|
+
1,
|
|
85
|
+
self.seq_len,
|
|
86
|
+
self.interest_num * self.embedding_dim,
|
|
87
|
+
self.embedding_dim,
|
|
88
|
+
)
|
|
89
|
+
)
|
|
63
90
|
nn.init.xavier_uniform_(self.w)
|
|
64
91
|
|
|
65
92
|
def forward(self, item_eb, mask):
|
|
@@ -70,11 +97,15 @@ class CapsuleNetwork(nn.Module):
|
|
|
70
97
|
item_eb_hat = self.linear(item_eb)
|
|
71
98
|
else:
|
|
72
99
|
u = torch.unsqueeze(item_eb, dim=2)
|
|
73
|
-
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
100
|
+
item_eb_hat = torch.sum(self.w[:, : self.seq_len, :, :] * u, dim=3)
|
|
74
101
|
|
|
75
|
-
item_eb_hat = torch.reshape(
|
|
102
|
+
item_eb_hat = torch.reshape(
|
|
103
|
+
item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim)
|
|
104
|
+
)
|
|
76
105
|
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
77
|
-
item_eb_hat = torch.reshape(
|
|
106
|
+
item_eb_hat = torch.reshape(
|
|
107
|
+
item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim)
|
|
108
|
+
)
|
|
78
109
|
|
|
79
110
|
if self.stop_grad:
|
|
80
111
|
item_eb_hat_iter = item_eb_hat.detach()
|
|
@@ -82,34 +113,47 @@ class CapsuleNetwork(nn.Module):
|
|
|
82
113
|
item_eb_hat_iter = item_eb_hat
|
|
83
114
|
|
|
84
115
|
if self.bilinear_type > 0:
|
|
85
|
-
capsule_weight = torch.zeros(
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
116
|
+
capsule_weight = torch.zeros(
|
|
117
|
+
item_eb_hat.shape[0],
|
|
118
|
+
self.interest_num,
|
|
119
|
+
self.seq_len,
|
|
120
|
+
device=item_eb.device,
|
|
121
|
+
requires_grad=False,
|
|
122
|
+
)
|
|
90
123
|
else:
|
|
91
|
-
capsule_weight = torch.randn(
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
124
|
+
capsule_weight = torch.randn(
|
|
125
|
+
item_eb_hat.shape[0],
|
|
126
|
+
self.interest_num,
|
|
127
|
+
self.seq_len,
|
|
128
|
+
device=item_eb.device,
|
|
129
|
+
requires_grad=False,
|
|
130
|
+
)
|
|
96
131
|
|
|
97
132
|
for i in range(self.routing_times): # 动态路由传播3次
|
|
98
133
|
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
99
134
|
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
100
135
|
|
|
101
136
|
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
102
|
-
capsule_softmax_weight = torch.where(
|
|
137
|
+
capsule_softmax_weight = torch.where(
|
|
138
|
+
torch.eq(atten_mask, 0), paddings, capsule_softmax_weight
|
|
139
|
+
)
|
|
103
140
|
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
104
141
|
|
|
105
142
|
if i < 2:
|
|
106
|
-
interest_capsule = torch.matmul(
|
|
143
|
+
interest_capsule = torch.matmul(
|
|
144
|
+
capsule_softmax_weight, item_eb_hat_iter
|
|
145
|
+
)
|
|
107
146
|
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
108
147
|
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
109
148
|
interest_capsule = scalar_factor * interest_capsule
|
|
110
149
|
|
|
111
|
-
delta_weight = torch.matmul(
|
|
112
|
-
|
|
150
|
+
delta_weight = torch.matmul(
|
|
151
|
+
item_eb_hat_iter,
|
|
152
|
+
torch.transpose(interest_capsule, 2, 3).contiguous(),
|
|
153
|
+
)
|
|
154
|
+
delta_weight = torch.reshape(
|
|
155
|
+
delta_weight, (-1, self.interest_num, self.seq_len)
|
|
156
|
+
)
|
|
113
157
|
capsule_weight = capsule_weight + delta_weight
|
|
114
158
|
else:
|
|
115
159
|
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
@@ -117,7 +161,9 @@ class CapsuleNetwork(nn.Module):
|
|
|
117
161
|
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
118
162
|
interest_capsule = scalar_factor * interest_capsule
|
|
119
163
|
|
|
120
|
-
interest_capsule = torch.reshape(
|
|
164
|
+
interest_capsule = torch.reshape(
|
|
165
|
+
interest_capsule, (-1, self.interest_num, self.embedding_dim)
|
|
166
|
+
)
|
|
121
167
|
|
|
122
168
|
if self.relu_layer:
|
|
123
169
|
interest_capsule = self.relu(interest_capsule)
|
|
@@ -129,45 +175,52 @@ class MIND(BaseMatchModel):
|
|
|
129
175
|
@property
|
|
130
176
|
def model_name(self) -> str:
|
|
131
177
|
return "MIND"
|
|
132
|
-
|
|
178
|
+
|
|
133
179
|
@property
|
|
134
180
|
def support_training_modes(self) -> list[str]:
|
|
135
181
|
"""MIND only supports pointwise training mode"""
|
|
136
|
-
return [
|
|
137
|
-
|
|
138
|
-
def __init__(
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
182
|
+
return ["pointwise"]
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
187
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
188
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
189
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
190
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
191
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
192
|
+
embedding_dim: int = 64,
|
|
193
|
+
num_interests: int = 4,
|
|
194
|
+
capsule_bilinear_type: int = 2,
|
|
195
|
+
routing_times: int = 3,
|
|
196
|
+
relu_layer: bool = False,
|
|
197
|
+
item_dnn_hidden_units: list[int] = [256, 128],
|
|
198
|
+
dnn_activation: str = "relu",
|
|
199
|
+
dnn_dropout: float = 0.0,
|
|
200
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
201
|
+
num_negative_samples: int = 100,
|
|
202
|
+
temperature: float = 1.0,
|
|
203
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
204
|
+
device: str = "cpu",
|
|
205
|
+
embedding_l1_reg: float = 0.0,
|
|
206
|
+
dense_l1_reg: float = 0.0,
|
|
207
|
+
embedding_l2_reg: float = 0.0,
|
|
208
|
+
dense_l2_reg: float = 0.0,
|
|
209
|
+
early_stop_patience: int = 20,
|
|
210
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
211
|
+
optimizer_params: dict | None = None,
|
|
212
|
+
scheduler: (
|
|
213
|
+
str
|
|
214
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
215
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
216
|
+
| None
|
|
217
|
+
) = None,
|
|
218
|
+
scheduler_params: dict | None = None,
|
|
219
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
220
|
+
loss_params: dict | list[dict] | None = None,
|
|
221
|
+
**kwargs,
|
|
222
|
+
):
|
|
223
|
+
|
|
171
224
|
super(MIND, self).__init__(
|
|
172
225
|
user_dense_features=user_dense_features,
|
|
173
226
|
user_sparse_features=user_sparse_features,
|
|
@@ -184,9 +237,9 @@ class MIND(BaseMatchModel):
|
|
|
184
237
|
dense_l1_reg=dense_l1_reg,
|
|
185
238
|
embedding_l2_reg=embedding_l2_reg,
|
|
186
239
|
dense_l2_reg=dense_l2_reg,
|
|
187
|
-
**kwargs
|
|
240
|
+
**kwargs,
|
|
188
241
|
)
|
|
189
|
-
|
|
242
|
+
|
|
190
243
|
self.embedding_dim = embedding_dim
|
|
191
244
|
self.num_interests = num_interests
|
|
192
245
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
@@ -198,16 +251,20 @@ class MIND(BaseMatchModel):
|
|
|
198
251
|
user_features.extend(user_sparse_features)
|
|
199
252
|
if user_sequence_features:
|
|
200
253
|
user_features.extend(user_sequence_features)
|
|
201
|
-
|
|
254
|
+
|
|
202
255
|
if len(user_features) > 0:
|
|
203
256
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
204
|
-
|
|
257
|
+
|
|
205
258
|
if not user_sequence_features or len(user_sequence_features) == 0:
|
|
206
259
|
raise ValueError("MIND requires at least one user sequence feature")
|
|
207
|
-
|
|
208
|
-
seq_max_len =
|
|
260
|
+
|
|
261
|
+
seq_max_len = (
|
|
262
|
+
user_sequence_features[0].max_len
|
|
263
|
+
if user_sequence_features[0].max_len
|
|
264
|
+
else 50
|
|
265
|
+
)
|
|
209
266
|
seq_embedding_dim = user_sequence_features[0].embedding_dim
|
|
210
|
-
|
|
267
|
+
|
|
211
268
|
# Capsule Network for multi-interest extraction
|
|
212
269
|
self.capsule_network = CapsuleNetwork(
|
|
213
270
|
embedding_dim=seq_embedding_dim,
|
|
@@ -215,15 +272,17 @@ class MIND(BaseMatchModel):
|
|
|
215
272
|
bilinear_type=capsule_bilinear_type,
|
|
216
273
|
interest_num=num_interests,
|
|
217
274
|
routing_times=routing_times,
|
|
218
|
-
relu_layer=relu_layer
|
|
275
|
+
relu_layer=relu_layer,
|
|
219
276
|
)
|
|
220
|
-
|
|
277
|
+
|
|
221
278
|
if seq_embedding_dim != embedding_dim:
|
|
222
|
-
self.interest_projection = nn.Linear(
|
|
279
|
+
self.interest_projection = nn.Linear(
|
|
280
|
+
seq_embedding_dim, embedding_dim, bias=False
|
|
281
|
+
)
|
|
223
282
|
nn.init.xavier_uniform_(self.interest_projection.weight)
|
|
224
283
|
else:
|
|
225
284
|
self.interest_projection = None
|
|
226
|
-
|
|
285
|
+
|
|
227
286
|
# Item tower
|
|
228
287
|
item_features = []
|
|
229
288
|
if item_dense_features:
|
|
@@ -232,10 +291,10 @@ class MIND(BaseMatchModel):
|
|
|
232
291
|
item_features.extend(item_sparse_features)
|
|
233
292
|
if item_sequence_features:
|
|
234
293
|
item_features.extend(item_sequence_features)
|
|
235
|
-
|
|
294
|
+
|
|
236
295
|
if len(item_features) > 0:
|
|
237
296
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
238
|
-
|
|
297
|
+
|
|
239
298
|
item_input_dim = 0
|
|
240
299
|
for feat in item_dense_features or []:
|
|
241
300
|
item_input_dim += 1
|
|
@@ -243,7 +302,7 @@ class MIND(BaseMatchModel):
|
|
|
243
302
|
item_input_dim += feat.embedding_dim
|
|
244
303
|
for feat in item_sequence_features or []:
|
|
245
304
|
item_input_dim += feat.embedding_dim
|
|
246
|
-
|
|
305
|
+
|
|
247
306
|
# Item DNN
|
|
248
307
|
if len(item_dnn_hidden_units) > 0:
|
|
249
308
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
@@ -252,20 +311,19 @@ class MIND(BaseMatchModel):
|
|
|
252
311
|
dims=item_dnn_units,
|
|
253
312
|
output_layer=False,
|
|
254
313
|
dropout=dnn_dropout,
|
|
255
|
-
activation=dnn_activation
|
|
314
|
+
activation=dnn_activation,
|
|
256
315
|
)
|
|
257
316
|
else:
|
|
258
317
|
self.item_dnn = None
|
|
259
|
-
|
|
318
|
+
|
|
260
319
|
self.register_regularization_weights(
|
|
261
|
-
embedding_attr=
|
|
262
|
-
include_modules=['capsule_network']
|
|
320
|
+
embedding_attr="user_embedding", include_modules=["capsule_network"]
|
|
263
321
|
)
|
|
264
322
|
self.register_regularization_weights(
|
|
265
|
-
embedding_attr=
|
|
266
|
-
include_modules=[
|
|
323
|
+
embedding_attr="item_embedding",
|
|
324
|
+
include_modules=["item_dnn"] if self.item_dnn else [],
|
|
267
325
|
)
|
|
268
|
-
|
|
326
|
+
|
|
269
327
|
self.compile(
|
|
270
328
|
optimizer=optimizer,
|
|
271
329
|
optimizer_params=optimizer_params,
|
|
@@ -276,11 +334,11 @@ class MIND(BaseMatchModel):
|
|
|
276
334
|
)
|
|
277
335
|
|
|
278
336
|
self.to(device)
|
|
279
|
-
|
|
337
|
+
|
|
280
338
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
281
339
|
"""
|
|
282
340
|
User tower with multi-interest extraction
|
|
283
|
-
|
|
341
|
+
|
|
284
342
|
Returns:
|
|
285
343
|
user_interests: [batch_size, num_interests, embedding_dim]
|
|
286
344
|
"""
|
|
@@ -291,43 +349,53 @@ class MIND(BaseMatchModel):
|
|
|
291
349
|
seq_emb = embed(seq_input.long()) # [batch_size, seq_len, embedding_dim]
|
|
292
350
|
|
|
293
351
|
mask = (seq_input != seq_feature.padding_idx).float() # [batch_size, seq_len]
|
|
294
|
-
|
|
295
|
-
multi_interests = self.capsule_network(
|
|
296
|
-
|
|
352
|
+
|
|
353
|
+
multi_interests = self.capsule_network(
|
|
354
|
+
seq_emb, mask
|
|
355
|
+
) # [batch_size, num_interests, seq_embedding_dim]
|
|
356
|
+
|
|
297
357
|
if self.interest_projection is not None:
|
|
298
|
-
multi_interests = self.interest_projection(
|
|
299
|
-
|
|
358
|
+
multi_interests = self.interest_projection(
|
|
359
|
+
multi_interests
|
|
360
|
+
) # [batch_size, num_interests, embedding_dim]
|
|
361
|
+
|
|
300
362
|
# L2 normalization
|
|
301
363
|
multi_interests = F.normalize(multi_interests, p=2, dim=-1)
|
|
302
|
-
|
|
364
|
+
|
|
303
365
|
return multi_interests
|
|
304
|
-
|
|
366
|
+
|
|
305
367
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
306
368
|
"""Item tower"""
|
|
307
|
-
all_item_features =
|
|
369
|
+
all_item_features = (
|
|
370
|
+
self.item_dense_features
|
|
371
|
+
+ self.item_sparse_features
|
|
372
|
+
+ self.item_sequence_features
|
|
373
|
+
)
|
|
308
374
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
309
|
-
|
|
375
|
+
|
|
310
376
|
if self.item_dnn is not None:
|
|
311
377
|
item_emb = self.item_dnn(item_emb)
|
|
312
|
-
|
|
378
|
+
|
|
313
379
|
# L2 normalization
|
|
314
380
|
item_emb = F.normalize(item_emb, p=2, dim=1)
|
|
315
|
-
|
|
381
|
+
|
|
316
382
|
return item_emb
|
|
317
|
-
|
|
318
|
-
def compute_similarity(
|
|
383
|
+
|
|
384
|
+
def compute_similarity(
|
|
385
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
386
|
+
) -> torch.Tensor:
|
|
319
387
|
item_emb_expanded = item_emb.unsqueeze(1)
|
|
320
|
-
|
|
321
|
-
if self.similarity_metric ==
|
|
388
|
+
|
|
389
|
+
if self.similarity_metric == "dot":
|
|
322
390
|
similarities = torch.sum(user_emb * item_emb_expanded, dim=-1)
|
|
323
|
-
elif self.similarity_metric ==
|
|
391
|
+
elif self.similarity_metric == "cosine":
|
|
324
392
|
similarities = F.cosine_similarity(user_emb, item_emb_expanded, dim=-1)
|
|
325
|
-
elif self.similarity_metric ==
|
|
393
|
+
elif self.similarity_metric == "euclidean":
|
|
326
394
|
similarities = -torch.sum((user_emb - item_emb_expanded) ** 2, dim=-1)
|
|
327
395
|
else:
|
|
328
396
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
329
397
|
|
|
330
398
|
max_similarity, _ = torch.max(similarities, dim=1) # [batch_size]
|
|
331
399
|
max_similarity = max_similarity / self.temperature
|
|
332
|
-
|
|
400
|
+
|
|
333
401
|
return max_similarity
|