nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/models/match/dssm.py
CHANGED
|
@@ -6,9 +6,10 @@ Reference:
|
|
|
6
6
|
[1] Huang P S, He X, Gao J, et al. Learning deep structured semantic models for web search using clickthrough data[C]
|
|
7
7
|
//Proceedings of the 22nd ACM international conference on Information & Knowledge Management. 2013: 2333-2338.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
11
|
-
from typing import
|
|
12
|
+
from typing import Literal
|
|
12
13
|
|
|
13
14
|
from nextrec.basic.model import BaseMatchModel
|
|
14
15
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
@@ -18,45 +19,52 @@ from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
|
18
19
|
class DSSM(BaseMatchModel):
|
|
19
20
|
"""
|
|
20
21
|
Deep Structured Semantic Model
|
|
21
|
-
|
|
22
|
+
|
|
22
23
|
Dual-tower model that encodes user and item features separately and
|
|
23
24
|
computes similarity via cosine or dot product.
|
|
24
25
|
"""
|
|
25
|
-
|
|
26
|
+
|
|
26
27
|
@property
|
|
27
28
|
def model_name(self) -> str:
|
|
28
29
|
return "DSSM"
|
|
29
|
-
|
|
30
|
-
def __init__(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
34
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
35
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
36
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
37
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
38
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
39
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
40
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
41
|
+
embedding_dim: int = 64,
|
|
42
|
+
dnn_activation: str = "relu",
|
|
43
|
+
dnn_dropout: float = 0.0,
|
|
44
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
45
|
+
num_negative_samples: int = 4,
|
|
46
|
+
temperature: float = 1.0,
|
|
47
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "cosine",
|
|
48
|
+
device: str = "cpu",
|
|
49
|
+
embedding_l1_reg: float = 0.0,
|
|
50
|
+
dense_l1_reg: float = 0.0,
|
|
51
|
+
embedding_l2_reg: float = 0.0,
|
|
52
|
+
dense_l2_reg: float = 0.0,
|
|
53
|
+
early_stop_patience: int = 20,
|
|
54
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
55
|
+
optimizer_params: dict | None = None,
|
|
56
|
+
scheduler: (
|
|
57
|
+
str
|
|
58
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
59
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
60
|
+
| None
|
|
61
|
+
) = None,
|
|
62
|
+
scheduler_params: dict | None = None,
|
|
63
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
64
|
+
loss_params: dict | list[dict] | None = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
|
|
60
68
|
super(DSSM, self).__init__(
|
|
61
69
|
user_dense_features=user_dense_features,
|
|
62
70
|
user_sparse_features=user_sparse_features,
|
|
@@ -73,14 +81,13 @@ class DSSM(BaseMatchModel):
|
|
|
73
81
|
dense_l1_reg=dense_l1_reg,
|
|
74
82
|
embedding_l2_reg=embedding_l2_reg,
|
|
75
83
|
dense_l2_reg=dense_l2_reg,
|
|
76
|
-
|
|
77
|
-
**kwargs
|
|
84
|
+
**kwargs,
|
|
78
85
|
)
|
|
79
|
-
|
|
86
|
+
|
|
80
87
|
self.embedding_dim = embedding_dim
|
|
81
88
|
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
82
89
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
83
|
-
|
|
90
|
+
|
|
84
91
|
# User tower embedding layer
|
|
85
92
|
user_features = []
|
|
86
93
|
if user_dense_features:
|
|
@@ -89,10 +96,10 @@ class DSSM(BaseMatchModel):
|
|
|
89
96
|
user_features.extend(user_sparse_features)
|
|
90
97
|
if user_sequence_features:
|
|
91
98
|
user_features.extend(user_sequence_features)
|
|
92
|
-
|
|
99
|
+
|
|
93
100
|
if len(user_features) > 0:
|
|
94
101
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
95
|
-
|
|
102
|
+
|
|
96
103
|
# Compute user tower input dimension
|
|
97
104
|
user_input_dim = 0
|
|
98
105
|
for feat in user_dense_features or []:
|
|
@@ -101,7 +108,7 @@ class DSSM(BaseMatchModel):
|
|
|
101
108
|
user_input_dim += feat.embedding_dim
|
|
102
109
|
for feat in user_sequence_features or []:
|
|
103
110
|
user_input_dim += feat.embedding_dim
|
|
104
|
-
|
|
111
|
+
|
|
105
112
|
# User DNN
|
|
106
113
|
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
107
114
|
self.user_dnn = MLP(
|
|
@@ -109,9 +116,9 @@ class DSSM(BaseMatchModel):
|
|
|
109
116
|
dims=user_dnn_units,
|
|
110
117
|
output_layer=False,
|
|
111
118
|
dropout=dnn_dropout,
|
|
112
|
-
activation=dnn_activation
|
|
119
|
+
activation=dnn_activation,
|
|
113
120
|
)
|
|
114
|
-
|
|
121
|
+
|
|
115
122
|
# Item tower embedding layer
|
|
116
123
|
item_features = []
|
|
117
124
|
if item_dense_features:
|
|
@@ -120,10 +127,10 @@ class DSSM(BaseMatchModel):
|
|
|
120
127
|
item_features.extend(item_sparse_features)
|
|
121
128
|
if item_sequence_features:
|
|
122
129
|
item_features.extend(item_sequence_features)
|
|
123
|
-
|
|
130
|
+
|
|
124
131
|
if len(item_features) > 0:
|
|
125
132
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
126
|
-
|
|
133
|
+
|
|
127
134
|
# Compute item tower input dimension
|
|
128
135
|
item_input_dim = 0
|
|
129
136
|
for feat in item_dense_features or []:
|
|
@@ -132,7 +139,7 @@ class DSSM(BaseMatchModel):
|
|
|
132
139
|
item_input_dim += feat.embedding_dim
|
|
133
140
|
for feat in item_sequence_features or []:
|
|
134
141
|
item_input_dim += feat.embedding_dim
|
|
135
|
-
|
|
142
|
+
|
|
136
143
|
# Item DNN
|
|
137
144
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
138
145
|
self.item_dnn = MLP(
|
|
@@ -140,18 +147,16 @@ class DSSM(BaseMatchModel):
|
|
|
140
147
|
dims=item_dnn_units,
|
|
141
148
|
output_layer=False,
|
|
142
149
|
dropout=dnn_dropout,
|
|
143
|
-
activation=dnn_activation
|
|
150
|
+
activation=dnn_activation,
|
|
144
151
|
)
|
|
145
|
-
|
|
152
|
+
|
|
146
153
|
self.register_regularization_weights(
|
|
147
|
-
embedding_attr=
|
|
148
|
-
include_modules=['user_dnn']
|
|
154
|
+
embedding_attr="user_embedding", include_modules=["user_dnn"]
|
|
149
155
|
)
|
|
150
156
|
self.register_regularization_weights(
|
|
151
|
-
embedding_attr=
|
|
152
|
-
include_modules=['item_dnn']
|
|
157
|
+
embedding_attr="item_embedding", include_modules=["item_dnn"]
|
|
153
158
|
)
|
|
154
|
-
|
|
159
|
+
|
|
155
160
|
if optimizer_params is None:
|
|
156
161
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
157
162
|
|
|
@@ -165,45 +170,53 @@ class DSSM(BaseMatchModel):
|
|
|
165
170
|
)
|
|
166
171
|
|
|
167
172
|
self.to(device)
|
|
168
|
-
|
|
173
|
+
|
|
169
174
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
170
175
|
"""
|
|
171
176
|
User tower encodes user features into embeddings.
|
|
172
|
-
|
|
177
|
+
|
|
173
178
|
Args:
|
|
174
179
|
user_input: user feature dict
|
|
175
|
-
|
|
180
|
+
|
|
176
181
|
Returns:
|
|
177
182
|
user_emb: [batch_size, embedding_dim]
|
|
178
183
|
"""
|
|
179
|
-
all_user_features =
|
|
184
|
+
all_user_features = (
|
|
185
|
+
self.user_dense_features
|
|
186
|
+
+ self.user_sparse_features
|
|
187
|
+
+ self.user_sequence_features
|
|
188
|
+
)
|
|
180
189
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
181
|
-
|
|
190
|
+
|
|
182
191
|
user_emb = self.user_dnn(user_emb)
|
|
183
|
-
|
|
192
|
+
|
|
184
193
|
# L2 normalize for cosine similarity
|
|
185
|
-
if self.similarity_metric ==
|
|
194
|
+
if self.similarity_metric == "cosine":
|
|
186
195
|
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
187
|
-
|
|
196
|
+
|
|
188
197
|
return user_emb
|
|
189
|
-
|
|
198
|
+
|
|
190
199
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
191
200
|
"""
|
|
192
201
|
Item tower encodes item features into embeddings.
|
|
193
|
-
|
|
202
|
+
|
|
194
203
|
Args:
|
|
195
204
|
item_input: item feature dict
|
|
196
|
-
|
|
205
|
+
|
|
197
206
|
Returns:
|
|
198
207
|
item_emb: [batch_size, embedding_dim] or [batch_size, num_items, embedding_dim]
|
|
199
208
|
"""
|
|
200
|
-
all_item_features =
|
|
209
|
+
all_item_features = (
|
|
210
|
+
self.item_dense_features
|
|
211
|
+
+ self.item_sparse_features
|
|
212
|
+
+ self.item_sequence_features
|
|
213
|
+
)
|
|
201
214
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
202
|
-
|
|
215
|
+
|
|
203
216
|
item_emb = self.item_dnn(item_emb)
|
|
204
|
-
|
|
217
|
+
|
|
205
218
|
# L2 normalize for cosine similarity
|
|
206
|
-
if self.similarity_metric ==
|
|
219
|
+
if self.similarity_metric == "cosine":
|
|
207
220
|
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
208
|
-
|
|
221
|
+
|
|
209
222
|
return item_emb
|
nextrec/models/match/dssm_v2.py
CHANGED
|
@@ -5,6 +5,7 @@ Author:
|
|
|
5
5
|
Reference:
|
|
6
6
|
DSSM v2 - DSSM with pairwise training using BPR loss
|
|
7
7
|
"""
|
|
8
|
+
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
10
11
|
from typing import Literal
|
|
@@ -18,40 +19,48 @@ class DSSM_v2(BaseMatchModel):
|
|
|
18
19
|
"""
|
|
19
20
|
DSSM with Pairwise Training
|
|
20
21
|
"""
|
|
22
|
+
|
|
21
23
|
@property
|
|
22
24
|
def model_name(self) -> str:
|
|
23
25
|
return "DSSM_v2"
|
|
24
|
-
|
|
25
|
-
def __init__(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
30
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
31
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
32
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
33
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
34
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
35
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
36
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
37
|
+
embedding_dim: int = 64,
|
|
38
|
+
dnn_activation: str = "relu",
|
|
39
|
+
dnn_dropout: float = 0.0,
|
|
40
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pairwise",
|
|
41
|
+
num_negative_samples: int = 4,
|
|
42
|
+
temperature: float = 1.0,
|
|
43
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
44
|
+
device: str = "cpu",
|
|
45
|
+
embedding_l1_reg: float = 0.0,
|
|
46
|
+
dense_l1_reg: float = 0.0,
|
|
47
|
+
embedding_l2_reg: float = 0.0,
|
|
48
|
+
dense_l2_reg: float = 0.0,
|
|
49
|
+
early_stop_patience: int = 20,
|
|
50
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
51
|
+
optimizer_params: dict | None = None,
|
|
52
|
+
scheduler: (
|
|
53
|
+
str
|
|
54
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
55
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
56
|
+
| None
|
|
57
|
+
) = None,
|
|
58
|
+
scheduler_params: dict | None = None,
|
|
59
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
60
|
+
loss_params: dict | list[dict] | None = None,
|
|
61
|
+
**kwargs,
|
|
62
|
+
):
|
|
63
|
+
|
|
55
64
|
super(DSSM_v2, self).__init__(
|
|
56
65
|
user_dense_features=user_dense_features,
|
|
57
66
|
user_sparse_features=user_sparse_features,
|
|
@@ -68,14 +77,13 @@ class DSSM_v2(BaseMatchModel):
|
|
|
68
77
|
dense_l1_reg=dense_l1_reg,
|
|
69
78
|
embedding_l2_reg=embedding_l2_reg,
|
|
70
79
|
dense_l2_reg=dense_l2_reg,
|
|
71
|
-
|
|
72
|
-
**kwargs
|
|
80
|
+
**kwargs,
|
|
73
81
|
)
|
|
74
|
-
|
|
82
|
+
|
|
75
83
|
self.embedding_dim = embedding_dim
|
|
76
84
|
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
77
85
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
78
|
-
|
|
86
|
+
|
|
79
87
|
# User tower
|
|
80
88
|
user_features = []
|
|
81
89
|
if user_dense_features:
|
|
@@ -84,10 +92,10 @@ class DSSM_v2(BaseMatchModel):
|
|
|
84
92
|
user_features.extend(user_sparse_features)
|
|
85
93
|
if user_sequence_features:
|
|
86
94
|
user_features.extend(user_sequence_features)
|
|
87
|
-
|
|
95
|
+
|
|
88
96
|
if len(user_features) > 0:
|
|
89
97
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
90
|
-
|
|
98
|
+
|
|
91
99
|
user_input_dim = 0
|
|
92
100
|
for feat in user_dense_features or []:
|
|
93
101
|
user_input_dim += 1
|
|
@@ -95,16 +103,16 @@ class DSSM_v2(BaseMatchModel):
|
|
|
95
103
|
user_input_dim += feat.embedding_dim
|
|
96
104
|
for feat in user_sequence_features or []:
|
|
97
105
|
user_input_dim += feat.embedding_dim
|
|
98
|
-
|
|
106
|
+
|
|
99
107
|
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
100
108
|
self.user_dnn = MLP(
|
|
101
109
|
input_dim=user_input_dim,
|
|
102
110
|
dims=user_dnn_units,
|
|
103
111
|
output_layer=False,
|
|
104
112
|
dropout=dnn_dropout,
|
|
105
|
-
activation=dnn_activation
|
|
113
|
+
activation=dnn_activation,
|
|
106
114
|
)
|
|
107
|
-
|
|
115
|
+
|
|
108
116
|
# Item tower
|
|
109
117
|
item_features = []
|
|
110
118
|
if item_dense_features:
|
|
@@ -113,10 +121,10 @@ class DSSM_v2(BaseMatchModel):
|
|
|
113
121
|
item_features.extend(item_sparse_features)
|
|
114
122
|
if item_sequence_features:
|
|
115
123
|
item_features.extend(item_sequence_features)
|
|
116
|
-
|
|
124
|
+
|
|
117
125
|
if len(item_features) > 0:
|
|
118
126
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
119
|
-
|
|
127
|
+
|
|
120
128
|
item_input_dim = 0
|
|
121
129
|
for feat in item_dense_features or []:
|
|
122
130
|
item_input_dim += 1
|
|
@@ -124,25 +132,23 @@ class DSSM_v2(BaseMatchModel):
|
|
|
124
132
|
item_input_dim += feat.embedding_dim
|
|
125
133
|
for feat in item_sequence_features or []:
|
|
126
134
|
item_input_dim += feat.embedding_dim
|
|
127
|
-
|
|
135
|
+
|
|
128
136
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
129
137
|
self.item_dnn = MLP(
|
|
130
138
|
input_dim=item_input_dim,
|
|
131
139
|
dims=item_dnn_units,
|
|
132
140
|
output_layer=False,
|
|
133
141
|
dropout=dnn_dropout,
|
|
134
|
-
activation=dnn_activation
|
|
142
|
+
activation=dnn_activation,
|
|
135
143
|
)
|
|
136
|
-
|
|
144
|
+
|
|
137
145
|
self.register_regularization_weights(
|
|
138
|
-
embedding_attr=
|
|
139
|
-
include_modules=['user_dnn']
|
|
146
|
+
embedding_attr="user_embedding", include_modules=["user_dnn"]
|
|
140
147
|
)
|
|
141
148
|
self.register_regularization_weights(
|
|
142
|
-
embedding_attr=
|
|
143
|
-
include_modules=['item_dnn']
|
|
149
|
+
embedding_attr="item_embedding", include_modules=["item_dnn"]
|
|
144
150
|
)
|
|
145
|
-
|
|
151
|
+
|
|
146
152
|
if optimizer_params is None:
|
|
147
153
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
148
154
|
|
|
@@ -156,25 +162,33 @@ class DSSM_v2(BaseMatchModel):
|
|
|
156
162
|
)
|
|
157
163
|
|
|
158
164
|
self.to(device)
|
|
159
|
-
|
|
165
|
+
|
|
160
166
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
161
167
|
"""User tower"""
|
|
162
|
-
all_user_features =
|
|
168
|
+
all_user_features = (
|
|
169
|
+
self.user_dense_features
|
|
170
|
+
+ self.user_sparse_features
|
|
171
|
+
+ self.user_sequence_features
|
|
172
|
+
)
|
|
163
173
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
164
174
|
user_emb = self.user_dnn(user_emb)
|
|
165
|
-
|
|
175
|
+
|
|
166
176
|
# Normalization for better pairwise training
|
|
167
177
|
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
168
|
-
|
|
178
|
+
|
|
169
179
|
return user_emb
|
|
170
|
-
|
|
180
|
+
|
|
171
181
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
172
182
|
"""Item tower"""
|
|
173
|
-
all_item_features =
|
|
183
|
+
all_item_features = (
|
|
184
|
+
self.item_dense_features
|
|
185
|
+
+ self.item_sparse_features
|
|
186
|
+
+ self.item_sequence_features
|
|
187
|
+
)
|
|
174
188
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
175
189
|
item_emb = self.item_dnn(item_emb)
|
|
176
|
-
|
|
190
|
+
|
|
177
191
|
# Normalization for better pairwise training
|
|
178
192
|
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
179
|
-
|
|
193
|
+
|
|
180
194
|
return item_emb
|