nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +220 -106
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1082 -400
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +272 -95
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +53 -37
- nextrec/models/multi_task/mmoe.py +64 -45
- nextrec/models/multi_task/ple.py +101 -48
- nextrec/models/multi_task/poso.py +113 -36
- nextrec/models/multi_task/share_bottom.py +48 -35
- nextrec/models/ranking/afm.py +72 -37
- nextrec/models/ranking/autoint.py +72 -55
- nextrec/models/ranking/dcn.py +55 -35
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +32 -22
- nextrec/models/ranking/dien.py +155 -99
- nextrec/models/ranking/din.py +85 -57
- nextrec/models/ranking/fibinet.py +52 -32
- nextrec/models/ranking/fm.py +29 -23
- nextrec/models/ranking/masknet.py +91 -29
- nextrec/models/ranking/pnn.py +31 -28
- nextrec/models/ranking/widedeep.py +34 -26
- nextrec/models/ranking/xdeepfm.py +60 -38
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/models/match/sdm.py
CHANGED
|
@@ -6,6 +6,7 @@ Reference:
|
|
|
6
6
|
[1] Ying H, Zhuang F, Zhang F, et al. Sequential recommender system based on hierarchical attention networks[C]
|
|
7
7
|
//IJCAI. 2018: 3926-3932.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
11
12
|
import torch.nn.functional as F
|
|
@@ -20,46 +21,53 @@ class SDM(BaseMatchModel):
|
|
|
20
21
|
@property
|
|
21
22
|
def model_name(self) -> str:
|
|
22
23
|
return "SDM"
|
|
23
|
-
|
|
24
|
+
|
|
24
25
|
@property
|
|
25
26
|
def support_training_modes(self) -> list[str]:
|
|
26
|
-
return [
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
27
|
+
return ["pointwise"]
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
32
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
33
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
34
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
35
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
36
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
37
|
+
embedding_dim: int = 64,
|
|
38
|
+
rnn_type: Literal["GRU", "LSTM"] = "GRU",
|
|
39
|
+
rnn_hidden_size: int = 64,
|
|
40
|
+
rnn_num_layers: int = 1,
|
|
41
|
+
rnn_dropout: float = 0.0,
|
|
42
|
+
use_short_term: bool = True,
|
|
43
|
+
use_long_term: bool = True,
|
|
44
|
+
item_dnn_hidden_units: list[int] = [256, 128],
|
|
45
|
+
dnn_activation: str = "relu",
|
|
46
|
+
dnn_dropout: float = 0.0,
|
|
47
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
48
|
+
num_negative_samples: int = 4,
|
|
49
|
+
temperature: float = 1.0,
|
|
50
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
51
|
+
device: str = "cpu",
|
|
52
|
+
embedding_l1_reg: float = 0.0,
|
|
53
|
+
dense_l1_reg: float = 0.0,
|
|
54
|
+
embedding_l2_reg: float = 0.0,
|
|
55
|
+
dense_l2_reg: float = 0.0,
|
|
56
|
+
early_stop_patience: int = 20,
|
|
57
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
58
|
+
optimizer_params: dict | None = None,
|
|
59
|
+
scheduler: (
|
|
60
|
+
str
|
|
61
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
62
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
63
|
+
| None
|
|
64
|
+
) = None,
|
|
65
|
+
scheduler_params: dict | None = None,
|
|
66
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
67
|
+
loss_params: dict | list[dict] | None = None,
|
|
68
|
+
**kwargs,
|
|
69
|
+
):
|
|
70
|
+
|
|
63
71
|
super(SDM, self).__init__(
|
|
64
72
|
user_dense_features=user_dense_features,
|
|
65
73
|
user_sparse_features=user_sparse_features,
|
|
@@ -76,16 +84,16 @@ class SDM(BaseMatchModel):
|
|
|
76
84
|
dense_l1_reg=dense_l1_reg,
|
|
77
85
|
embedding_l2_reg=embedding_l2_reg,
|
|
78
86
|
dense_l2_reg=dense_l2_reg,
|
|
79
|
-
**kwargs
|
|
87
|
+
**kwargs,
|
|
80
88
|
)
|
|
81
|
-
|
|
89
|
+
|
|
82
90
|
self.embedding_dim = embedding_dim
|
|
83
91
|
self.rnn_type = rnn_type
|
|
84
92
|
self.rnn_hidden_size = rnn_hidden_size
|
|
85
93
|
self.use_short_term = use_short_term
|
|
86
94
|
self.use_long_term = use_long_term
|
|
87
95
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
88
|
-
|
|
96
|
+
|
|
89
97
|
# User tower
|
|
90
98
|
user_features = []
|
|
91
99
|
if user_dense_features:
|
|
@@ -94,54 +102,54 @@ class SDM(BaseMatchModel):
|
|
|
94
102
|
user_features.extend(user_sparse_features)
|
|
95
103
|
if user_sequence_features:
|
|
96
104
|
user_features.extend(user_sequence_features)
|
|
97
|
-
|
|
105
|
+
|
|
98
106
|
if len(user_features) > 0:
|
|
99
107
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
100
|
-
|
|
108
|
+
|
|
101
109
|
if not user_sequence_features or len(user_sequence_features) == 0:
|
|
102
110
|
raise ValueError("SDM requires at least one user sequence feature")
|
|
103
|
-
|
|
111
|
+
|
|
104
112
|
seq_emb_dim = user_sequence_features[0].embedding_dim
|
|
105
|
-
|
|
106
|
-
if rnn_type ==
|
|
113
|
+
|
|
114
|
+
if rnn_type == "GRU":
|
|
107
115
|
self.rnn = nn.GRU(
|
|
108
116
|
input_size=seq_emb_dim,
|
|
109
117
|
hidden_size=rnn_hidden_size,
|
|
110
118
|
num_layers=rnn_num_layers,
|
|
111
119
|
batch_first=True,
|
|
112
|
-
dropout=rnn_dropout if rnn_num_layers > 1 else 0.0
|
|
120
|
+
dropout=rnn_dropout if rnn_num_layers > 1 else 0.0,
|
|
113
121
|
)
|
|
114
|
-
elif rnn_type ==
|
|
122
|
+
elif rnn_type == "LSTM":
|
|
115
123
|
self.rnn = nn.LSTM(
|
|
116
124
|
input_size=seq_emb_dim,
|
|
117
125
|
hidden_size=rnn_hidden_size,
|
|
118
126
|
num_layers=rnn_num_layers,
|
|
119
127
|
batch_first=True,
|
|
120
|
-
dropout=rnn_dropout if rnn_num_layers > 1 else 0.0
|
|
128
|
+
dropout=rnn_dropout if rnn_num_layers > 1 else 0.0,
|
|
121
129
|
)
|
|
122
130
|
else:
|
|
123
131
|
raise ValueError(f"Unknown RNN type: {rnn_type}")
|
|
124
|
-
|
|
132
|
+
|
|
125
133
|
user_final_dim = 0
|
|
126
134
|
if use_long_term:
|
|
127
|
-
user_final_dim += rnn_hidden_size
|
|
135
|
+
user_final_dim += rnn_hidden_size
|
|
128
136
|
if use_short_term:
|
|
129
|
-
user_final_dim += seq_emb_dim
|
|
130
|
-
|
|
137
|
+
user_final_dim += seq_emb_dim
|
|
138
|
+
|
|
131
139
|
for feat in user_dense_features or []:
|
|
132
140
|
user_final_dim += 1
|
|
133
141
|
for feat in user_sparse_features or []:
|
|
134
142
|
user_final_dim += feat.embedding_dim
|
|
135
|
-
|
|
143
|
+
|
|
136
144
|
# User DNN to final embedding
|
|
137
145
|
self.user_dnn = MLP(
|
|
138
146
|
input_dim=user_final_dim,
|
|
139
147
|
dims=[rnn_hidden_size * 2, embedding_dim],
|
|
140
148
|
output_layer=False,
|
|
141
149
|
dropout=dnn_dropout,
|
|
142
|
-
activation=dnn_activation
|
|
150
|
+
activation=dnn_activation,
|
|
143
151
|
)
|
|
144
|
-
|
|
152
|
+
|
|
145
153
|
# Item tower
|
|
146
154
|
item_features = []
|
|
147
155
|
if item_dense_features:
|
|
@@ -150,10 +158,10 @@ class SDM(BaseMatchModel):
|
|
|
150
158
|
item_features.extend(item_sparse_features)
|
|
151
159
|
if item_sequence_features:
|
|
152
160
|
item_features.extend(item_sequence_features)
|
|
153
|
-
|
|
161
|
+
|
|
154
162
|
if len(item_features) > 0:
|
|
155
163
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
156
|
-
|
|
164
|
+
|
|
157
165
|
item_input_dim = 0
|
|
158
166
|
for feat in item_dense_features or []:
|
|
159
167
|
item_input_dim += 1
|
|
@@ -161,7 +169,7 @@ class SDM(BaseMatchModel):
|
|
|
161
169
|
item_input_dim += feat.embedding_dim
|
|
162
170
|
for feat in item_sequence_features or []:
|
|
163
171
|
item_input_dim += feat.embedding_dim
|
|
164
|
-
|
|
172
|
+
|
|
165
173
|
# Item DNN
|
|
166
174
|
if len(item_dnn_hidden_units) > 0:
|
|
167
175
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
@@ -170,20 +178,19 @@ class SDM(BaseMatchModel):
|
|
|
170
178
|
dims=item_dnn_units,
|
|
171
179
|
output_layer=False,
|
|
172
180
|
dropout=dnn_dropout,
|
|
173
|
-
activation=dnn_activation
|
|
181
|
+
activation=dnn_activation,
|
|
174
182
|
)
|
|
175
183
|
else:
|
|
176
184
|
self.item_dnn = None
|
|
177
|
-
|
|
185
|
+
|
|
178
186
|
self.register_regularization_weights(
|
|
179
|
-
embedding_attr=
|
|
180
|
-
include_modules=['rnn', 'user_dnn']
|
|
187
|
+
embedding_attr="user_embedding", include_modules=["rnn", "user_dnn"]
|
|
181
188
|
)
|
|
182
189
|
self.register_regularization_weights(
|
|
183
|
-
embedding_attr=
|
|
184
|
-
include_modules=[
|
|
190
|
+
embedding_attr="item_embedding",
|
|
191
|
+
include_modules=["item_dnn"] if self.item_dnn else [],
|
|
185
192
|
)
|
|
186
|
-
|
|
193
|
+
|
|
187
194
|
self.compile(
|
|
188
195
|
optimizer=optimizer,
|
|
189
196
|
optimizer_params=optimizer_params,
|
|
@@ -194,38 +201,44 @@ class SDM(BaseMatchModel):
|
|
|
194
201
|
)
|
|
195
202
|
|
|
196
203
|
self.to(device)
|
|
197
|
-
|
|
204
|
+
|
|
198
205
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
199
206
|
seq_feature = self.user_sequence_features[0]
|
|
200
207
|
seq_input = user_input[seq_feature.name]
|
|
201
|
-
|
|
208
|
+
|
|
202
209
|
embed = self.user_embedding.embed_dict[seq_feature.embedding_name]
|
|
203
210
|
seq_emb = embed(seq_input.long()) # [batch_size, seq_len, seq_emb_dim]
|
|
204
|
-
|
|
205
|
-
if self.rnn_type ==
|
|
206
|
-
rnn_output, hidden = self.rnn(
|
|
207
|
-
|
|
211
|
+
|
|
212
|
+
if self.rnn_type == "GRU":
|
|
213
|
+
rnn_output, hidden = self.rnn(
|
|
214
|
+
seq_emb
|
|
215
|
+
) # hidden: [num_layers, batch, hidden_size]
|
|
216
|
+
elif self.rnn_type == "LSTM":
|
|
208
217
|
rnn_output, (hidden, cell) = self.rnn(seq_emb)
|
|
209
|
-
|
|
218
|
+
|
|
210
219
|
features_list = []
|
|
211
|
-
|
|
220
|
+
|
|
212
221
|
if self.use_long_term:
|
|
213
222
|
if self.rnn.num_layers > 1:
|
|
214
223
|
long_term = hidden[-1, :, :] # [batch_size, hidden_size]
|
|
215
224
|
else:
|
|
216
225
|
long_term = hidden.squeeze(0) # [batch_size, hidden_size]
|
|
217
226
|
features_list.append(long_term)
|
|
218
|
-
|
|
227
|
+
|
|
219
228
|
if self.use_short_term:
|
|
220
|
-
mask = (
|
|
229
|
+
mask = (
|
|
230
|
+
seq_input != seq_feature.padding_idx
|
|
231
|
+
).float() # [batch_size, seq_len]
|
|
221
232
|
seq_lengths = mask.sum(dim=1).long() - 1 # [batch_size]
|
|
222
233
|
seq_lengths = torch.clamp(seq_lengths, min=0)
|
|
223
|
-
|
|
234
|
+
|
|
224
235
|
batch_size = seq_emb.size(0)
|
|
225
236
|
batch_indices = torch.arange(batch_size, device=seq_emb.device)
|
|
226
|
-
short_term = seq_emb[
|
|
237
|
+
short_term = seq_emb[
|
|
238
|
+
batch_indices, seq_lengths, :
|
|
239
|
+
] # [batch_size, seq_emb_dim]
|
|
227
240
|
features_list.append(short_term)
|
|
228
|
-
|
|
241
|
+
|
|
229
242
|
if self.user_dense_features:
|
|
230
243
|
dense_features = []
|
|
231
244
|
for feat in self.user_dense_features:
|
|
@@ -236,7 +249,7 @@ class SDM(BaseMatchModel):
|
|
|
236
249
|
dense_features.append(val)
|
|
237
250
|
if dense_features:
|
|
238
251
|
features_list.append(torch.cat(dense_features, dim=1))
|
|
239
|
-
|
|
252
|
+
|
|
240
253
|
if self.user_sparse_features:
|
|
241
254
|
sparse_features = []
|
|
242
255
|
for feat in self.user_sparse_features:
|
|
@@ -246,22 +259,26 @@ class SDM(BaseMatchModel):
|
|
|
246
259
|
sparse_features.append(sparse_emb)
|
|
247
260
|
if sparse_features:
|
|
248
261
|
features_list.append(torch.cat(sparse_features, dim=1))
|
|
249
|
-
|
|
262
|
+
|
|
250
263
|
user_features = torch.cat(features_list, dim=1)
|
|
251
264
|
user_emb = self.user_dnn(user_features)
|
|
252
265
|
user_emb = F.normalize(user_emb, p=2, dim=1)
|
|
253
|
-
|
|
266
|
+
|
|
254
267
|
return user_emb
|
|
255
|
-
|
|
268
|
+
|
|
256
269
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
257
270
|
"""Item tower"""
|
|
258
|
-
all_item_features =
|
|
271
|
+
all_item_features = (
|
|
272
|
+
self.item_dense_features
|
|
273
|
+
+ self.item_sparse_features
|
|
274
|
+
+ self.item_sequence_features
|
|
275
|
+
)
|
|
259
276
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
260
|
-
|
|
277
|
+
|
|
261
278
|
if self.item_dnn is not None:
|
|
262
279
|
item_emb = self.item_dnn(item_emb)
|
|
263
|
-
|
|
280
|
+
|
|
264
281
|
# L2 normalization
|
|
265
282
|
item_emb = F.normalize(item_emb, p=2, dim=1)
|
|
266
|
-
|
|
283
|
+
|
|
267
284
|
return item_emb
|
|
@@ -6,13 +6,14 @@ Reference:
|
|
|
6
6
|
[1] Covington P, Adams J, Sargin E. Deep neural networks for youtube recommendations[C]
|
|
7
7
|
//Proceedings of the 10th ACM conference on recommender systems. 2016: 191-198.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
11
12
|
from typing import Literal
|
|
12
13
|
|
|
13
14
|
from nextrec.basic.model import BaseMatchModel
|
|
14
15
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
15
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
16
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class YoutubeDNN(BaseMatchModel):
|
|
@@ -22,41 +23,48 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
22
23
|
Item tower: item features -> item embedding.
|
|
23
24
|
Training usually uses listwise / sampled softmax style objectives.
|
|
24
25
|
"""
|
|
25
|
-
|
|
26
|
+
|
|
26
27
|
@property
|
|
27
28
|
def model_name(self) -> str:
|
|
28
29
|
return "YouTubeDNN"
|
|
29
|
-
|
|
30
|
-
def __init__(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
34
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
35
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
36
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
37
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
38
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
39
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
40
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
41
|
+
embedding_dim: int = 64,
|
|
42
|
+
dnn_activation: str = "relu",
|
|
43
|
+
dnn_dropout: float = 0.0,
|
|
44
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "listwise",
|
|
45
|
+
num_negative_samples: int = 100,
|
|
46
|
+
temperature: float = 1.0,
|
|
47
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
48
|
+
device: str = "cpu",
|
|
49
|
+
embedding_l1_reg: float = 0.0,
|
|
50
|
+
dense_l1_reg: float = 0.0,
|
|
51
|
+
embedding_l2_reg: float = 0.0,
|
|
52
|
+
dense_l2_reg: float = 0.0,
|
|
53
|
+
early_stop_patience: int = 20,
|
|
54
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
55
|
+
optimizer_params: dict | None = None,
|
|
56
|
+
scheduler: (
|
|
57
|
+
str
|
|
58
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
59
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
60
|
+
| None
|
|
61
|
+
) = None,
|
|
62
|
+
scheduler_params: dict | None = None,
|
|
63
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
64
|
+
loss_params: dict | list[dict] | None = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
|
|
60
68
|
super(YoutubeDNN, self).__init__(
|
|
61
69
|
user_dense_features=user_dense_features,
|
|
62
70
|
user_sparse_features=user_sparse_features,
|
|
@@ -73,13 +81,13 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
73
81
|
dense_l1_reg=dense_l1_reg,
|
|
74
82
|
embedding_l2_reg=embedding_l2_reg,
|
|
75
83
|
dense_l2_reg=dense_l2_reg,
|
|
76
|
-
**kwargs
|
|
84
|
+
**kwargs,
|
|
77
85
|
)
|
|
78
|
-
|
|
86
|
+
|
|
79
87
|
self.embedding_dim = embedding_dim
|
|
80
88
|
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
81
89
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
82
|
-
|
|
90
|
+
|
|
83
91
|
# User tower
|
|
84
92
|
user_features = []
|
|
85
93
|
if user_dense_features:
|
|
@@ -88,10 +96,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
88
96
|
user_features.extend(user_sparse_features)
|
|
89
97
|
if user_sequence_features:
|
|
90
98
|
user_features.extend(user_sequence_features)
|
|
91
|
-
|
|
99
|
+
|
|
92
100
|
if len(user_features) > 0:
|
|
93
101
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
94
|
-
|
|
102
|
+
|
|
95
103
|
user_input_dim = 0
|
|
96
104
|
for feat in user_dense_features or []:
|
|
97
105
|
user_input_dim += 1
|
|
@@ -100,16 +108,16 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
100
108
|
for feat in user_sequence_features or []:
|
|
101
109
|
# Sequence features are pooled before entering the DNN
|
|
102
110
|
user_input_dim += feat.embedding_dim
|
|
103
|
-
|
|
111
|
+
|
|
104
112
|
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
105
113
|
self.user_dnn = MLP(
|
|
106
114
|
input_dim=user_input_dim,
|
|
107
115
|
dims=user_dnn_units,
|
|
108
116
|
output_layer=False,
|
|
109
117
|
dropout=dnn_dropout,
|
|
110
|
-
activation=dnn_activation
|
|
118
|
+
activation=dnn_activation,
|
|
111
119
|
)
|
|
112
|
-
|
|
120
|
+
|
|
113
121
|
# Item tower
|
|
114
122
|
item_features = []
|
|
115
123
|
if item_dense_features:
|
|
@@ -118,10 +126,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
118
126
|
item_features.extend(item_sparse_features)
|
|
119
127
|
if item_sequence_features:
|
|
120
128
|
item_features.extend(item_sequence_features)
|
|
121
|
-
|
|
129
|
+
|
|
122
130
|
if len(item_features) > 0:
|
|
123
131
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
124
|
-
|
|
132
|
+
|
|
125
133
|
item_input_dim = 0
|
|
126
134
|
for feat in item_dense_features or []:
|
|
127
135
|
item_input_dim += 1
|
|
@@ -129,25 +137,23 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
129
137
|
item_input_dim += feat.embedding_dim
|
|
130
138
|
for feat in item_sequence_features or []:
|
|
131
139
|
item_input_dim += feat.embedding_dim
|
|
132
|
-
|
|
140
|
+
|
|
133
141
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
134
142
|
self.item_dnn = MLP(
|
|
135
143
|
input_dim=item_input_dim,
|
|
136
144
|
dims=item_dnn_units,
|
|
137
145
|
output_layer=False,
|
|
138
146
|
dropout=dnn_dropout,
|
|
139
|
-
activation=dnn_activation
|
|
147
|
+
activation=dnn_activation,
|
|
140
148
|
)
|
|
141
|
-
|
|
149
|
+
|
|
142
150
|
self.register_regularization_weights(
|
|
143
|
-
embedding_attr=
|
|
144
|
-
include_modules=['user_dnn']
|
|
151
|
+
embedding_attr="user_embedding", include_modules=["user_dnn"]
|
|
145
152
|
)
|
|
146
153
|
self.register_regularization_weights(
|
|
147
|
-
embedding_attr=
|
|
148
|
-
include_modules=['item_dnn']
|
|
154
|
+
embedding_attr="item_embedding", include_modules=["item_dnn"]
|
|
149
155
|
)
|
|
150
|
-
|
|
156
|
+
|
|
151
157
|
self.compile(
|
|
152
158
|
optimizer=optimizer,
|
|
153
159
|
optimizer_params=optimizer_params,
|
|
@@ -158,27 +164,35 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
158
164
|
)
|
|
159
165
|
|
|
160
166
|
self.to(device)
|
|
161
|
-
|
|
167
|
+
|
|
162
168
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
163
169
|
"""
|
|
164
170
|
User tower to encode historical behavior sequences and user features.
|
|
165
171
|
"""
|
|
166
|
-
all_user_features =
|
|
172
|
+
all_user_features = (
|
|
173
|
+
self.user_dense_features
|
|
174
|
+
+ self.user_sparse_features
|
|
175
|
+
+ self.user_sequence_features
|
|
176
|
+
)
|
|
167
177
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
168
178
|
user_emb = self.user_dnn(user_emb)
|
|
169
|
-
|
|
179
|
+
|
|
170
180
|
# L2 normalization
|
|
171
181
|
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
172
|
-
|
|
182
|
+
|
|
173
183
|
return user_emb
|
|
174
|
-
|
|
184
|
+
|
|
175
185
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
176
186
|
"""Item tower"""
|
|
177
|
-
all_item_features =
|
|
187
|
+
all_item_features = (
|
|
188
|
+
self.item_dense_features
|
|
189
|
+
+ self.item_sparse_features
|
|
190
|
+
+ self.item_sequence_features
|
|
191
|
+
)
|
|
178
192
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
179
193
|
item_emb = self.item_dnn(item_emb)
|
|
180
|
-
|
|
194
|
+
|
|
181
195
|
# L2 normalization
|
|
182
196
|
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
183
|
-
|
|
197
|
+
|
|
184
198
|
return item_emb
|