nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +220 -106
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1082 -400
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +272 -95
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +53 -37
- nextrec/models/multi_task/mmoe.py +64 -45
- nextrec/models/multi_task/ple.py +101 -48
- nextrec/models/multi_task/poso.py +113 -36
- nextrec/models/multi_task/share_bottom.py +48 -35
- nextrec/models/ranking/afm.py +72 -37
- nextrec/models/ranking/autoint.py +72 -55
- nextrec/models/ranking/dcn.py +55 -35
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +32 -22
- nextrec/models/ranking/dien.py +155 -99
- nextrec/models/ranking/din.py +85 -57
- nextrec/models/ranking/fibinet.py +52 -32
- nextrec/models/ranking/fm.py +29 -23
- nextrec/models/ranking/masknet.py +91 -29
- nextrec/models/ranking/pnn.py +31 -28
- nextrec/models/ranking/widedeep.py +34 -26
- nextrec/models/ranking/xdeepfm.py +60 -38
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/models/ranking/din.py
CHANGED
|
@@ -12,7 +12,12 @@ import torch
|
|
|
12
12
|
import torch.nn as nn
|
|
13
13
|
|
|
14
14
|
from nextrec.basic.model import BaseModel
|
|
15
|
-
from nextrec.basic.layers import
|
|
15
|
+
from nextrec.basic.layers import (
|
|
16
|
+
EmbeddingLayer,
|
|
17
|
+
MLP,
|
|
18
|
+
AttentionPoolingLayer,
|
|
19
|
+
PredictionLayer,
|
|
20
|
+
)
|
|
16
21
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
17
22
|
|
|
18
23
|
|
|
@@ -24,28 +29,30 @@ class DIN(BaseModel):
|
|
|
24
29
|
@property
|
|
25
30
|
def default_task(self):
|
|
26
31
|
return "binary"
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
dense_features: list[DenseFeature],
|
|
36
|
+
sparse_features: list[SparseFeature],
|
|
37
|
+
sequence_features: list[SequenceFeature],
|
|
38
|
+
mlp_params: dict,
|
|
39
|
+
attention_hidden_units: list[int] = [80, 40],
|
|
40
|
+
attention_activation: str = "sigmoid",
|
|
41
|
+
attention_use_softmax: bool = True,
|
|
42
|
+
target: list[str] = [],
|
|
43
|
+
task: str | list[str] | None = None,
|
|
44
|
+
optimizer: str = "adam",
|
|
45
|
+
optimizer_params: dict = {},
|
|
46
|
+
loss: str | nn.Module | None = "bce",
|
|
47
|
+
loss_params: dict | list[dict] | None = None,
|
|
48
|
+
device: str = "cpu",
|
|
49
|
+
embedding_l1_reg=1e-6,
|
|
50
|
+
dense_l1_reg=1e-5,
|
|
51
|
+
embedding_l2_reg=1e-5,
|
|
52
|
+
dense_l2_reg=1e-4,
|
|
53
|
+
**kwargs,
|
|
54
|
+
):
|
|
55
|
+
|
|
49
56
|
super(DIN, self).__init__(
|
|
50
57
|
dense_features=dense_features,
|
|
51
58
|
sparse_features=sparse_features,
|
|
@@ -57,43 +64,54 @@ class DIN(BaseModel):
|
|
|
57
64
|
dense_l1_reg=dense_l1_reg,
|
|
58
65
|
embedding_l2_reg=embedding_l2_reg,
|
|
59
66
|
dense_l2_reg=dense_l2_reg,
|
|
60
|
-
**kwargs
|
|
67
|
+
**kwargs,
|
|
61
68
|
)
|
|
62
69
|
|
|
63
70
|
self.loss = loss
|
|
64
71
|
if self.loss is None:
|
|
65
72
|
self.loss = "bce"
|
|
66
|
-
|
|
73
|
+
|
|
67
74
|
# Features classification
|
|
68
75
|
# DIN requires: candidate item + user behavior sequence + other features
|
|
69
76
|
if len(sequence_features) == 0:
|
|
70
|
-
raise ValueError(
|
|
71
|
-
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"DIN requires at least one sequence feature for user behavior history"
|
|
79
|
+
)
|
|
80
|
+
|
|
72
81
|
self.behavior_feature = sequence_features[0] # User behavior sequence
|
|
73
|
-
self.candidate_feature =
|
|
74
|
-
|
|
82
|
+
self.candidate_feature = (
|
|
83
|
+
sparse_features[-1] if sparse_features else None
|
|
84
|
+
) # Candidate item
|
|
85
|
+
|
|
75
86
|
# Other features (excluding behavior sequence in final concatenation)
|
|
76
|
-
self.other_sparse_features =
|
|
87
|
+
self.other_sparse_features = (
|
|
88
|
+
sparse_features[:-1] if self.candidate_feature else sparse_features
|
|
89
|
+
)
|
|
77
90
|
self.dense_features_list = dense_features
|
|
78
|
-
|
|
91
|
+
|
|
79
92
|
# All features for embedding
|
|
80
93
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
81
94
|
|
|
82
95
|
# Embedding layer
|
|
83
96
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
84
|
-
|
|
97
|
+
|
|
85
98
|
# Attention layer for behavior sequence
|
|
86
99
|
behavior_emb_dim = self.behavior_feature.embedding_dim
|
|
87
100
|
self.candidate_attention_proj = None
|
|
88
|
-
if
|
|
89
|
-
self.
|
|
101
|
+
if (
|
|
102
|
+
self.candidate_feature is not None
|
|
103
|
+
and self.candidate_feature.embedding_dim != behavior_emb_dim
|
|
104
|
+
):
|
|
105
|
+
self.candidate_attention_proj = nn.Linear(
|
|
106
|
+
self.candidate_feature.embedding_dim, behavior_emb_dim
|
|
107
|
+
)
|
|
90
108
|
self.attention = AttentionPoolingLayer(
|
|
91
109
|
embedding_dim=behavior_emb_dim,
|
|
92
110
|
hidden_units=attention_hidden_units,
|
|
93
111
|
activation=attention_activation,
|
|
94
|
-
use_softmax=attention_use_softmax
|
|
112
|
+
use_softmax=attention_use_softmax,
|
|
95
113
|
)
|
|
96
|
-
|
|
114
|
+
|
|
97
115
|
# Calculate MLP input dimension
|
|
98
116
|
# candidate + attention_pooled_behavior + other_sparse + dense
|
|
99
117
|
mlp_input_dim = 0
|
|
@@ -101,16 +119,18 @@ class DIN(BaseModel):
|
|
|
101
119
|
mlp_input_dim += self.candidate_feature.embedding_dim
|
|
102
120
|
mlp_input_dim += behavior_emb_dim # attention pooled
|
|
103
121
|
mlp_input_dim += sum([f.embedding_dim for f in self.other_sparse_features])
|
|
104
|
-
mlp_input_dim += sum(
|
|
105
|
-
|
|
122
|
+
mlp_input_dim += sum(
|
|
123
|
+
[getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
|
|
124
|
+
)
|
|
125
|
+
|
|
106
126
|
# MLP for final prediction
|
|
107
127
|
self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
|
|
108
128
|
self.prediction_layer = PredictionLayer(task_type=self.task)
|
|
109
129
|
|
|
110
130
|
# Register regularization weights
|
|
111
131
|
self.register_regularization_weights(
|
|
112
|
-
embedding_attr=
|
|
113
|
-
include_modules=[
|
|
132
|
+
embedding_attr="embedding",
|
|
133
|
+
include_modules=["attention", "mlp", "candidate_attention_proj"],
|
|
114
134
|
)
|
|
115
135
|
|
|
116
136
|
self.compile(
|
|
@@ -123,61 +143,69 @@ class DIN(BaseModel):
|
|
|
123
143
|
def forward(self, x):
|
|
124
144
|
# Get candidate item embedding
|
|
125
145
|
if self.candidate_feature:
|
|
126
|
-
candidate_emb = self.embedding.embed_dict[
|
|
146
|
+
candidate_emb = self.embedding.embed_dict[
|
|
147
|
+
self.candidate_feature.embedding_name
|
|
148
|
+
](
|
|
127
149
|
x[self.candidate_feature.name].long()
|
|
128
150
|
) # [B, emb_dim]
|
|
129
151
|
else:
|
|
130
152
|
candidate_emb = None
|
|
131
|
-
|
|
153
|
+
|
|
132
154
|
# Get behavior sequence embedding
|
|
133
155
|
behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
|
|
134
156
|
behavior_emb = self.embedding.embed_dict[self.behavior_feature.embedding_name](
|
|
135
157
|
behavior_seq
|
|
136
158
|
) # [B, seq_len, emb_dim]
|
|
137
|
-
|
|
159
|
+
|
|
138
160
|
# Create mask for padding
|
|
139
161
|
if self.behavior_feature.padding_idx is not None:
|
|
140
|
-
mask = (
|
|
162
|
+
mask = (
|
|
163
|
+
(behavior_seq != self.behavior_feature.padding_idx)
|
|
164
|
+
.unsqueeze(-1)
|
|
165
|
+
.float()
|
|
166
|
+
)
|
|
141
167
|
else:
|
|
142
168
|
mask = (behavior_seq != 0).unsqueeze(-1).float()
|
|
143
|
-
|
|
169
|
+
|
|
144
170
|
# Apply attention pooling
|
|
145
171
|
if candidate_emb is not None:
|
|
146
172
|
candidate_query = candidate_emb
|
|
147
173
|
if self.candidate_attention_proj is not None:
|
|
148
174
|
candidate_query = self.candidate_attention_proj(candidate_query)
|
|
149
175
|
pooled_behavior = self.attention(
|
|
150
|
-
query=candidate_query,
|
|
151
|
-
keys=behavior_emb,
|
|
152
|
-
mask=mask
|
|
176
|
+
query=candidate_query, keys=behavior_emb, mask=mask
|
|
153
177
|
) # [B, emb_dim]
|
|
154
178
|
else:
|
|
155
179
|
# If no candidate, use mean pooling
|
|
156
|
-
pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (
|
|
157
|
-
|
|
180
|
+
pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (
|
|
181
|
+
mask.sum(dim=1) + 1e-9
|
|
182
|
+
)
|
|
183
|
+
|
|
158
184
|
# Get other features
|
|
159
185
|
other_embeddings = []
|
|
160
|
-
|
|
186
|
+
|
|
161
187
|
if candidate_emb is not None:
|
|
162
188
|
other_embeddings.append(candidate_emb)
|
|
163
|
-
|
|
189
|
+
|
|
164
190
|
other_embeddings.append(pooled_behavior)
|
|
165
|
-
|
|
191
|
+
|
|
166
192
|
# Other sparse features
|
|
167
193
|
for feat in self.other_sparse_features:
|
|
168
|
-
feat_emb = self.embedding.embed_dict[feat.embedding_name](
|
|
194
|
+
feat_emb = self.embedding.embed_dict[feat.embedding_name](
|
|
195
|
+
x[feat.name].long()
|
|
196
|
+
)
|
|
169
197
|
other_embeddings.append(feat_emb)
|
|
170
|
-
|
|
198
|
+
|
|
171
199
|
# Dense features
|
|
172
200
|
for feat in self.dense_features_list:
|
|
173
201
|
val = x[feat.name].float()
|
|
174
202
|
if val.dim() == 1:
|
|
175
203
|
val = val.unsqueeze(1)
|
|
176
204
|
other_embeddings.append(val)
|
|
177
|
-
|
|
205
|
+
|
|
178
206
|
# Concatenate all features
|
|
179
207
|
concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
|
|
180
|
-
|
|
208
|
+
|
|
181
209
|
# MLP prediction
|
|
182
210
|
y = self.mlp(concat_input) # [B, 1]
|
|
183
211
|
return self.prediction_layer(y)
|
|
@@ -30,27 +30,29 @@ class FiBiNET(BaseModel):
|
|
|
30
30
|
@property
|
|
31
31
|
def default_task(self):
|
|
32
32
|
return "binary"
|
|
33
|
-
|
|
34
|
-
def __init__(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
dense_features: list[DenseFeature] | list = [],
|
|
37
|
+
sparse_features: list[SparseFeature] | list = [],
|
|
38
|
+
sequence_features: list[SequenceFeature] | list = [],
|
|
39
|
+
mlp_params: dict = {},
|
|
40
|
+
bilinear_type: str = "field_interaction",
|
|
41
|
+
senet_reduction: int = 3,
|
|
42
|
+
target: list[str] | list = [],
|
|
43
|
+
task: str | list[str] | None = None,
|
|
44
|
+
optimizer: str = "adam",
|
|
45
|
+
optimizer_params: dict = {},
|
|
46
|
+
loss: str | nn.Module | None = "bce",
|
|
47
|
+
loss_params: dict | list[dict] | None = None,
|
|
48
|
+
device: str = "cpu",
|
|
49
|
+
embedding_l1_reg=1e-6,
|
|
50
|
+
dense_l1_reg=1e-5,
|
|
51
|
+
embedding_l2_reg=1e-5,
|
|
52
|
+
dense_l2_reg=1e-4,
|
|
53
|
+
**kwargs,
|
|
54
|
+
):
|
|
55
|
+
|
|
54
56
|
super(FiBiNET, self).__init__(
|
|
55
57
|
dense_features=dense_features,
|
|
56
58
|
sparse_features=sparse_features,
|
|
@@ -62,28 +64,36 @@ class FiBiNET(BaseModel):
|
|
|
62
64
|
dense_l1_reg=dense_l1_reg,
|
|
63
65
|
embedding_l2_reg=embedding_l2_reg,
|
|
64
66
|
dense_l2_reg=dense_l2_reg,
|
|
65
|
-
**kwargs
|
|
67
|
+
**kwargs,
|
|
66
68
|
)
|
|
67
69
|
|
|
68
70
|
self.loss = loss
|
|
69
71
|
if self.loss is None:
|
|
70
72
|
self.loss = "bce"
|
|
71
|
-
|
|
73
|
+
|
|
72
74
|
self.linear_features = sparse_features + sequence_features
|
|
73
75
|
self.deep_features = dense_features + sparse_features + sequence_features
|
|
74
76
|
self.interaction_features = sparse_features + sequence_features
|
|
75
77
|
|
|
76
78
|
if len(self.interaction_features) < 2:
|
|
77
|
-
raise ValueError(
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"FiBiNET requires at least two sparse/sequence features for interactions."
|
|
81
|
+
)
|
|
78
82
|
|
|
79
83
|
self.embedding = EmbeddingLayer(features=self.deep_features)
|
|
80
84
|
|
|
81
85
|
self.num_fields = len(self.interaction_features)
|
|
82
86
|
self.embedding_dim = self.interaction_features[0].embedding_dim
|
|
83
|
-
if any(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
+
if any(
|
|
88
|
+
f.embedding_dim != self.embedding_dim for f in self.interaction_features
|
|
89
|
+
):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"All interaction features must share the same embedding_dim in FiBiNET."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.senet = SENETLayer(
|
|
95
|
+
num_fields=self.num_fields, reduction_ratio=senet_reduction
|
|
96
|
+
)
|
|
87
97
|
self.bilinear_standard = BiLinearInteractionLayer(
|
|
88
98
|
input_dim=self.embedding_dim,
|
|
89
99
|
num_fields=self.num_fields,
|
|
@@ -105,8 +115,14 @@ class FiBiNET(BaseModel):
|
|
|
105
115
|
|
|
106
116
|
# Register regularization weights
|
|
107
117
|
self.register_regularization_weights(
|
|
108
|
-
embedding_attr=
|
|
109
|
-
include_modules=[
|
|
118
|
+
embedding_attr="embedding",
|
|
119
|
+
include_modules=[
|
|
120
|
+
"linear",
|
|
121
|
+
"senet",
|
|
122
|
+
"bilinear_standard",
|
|
123
|
+
"bilinear_senet",
|
|
124
|
+
"mlp",
|
|
125
|
+
],
|
|
110
126
|
)
|
|
111
127
|
|
|
112
128
|
self.compile(
|
|
@@ -117,10 +133,14 @@ class FiBiNET(BaseModel):
|
|
|
117
133
|
)
|
|
118
134
|
|
|
119
135
|
def forward(self, x):
|
|
120
|
-
input_linear = self.embedding(
|
|
136
|
+
input_linear = self.embedding(
|
|
137
|
+
x=x, features=self.linear_features, squeeze_dim=True
|
|
138
|
+
)
|
|
121
139
|
y_linear = self.linear(input_linear)
|
|
122
140
|
|
|
123
|
-
field_emb = self.embedding(
|
|
141
|
+
field_emb = self.embedding(
|
|
142
|
+
x=x, features=self.interaction_features, squeeze_dim=False
|
|
143
|
+
)
|
|
124
144
|
senet_emb = self.senet(field_emb)
|
|
125
145
|
|
|
126
146
|
bilinear_standard = self.bilinear_standard(field_emb).flatten(start_dim=1)
|
nextrec/models/ranking/fm.py
CHANGED
|
@@ -9,7 +9,12 @@ Reference:
|
|
|
9
9
|
import torch.nn as nn
|
|
10
10
|
|
|
11
11
|
from nextrec.basic.model import BaseModel
|
|
12
|
-
from nextrec.basic.layers import
|
|
12
|
+
from nextrec.basic.layers import (
|
|
13
|
+
EmbeddingLayer,
|
|
14
|
+
FM as FMInteraction,
|
|
15
|
+
LR,
|
|
16
|
+
PredictionLayer,
|
|
17
|
+
)
|
|
13
18
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
14
19
|
|
|
15
20
|
|
|
@@ -21,24 +26,26 @@ class FM(BaseModel):
|
|
|
21
26
|
@property
|
|
22
27
|
def default_task(self):
|
|
23
28
|
return "binary"
|
|
24
|
-
|
|
25
|
-
def __init__(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
dense_features: list[DenseFeature] | list = [],
|
|
33
|
+
sparse_features: list[SparseFeature] | list = [],
|
|
34
|
+
sequence_features: list[SequenceFeature] | list = [],
|
|
35
|
+
target: list[str] | list = [],
|
|
36
|
+
task: str | list[str] | None = None,
|
|
37
|
+
optimizer: str = "adam",
|
|
38
|
+
optimizer_params: dict = {},
|
|
39
|
+
loss: str | nn.Module | None = "bce",
|
|
40
|
+
loss_params: dict | list[dict] | None = None,
|
|
41
|
+
device: str = "cpu",
|
|
42
|
+
embedding_l1_reg=1e-6,
|
|
43
|
+
dense_l1_reg=1e-5,
|
|
44
|
+
embedding_l2_reg=1e-5,
|
|
45
|
+
dense_l2_reg=1e-4,
|
|
46
|
+
**kwargs,
|
|
47
|
+
):
|
|
48
|
+
|
|
42
49
|
super(FM, self).__init__(
|
|
43
50
|
dense_features=dense_features,
|
|
44
51
|
sparse_features=sparse_features,
|
|
@@ -50,13 +57,13 @@ class FM(BaseModel):
|
|
|
50
57
|
dense_l1_reg=dense_l1_reg,
|
|
51
58
|
embedding_l2_reg=embedding_l2_reg,
|
|
52
59
|
dense_l2_reg=dense_l2_reg,
|
|
53
|
-
**kwargs
|
|
60
|
+
**kwargs,
|
|
54
61
|
)
|
|
55
62
|
|
|
56
63
|
self.loss = loss
|
|
57
64
|
if self.loss is None:
|
|
58
65
|
self.loss = "bce"
|
|
59
|
-
|
|
66
|
+
|
|
60
67
|
self.fm_features = sparse_features + sequence_features
|
|
61
68
|
if len(self.fm_features) == 0:
|
|
62
69
|
raise ValueError("FM requires at least one sparse or sequence feature.")
|
|
@@ -70,8 +77,7 @@ class FM(BaseModel):
|
|
|
70
77
|
|
|
71
78
|
# Register regularization weights
|
|
72
79
|
self.register_regularization_weights(
|
|
73
|
-
embedding_attr=
|
|
74
|
-
include_modules=['linear']
|
|
80
|
+
embedding_attr="embedding", include_modules=["linear"]
|
|
75
81
|
)
|
|
76
82
|
|
|
77
83
|
self.compile(
|
|
@@ -69,12 +69,13 @@ class InstanceGuidedMask(nn.Module):
|
|
|
69
69
|
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
|
70
70
|
|
|
71
71
|
def forward(self, v_emb_flat: torch.Tensor) -> torch.Tensor:
|
|
72
|
-
# v_emb_flat: [batch, features count * embedding_dim]
|
|
72
|
+
# v_emb_flat: [batch, features count * embedding_dim]
|
|
73
73
|
x = self.fc1(v_emb_flat)
|
|
74
74
|
x = F.relu(x)
|
|
75
75
|
v_mask = self.fc2(x)
|
|
76
76
|
return v_mask
|
|
77
77
|
|
|
78
|
+
|
|
78
79
|
class MaskBlockOnEmbedding(nn.Module):
|
|
79
80
|
def __init__(
|
|
80
81
|
self,
|
|
@@ -86,20 +87,28 @@ class MaskBlockOnEmbedding(nn.Module):
|
|
|
86
87
|
super().__init__()
|
|
87
88
|
self.num_fields = num_fields
|
|
88
89
|
self.embedding_dim = embedding_dim
|
|
89
|
-
self.input_dim =
|
|
90
|
+
self.input_dim = (
|
|
91
|
+
num_fields * embedding_dim
|
|
92
|
+
) # input_dim = features count * embedding_dim
|
|
90
93
|
self.ln_emb = nn.LayerNorm(embedding_dim)
|
|
91
|
-
self.mask_gen = InstanceGuidedMask(
|
|
94
|
+
self.mask_gen = InstanceGuidedMask(
|
|
95
|
+
input_dim=self.input_dim,
|
|
96
|
+
hidden_dim=mask_hidden_dim,
|
|
97
|
+
output_dim=self.input_dim,
|
|
98
|
+
)
|
|
92
99
|
self.ffn = nn.Linear(self.input_dim, hidden_dim)
|
|
93
100
|
self.ln_hid = nn.LayerNorm(hidden_dim)
|
|
94
101
|
|
|
95
102
|
# different from MaskBlockOnHidden: input is field embeddings
|
|
96
|
-
def forward(
|
|
103
|
+
def forward(
|
|
104
|
+
self, field_emb: torch.Tensor, v_emb_flat: torch.Tensor
|
|
105
|
+
) -> torch.Tensor:
|
|
97
106
|
B = field_emb.size(0)
|
|
98
|
-
norm_emb = self.ln_emb(field_emb)
|
|
99
|
-
norm_emb_flat = norm_emb.view(B, -1)
|
|
100
|
-
v_mask = self.mask_gen(v_emb_flat)
|
|
101
|
-
v_masked_emb = v_mask * norm_emb_flat
|
|
102
|
-
hidden = self.ffn(v_masked_emb)
|
|
107
|
+
norm_emb = self.ln_emb(field_emb) # [B, features count, embedding_dim]
|
|
108
|
+
norm_emb_flat = norm_emb.view(B, -1) # [B, features count * embedding_dim]
|
|
109
|
+
v_mask = self.mask_gen(v_emb_flat) # [B, features count * embedding_dim]
|
|
110
|
+
v_masked_emb = v_mask * norm_emb_flat # [B, features count * embedding_dim]
|
|
111
|
+
hidden = self.ffn(v_masked_emb) # [B, hidden_dim]
|
|
103
112
|
hidden = self.ln_hid(hidden)
|
|
104
113
|
hidden = F.relu(hidden)
|
|
105
114
|
|
|
@@ -123,15 +132,21 @@ class MaskBlockOnHidden(nn.Module):
|
|
|
123
132
|
self.ln_input = nn.LayerNorm(hidden_dim)
|
|
124
133
|
self.ln_output = nn.LayerNorm(hidden_dim)
|
|
125
134
|
|
|
126
|
-
self.mask_gen = InstanceGuidedMask(
|
|
135
|
+
self.mask_gen = InstanceGuidedMask(
|
|
136
|
+
input_dim=self.v_emb_dim,
|
|
137
|
+
hidden_dim=mask_hidden_dim,
|
|
138
|
+
output_dim=hidden_dim,
|
|
139
|
+
)
|
|
127
140
|
self.ffn = nn.Linear(hidden_dim, hidden_dim)
|
|
128
141
|
|
|
129
142
|
# different from MaskBlockOnEmbedding: input is hidden representation
|
|
130
|
-
def forward(
|
|
131
|
-
|
|
143
|
+
def forward(
|
|
144
|
+
self, hidden_in: torch.Tensor, v_emb_flat: torch.Tensor
|
|
145
|
+
) -> torch.Tensor:
|
|
146
|
+
norm_hidden = self.ln_input(hidden_in)
|
|
132
147
|
v_mask = self.mask_gen(v_emb_flat)
|
|
133
|
-
v_masked_hid = v_mask * norm_hidden
|
|
134
|
-
out = self.ffn(v_masked_hid)
|
|
148
|
+
v_masked_hid = v_mask * norm_hidden
|
|
149
|
+
out = self.ffn(v_masked_hid)
|
|
135
150
|
out = self.ln_output(out)
|
|
136
151
|
out = F.relu(out)
|
|
137
152
|
return out
|
|
@@ -151,7 +166,7 @@ class MaskNet(BaseModel):
|
|
|
151
166
|
dense_features: list[DenseFeature] | None = None,
|
|
152
167
|
sparse_features: list[SparseFeature] | None = None,
|
|
153
168
|
sequence_features: list[SequenceFeature] | None = None,
|
|
154
|
-
model_type: str = "parallel",
|
|
169
|
+
model_type: str = "parallel", # "serial" or "parallel"
|
|
155
170
|
num_blocks: int = 3,
|
|
156
171
|
mask_hidden_dim: int = 64,
|
|
157
172
|
block_hidden_dim: int = 256,
|
|
@@ -199,50 +214,97 @@ class MaskNet(BaseModel):
|
|
|
199
214
|
self.sparse_features = sparse_features
|
|
200
215
|
self.sequence_features = sequence_features
|
|
201
216
|
self.mask_features = self.all_features # use all features for masking
|
|
202
|
-
assert
|
|
217
|
+
assert (
|
|
218
|
+
len(self.mask_features) > 0
|
|
219
|
+
), "MaskNet requires at least one feature for masking."
|
|
203
220
|
self.embedding = EmbeddingLayer(features=self.mask_features)
|
|
204
221
|
self.num_fields = len(self.mask_features)
|
|
205
222
|
self.embedding_dim = getattr(self.mask_features[0], "embedding_dim", None)
|
|
206
|
-
assert
|
|
223
|
+
assert (
|
|
224
|
+
self.embedding_dim is not None
|
|
225
|
+
), "MaskNet requires mask_features to have 'embedding_dim' defined."
|
|
207
226
|
|
|
208
227
|
for f in self.mask_features:
|
|
209
228
|
edim = getattr(f, "embedding_dim", None)
|
|
210
229
|
if edim is None or edim != self.embedding_dim:
|
|
211
|
-
raise ValueError(
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"MaskNet expects identical embedding_dim across all mask_features, but got {edim} for feature {getattr(f, 'name', type(f))}."
|
|
232
|
+
)
|
|
212
233
|
|
|
213
234
|
self.v_emb_dim = self.num_fields * self.embedding_dim
|
|
214
235
|
self.model_type = model_type.lower()
|
|
215
|
-
assert self.model_type in (
|
|
236
|
+
assert self.model_type in (
|
|
237
|
+
"serial",
|
|
238
|
+
"parallel",
|
|
239
|
+
), "model_type must be either 'serial' or 'parallel'."
|
|
216
240
|
|
|
217
241
|
self.num_blocks = max(1, num_blocks)
|
|
218
242
|
self.block_hidden_dim = block_hidden_dim
|
|
219
|
-
self.block_dropout =
|
|
243
|
+
self.block_dropout = (
|
|
244
|
+
nn.Dropout(block_dropout) if block_dropout > 0 else nn.Identity()
|
|
245
|
+
)
|
|
220
246
|
|
|
221
247
|
if self.model_type == "serial":
|
|
222
|
-
self.first_block = MaskBlockOnEmbedding(
|
|
248
|
+
self.first_block = MaskBlockOnEmbedding(
|
|
249
|
+
num_fields=self.num_fields,
|
|
250
|
+
embedding_dim=self.embedding_dim,
|
|
251
|
+
mask_hidden_dim=mask_hidden_dim,
|
|
252
|
+
hidden_dim=block_hidden_dim,
|
|
253
|
+
)
|
|
223
254
|
self.hidden_blocks = nn.ModuleList(
|
|
224
|
-
[
|
|
255
|
+
[
|
|
256
|
+
MaskBlockOnHidden(
|
|
257
|
+
num_fields=self.num_fields,
|
|
258
|
+
embedding_dim=self.embedding_dim,
|
|
259
|
+
mask_hidden_dim=mask_hidden_dim,
|
|
260
|
+
hidden_dim=block_hidden_dim,
|
|
261
|
+
)
|
|
262
|
+
for _ in range(self.num_blocks - 1)
|
|
263
|
+
]
|
|
264
|
+
)
|
|
225
265
|
self.mask_blocks = nn.ModuleList([self.first_block, *self.hidden_blocks])
|
|
226
266
|
self.output_layer = nn.Linear(block_hidden_dim, 1)
|
|
227
267
|
self.final_mlp = None
|
|
228
268
|
|
|
229
269
|
else: # parallel
|
|
230
|
-
self.mask_blocks = nn.ModuleList(
|
|
231
|
-
|
|
270
|
+
self.mask_blocks = nn.ModuleList(
|
|
271
|
+
[
|
|
272
|
+
MaskBlockOnEmbedding(
|
|
273
|
+
num_fields=self.num_fields,
|
|
274
|
+
embedding_dim=self.embedding_dim,
|
|
275
|
+
mask_hidden_dim=mask_hidden_dim,
|
|
276
|
+
hidden_dim=block_hidden_dim,
|
|
277
|
+
)
|
|
278
|
+
for _ in range(self.num_blocks)
|
|
279
|
+
]
|
|
280
|
+
)
|
|
281
|
+
self.final_mlp = MLP(
|
|
282
|
+
input_dim=self.num_blocks * block_hidden_dim, **mlp_params
|
|
283
|
+
)
|
|
232
284
|
self.output_layer = None
|
|
233
285
|
self.prediction_layer = PredictionLayer(task_type=self.task)
|
|
234
286
|
|
|
235
287
|
if self.model_type == "serial":
|
|
236
|
-
self.register_regularization_weights(
|
|
288
|
+
self.register_regularization_weights(
|
|
289
|
+
embedding_attr="embedding",
|
|
290
|
+
include_modules=["mask_blocks", "output_layer"],
|
|
291
|
+
)
|
|
237
292
|
# serial
|
|
238
293
|
else:
|
|
239
|
-
self.register_regularization_weights(
|
|
240
|
-
|
|
294
|
+
self.register_regularization_weights(
|
|
295
|
+
embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"]
|
|
296
|
+
)
|
|
297
|
+
self.compile(
|
|
298
|
+
optimizer=optimizer,
|
|
299
|
+
optimizer_params=optimizer_params,
|
|
300
|
+
loss=loss,
|
|
301
|
+
loss_params=loss_params,
|
|
302
|
+
)
|
|
241
303
|
|
|
242
304
|
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
243
305
|
field_emb = self.embedding(x=x, features=self.mask_features, squeeze_dim=False)
|
|
244
306
|
B = field_emb.size(0)
|
|
245
|
-
v_emb_flat = field_emb.view(B, -1) # flattened embeddings
|
|
307
|
+
v_emb_flat = field_emb.view(B, -1) # flattened embeddings
|
|
246
308
|
|
|
247
309
|
if self.model_type == "parallel":
|
|
248
310
|
block_outputs = []
|
|
@@ -253,7 +315,7 @@ class MaskNet(BaseModel):
|
|
|
253
315
|
concat_hidden = torch.cat(block_outputs, dim=-1)
|
|
254
316
|
logit = self.final_mlp(concat_hidden) # [B, 1]
|
|
255
317
|
# serial
|
|
256
|
-
else:
|
|
318
|
+
else:
|
|
257
319
|
hidden = self.first_block(field_emb, v_emb_flat)
|
|
258
320
|
hidden = self.block_dropout(hidden)
|
|
259
321
|
for block in self.hidden_blocks:
|