nextrec 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. nextrec/__init__.py +0 -30
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/layers.py +32 -15
  4. nextrec/basic/loggers.py +1 -1
  5. nextrec/basic/model.py +440 -189
  6. nextrec/basic/session.py +4 -2
  7. nextrec/data/__init__.py +0 -25
  8. nextrec/data/data_processing.py +31 -19
  9. nextrec/data/dataloader.py +51 -16
  10. nextrec/models/generative/__init__.py +0 -5
  11. nextrec/models/generative/hstu.py +3 -2
  12. nextrec/models/match/__init__.py +0 -13
  13. nextrec/models/match/dssm.py +0 -1
  14. nextrec/models/match/dssm_v2.py +0 -1
  15. nextrec/models/match/mind.py +0 -1
  16. nextrec/models/match/sdm.py +0 -1
  17. nextrec/models/match/youtube_dnn.py +0 -1
  18. nextrec/models/multi_task/__init__.py +0 -0
  19. nextrec/models/multi_task/esmm.py +5 -7
  20. nextrec/models/multi_task/mmoe.py +10 -6
  21. nextrec/models/multi_task/ple.py +10 -6
  22. nextrec/models/multi_task/poso.py +9 -6
  23. nextrec/models/multi_task/share_bottom.py +10 -7
  24. nextrec/models/ranking/__init__.py +0 -27
  25. nextrec/models/ranking/afm.py +113 -21
  26. nextrec/models/ranking/autoint.py +15 -9
  27. nextrec/models/ranking/dcn.py +8 -11
  28. nextrec/models/ranking/deepfm.py +5 -5
  29. nextrec/models/ranking/dien.py +4 -4
  30. nextrec/models/ranking/din.py +4 -4
  31. nextrec/models/ranking/fibinet.py +4 -4
  32. nextrec/models/ranking/fm.py +4 -4
  33. nextrec/models/ranking/masknet.py +4 -5
  34. nextrec/models/ranking/pnn.py +4 -4
  35. nextrec/models/ranking/widedeep.py +4 -4
  36. nextrec/models/ranking/xdeepfm.py +4 -4
  37. nextrec/utils/__init__.py +7 -3
  38. nextrec/utils/device.py +32 -1
  39. nextrec/utils/distributed.py +114 -0
  40. nextrec/utils/synthetic_data.py +413 -0
  41. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
  42. nextrec-0.4.1.dist-info/RECORD +66 -0
  43. nextrec-0.3.5.dist-info/RECORD +0 -63
  44. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
  45. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,27 +0,0 @@
