nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/models/match/dssm.py
CHANGED
|
@@ -6,9 +6,10 @@ Reference:
|
|
|
6
6
|
[1] Huang P S, He X, Gao J, et al. Learning deep structured semantic models for web search using clickthrough data[C]
|
|
7
7
|
//Proceedings of the 22nd ACM international conference on Information & Knowledge Management. 2013: 2333-2338.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
11
|
-
from typing import
|
|
12
|
+
from typing import Literal
|
|
12
13
|
|
|
13
14
|
from nextrec.basic.model import BaseMatchModel
|
|
14
15
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
@@ -18,45 +19,52 @@ from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
|
18
19
|
class DSSM(BaseMatchModel):
|
|
19
20
|
"""
|
|
20
21
|
Deep Structured Semantic Model
|
|
21
|
-
|
|
22
|
+
|
|
22
23
|
Dual-tower model that encodes user and item features separately and
|
|
23
24
|
computes similarity via cosine or dot product.
|
|
24
25
|
"""
|
|
25
|
-
|
|
26
|
+
|
|
26
27
|
@property
|
|
27
28
|
def model_name(self) -> str:
|
|
28
29
|
return "DSSM"
|
|
29
|
-
|
|
30
|
-
def __init__(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
34
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
35
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
36
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
37
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
38
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
39
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
40
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
41
|
+
embedding_dim: int = 64,
|
|
42
|
+
dnn_activation: str = "relu",
|
|
43
|
+
dnn_dropout: float = 0.0,
|
|
44
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
45
|
+
num_negative_samples: int = 4,
|
|
46
|
+
temperature: float = 1.0,
|
|
47
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "cosine",
|
|
48
|
+
device: str = "cpu",
|
|
49
|
+
embedding_l1_reg: float = 0.0,
|
|
50
|
+
dense_l1_reg: float = 0.0,
|
|
51
|
+
embedding_l2_reg: float = 0.0,
|
|
52
|
+
dense_l2_reg: float = 0.0,
|
|
53
|
+
early_stop_patience: int = 20,
|
|
54
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
55
|
+
optimizer_params: dict | None = None,
|
|
56
|
+
scheduler: (
|
|
57
|
+
str
|
|
58
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
59
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
60
|
+
| None
|
|
61
|
+
) = None,
|
|
62
|
+
scheduler_params: dict | None = None,
|
|
63
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
64
|
+
loss_params: dict | list[dict] | None = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
|
|
60
68
|
super(DSSM, self).__init__(
|
|
61
69
|
user_dense_features=user_dense_features,
|
|
62
70
|
user_sparse_features=user_sparse_features,
|
|
@@ -73,13 +81,13 @@ class DSSM(BaseMatchModel):
|
|
|
73
81
|
dense_l1_reg=dense_l1_reg,
|
|
74
82
|
embedding_l2_reg=embedding_l2_reg,
|
|
75
83
|
dense_l2_reg=dense_l2_reg,
|
|
76
|
-
**kwargs
|
|
84
|
+
**kwargs,
|
|
77
85
|
)
|
|
78
|
-
|
|
86
|
+
|
|
79
87
|
self.embedding_dim = embedding_dim
|
|
80
88
|
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
81
89
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
82
|
-
|
|
90
|
+
|
|
83
91
|
# User tower embedding layer
|
|
84
92
|
user_features = []
|
|
85
93
|
if user_dense_features:
|
|
@@ -88,10 +96,10 @@ class DSSM(BaseMatchModel):
|
|
|
88
96
|
user_features.extend(user_sparse_features)
|
|
89
97
|
if user_sequence_features:
|
|
90
98
|
user_features.extend(user_sequence_features)
|
|
91
|
-
|
|
99
|
+
|
|
92
100
|
if len(user_features) > 0:
|
|
93
101
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
94
|
-
|
|
102
|
+
|
|
95
103
|
# Compute user tower input dimension
|
|
96
104
|
user_input_dim = 0
|
|
97
105
|
for feat in user_dense_features or []:
|
|
@@ -100,7 +108,7 @@ class DSSM(BaseMatchModel):
|
|
|
100
108
|
user_input_dim += feat.embedding_dim
|
|
101
109
|
for feat in user_sequence_features or []:
|
|
102
110
|
user_input_dim += feat.embedding_dim
|
|
103
|
-
|
|
111
|
+
|
|
104
112
|
# User DNN
|
|
105
113
|
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
106
114
|
self.user_dnn = MLP(
|
|
@@ -108,9 +116,9 @@ class DSSM(BaseMatchModel):
|
|
|
108
116
|
dims=user_dnn_units,
|
|
109
117
|
output_layer=False,
|
|
110
118
|
dropout=dnn_dropout,
|
|
111
|
-
activation=dnn_activation
|
|
119
|
+
activation=dnn_activation,
|
|
112
120
|
)
|
|
113
|
-
|
|
121
|
+
|
|
114
122
|
# Item tower embedding layer
|
|
115
123
|
item_features = []
|
|
116
124
|
if item_dense_features:
|
|
@@ -119,10 +127,10 @@ class DSSM(BaseMatchModel):
|
|
|
119
127
|
item_features.extend(item_sparse_features)
|
|
120
128
|
if item_sequence_features:
|
|
121
129
|
item_features.extend(item_sequence_features)
|
|
122
|
-
|
|
130
|
+
|
|
123
131
|
if len(item_features) > 0:
|
|
124
132
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
125
|
-
|
|
133
|
+
|
|
126
134
|
# Compute item tower input dimension
|
|
127
135
|
item_input_dim = 0
|
|
128
136
|
for feat in item_dense_features or []:
|
|
@@ -131,7 +139,7 @@ class DSSM(BaseMatchModel):
|
|
|
131
139
|
item_input_dim += feat.embedding_dim
|
|
132
140
|
for feat in item_sequence_features or []:
|
|
133
141
|
item_input_dim += feat.embedding_dim
|
|
134
|
-
|
|
142
|
+
|
|
135
143
|
# Item DNN
|
|
136
144
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
137
145
|
self.item_dnn = MLP(
|
|
@@ -139,18 +147,16 @@ class DSSM(BaseMatchModel):
|
|
|
139
147
|
dims=item_dnn_units,
|
|
140
148
|
output_layer=False,
|
|
141
149
|
dropout=dnn_dropout,
|
|
142
|
-
activation=dnn_activation
|
|
150
|
+
activation=dnn_activation,
|
|
143
151
|
)
|
|
144
|
-
|
|
152
|
+
|
|
145
153
|
self.register_regularization_weights(
|
|
146
|
-
embedding_attr=
|
|
147
|
-
include_modules=['user_dnn']
|
|
154
|
+
embedding_attr="user_embedding", include_modules=["user_dnn"]
|
|
148
155
|
)
|
|
149
156
|
self.register_regularization_weights(
|
|
150
|
-
embedding_attr=
|
|
151
|
-
include_modules=['item_dnn']
|
|
157
|
+
embedding_attr="item_embedding", include_modules=["item_dnn"]
|
|
152
158
|
)
|
|
153
|
-
|
|
159
|
+
|
|
154
160
|
if optimizer_params is None:
|
|
155
161
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
156
162
|
|
|
@@ -164,45 +170,53 @@ class DSSM(BaseMatchModel):
|
|
|
164
170
|
)
|
|
165
171
|
|
|
166
172
|
self.to(device)
|
|
167
|
-
|
|
173
|
+
|
|
168
174
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
169
175
|
"""
|
|
170
176
|
User tower encodes user features into embeddings.
|
|
171
|
-
|
|
177
|
+
|
|
172
178
|
Args:
|
|
173
179
|
user_input: user feature dict
|
|
174
|
-
|
|
180
|
+
|
|
175
181
|
Returns:
|
|
176
182
|
user_emb: [batch_size, embedding_dim]
|
|
177
183
|
"""
|
|
178
|
-
all_user_features =
|
|
184
|
+
all_user_features = (
|
|
185
|
+
self.user_dense_features
|
|
186
|
+
+ self.user_sparse_features
|
|
187
|
+
+ self.user_sequence_features
|
|
188
|
+
)
|
|
179
189
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
180
|
-
|
|
190
|
+
|
|
181
191
|
user_emb = self.user_dnn(user_emb)
|
|
182
|
-
|
|
192
|
+
|
|
183
193
|
# L2 normalize for cosine similarity
|
|
184
|
-
if self.similarity_metric ==
|
|
194
|
+
if self.similarity_metric == "cosine":
|
|
185
195
|
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
186
|
-
|
|
196
|
+
|
|
187
197
|
return user_emb
|
|
188
|
-
|
|
198
|
+
|
|
189
199
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
190
200
|
"""
|
|
191
201
|
Item tower encodes item features into embeddings.
|
|
192
|
-
|
|
202
|
+
|
|
193
203
|
Args:
|
|
194
204
|
item_input: item feature dict
|
|
195
|
-
|
|
205
|
+
|
|
196
206
|
Returns:
|
|
197
207
|
item_emb: [batch_size, embedding_dim] or [batch_size, num_items, embedding_dim]
|
|
198
208
|
"""
|
|
199
|
-
all_item_features =
|
|
209
|
+
all_item_features = (
|
|
210
|
+
self.item_dense_features
|
|
211
|
+
+ self.item_sparse_features
|
|
212
|
+
+ self.item_sequence_features
|
|
213
|
+
)
|
|
200
214
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
201
|
-
|
|
215
|
+
|
|
202
216
|
item_emb = self.item_dnn(item_emb)
|
|
203
|
-
|
|
217
|
+
|
|
204
218
|
# L2 normalize for cosine similarity
|
|
205
|
-
if self.similarity_metric ==
|
|
219
|
+
if self.similarity_metric == "cosine":
|
|
206
220
|
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
207
|
-
|
|
221
|
+
|
|
208
222
|
return item_emb
|
nextrec/models/match/dssm_v2.py
CHANGED
|
@@ -5,6 +5,7 @@ Author:
|
|
|
5
5
|
Reference:
|
|
6
6
|
DSSM v2 - DSSM with pairwise training using BPR loss
|
|
7
7
|
"""
|
|
8
|
+
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
10
11
|
from typing import Literal
|
|
@@ -18,40 +19,48 @@ class DSSM_v2(BaseMatchModel):
|
|
|
18
19
|
"""
|
|
19
20
|
DSSM with Pairwise Training
|
|
20
21
|
"""
|
|
22
|
+
|
|
21
23
|
@property
|
|
22
24
|
def model_name(self) -> str:
|
|
23
25
|
return "DSSM_v2"
|
|
24
|
-
|
|
25
|
-
def __init__(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
30
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
31
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
32
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
33
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
34
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
35
|
+
user_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
36
|
+
item_dnn_hidden_units: list[int] = [256, 128, 64],
|
|
37
|
+
embedding_dim: int = 64,
|
|
38
|
+
dnn_activation: str = "relu",
|
|
39
|
+
dnn_dropout: float = 0.0,
|
|
40
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pairwise",
|
|
41
|
+
num_negative_samples: int = 4,
|
|
42
|
+
temperature: float = 1.0,
|
|
43
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
44
|
+
device: str = "cpu",
|
|
45
|
+
embedding_l1_reg: float = 0.0,
|
|
46
|
+
dense_l1_reg: float = 0.0,
|
|
47
|
+
embedding_l2_reg: float = 0.0,
|
|
48
|
+
dense_l2_reg: float = 0.0,
|
|
49
|
+
early_stop_patience: int = 20,
|
|
50
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
51
|
+
optimizer_params: dict | None = None,
|
|
52
|
+
scheduler: (
|
|
53
|
+
str
|
|
54
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
55
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
56
|
+
| None
|
|
57
|
+
) = None,
|
|
58
|
+
scheduler_params: dict | None = None,
|
|
59
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
60
|
+
loss_params: dict | list[dict] | None = None,
|
|
61
|
+
**kwargs,
|
|
62
|
+
):
|
|
63
|
+
|
|
55
64
|
super(DSSM_v2, self).__init__(
|
|
56
65
|
user_dense_features=user_dense_features,
|
|
57
66
|
user_sparse_features=user_sparse_features,
|
|
@@ -68,13 +77,13 @@ class DSSM_v2(BaseMatchModel):
|
|
|
68
77
|
dense_l1_reg=dense_l1_reg,
|
|
69
78
|
embedding_l2_reg=embedding_l2_reg,
|
|
70
79
|
dense_l2_reg=dense_l2_reg,
|
|
71
|
-
**kwargs
|
|
80
|
+
**kwargs,
|
|
72
81
|
)
|
|
73
|
-
|
|
82
|
+
|
|
74
83
|
self.embedding_dim = embedding_dim
|
|
75
84
|
self.user_dnn_hidden_units = user_dnn_hidden_units
|
|
76
85
|
self.item_dnn_hidden_units = item_dnn_hidden_units
|
|
77
|
-
|
|
86
|
+
|
|
78
87
|
# User tower
|
|
79
88
|
user_features = []
|
|
80
89
|
if user_dense_features:
|
|
@@ -83,10 +92,10 @@ class DSSM_v2(BaseMatchModel):
|
|
|
83
92
|
user_features.extend(user_sparse_features)
|
|
84
93
|
if user_sequence_features:
|
|
85
94
|
user_features.extend(user_sequence_features)
|
|
86
|
-
|
|
95
|
+
|
|
87
96
|
if len(user_features) > 0:
|
|
88
97
|
self.user_embedding = EmbeddingLayer(user_features)
|
|
89
|
-
|
|
98
|
+
|
|
90
99
|
user_input_dim = 0
|
|
91
100
|
for feat in user_dense_features or []:
|
|
92
101
|
user_input_dim += 1
|
|
@@ -94,16 +103,16 @@ class DSSM_v2(BaseMatchModel):
|
|
|
94
103
|
user_input_dim += feat.embedding_dim
|
|
95
104
|
for feat in user_sequence_features or []:
|
|
96
105
|
user_input_dim += feat.embedding_dim
|
|
97
|
-
|
|
106
|
+
|
|
98
107
|
user_dnn_units = user_dnn_hidden_units + [embedding_dim]
|
|
99
108
|
self.user_dnn = MLP(
|
|
100
109
|
input_dim=user_input_dim,
|
|
101
110
|
dims=user_dnn_units,
|
|
102
111
|
output_layer=False,
|
|
103
112
|
dropout=dnn_dropout,
|
|
104
|
-
activation=dnn_activation
|
|
113
|
+
activation=dnn_activation,
|
|
105
114
|
)
|
|
106
|
-
|
|
115
|
+
|
|
107
116
|
# Item tower
|
|
108
117
|
item_features = []
|
|
109
118
|
if item_dense_features:
|
|
@@ -112,10 +121,10 @@ class DSSM_v2(BaseMatchModel):
|
|
|
112
121
|
item_features.extend(item_sparse_features)
|
|
113
122
|
if item_sequence_features:
|
|
114
123
|
item_features.extend(item_sequence_features)
|
|
115
|
-
|
|
124
|
+
|
|
116
125
|
if len(item_features) > 0:
|
|
117
126
|
self.item_embedding = EmbeddingLayer(item_features)
|
|
118
|
-
|
|
127
|
+
|
|
119
128
|
item_input_dim = 0
|
|
120
129
|
for feat in item_dense_features or []:
|
|
121
130
|
item_input_dim += 1
|
|
@@ -123,25 +132,23 @@ class DSSM_v2(BaseMatchModel):
|
|
|
123
132
|
item_input_dim += feat.embedding_dim
|
|
124
133
|
for feat in item_sequence_features or []:
|
|
125
134
|
item_input_dim += feat.embedding_dim
|
|
126
|
-
|
|
135
|
+
|
|
127
136
|
item_dnn_units = item_dnn_hidden_units + [embedding_dim]
|
|
128
137
|
self.item_dnn = MLP(
|
|
129
138
|
input_dim=item_input_dim,
|
|
130
139
|
dims=item_dnn_units,
|
|
131
140
|
output_layer=False,
|
|
132
141
|
dropout=dnn_dropout,
|
|
133
|
-
activation=dnn_activation
|
|
142
|
+
activation=dnn_activation,
|
|
134
143
|
)
|
|
135
|
-
|
|
144
|
+
|
|
136
145
|
self.register_regularization_weights(
|
|
137
|
-
embedding_attr=
|
|
138
|
-
include_modules=['user_dnn']
|
|
146
|
+
embedding_attr="user_embedding", include_modules=["user_dnn"]
|
|
139
147
|
)
|
|
140
148
|
self.register_regularization_weights(
|
|
141
|
-
embedding_attr=
|
|
142
|
-
include_modules=['item_dnn']
|
|
149
|
+
embedding_attr="item_embedding", include_modules=["item_dnn"]
|
|
143
150
|
)
|
|
144
|
-
|
|
151
|
+
|
|
145
152
|
if optimizer_params is None:
|
|
146
153
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
147
154
|
|
|
@@ -155,25 +162,33 @@ class DSSM_v2(BaseMatchModel):
|
|
|
155
162
|
)
|
|
156
163
|
|
|
157
164
|
self.to(device)
|
|
158
|
-
|
|
165
|
+
|
|
159
166
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
160
167
|
"""User tower"""
|
|
161
|
-
all_user_features =
|
|
168
|
+
all_user_features = (
|
|
169
|
+
self.user_dense_features
|
|
170
|
+
+ self.user_sparse_features
|
|
171
|
+
+ self.user_sequence_features
|
|
172
|
+
)
|
|
162
173
|
user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
|
|
163
174
|
user_emb = self.user_dnn(user_emb)
|
|
164
|
-
|
|
175
|
+
|
|
165
176
|
# Normalization for better pairwise training
|
|
166
177
|
user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
|
|
167
|
-
|
|
178
|
+
|
|
168
179
|
return user_emb
|
|
169
|
-
|
|
180
|
+
|
|
170
181
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
171
182
|
"""Item tower"""
|
|
172
|
-
all_item_features =
|
|
183
|
+
all_item_features = (
|
|
184
|
+
self.item_dense_features
|
|
185
|
+
+ self.item_sparse_features
|
|
186
|
+
+ self.item_sequence_features
|
|
187
|
+
)
|
|
173
188
|
item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
|
|
174
189
|
item_emb = self.item_dnn(item_emb)
|
|
175
|
-
|
|
190
|
+
|
|
176
191
|
# Normalization for better pairwise training
|
|
177
192
|
item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
|
|
178
|
-
|
|
193
|
+
|
|
179
194
|
return item_emb
|