nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/models/match/mind.py
CHANGED
|
@@ -6,6 +6,7 @@ Reference:
|
|
|
6
6
|
[1] Li C, Liu Z, Wu M, et al. Multi-interest network with dynamic routing for recommendation at Tmall[C]
|
|
7
7
|
//Proceedings of the 28th ACM international conference on information and knowledge management. 2019: 2615-2623.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
11
12
|
import torch.nn.functional as F
|
|
@@ -15,6 +16,7 @@ from nextrec.basic.model import BaseMatchModel
|
|
|
15
16
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
16
17
|
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
17
18
|
|
|
19
|
+
|
|
18
20
|
class MultiInterestSA(nn.Module):
|
|
19
21
|
"""Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
|
|
20
22
|
|
|
@@ -22,19 +24,25 @@ class MultiInterestSA(nn.Module):
|
|
|
22
24
|
super(MultiInterestSA, self).__init__()
|
|
23
25
|
self.embedding_dim = embedding_dim
|
|
24
26
|
self.interest_num = interest_num
|
|
25
|
-
if hidden_dim
|
|
27
|
+
if hidden_dim is None:
|
|
26
28
|
self.hidden_dim = self.embedding_dim * 4
|
|
27
|
-
self.W1 = torch.nn.Parameter(
|
|
28
|
-
|
|
29
|
-
|
|
29
|
+
self.W1 = torch.nn.Parameter(
|
|
30
|
+
torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True
|
|
31
|
+
)
|
|
32
|
+
self.W2 = torch.nn.Parameter(
|
|
33
|
+
torch.rand(self.hidden_dim, self.interest_num), requires_grad=True
|
|
34
|
+
)
|
|
35
|
+
self.W3 = torch.nn.Parameter(
|
|
36
|
+
torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True
|
|
37
|
+
)
|
|
30
38
|
|
|
31
39
|
def forward(self, seq_emb, mask=None):
|
|
32
|
-
H = torch.einsum(
|
|
33
|
-
if mask
|
|
34
|
-
A = torch.einsum(
|
|
40
|
+
H = torch.einsum("bse, ed -> bsd", seq_emb, self.W1).tanh()
|
|
41
|
+
if mask is not None:
|
|
42
|
+
A = torch.einsum("bsd, dk -> bsk", H, self.W2) + -1.0e9 * (1 - mask.float())
|
|
35
43
|
A = F.softmax(A, dim=1)
|
|
36
44
|
else:
|
|
37
|
-
A = F.softmax(torch.einsum(
|
|
45
|
+
A = F.softmax(torch.einsum("bsd, dk -> bsk", H, self.W2), dim=1)
|
|
38
46
|
A = A.permute(0, 2, 1)
|
|
39
47
|
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
40
48
|
return multi_interest_emb
|
|
@@ -43,7 +51,15 @@ class MultiInterestSA(nn.Module):
|
|
|
43
51
|
class CapsuleNetwork(nn.Module):
|
|
44
52
|
"""Dynamic routing capsule network used in MIND (Li et al., 2019)."""
|
|
45
53
|
|
|
46
|
-
def __init__(
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
embedding_dim,
|
|
57
|
+
seq_len,
|
|
58
|
+
bilinear_type=2,
|
|
59
|
+
interest_num=4,
|
|
60
|
+
routing_times=3,
|
|
61
|
+
relu_layer=False,
|
|
62
|
+
):
|
|
47
63
|
super(CapsuleNetwork, self).__init__()
|
|
48
64
|
self.embedding_dim = embedding_dim # h
|
|
49
65
|
self.seq_len = seq_len # s
|
|
@@ -53,13 +69,24 @@ class CapsuleNetwork(nn.Module):
|
|
|
53
69
|
|
|
54
70
|
self.relu_layer = relu_layer
|
|
55
71
|
self.stop_grad = True
|
|
56
|
-
self.relu = nn.Sequential(
|
|
72
|
+
self.relu = nn.Sequential(
|
|
73
|
+
nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU()
|
|
74
|
+
)
|
|
57
75
|
if self.bilinear_type == 0: # MIND
|
|
58
76
|
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
59
77
|
elif self.bilinear_type == 1:
|
|
60
|
-
self.linear = nn.Linear(
|
|
78
|
+
self.linear = nn.Linear(
|
|
79
|
+
self.embedding_dim, self.embedding_dim * self.interest_num, bias=False
|
|
80
|
+
)
|
|
61
81
|
else:
|
|
62
|
-
self.w = nn.Parameter(
|
|
82
|
+
self.w = nn.Parameter(
|
|
83
|
+
torch.Tensor(
|
|
84
|
+
1,
|
|
85
|
+
self.seq_len,
|
|
86
|
+
self.interest_num * self.embedding_dim,
|
|
87
|
+
self.embedding_dim,
|
|
88
|
+
)
|
|
89
|
+
)
|
|
63
90
|
nn.init.xavier_uniform_(self.w)
|
|
64
91
|
|
|
65
92
|
def forward(self, item_eb, mask):
|
|
@@ -70,11 +97,15 @@ class CapsuleNetwork(nn.Module):
|
|
|
70
97
|
item_eb_hat = self.linear(item_eb)
|
|
71
98
|
else:
|
|
72
99
|
u = torch.unsqueeze(item_eb, dim=2)
|
|
73
|
-
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
100
|
+
item_eb_hat = torch.sum(self.w[:, : self.seq_len, :, :] * u, dim=3)
|
|
74
101
|
|
|
75
|
-
item_eb_hat = torch.reshape(
|
|
102
|
+
item_eb_hat = torch.reshape(
|
|
103
|
+
item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim)
|
|
104
|
+
)
|
|
76
105
|
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
77
|
-
item_eb_hat = torch.reshape(
|
|
106
|
+
item_eb_hat = torch.reshape(
|
|
107
|
+
item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim)
|
|
108
|
+
)
|
|
78
109
|
|
|
79
110
|
if self.stop_grad:
|
|
80
111
|
item_eb_hat_iter = item_eb_hat.detach()
|
|
@@ -82,34 +113,47 @@ class CapsuleNetwork(nn.Module):
|
|
|
82
113
|
item_eb_hat_iter = item_eb_hat
|
|
83
114
|
|
|
84
115
|
if self.bilinear_type > 0:
|
|
85
|
-
capsule_weight = torch.zeros(
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
116
|
+
capsule_weight = torch.zeros(
|
|
117
|
+
item_eb_hat.shape[0],
|
|
118
|
+
self.interest_num,
|
|
119
|
+
self.seq_len,
|
|
120
|
+
device=item_eb.device,
|
|
121
|
+
requires_grad=False,
|
|
122
|
+
)
|
|
90
123
|
else:
|
|
91
|
-
capsule_weight = torch.randn(
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
124
|
+
capsule_weight = torch.randn(
|
|
125
|
+
item_eb_hat.shape[0],
|
|
126
|
+
self.interest_num,
|
|
127
|
+
self.seq_len,
|
|
128
|
+
device=item_eb.device,
|
|
129
|
+
requires_grad=False,
|
|
130
|
+
)
|
|
96
131
|
|
|
97
132
|
for i in range(self.routing_times): # 动态路由传播3次
|
|
98
133
|
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
99
134
|
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
100
135
|
|
|
101
136
|
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
102
|
-
capsule_softmax_weight = torch.where(
|
|
137
|
+
capsule_softmax_weight = torch.where(
|
|
138
|
+
torch.eq(atten_mask, 0), paddings, capsule_softmax_weight
|
|
139
|
+
)
|
|
103
140
|
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
104
141
|
|
|
105
142
|
if i < 2:
|
|
106
|
-
interest_capsule = torch.matmul(
|
|
143
|
+
interest_capsule = torch.matmul(
|
|
144
|
+
capsule_softmax_weight, item_eb_hat_iter
|
|
145
|
+
)
|
|
107
146
|
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
108
147
|
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
109
148
|
interest_capsule = scalar_factor * interest_capsule
|
|
110
149
|
|
|
111
|
-
delta_weight = torch.matmul(
|
|
112
|
-
|
|
150
|
+
delta_weight = torch.matmul(
|
|
151
|
+
item_eb_hat_iter,
|
|
152
|
+
torch.transpose(interest_capsule, 2, 3).contiguous(),
|
|
153
|
+
)
|
|
154
|
+
delta_weight = torch.reshape(
|
|
155
|
+
delta_weight, (-1, self.interest_num, self.seq_len)
|
|
156
|
+
)
|
|
113
157
|
capsule_weight = capsule_weight + delta_weight
|
|
114
158
|
else:
|
|
115
159
|
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
@@ -117,7 +161,9 @@ class CapsuleNetwork(nn.Module):
|
|
|
117
161
|
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
118
162
|
interest_capsule = scalar_factor * interest_capsule
|
|
119
163
|
|
|
120
|
-
interest_capsule = torch.reshape(
|
|
164
|
+
interest_capsule = torch.reshape(
|
|
165
|
+
interest_capsule, (-1, self.interest_num, self.embedding_dim)
|
|
166
|
+
)
|
|
121
167
|
|
|
122
168
|
if self.relu_layer:
|
|
123
169
|
interest_capsule = self.relu(interest_capsule)
|
|
@@ -129,45 +175,52 @@ class MIND(BaseMatchModel):
|
|
|
129
175
|
@property
|
|
130
176
|
def model_name(self) -> str:
|
|
131
177
|
return "MIND"
|
|
132
|
-
|
|
178
|
+
|
|
133
179
|
@property
|
|
134
180
|
def support_training_modes(self) -> list[str]:
|
|
135
181
|
"""MIND only supports pointwise training mode"""
|
|
136
|
-
return [
|
|
137
|
-
|
|
138
|
-
def __init__(
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
182
|
+
return ["pointwise"]
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
187
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
188
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
189
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
190
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
191
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
192
|
+
embedding_dim: int = 64,
|
|
193
|
+
num_interests: int = 4,
|
|
194
|
+
capsule_bilinear_type: int = 2,
|
|
195
|
+
routing_times: int = 3,
|
|
196
|
+
relu_layer: bool = False,
|
|
197
|
+
item_dnn_hidden_units: list[int] = [256, 128],
|
|
198
|
+
dnn_activation: str = "relu",
|
|
199
|
+
dnn_dropout: float = 0.0,
|
|
200
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
201
|
+
num_negative_samples: int = 100,
|
|
202
|
+
temperature: float = 1.0,
|
|
203
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
204
|
+
device: str = "cpu",
|
|
205
|
+
embedding_l1_reg: float = 0.0,
|
|
206
|
+
dense_l1_reg: float = 0.0,
|
|
207
|
+
embedding_l2_reg: float = 0.0,
|
|
208
|
+
dense_l2_reg: float = 0.0,
|
|
209
|
+
early_stop_patience: int = 20,
|
|
210
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
211
|
+
optimizer_params: dict | None = None,
|
|
212
|
+
scheduler: (
|
|
213
|
+
str
|
|
214
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
215
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
216
|
+
| None
|
|
217
|
+
) = None,
|
|
218
|
+
scheduler_params: dict | None = None,
|
|
219
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
220
|
+
loss_params: dict | list[dict] | None = None,
|
|
221
|
+
**kwargs,
|
|
222
|
+
):
|
|
223
|
+
|
|
171
224
|
super(MIND, self).__init__(
|
|
172
225
|
user_dense_features=user_dense_features,
|
|
173
226
|
user_sparse_features=user_sparse_features,
|
|
@@ -184,10 +237,9 @@ class MIND(BaseMatchModel):
|
|
|
184
237
|
dense_l1_reg=dense_l1_reg,
|
|
185
238
|
embedding_l2_reg=embedding_l2_reg,
|
|
186
239
|
dense_l2_reg=dense_l2_reg,
|
|
187
|
-
|
|
188
|
-
**kwargs
|
|
240
|
+
**kwargs,
|
|
189
241
|
)
|
|
190
|
-
|
|
242
|
+
|
|
191
243
|
self.embedding_dim = embedding_dim
|
|
192
244
|
self.num_interests = num_interests
|
|
193
245
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
@@ -199,16 +251,20 @@ class MIND(BaseMatchModel):
|
|
|
199
251
|
user_features.extend(user_sparse_features)
|
|
200
252
|
if user_sequence_features:
|
|
201
253
|
user_features.extend(user_sequence_features)
|
|
202
|
-
|
|
254
|
+
|
|
203
255
|
if len(user_features) > 0:
|
|
204
256
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
205
|
-
|
|
257
|
+
|
|
206
258
|
if not user_sequence_features or len(user_sequence_features) == 0:
|
|
207
259
|
raise ValueError("MIND requires at least one user sequence feature")
|
|
208
|
-
|
|
209
|
-
seq_max_len =
|
|
260
|
+
|
|
261
|
+
seq_max_len = (
|
|
262
|
+
user_sequence_features[0].max_len
|
|
263
|
+
if user_sequence_features[0].max_len
|
|
264
|
+
else 50
|
|
265
|
+
)
|
|
210
266
|
seq_embedding_dim = user_sequence_features[0].embedding_dim
|
|
211
|
-
|
|
267
|
+
|
|
212
268
|
# Capsule Network for multi-interest extraction
|
|
213
269
|
self.capsule_network = CapsuleNetwork(
|
|
214
270
|
embedding_dim=seq_embedding_dim,
|
|
@@ -216,15 +272,17 @@ class MIND(BaseMatchModel):
|
|
|
216
272
|
bilinear_type=capsule_bilinear_type,
|
|
217
273
|
interest_num=num_interests,
|
|
218
274
|
routing_times=routing_times,
|
|
219
|
-
relu_layer=relu_layer
|
|
275
|
+
relu_layer=relu_layer,
|
|
220
276
|
)
|
|
221
|
-
|
|
277
|
+
|
|
222
278
|
if seq_embedding_dim != embedding_dim:
|
|
223
|
-
self.interest_projection = nn.Linear(
|
|
279
|
+
self.interest_projection = nn.Linear(
|
|
280
|
+
seq_embedding_dim, embedding_dim, bias=False
|
|
281
|
+
)
|
|
224
282
|
nn.init.xavier_uniform_(self.interest_projection.weight)
|
|
225
283
|
else:
|
|
226
284
|
self.interest_projection = None
|
|
227
|
-
|
|
285
|
+
|
|
228
286
|
# Item tower
|
|
229
287
|
item_features = []
|
|
230
288
|
if item_dense_features:
|
|
@@ -233,10 +291,10 @@ class MIND(BaseMatchModel):
|
|
|
233
291
|
item_features.extend(item_sparse_features)
|
|
234
292
|
if item_sequence_features:
|
|
235
293
|
item_features.extend(item_sequence_features)
|
|
236
|
-
|
|
294
|
+
|
|
237
295
|
if len(item_features) > 0:
|
|
238
296
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
239
|
-
|
|
297
|
+
|
|
240
298
|
item_input_dim = 0
|
|
241
299
|
for feat in item_dense_features or []:
|
|
242
300
|
item_input_dim += 1
|
|
@@ -244,7 +302,7 @@ class MIND(BaseMatchModel):
|
|
|
244
302
|
item_input_dim += feat.embedding_dim
|
|
245
303
|
for feat in item_sequence_features or []:
|
|
246
304
|
item_input_dim += feat.embedding_dim
|
|
247
|
-
|
|
305
|
+
|
|
248
306
|
# Item DNN
|
|
249
307
|
if len(item_dnn_hidden_units) > 0:
|
|
250
308
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
@@ -253,20 +311,19 @@ class MIND(BaseMatchModel):
|
|
|
253
311
|
dims=item_dnn_units,
|
|
254
312
|
output_layer=False,
|
|
255
313
|
dropout=dnn_dropout,
|
|
256
|
-
activation=dnn_activation
|
|
314
|
+
activation=dnn_activation,
|
|
257
315
|
)
|
|
258
316
|
else:
|
|
259
317
|
self.item_dnn = None
|
|
260
|
-
|
|
318
|
+
|
|
261
319
|
self.register_regularization_weights(
|
|
262
|
-
embedding_attr=
|
|
263
|
-
include_modules=['capsule_network']
|
|
320
|
+
embedding_attr="user_embedding", include_modules=["capsule_network"]
|
|
264
321
|
)
|
|
265
322
|
self.register_regularization_weights(
|
|
266
|
-
embedding_attr=
|
|
267
|
-
include_modules=[
|
|
323
|
+
embedding_attr="item_embedding",
|
|
324
|
+
include_modules=["item_dnn"] if self.item_dnn else [],
|
|
268
325
|
)
|
|
269
|
-
|
|
326
|
+
|
|
270
327
|
self.compile(
|
|
271
328
|
optimizer=optimizer,
|
|
272
329
|
optimizer_params=optimizer_params,
|
|
@@ -277,11 +334,11 @@ class MIND(BaseMatchModel):
|
|
|
277
334
|
)
|
|
278
335
|
|
|
279
336
|
self.to(device)
|
|
280
|
-
|
|
337
|
+
|
|
281
338
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
282
339
|
"""
|
|
283
340
|
User tower with multi-interest extraction
|
|
284
|
-
|
|
341
|
+
|
|
285
342
|
Returns:
|
|
286
343
|
user_interests: [batch_size, num_interests, embedding_dim]
|
|
287
344
|
"""
|
|
@@ -292,43 +349,53 @@ class MIND(BaseMatchModel):
|
|
|
292
349
|
seq_emb = embed(seq_input.long()) # [batch_size, seq_len, embedding_dim]
|
|
293
350
|
|
|
294
351
|
mask = (seq_input != seq_feature.padding_idx).float() # [batch_size, seq_len]
|
|
295
|
-
|
|
296
|
-
multi_interests = self.capsule_network(
|
|
297
|
-
|
|
352
|
+
|
|
353
|
+
multi_interests = self.capsule_network(
|
|
354
|
+
seq_emb, mask
|
|
355
|
+
) # [batch_size, num_interests, seq_embedding_dim]
|
|
356
|
+
|
|
298
357
|
if self.interest_projection is not None:
|
|
299
|
-
multi_interests = self.interest_projection(
|
|
300
|
-
|
|
358
|
+
multi_interests = self.interest_projection(
|
|
359
|
+
multi_interests
|
|
360
|
+
) # [batch_size, num_interests, embedding_dim]
|
|
361
|
+
|
|
301
362
|
# L2 normalization
|
|
302
363
|
multi_interests = F.normalize(multi_interests, p=2, dim=-1)
|
|
303
|
-
|
|
364
|
+
|
|
304
365
|
return multi_interests
|
|
305
|
-
|
|
366
|
+
|
|
306
367
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
307
368
|
"""Item tower"""
|
|
308
|
-
all_item_features =
|
|
369
|
+
all_item_features = (
|
|
370
|
+
self.item_dense_features
|
|
371
|
+
+ self.item_sparse_features
|
|
372
|
+
+ self.item_sequence_features
|
|
373
|
+
)
|
|
309
374
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
310
|
-
|
|
375
|
+
|
|
311
376
|
if self.item_dnn is not None:
|
|
312
377
|
item_emb = self.item_dnn(item_emb)
|
|
313
|
-
|
|
378
|
+
|
|
314
379
|
# L2 normalization
|
|
315
380
|
item_emb = F.normalize(item_emb, p=2, dim=1)
|
|
316
|
-
|
|
381
|
+
|
|
317
382
|
return item_emb
|
|
318
|
-
|
|
319
|
-
def compute_similarity(
|
|
383
|
+
|
|
384
|
+
def compute_similarity(
|
|
385
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
386
|
+
) -> torch.Tensor:
|
|
320
387
|
item_emb_expanded = item_emb.unsqueeze(1)
|
|
321
|
-
|
|
322
|
-
if self.similarity_metric ==
|
|
388
|
+
|
|
389
|
+
if self.similarity_metric == "dot":
|
|
323
390
|
similarities = torch.sum(user_emb * item_emb_expanded, dim=-1)
|
|
324
|
-
elif self.similarity_metric ==
|
|
391
|
+
elif self.similarity_metric == "cosine":
|
|
325
392
|
similarities = F.cosine_similarity(user_emb, item_emb_expanded, dim=-1)
|
|
326
|
-
elif self.similarity_metric ==
|
|
393
|
+
elif self.similarity_metric == "euclidean":
|
|
327
394
|
similarities = -torch.sum((user_emb - item_emb_expanded) ** 2, dim=-1)
|
|
328
395
|
else:
|
|
329
396
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
330
397
|
|
|
331
398
|
max_similarity, _ = torch.max(similarities, dim=1) # [batch_size]
|
|
332
399
|
max_similarity = max_similarity / self.temperature
|
|
333
|
-
|
|
400
|
+
|
|
334
401
|
return max_similarity
|