1
- from .fm import FM
2
- from .afm import AFM
3
- from .masknet import MaskNet
4
- from .pnn import PNN
5
- from .deepfm import DeepFM
6
- from .autoint import AutoInt
7
- from .widedeep import WideDeep
8
- from .xdeepfm import xDeepFM
9
- from .dcn import DCN
10
- from .fibinet import FiBiNET
11
- from .din import DIN
12
- from .dien import DIEN
13
-
14
- __all__ = [
15
- 'DeepFM',
16
- 'AutoInt',
17
- 'WideDeep',
18
- 'xDeepFM',
19
- 'DCN',
20
- 'DIN',
21
- 'DIEN',
22
- 'FM',
23
- 'AFM',
24
- 'MaskNet',
25
- 'PNN',
26
- 'FiBiNET',
27
- ]
@@ -1,17 +1,46 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 06/12/2025
4
+ Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
- [1] Xiao J, Ye H, He X, et al. Attentional factorization machines: Learning the weight of
7
- feature interactions via attention networks[C]//IJCAI. 2017: 3119-3125.
6
+ [1] Xiao J, Ye H, He X, et al. Attentional factorization machines: Learning the weight of
7
+ feature interactions via attention networks[C]//IJCAI. 2017: 3119-3125.
8
+
9
+ Attentional Factorization Machine (AFM) builds on FM by learning an importance
10
+ weight for every second-order interaction instead of treating all pairs equally.
11
+ It retains FM’s linear (first-order) component for sparsity-friendly modeling,
12
+ while using an attention network to reweight the element-wise product of field
13
+ embeddings before aggregation.
14
+
15
+ In each forward pass:
16
+ (1) Embed each field and compute pairwise element-wise products v_i ⊙ v_j
17
+ (2) Pass interactions through an attention MLP (ReLU + projection) to score them
18
+ (3) Softmax-normalize scores to obtain interaction weights
19
+ (4) Weighted sum of interactions -> linear projection -> add FM first-order term
20
+
21
+ Key Advantages:
22
+ - Learns which feature pairs contribute most via attention weights
23
+ - Keeps FM efficiency and interpretability by preserving first-order terms
24
+ - Softmax-normalized reweighting reduces noise from uninformative interactions
25
+
26
+ AFM 在 FM 的二阶交互上引入注意力,为每个特征对学习重要性权重;同时保留 FM 的一阶项,
27
+ 保持对稀疏特征的友好与可解释性。具体流程:
28
+ (1) 对各字段做 embedding,并计算所有特征对的元素积 v_i ⊙ v_j
29
+ (2) 经由注意力 MLP(ReLU + 线性映射)得到交互得分
30
+ (3) 通过 softmax 归一化交互得分,得到权重
31
+ (4) 将加权交互求和、线性映射,再与一阶项相加得到最终预测
32
+
33
+ 主要优点:
34
+ - 注意力显式告诉哪些特征对更重要
35
+ - 保留 FM 的效率和可解释性
36
+ - softmax 归一化减弱噪声交互的影响
8
37
  """
9
38
 
10
39
  import torch
11
40
  import torch.nn as nn
12
41
 
13
42
  from nextrec.basic.model import BaseModel
14
- from nextrec.basic.layers import EmbeddingLayer, LR, PredictionLayer
43
+ from nextrec.basic.layers import EmbeddingLayer, LR, PredictionLayer, InputMask
15
44
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
45
 
17
46
 
@@ -21,7 +50,7 @@ class AFM(BaseModel):
21
50
  return "AFM"
22
51
 
23
52
  @property
24
- def task_type(self):
53
+ def default_task(self):
25
54
  return "binary"
26
55
 
27
56
  def __init__(self,
@@ -31,6 +60,7 @@ class AFM(BaseModel):
31
60
  attention_dim: int = 32,
32
61
  attention_dropout: float = 0.0,
33
62
  target: list[str] | list = [],
63
+ task: str | list[str] | None = None,
34
64
  optimizer: str = "adam",
35
65
  optimizer_params: dict = {},
36
66
  loss: str | nn.Module | None = "bce",
@@ -46,45 +76,64 @@ class AFM(BaseModel):
46
76
  sparse_features=sparse_features,
47
77
  sequence_features=sequence_features,
48
78
  target=target,
49
- task=self.task_type,
79
+ task=task or self.default_task,
50
80
  device=device,
51
81
  embedding_l1_reg=embedding_l1_reg,
52
82
  dense_l1_reg=dense_l1_reg,
53
83
  embedding_l2_reg=embedding_l2_reg,
54
84
  dense_l2_reg=dense_l2_reg,
55
- early_stop_patience=20,
56
85
  **kwargs
57
86
  )
58
87
 
59
- self.loss = loss
60
- if self.loss is None:
61
- self.loss = "bce"
88
+ if target is None:
89
+ target = []
90
+ if optimizer_params is None:
91
+ optimizer_params = {}
92
+ if loss is None:
93
+ loss = "bce"
62
94
 
63
95
  self.fm_features = sparse_features + sequence_features
64
96
  if len(self.fm_features) < 2:
65
97
  raise ValueError("AFM requires at least two sparse/sequence features to build pairwise interactions.")
66
98
 
67
- # Assume uniform embedding dimension across FM fields
99
+ # make sure all embedding dimension are the same for FM features
68
100
  self.embedding_dim = self.fm_features[0].embedding_dim
69
101
  if any(f.embedding_dim != self.embedding_dim for f in self.fm_features):
70
102
  raise ValueError("All FM features must share the same embedding_dim for AFM.")
71
103
 
72
- self.embedding = EmbeddingLayer(features=self.fm_features)
73
-
74
- fm_input_dim = sum([f.embedding_dim for f in self.fm_features])
75
- self.linear = LR(fm_input_dim)
104
+ self.embedding = EmbeddingLayer(features=self.fm_features) # [Batch, Field, Dim ]
105
+
106
+ # First-order terms: dense linear + one hot embeddings
107
+ self.dense_features = list(dense_features)
108
+ dense_input_dim = sum([f.input_dim for f in self.dense_features])
109
+ self.linear_dense = nn.Linear(dense_input_dim, 1, bias=True) if dense_input_dim > 0 else None
110
+
111
+ # First-order term: sparse/sequence features one-hot
112
+ # **INFO**: source paper does not contain sequence features in experiments,
113
+ # but we implement it here for completeness. if you want follow the paper strictly,
114
+ # remove sequence features from fm_features.
115
+ self.first_order_embeddings = nn.ModuleDict()
116
+ for feature in self.fm_features:
117
+ if feature.embedding_name in self.first_order_embeddings: # shared embedding
118
+ continue
119
+ emb = nn.Embedding(num_embeddings=feature.vocab_size, embedding_dim=1, padding_idx=feature.padding_idx) # equal to one-hot encoding weight
120
+ # nn.init.zeros_(emb.weight)
121
+ self.first_order_embeddings[feature.embedding_name] = emb
76
122
 
77
123
  self.attention_linear = nn.Linear(self.embedding_dim, attention_dim)
78
124
  self.attention_p = nn.Linear(attention_dim, 1, bias=False)
79
125
  self.attention_dropout = nn.Dropout(attention_dropout)
80
126
  self.output_projection = nn.Linear(self.embedding_dim, 1, bias=False)
81
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
127
+ self.prediction_layer = PredictionLayer(task_type=self.default_task)
128
+ self.input_mask = InputMask()
82
129
 
83
130
  # Register regularization weights
84
131
  self.register_regularization_weights(
85
132
  embedding_attr='embedding',
86
- include_modules=['linear', 'attention_linear', 'attention_p', 'output_projection']
133
+ include_modules=['linear_dense', 'attention_linear', 'attention_p', 'output_projection']
87
134
  )
135
+ # add first-order embeddings to embedding regularization list
136
+ self.embedding_params.extend(emb.weight for emb in self.first_order_embeddings.values())
88
137
 
89
138
  self.compile(
90
139
  optimizer=optimizer,
@@ -95,10 +144,53 @@ class AFM(BaseModel):
95
144
 
96
145
  def forward(self, x):
97
146
  field_emb = self.embedding(x=x, features=self.fm_features, squeeze_dim=False) # [B, F, D]
98
- input_linear = field_emb.flatten(start_dim=1)
99
- y_linear = self.linear(input_linear)
147
+ batch_size = field_emb.size(0)
148
+ y_linear = torch.zeros(batch_size, 1, device=field_emb.device)
149
+
150
+ # First-order dense part
151
+ if self.linear_dense is not None:
152
+ dense_inputs = [x[f.name].float().view(batch_size, -1) for f in self.dense_features]
153
+ dense_stack = torch.cat(dense_inputs, dim=1) if dense_inputs else None
154
+ if dense_stack is not None:
155
+ y_linear = y_linear + self.linear_dense(dense_stack)
156
+
157
+ # First-order sparse/sequence part
158
+ first_order_terms = []
159
+ for feature in self.fm_features:
160
+ emb = self.first_order_embeddings[feature.embedding_name]
161
+ if isinstance(feature, SparseFeature):
162
+ term = emb(x[feature.name].long()) # [B, 1]
163
+ else: # SequenceFeature
164
+ seq_input = x[feature.name].long() # [B, 1]
165
+ if feature.max_len is not None and seq_input.size(1) > feature.max_len:
166
+ seq_input = seq_input[:, -feature.max_len :]
167
+ mask = self.input_mask(x, feature, seq_input).squeeze(1) # [B, 1]
168
+ seq_weight = emb(seq_input).squeeze(-1) # [B, L]
169
+ term = (seq_weight * mask).sum(dim=1, keepdim=True) # [B, 1]
170
+ first_order_terms.append(term)
171
+ if first_order_terms:
172
+ y_linear = y_linear + torch.sum(torch.cat(first_order_terms, dim=1), dim=1, keepdim=True)
100
173
 
101
174
  interactions = []
175
+ feature_values = []
176
+ for feature in self.fm_features:
177
+ value = x.get(f"{feature.name}_value")
178
+ if value is not None:
179
+ value = value.float()
180
+ if value.dim() == 1:
181
+ value = value.unsqueeze(-1)
182
+ else:
183
+ if isinstance(feature, SequenceFeature):
184
+ seq_input = x[feature.name].long()
185
+ if feature.max_len is not None and seq_input.size(1) > feature.max_len:
186
+ seq_input = seq_input[:, -feature.max_len :]
187
+ value = self.input_mask(x, feature, seq_input).sum(dim=2) # [B, 1]
188
+ else:
189
+ value = torch.ones(batch_size, 1, device=field_emb.device)
190
+ feature_values.append(value)
191
+ feature_values_tensor = torch.cat(feature_values, dim=1).unsqueeze(-1) # [B, F, 1]
192
+ field_emb = field_emb * feature_values_tensor
193
+
102
194
  num_fields = field_emb.shape[1]
103
195
  for i in range(num_fields - 1):
104
196
  vi = field_emb[:, i, :]
@@ -107,7 +199,7 @@ class AFM(BaseModel):
107
199
  interactions.append(vi * vj)
108
200
 
109
201
  pair_tensor = torch.stack(interactions, dim=1) # [B, num_pairs, D]
110
- attention_scores = torch.tanh(self.attention_linear(pair_tensor))
202
+ attention_scores = torch.relu(self.attention_linear(pair_tensor))
111
203
  attention_scores = self.attention_p(attention_scores) # [B, num_pairs, 1]
112
204
  attention_weights = torch.softmax(attention_scores, dim=1)
113
205
 
@@ -68,7 +68,7 @@ class AutoInt(BaseModel):
68
68
  return "AutoInt"
69
69
 
70
70
  @property
71
- def task_type(self):
71
+ def default_task(self):
72
72
  return "binary"
73
73
 
74
74
  def __init__(self,
@@ -80,9 +80,10 @@ class AutoInt(BaseModel):
80
80
  att_head_num: int = 2,
81
81
  att_dropout: float = 0.0,
82
82
  att_use_residual: bool = True,
83
- target: list[str] = [],
83
+ target: list[str] | None = None,
84
+ task: str | list[str] | None = None,
84
85
  optimizer: str = "adam",
85
- optimizer_params: dict = {},
86
+ optimizer_params: dict | None = None,
86
87
  loss: str | nn.Module | None = "bce",
87
88
  loss_params: dict | list[dict] | None = None,
88
89
  device: str = 'cpu',
@@ -97,24 +98,29 @@ class AutoInt(BaseModel):
97
98
  sparse_features=sparse_features,
98
99
  sequence_features=sequence_features,
99
100
  target=target,
100
- task=self.task_type,
101
+ task=task or self.default_task,
101
102
  device=device,
102
103
  embedding_l1_reg=embedding_l1_reg,
103
104
  dense_l1_reg=dense_l1_reg,
104
105
  embedding_l2_reg=embedding_l2_reg,
105
106
  dense_l2_reg=dense_l2_reg,
106
- early_stop_patience=20,
107
107
  **kwargs
108
108
  )
109
109
 
110
- self.loss = loss
111
- if self.loss is None:
112
- self.loss = "bce"
110
+ if target is None:
111
+ target = []
112
+ if optimizer_params is None:
113
+ optimizer_params = {}
114
+ if loss is None:
115
+ loss = "bce"
113
116
 
114
117
  self.att_layer_num = att_layer_num
115
118
  self.att_embedding_dim = att_embedding_dim
116
119
 
117
120
  # Use sparse and sequence features for interaction
121
+ # **INFO**: this is different from the original paper, we also include dense features
122
+ # if you want to follow the paper strictly, set dense_features=[]
123
+ # or modify the code accordingly
118
124
  self.interaction_features = dense_features + sparse_features + sequence_features
119
125
 
120
126
  # All features for embedding
@@ -147,7 +153,7 @@ class AutoInt(BaseModel):
147
153
 
148
154
  # Final prediction layer
149
155
  self.fc = nn.Linear(num_fields * att_embedding_dim, 1)
150
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
156
+ self.prediction_layer = PredictionLayer(task_type=self.default_task)
151
157
 
152
158
  # Register regularization weights
153
159
  self.register_regularization_weights(
@@ -25,15 +25,11 @@ class CrossNetwork(nn.Module):
25
25
  self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
26
26
 
27
27
  def forward(self, x):
28
- """
29
- :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
30
- """
31
28
  x0 = x
32
29
  for i in range(self.num_layers):
33
30
  xw = self.w[i](x)
34
31
  x = x0 * xw + self.b[i] + x
35
- return x
36
-
32
+ return x # [batch_size, input_dim]
37
33
 
38
34
  class DCN(BaseModel):
39
35
  @property
@@ -41,9 +37,9 @@ class DCN(BaseModel):
41
37
  return "DCN"
42
38
 
43
39
  @property
44
- def task_type(self):
40
+ def default_task(self):
45
41
  return "binary"
46
-
42
+
47
43
  def __init__(self,
48
44
  dense_features: list[DenseFeature],
49
45
  sparse_features: list[SparseFeature],
@@ -51,6 +47,7 @@ class DCN(BaseModel):
51
47
  cross_num: int = 3,
52
48
  mlp_params: dict | None = None,
53
49
  target: list[str] = [],
50
+ task: str | list[str] | None = None,
54
51
  optimizer: str = "adam",
55
52
  optimizer_params: dict = {},
56
53
  loss: str | nn.Module | None = "bce",
@@ -67,13 +64,12 @@ class DCN(BaseModel):
67
64
  sparse_features=sparse_features,
68
65
  sequence_features=sequence_features,
69
66
  target=target,
70
- task=self.task_type,
67
+ task=task or self.default_task,
71
68
  device=device,
72
69
  embedding_l1_reg=embedding_l1_reg,
73
70
  dense_l1_reg=dense_l1_reg,
74
71
  embedding_l2_reg=embedding_l2_reg,
75
72
  dense_l2_reg=dense_l2_reg,
76
- early_stop_patience=20,
77
73
  **kwargs
78
74
  )
79
75
 
@@ -99,14 +95,15 @@ class DCN(BaseModel):
99
95
  if mlp_params is not None:
100
96
  self.use_dnn = True
101
97
  self.mlp = MLP(input_dim=input_dim, **mlp_params)
98
+ deep_dim = self.mlp.output_dim
102
99
  # Final layer combines cross and deep
103
- self.final_layer = nn.Linear(input_dim + 1, 1) # +1 for MLP output
100
+ self.final_layer = nn.Linear(input_dim + deep_dim, 1) # + deep_dim for MLP output
104
101
  else:
105
102
  self.use_dnn = False
106
103
  # Final layer only uses cross network output
107
104
  self.final_layer = nn.Linear(input_dim, 1)
108
105
 
109
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
106
+ self.prediction_layer = PredictionLayer(task_type=self.task)
110
107
 
111
108
  # Register regularization weights
112
109
  self.register_regularization_weights(
@@ -56,15 +56,16 @@ class DeepFM(BaseModel):
56
56
  return "DeepFM"
57
57
 
58
58
  @property
59
- def task_type(self):
59
+ def default_task(self):
60
60
  return "binary"
61
-
61
+
62
62
  def __init__(self,
63
63
  dense_features: list[DenseFeature]|list = [],
64
64
  sparse_features: list[SparseFeature]|list = [],
65
65
  sequence_features: list[SequenceFeature]|list = [],
66
66
  mlp_params: dict = {},
67
67
  target: list[str]|str = [],
68
+ task: str | list[str] | None = None,
68
69
  optimizer: str = "adam",
69
70
  optimizer_params: dict = {},
70
71
  loss: str | nn.Module | None = "bce",
@@ -80,13 +81,12 @@ class DeepFM(BaseModel):
80
81
  sparse_features=sparse_features,
81
82
  sequence_features=sequence_features,
82
83
  target=target,
83
- task=self.task_type,
84
+ task=task or self.default_task,
84
85
  device=device,
85
86
  embedding_l1_reg=embedding_l1_reg,
86
87
  dense_l1_reg=dense_l1_reg,
87
88
  embedding_l2_reg=embedding_l2_reg,
88
89
  dense_l2_reg=dense_l2_reg,
89
- early_stop_patience=20,
90
90
  **kwargs
91
91
  )
92
92
 
@@ -104,7 +104,7 @@ class DeepFM(BaseModel):
104
104
  self.linear = LR(fm_emb_dim_total)
105
105
  self.fm = FM(reduce_sum=True)
106
106
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
107
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
107
+ self.prediction_layer = PredictionLayer(task_type=self.default_task)
108
108
 
109
109
  # Register regularization weights
110
110
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
@@ -146,7 +146,7 @@ class DIEN(BaseModel):
146
146
  return "DIEN"
147
147
 
148
148
  @property
149
- def task_type(self):
149
+ def default_task(self):
150
150
  return "binary"
151
151
 
152
152
  def __init__(self,
@@ -159,6 +159,7 @@ class DIEN(BaseModel):
159
159
  attention_activation: str = 'sigmoid',
160
160
  use_negsampling: bool = False,
161
161
  target: list[str] = [],
162
+ task: str | list[str] | None = None,
162
163
  optimizer: str = "adam",
163
164
  optimizer_params: dict = {},
164
165
  loss: str | nn.Module | None = "bce",
@@ -175,13 +176,12 @@ class DIEN(BaseModel):
175
176
  sparse_features=sparse_features,
176
177
  sequence_features=sequence_features,
177
178
  target=target,
178
- task=self.task_type,
179
+ task=task or self.default_task,
179
180
  device=device,
180
181
  embedding_l1_reg=embedding_l1_reg,
181
182
  dense_l1_reg=dense_l1_reg,
182
183
  embedding_l2_reg=embedding_l2_reg,
183
184
  dense_l2_reg=dense_l2_reg,
184
- early_stop_patience=20,
185
185
  **kwargs
186
186
  )
187
187
 
@@ -235,7 +235,7 @@ class DIEN(BaseModel):
235
235
  mlp_input_dim += sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
236
236
  # MLP for final prediction
237
237
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
238
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
238
+ self.prediction_layer = PredictionLayer(task_type=self.task)
239
239
  # Register regularization weights
240
240
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['interest_extractor', 'interest_evolution', 'attention_layer', 'mlp', 'candidate_proj'])
241
241
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
@@ -22,7 +22,7 @@ class DIN(BaseModel):
22
22
  return "DIN"
23
23
 
24
24
  @property
25
- def task_type(self):
25
+ def default_task(self):
26
26
  return "binary"
27
27
 
28
28
  def __init__(self,
@@ -34,6 +34,7 @@ class DIN(BaseModel):
34
34
  attention_activation: str = 'sigmoid',
35
35
  attention_use_softmax: bool = True,
36
36
  target: list[str] = [],
37
+ task: str | list[str] | None = None,
37
38
  optimizer: str = "adam",
38
39
  optimizer_params: dict = {},
39
40
  loss: str | nn.Module | None = "bce",
@@ -50,13 +51,12 @@ class DIN(BaseModel):
50
51
  sparse_features=sparse_features,
51
52
  sequence_features=sequence_features,
52
53
  target=target,
53
- task=self.task_type,
54
+ task=task or self.default_task,
54
55
  device=device,
55
56
  embedding_l1_reg=embedding_l1_reg,
56
57
  dense_l1_reg=dense_l1_reg,
57
58
  embedding_l2_reg=embedding_l2_reg,
58
59
  dense_l2_reg=dense_l2_reg,
59
- early_stop_patience=20,
60
60
  **kwargs
61
61
  )
62
62
 
@@ -105,7 +105,7 @@ class DIN(BaseModel):
105
105
 
106
106
  # MLP for final prediction
107
107
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
108
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
108
+ self.prediction_layer = PredictionLayer(task_type=self.task)
109
109
 
110
110
  # Register regularization weights
111
111
  self.register_regularization_weights(
@@ -28,7 +28,7 @@ class FiBiNET(BaseModel):
28
28
  return "FiBiNET"
29
29
 
30
30
  @property
31
- def task_type(self):
31
+ def default_task(self):
32
32
  return "binary"
33
33
 
34
34
  def __init__(self,
@@ -39,6 +39,7 @@ class FiBiNET(BaseModel):
39
39
  bilinear_type: str = "field_interaction",
40
40
  senet_reduction: int = 3,
41
41
  target: list[str] | list = [],
42
+ task: str | list[str] | None = None,
42
43
  optimizer: str = "adam",
43
44
  optimizer_params: dict = {},
44
45
  loss: str | nn.Module | None = "bce",
@@ -55,13 +56,12 @@ class FiBiNET(BaseModel):
55
56
  sparse_features=sparse_features,
56
57
  sequence_features=sequence_features,
57
58
  target=target,
58
- task=self.task_type,
59
+ task=task or self.default_task,
59
60
  device=device,
60
61
  embedding_l1_reg=embedding_l1_reg,
61
62
  dense_l1_reg=dense_l1_reg,
62
63
  embedding_l2_reg=embedding_l2_reg,
63
64
  dense_l2_reg=dense_l2_reg,
64
- early_stop_patience=20,
65
65
  **kwargs
66
66
  )
67
67
 
@@ -101,7 +101,7 @@ class FiBiNET(BaseModel):
101
101
  num_pairs = self.num_fields * (self.num_fields - 1) // 2
102
102
  interaction_dim = num_pairs * self.embedding_dim * 2
103
103
  self.mlp = MLP(input_dim=interaction_dim, **mlp_params)
104
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
104
+ self.prediction_layer = PredictionLayer(task_type=self.default_task)
105
105
 
106
106
  # Register regularization weights
107
107
  self.register_regularization_weights(
@@ -19,7 +19,7 @@ class FM(BaseModel):
19
19
  return "FM"
20
20
 
21
21
  @property
22
- def task_type(self):
22
+ def default_task(self):
23
23
  return "binary"
24
24
 
25
25
  def __init__(self,
@@ -27,6 +27,7 @@ class FM(BaseModel):
27
27
  sparse_features: list[SparseFeature] | list = [],
28
28
  sequence_features: list[SequenceFeature] | list = [],
29
29
  target: list[str] | list = [],
30
+ task: str | list[str] | None = None,
30
31
  optimizer: str = "adam",
31
32
  optimizer_params: dict = {},
32
33
  loss: str | nn.Module | None = "bce",
@@ -43,13 +44,12 @@ class FM(BaseModel):
43
44
  sparse_features=sparse_features,
44
45
  sequence_features=sequence_features,
45
46
  target=target,
46
- task=self.task_type,
47
+ task=task or self.default_task,
47
48
  device=device,
48
49
  embedding_l1_reg=embedding_l1_reg,
49
50
  dense_l1_reg=dense_l1_reg,
50
51
  embedding_l2_reg=embedding_l2_reg,
51
52
  dense_l2_reg=dense_l2_reg,
52
- early_stop_patience=20,
53
53
  **kwargs
54
54
  )
55
55
 
@@ -66,7 +66,7 @@ class FM(BaseModel):
66
66
  fm_input_dim = sum([f.embedding_dim for f in self.fm_features])
67
67
  self.linear = LR(fm_input_dim)
68
68
  self.fm = FMInteraction(reduce_sum=True)
69
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
69
+ self.prediction_layer = PredictionLayer(task_type=self.task)
70
70
 
71
71
  # Register regularization weights
72
72
  self.register_regularization_weights(
@@ -143,8 +143,7 @@ class MaskNet(BaseModel):
143
143
  return "MaskNet"
144
144
 
145
145
  @property
146
- def task_type(self):
147
- # Align with PredictionLayer supported task types
146
+ def default_task(self):
148
147
  return "binary"
149
148
 
150
149
  def __init__(
@@ -159,6 +158,7 @@ class MaskNet(BaseModel):
159
158
  block_dropout: float = 0.0,
160
159
  mlp_params: dict | None = None,
161
160
  target: list[str] | None = None,
161
+ task: str | list[str] | None = None,
162
162
  optimizer: str = "adam",
163
163
  optimizer_params: dict | None = None,
164
164
  loss: str | nn.Module | None = "bce",
@@ -182,13 +182,12 @@ class MaskNet(BaseModel):
182
182
  sparse_features=sparse_features,
183
183
  sequence_features=sequence_features,
184
184
  target=target,
185
- task=self.task_type,
185
+ task=task or self.default_task,
186
186
  device=device,
187
187
  embedding_l1_reg=embedding_l1_reg,
188
188
  dense_l1_reg=dense_l1_reg,
189
189
  embedding_l2_reg=embedding_l2_reg,
190
190
  dense_l2_reg=dense_l2_reg,
191
- early_stop_patience=20,
192
191
  **kwargs,
193
192
  )
194
193
 
@@ -231,7 +230,7 @@ class MaskNet(BaseModel):
231
230
  self.mask_blocks = nn.ModuleList([MaskBlockOnEmbedding(num_fields=self.num_fields, embedding_dim=self.embedding_dim, mask_hidden_dim=mask_hidden_dim, hidden_dim=block_hidden_dim) for _ in range(self.num_blocks)])
232
231
  self.final_mlp = MLP(input_dim=self.num_blocks * block_hidden_dim, **mlp_params)
233
232
  self.output_layer = None
234
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
233
+ self.prediction_layer = PredictionLayer(task_type=self.task)
235
234
 
236
235
  if self.model_type == "serial":
237
236
  self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "output_layer"],)
@@ -20,7 +20,7 @@ class PNN(BaseModel):
20
20
  return "PNN"
21
21
 
22
22
  @property
23
- def task_type(self):
23
+ def default_task(self):
24
24
  return "binary"
25
25
 
26
26
  def __init__(self,
@@ -31,6 +31,7 @@ class PNN(BaseModel):
31
31
  product_type: str = "inner",
32
32
  outer_product_dim: int | None = None,
33
33
  target: list[str] | list = [],
34
+ task: str | list[str] | None = None,
34
35
  optimizer: str = "adam",
35
36
  optimizer_params: dict = {},
36
37
  loss: str | nn.Module | None = "bce",
@@ -47,13 +48,12 @@ class PNN(BaseModel):
47
48
  sparse_features=sparse_features,
48
49
  sequence_features=sequence_features,
49
50
  target=target,
50
- task=self.task_type,
51
+ task=task or self.default_task,
51
52
  device=device,
52
53
  embedding_l1_reg=embedding_l1_reg,
53
54
  dense_l1_reg=dense_l1_reg,
54
55
  embedding_l2_reg=embedding_l2_reg,
55
56
  dense_l2_reg=dense_l2_reg,
56
- early_stop_patience=20,
57
57
  **kwargs
58
58
  )
59
59
 
@@ -86,7 +86,7 @@ class PNN(BaseModel):
86
86
 
87
87
  linear_dim = self.num_fields * self.embedding_dim
88
88
  self.mlp = MLP(input_dim=linear_dim + product_dim, **mlp_params)
89
- self.prediction_layer = PredictionLayer(task_type=self.task_type)
89
+ self.prediction_layer = PredictionLayer(task_type=self.task)
90
90
 
91
91
  modules = ['mlp']
92
92
  if self.product_type == "outer":