nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/models/ranking/din.py
CHANGED
|
@@ -12,7 +12,12 @@ import torch
|
|
|
12
12
|
import torch.nn as nn
|
|
13
13
|
|
|
14
14
|
from nextrec.basic.model import BaseModel
|
|
15
|
-
from nextrec.basic.layers import
|
|
15
|
+
from nextrec.basic.layers import (
|
|
16
|
+
EmbeddingLayer,
|
|
17
|
+
MLP,
|
|
18
|
+
AttentionPoolingLayer,
|
|
19
|
+
PredictionLayer,
|
|
20
|
+
)
|
|
16
21
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
17
22
|
|
|
18
23
|
|
|
@@ -22,78 +27,91 @@ class DIN(BaseModel):
|
|
|
22
27
|
return "DIN"
|
|
23
28
|
|
|
24
29
|
@property
|
|
25
|
-
def
|
|
30
|
+
def default_task(self):
|
|
26
31
|
return "binary"
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
dense_features: list[DenseFeature],
|
|
36
|
+
sparse_features: list[SparseFeature],
|
|
37
|
+
sequence_features: list[SequenceFeature],
|
|
38
|
+
mlp_params: dict,
|
|
39
|
+
attention_hidden_units: list[int] = [80, 40],
|
|
40
|
+
attention_activation: str = "sigmoid",
|
|
41
|
+
attention_use_softmax: bool = True,
|
|
42
|
+
target: list[str] = [],
|
|
43
|
+
task: str | list[str] | None = None,
|
|
44
|
+
optimizer: str = "adam",
|
|
45
|
+
optimizer_params: dict = {},
|
|
46
|
+
loss: str | nn.Module | None = "bce",
|
|
47
|
+
loss_params: dict | list[dict] | None = None,
|
|
48
|
+
device: str = "cpu",
|
|
49
|
+
embedding_l1_reg=1e-6,
|
|
50
|
+
dense_l1_reg=1e-5,
|
|
51
|
+
embedding_l2_reg=1e-5,
|
|
52
|
+
dense_l2_reg=1e-4,
|
|
53
|
+
**kwargs,
|
|
54
|
+
):
|
|
55
|
+
|
|
48
56
|
super(DIN, self).__init__(
|
|
49
57
|
dense_features=dense_features,
|
|
50
58
|
sparse_features=sparse_features,
|
|
51
59
|
sequence_features=sequence_features,
|
|
52
60
|
target=target,
|
|
53
|
-
task=self.
|
|
61
|
+
task=task or self.default_task,
|
|
54
62
|
device=device,
|
|
55
63
|
embedding_l1_reg=embedding_l1_reg,
|
|
56
64
|
dense_l1_reg=dense_l1_reg,
|
|
57
65
|
embedding_l2_reg=embedding_l2_reg,
|
|
58
66
|
dense_l2_reg=dense_l2_reg,
|
|
59
|
-
|
|
60
|
-
**kwargs
|
|
67
|
+
**kwargs,
|
|
61
68
|
)
|
|
62
69
|
|
|
63
70
|
self.loss = loss
|
|
64
71
|
if self.loss is None:
|
|
65
72
|
self.loss = "bce"
|
|
66
|
-
|
|
73
|
+
|
|
67
74
|
# Features classification
|
|
68
75
|
# DIN requires: candidate item + user behavior sequence + other features
|
|
69
76
|
if len(sequence_features) == 0:
|
|
70
|
-
raise ValueError(
|
|
71
|
-
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"DIN requires at least one sequence feature for user behavior history"
|
|
79
|
+
)
|
|
80
|
+
|
|
72
81
|
self.behavior_feature = sequence_features[0] # User behavior sequence
|
|
73
|
-
self.candidate_feature =
|
|
74
|
-
|
|
82
|
+
self.candidate_feature = (
|
|
83
|
+
sparse_features[-1] if sparse_features else None
|
|
84
|
+
) # Candidate item
|
|
85
|
+
|
|
75
86
|
# Other features (excluding behavior sequence in final concatenation)
|
|
76
|
-
self.other_sparse_features =
|
|
87
|
+
self.other_sparse_features = (
|
|
88
|
+
sparse_features[:-1] if self.candidate_feature else sparse_features
|
|
89
|
+
)
|
|
77
90
|
self.dense_features_list = dense_features
|
|
78
|
-
|
|
91
|
+
|
|
79
92
|
# All features for embedding
|
|
80
93
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
81
94
|
|
|
82
95
|
# Embedding layer
|
|
83
96
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
84
|
-
|
|
97
|
+
|
|
85
98
|
# Attention layer for behavior sequence
|
|
86
99
|
behavior_emb_dim = self.behavior_feature.embedding_dim
|
|
87
100
|
self.candidate_attention_proj = None
|
|
88
|
-
if
|
|
89
|
-
self.
|
|
101
|
+
if (
|
|
102
|
+
self.candidate_feature is not None
|
|
103
|
+
and self.candidate_feature.embedding_dim != behavior_emb_dim
|
|
104
|
+
):
|
|
105
|
+
self.candidate_attention_proj = nn.Linear(
|
|
106
|
+
self.candidate_feature.embedding_dim, behavior_emb_dim
|
|
107
|
+
)
|
|
90
108
|
self.attention = AttentionPoolingLayer(
|
|
91
109
|
embedding_dim=behavior_emb_dim,
|
|
92
110
|
hidden_units=attention_hidden_units,
|
|
93
111
|
activation=attention_activation,
|
|
94
|
-
use_softmax=attention_use_softmax
|
|
112
|
+
use_softmax=attention_use_softmax,
|
|
95
113
|
)
|
|
96
|
-
|
|
114
|
+
|
|
97
115
|
# Calculate MLP input dimension
|
|
98
116
|
# candidate + attention_pooled_behavior + other_sparse + dense
|
|
99
117
|
mlp_input_dim = 0
|
|
@@ -101,16 +119,18 @@ class DIN(BaseModel):
|
|
|
101
119
|
mlp_input_dim += self.candidate_feature.embedding_dim
|
|
102
120
|
mlp_input_dim += behavior_emb_dim # attention pooled
|
|
103
121
|
mlp_input_dim += sum([f.embedding_dim for f in self.other_sparse_features])
|
|
104
|
-
mlp_input_dim += sum(
|
|
105
|
-
|
|
122
|
+
mlp_input_dim += sum(
|
|
123
|
+
[getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
|
|
124
|
+
)
|
|
125
|
+
|
|
106
126
|
# MLP for final prediction
|
|
107
127
|
self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
|
|
108
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
128
|
+
self.prediction_layer = PredictionLayer(task_type=self.task)
|
|
109
129
|
|
|
110
130
|
# Register regularization weights
|
|
111
131
|
self.register_regularization_weights(
|
|
112
|
-
embedding_attr=
|
|
113
|
-
include_modules=[
|
|
132
|
+
embedding_attr="embedding",
|
|
133
|
+
include_modules=["attention", "mlp", "candidate_attention_proj"],
|
|
114
134
|
)
|
|
115
135
|
|
|
116
136
|
self.compile(
|
|
@@ -123,61 +143,69 @@ class DIN(BaseModel):
|
|
|
123
143
|
def forward(self, x):
|
|
124
144
|
# Get candidate item embedding
|
|
125
145
|
if self.candidate_feature:
|
|
126
|
-
candidate_emb = self.embedding.embed_dict[
|
|
146
|
+
candidate_emb = self.embedding.embed_dict[
|
|
147
|
+
self.candidate_feature.embedding_name
|
|
148
|
+
](
|
|
127
149
|
x[self.candidate_feature.name].long()
|
|
128
150
|
) # [B, emb_dim]
|
|
129
151
|
else:
|
|
130
152
|
candidate_emb = None
|
|
131
|
-
|
|
153
|
+
|
|
132
154
|
# Get behavior sequence embedding
|
|
133
155
|
behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
|
|
134
156
|
behavior_emb = self.embedding.embed_dict[self.behavior_feature.embedding_name](
|
|
135
157
|
behavior_seq
|
|
136
158
|
) # [B, seq_len, emb_dim]
|
|
137
|
-
|
|
159
|
+
|
|
138
160
|
# Create mask for padding
|
|
139
161
|
if self.behavior_feature.padding_idx is not None:
|
|
140
|
-
mask = (
|
|
162
|
+
mask = (
|
|
163
|
+
(behavior_seq != self.behavior_feature.padding_idx)
|
|
164
|
+
.unsqueeze(-1)
|
|
165
|
+
.float()
|
|
166
|
+
)
|
|
141
167
|
else:
|
|
142
168
|
mask = (behavior_seq != 0).unsqueeze(-1).float()
|
|
143
|
-
|
|
169
|
+
|
|
144
170
|
# Apply attention pooling
|
|
145
171
|
if candidate_emb is not None:
|
|
146
172
|
candidate_query = candidate_emb
|
|
147
173
|
if self.candidate_attention_proj is not None:
|
|
148
174
|
candidate_query = self.candidate_attention_proj(candidate_query)
|
|
149
175
|
pooled_behavior = self.attention(
|
|
150
|
-
query=candidate_query,
|
|
151
|
-
keys=behavior_emb,
|
|
152
|
-
mask=mask
|
|
176
|
+
query=candidate_query, keys=behavior_emb, mask=mask
|
|
153
177
|
) # [B, emb_dim]
|
|
154
178
|
else:
|
|
155
179
|
# If no candidate, use mean pooling
|
|
156
|
-
pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (
|
|
157
|
-
|
|
180
|
+
pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (
|
|
181
|
+
mask.sum(dim=1) + 1e-9
|
|
182
|
+
)
|
|
183
|
+
|
|
158
184
|
# Get other features
|
|
159
185
|
other_embeddings = []
|
|
160
|
-
|
|
186
|
+
|
|
161
187
|
if candidate_emb is not None:
|
|
162
188
|
other_embeddings.append(candidate_emb)
|
|
163
|
-
|
|
189
|
+
|
|
164
190
|
other_embeddings.append(pooled_behavior)
|
|
165
|
-
|
|
191
|
+
|
|
166
192
|
# Other sparse features
|
|
167
193
|
for feat in self.other_sparse_features:
|
|
168
|
-
feat_emb = self.embedding.embed_dict[feat.embedding_name](
|
|
194
|
+
feat_emb = self.embedding.embed_dict[feat.embedding_name](
|
|
195
|
+
x[feat.name].long()
|
|
196
|
+
)
|
|
169
197
|
other_embeddings.append(feat_emb)
|
|
170
|
-
|
|
198
|
+
|
|
171
199
|
# Dense features
|
|
172
200
|
for feat in self.dense_features_list:
|
|
173
201
|
val = x[feat.name].float()
|
|
174
202
|
if val.dim() == 1:
|
|
175
203
|
val = val.unsqueeze(1)
|
|
176
204
|
other_embeddings.append(val)
|
|
177
|
-
|
|
205
|
+
|
|
178
206
|
# Concatenate all features
|
|
179
207
|
concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
|
|
180
|
-
|
|
208
|
+
|
|
181
209
|
# MLP prediction
|
|
182
210
|
y = self.mlp(concat_input) # [B, 1]
|
|
183
211
|
return self.prediction_layer(y)
|
|
@@ -28,62 +28,72 @@ class FiBiNET(BaseModel):
|
|
|
28
28
|
return "FiBiNET"
|
|
29
29
|
|
|
30
30
|
@property
|
|
31
|
-
def
|
|
31
|
+
def default_task(self):
|
|
32
32
|
return "binary"
|
|
33
|
-
|
|
34
|
-
def __init__(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
dense_features: list[DenseFeature] | list = [],
|
|
37
|
+
sparse_features: list[SparseFeature] | list = [],
|
|
38
|
+
sequence_features: list[SequenceFeature] | list = [],
|
|
39
|
+
mlp_params: dict = {},
|
|
40
|
+
bilinear_type: str = "field_interaction",
|
|
41
|
+
senet_reduction: int = 3,
|
|
42
|
+
target: list[str] | list = [],
|
|
43
|
+
task: str | list[str] | None = None,
|
|
44
|
+
optimizer: str = "adam",
|
|
45
|
+
optimizer_params: dict = {},
|
|
46
|
+
loss: str | nn.Module | None = "bce",
|
|
47
|
+
loss_params: dict | list[dict] | None = None,
|
|
48
|
+
device: str = "cpu",
|
|
49
|
+
embedding_l1_reg=1e-6,
|
|
50
|
+
dense_l1_reg=1e-5,
|
|
51
|
+
embedding_l2_reg=1e-5,
|
|
52
|
+
dense_l2_reg=1e-4,
|
|
53
|
+
**kwargs,
|
|
54
|
+
):
|
|
55
|
+
|
|
53
56
|
super(FiBiNET, self).__init__(
|
|
54
57
|
dense_features=dense_features,
|
|
55
58
|
sparse_features=sparse_features,
|
|
56
59
|
sequence_features=sequence_features,
|
|
57
60
|
target=target,
|
|
58
|
-
task=self.
|
|
61
|
+
task=task or self.default_task,
|
|
59
62
|
device=device,
|
|
60
63
|
embedding_l1_reg=embedding_l1_reg,
|
|
61
64
|
dense_l1_reg=dense_l1_reg,
|
|
62
65
|
embedding_l2_reg=embedding_l2_reg,
|
|
63
66
|
dense_l2_reg=dense_l2_reg,
|
|
64
|
-
|
|
65
|
-
**kwargs
|
|
67
|
+
**kwargs,
|
|
66
68
|
)
|
|
67
69
|
|
|
68
70
|
self.loss = loss
|
|
69
71
|
if self.loss is None:
|
|
70
72
|
self.loss = "bce"
|
|
71
|
-
|
|
73
|
+
|
|
72
74
|
self.linear_features = sparse_features + sequence_features
|
|
73
75
|
self.deep_features = dense_features + sparse_features + sequence_features
|
|
74
76
|
self.interaction_features = sparse_features + sequence_features
|
|
75
77
|
|
|
76
78
|
if len(self.interaction_features) < 2:
|
|
77
|
-
raise ValueError(
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"FiBiNET requires at least two sparse/sequence features for interactions."
|
|
81
|
+
)
|
|
78
82
|
|
|
79
83
|
self.embedding = EmbeddingLayer(features=self.deep_features)
|
|
80
84
|
|
|
81
85
|
self.num_fields = len(self.interaction_features)
|
|
82
86
|
self.embedding_dim = self.interaction_features[0].embedding_dim
|
|
83
|
-
if any(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
+
if any(
|
|
88
|
+
f.embedding_dim != self.embedding_dim for f in self.interaction_features
|
|
89
|
+
):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"All interaction features must share the same embedding_dim in FiBiNET."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.senet = SENETLayer(
|
|
95
|
+
num_fields=self.num_fields, reduction_ratio=senet_reduction
|
|
96
|
+
)
|
|
87
97
|
self.bilinear_standard = BiLinearInteractionLayer(
|
|
88
98
|
input_dim=self.embedding_dim,
|
|
89
99
|
num_fields=self.num_fields,
|
|
@@ -101,12 +111,18 @@ class FiBiNET(BaseModel):
|
|
|
101
111
|
num_pairs = self.num_fields * (self.num_fields - 1) // 2
|
|
102
112
|
interaction_dim = num_pairs * self.embedding_dim * 2
|
|
103
113
|
self.mlp = MLP(input_dim=interaction_dim, **mlp_params)
|
|
104
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
114
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task)
|
|
105
115
|
|
|
106
116
|
# Register regularization weights
|
|
107
117
|
self.register_regularization_weights(
|
|
108
|
-
embedding_attr=
|
|
109
|
-
include_modules=[
|
|
118
|
+
embedding_attr="embedding",
|
|
119
|
+
include_modules=[
|
|
120
|
+
"linear",
|
|
121
|
+
"senet",
|
|
122
|
+
"bilinear_standard",
|
|
123
|
+
"bilinear_senet",
|
|
124
|
+
"mlp",
|
|
125
|
+
],
|
|
110
126
|
)
|
|
111
127
|
|
|
112
128
|
self.compile(
|
|
@@ -117,10 +133,14 @@ class FiBiNET(BaseModel):
|
|
|
117
133
|
)
|
|
118
134
|
|
|
119
135
|
def forward(self, x):
|
|
120
|
-
input_linear = self.embedding(
|
|
136
|
+
input_linear = self.embedding(
|
|
137
|
+
x=x, features=self.linear_features, squeeze_dim=True
|
|
138
|
+
)
|
|
121
139
|
y_linear = self.linear(input_linear)
|
|
122
140
|
|
|
123
|
-
field_emb = self.embedding(
|
|
141
|
+
field_emb = self.embedding(
|
|
142
|
+
x=x, features=self.interaction_features, squeeze_dim=False
|
|
143
|
+
)
|
|
124
144
|
senet_emb = self.senet(field_emb)
|
|
125
145
|
|
|
126
146
|
bilinear_standard = self.bilinear_standard(field_emb).flatten(start_dim=1)
|
nextrec/models/ranking/fm.py
CHANGED
|
@@ -9,7 +9,12 @@ Reference:
|
|
|
9
9
|
import torch.nn as nn
|
|
10
10
|
|
|
11
11
|
from nextrec.basic.model import BaseModel
|
|
12
|
-
from nextrec.basic.layers import
|
|
12
|
+
from nextrec.basic.layers import (
|
|
13
|
+
EmbeddingLayer,
|
|
14
|
+
FM as FMInteraction,
|
|
15
|
+
LR,
|
|
16
|
+
PredictionLayer,
|
|
17
|
+
)
|
|
13
18
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
14
19
|
|
|
15
20
|
|
|
@@ -19,44 +24,46 @@ class FM(BaseModel):
|
|
|
19
24
|
return "FM"
|
|
20
25
|
|
|
21
26
|
@property
|
|
22
|
-
def
|
|
27
|
+
def default_task(self):
|
|
23
28
|
return "binary"
|
|
24
|
-
|
|
25
|
-
def __init__(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
dense_features: list[DenseFeature] | list = [],
|
|
33
|
+
sparse_features: list[SparseFeature] | list = [],
|
|
34
|
+
sequence_features: list[SequenceFeature] | list = [],
|
|
35
|
+
target: list[str] | list = [],
|
|
36
|
+
task: str | list[str] | None = None,
|
|
37
|
+
optimizer: str = "adam",
|
|
38
|
+
optimizer_params: dict = {},
|
|
39
|
+
loss: str | nn.Module | None = "bce",
|
|
40
|
+
loss_params: dict | list[dict] | None = None,
|
|
41
|
+
device: str = "cpu",
|
|
42
|
+
embedding_l1_reg=1e-6,
|
|
43
|
+
dense_l1_reg=1e-5,
|
|
44
|
+
embedding_l2_reg=1e-5,
|
|
45
|
+
dense_l2_reg=1e-4,
|
|
46
|
+
**kwargs,
|
|
47
|
+
):
|
|
48
|
+
|
|
41
49
|
super(FM, self).__init__(
|
|
42
50
|
dense_features=dense_features,
|
|
43
51
|
sparse_features=sparse_features,
|
|
44
52
|
sequence_features=sequence_features,
|
|
45
53
|
target=target,
|
|
46
|
-
task=self.
|
|
54
|
+
task=task or self.default_task,
|
|
47
55
|
device=device,
|
|
48
56
|
embedding_l1_reg=embedding_l1_reg,
|
|
49
57
|
dense_l1_reg=dense_l1_reg,
|
|
50
58
|
embedding_l2_reg=embedding_l2_reg,
|
|
51
59
|
dense_l2_reg=dense_l2_reg,
|
|
52
|
-
|
|
53
|
-
**kwargs
|
|
60
|
+
**kwargs,
|
|
54
61
|
)
|
|
55
62
|
|
|
56
63
|
self.loss = loss
|
|
57
64
|
if self.loss is None:
|
|
58
65
|
self.loss = "bce"
|
|
59
|
-
|
|
66
|
+
|
|
60
67
|
self.fm_features = sparse_features + sequence_features
|
|
61
68
|
if len(self.fm_features) == 0:
|
|
62
69
|
raise ValueError("FM requires at least one sparse or sequence feature.")
|
|
@@ -66,12 +73,11 @@ class FM(BaseModel):
|
|
|
66
73
|
fm_input_dim = sum([f.embedding_dim for f in self.fm_features])
|
|
67
74
|
self.linear = LR(fm_input_dim)
|
|
68
75
|
self.fm = FMInteraction(reduce_sum=True)
|
|
69
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
76
|
+
self.prediction_layer = PredictionLayer(task_type=self.task)
|
|
70
77
|
|
|
71
78
|
# Register regularization weights
|
|
72
79
|
self.register_regularization_weights(
|
|
73
|
-
embedding_attr=
|
|
74
|
-
include_modules=['linear']
|
|
80
|
+
embedding_attr="embedding", include_modules=["linear"]
|
|
75
81
|
)
|
|
76
82
|
|
|
77
83
|
self.compile(
|