nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +220 -106
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1082 -400
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +272 -95
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +53 -37
  29. nextrec/models/multi_task/mmoe.py +64 -45
  30. nextrec/models/multi_task/ple.py +101 -48
  31. nextrec/models/multi_task/poso.py +113 -36
  32. nextrec/models/multi_task/share_bottom.py +48 -35
  33. nextrec/models/ranking/afm.py +72 -37
  34. nextrec/models/ranking/autoint.py +72 -55
  35. nextrec/models/ranking/dcn.py +55 -35
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +32 -22
  38. nextrec/models/ranking/dien.py +155 -99
  39. nextrec/models/ranking/din.py +85 -57
  40. nextrec/models/ranking/fibinet.py +52 -32
  41. nextrec/models/ranking/fm.py +29 -23
  42. nextrec/models/ranking/masknet.py +91 -29
  43. nextrec/models/ranking/pnn.py +31 -28
  44. nextrec/models/ranking/widedeep.py +34 -26
  45. nextrec/models/ranking/xdeepfm.py +60 -38
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +283 -165
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.4.1.dist-info/RECORD +0 -66
  61. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,12 @@ import torch
12
12
  import torch.nn as nn
13
13
 
14
14
  from nextrec.basic.model import BaseModel
15
- from nextrec.basic.layers import EmbeddingLayer, MLP, AttentionPoolingLayer, PredictionLayer
15
+ from nextrec.basic.layers import (
16
+ EmbeddingLayer,
17
+ MLP,
18
+ AttentionPoolingLayer,
19
+ PredictionLayer,
20
+ )
16
21
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
22
 
18
23
 
@@ -24,28 +29,30 @@ class DIN(BaseModel):
24
29
  @property
25
30
  def default_task(self):
26
31
  return "binary"
27
-
28
- def __init__(self,
29
- dense_features: list[DenseFeature],
30
- sparse_features: list[SparseFeature],
31
- sequence_features: list[SequenceFeature],
32
- mlp_params: dict,
33
- attention_hidden_units: list[int] = [80, 40],
34
- attention_activation: str = 'sigmoid',
35
- attention_use_softmax: bool = True,
36
- target: list[str] = [],
37
- task: str | list[str] | None = None,
38
- optimizer: str = "adam",
39
- optimizer_params: dict = {},
40
- loss: str | nn.Module | None = "bce",
41
- loss_params: dict | list[dict] | None = None,
42
- device: str = 'cpu',
43
- embedding_l1_reg=1e-6,
44
- dense_l1_reg=1e-5,
45
- embedding_l2_reg=1e-5,
46
- dense_l2_reg=1e-4,
47
- **kwargs):
48
-
32
+
33
+ def __init__(
34
+ self,
35
+ dense_features: list[DenseFeature],
36
+ sparse_features: list[SparseFeature],
37
+ sequence_features: list[SequenceFeature],
38
+ mlp_params: dict,
39
+ attention_hidden_units: list[int] = [80, 40],
40
+ attention_activation: str = "sigmoid",
41
+ attention_use_softmax: bool = True,
42
+ target: list[str] = [],
43
+ task: str | list[str] | None = None,
44
+ optimizer: str = "adam",
45
+ optimizer_params: dict = {},
46
+ loss: str | nn.Module | None = "bce",
47
+ loss_params: dict | list[dict] | None = None,
48
+ device: str = "cpu",
49
+ embedding_l1_reg=1e-6,
50
+ dense_l1_reg=1e-5,
51
+ embedding_l2_reg=1e-5,
52
+ dense_l2_reg=1e-4,
53
+ **kwargs,
54
+ ):
55
+
49
56
  super(DIN, self).__init__(
50
57
  dense_features=dense_features,
51
58
  sparse_features=sparse_features,
@@ -57,43 +64,54 @@ class DIN(BaseModel):
57
64
  dense_l1_reg=dense_l1_reg,
58
65
  embedding_l2_reg=embedding_l2_reg,
59
66
  dense_l2_reg=dense_l2_reg,
60
- **kwargs
67
+ **kwargs,
61
68
  )
62
69
 
63
70
  self.loss = loss
64
71
  if self.loss is None:
65
72
  self.loss = "bce"
66
-
73
+
67
74
  # Features classification
68
75
  # DIN requires: candidate item + user behavior sequence + other features
69
76
  if len(sequence_features) == 0:
70
- raise ValueError("DIN requires at least one sequence feature for user behavior history")
71
-
77
+ raise ValueError(
78
+ "DIN requires at least one sequence feature for user behavior history"
79
+ )
80
+
72
81
  self.behavior_feature = sequence_features[0] # User behavior sequence
73
- self.candidate_feature = sparse_features[-1] if sparse_features else None # Candidate item
74
-
82
+ self.candidate_feature = (
83
+ sparse_features[-1] if sparse_features else None
84
+ ) # Candidate item
85
+
75
86
  # Other features (excluding behavior sequence in final concatenation)
76
- self.other_sparse_features = sparse_features[:-1] if self.candidate_feature else sparse_features
87
+ self.other_sparse_features = (
88
+ sparse_features[:-1] if self.candidate_feature else sparse_features
89
+ )
77
90
  self.dense_features_list = dense_features
78
-
91
+
79
92
  # All features for embedding
80
93
  self.all_features = dense_features + sparse_features + sequence_features
81
94
 
82
95
  # Embedding layer
83
96
  self.embedding = EmbeddingLayer(features=self.all_features)
84
-
97
+
85
98
  # Attention layer for behavior sequence
86
99
  behavior_emb_dim = self.behavior_feature.embedding_dim
87
100
  self.candidate_attention_proj = None
88
- if self.candidate_feature is not None and self.candidate_feature.embedding_dim != behavior_emb_dim:
89
- self.candidate_attention_proj = nn.Linear(self.candidate_feature.embedding_dim, behavior_emb_dim)
101
+ if (
102
+ self.candidate_feature is not None
103
+ and self.candidate_feature.embedding_dim != behavior_emb_dim
104
+ ):
105
+ self.candidate_attention_proj = nn.Linear(
106
+ self.candidate_feature.embedding_dim, behavior_emb_dim
107
+ )
90
108
  self.attention = AttentionPoolingLayer(
91
109
  embedding_dim=behavior_emb_dim,
92
110
  hidden_units=attention_hidden_units,
93
111
  activation=attention_activation,
94
- use_softmax=attention_use_softmax
112
+ use_softmax=attention_use_softmax,
95
113
  )
96
-
114
+
97
115
  # Calculate MLP input dimension
98
116
  # candidate + attention_pooled_behavior + other_sparse + dense
99
117
  mlp_input_dim = 0
@@ -101,16 +119,18 @@ class DIN(BaseModel):
101
119
  mlp_input_dim += self.candidate_feature.embedding_dim
102
120
  mlp_input_dim += behavior_emb_dim # attention pooled
103
121
  mlp_input_dim += sum([f.embedding_dim for f in self.other_sparse_features])
104
- mlp_input_dim += sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
105
-
122
+ mlp_input_dim += sum(
123
+ [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
124
+ )
125
+
106
126
  # MLP for final prediction
107
127
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
108
128
  self.prediction_layer = PredictionLayer(task_type=self.task)
109
129
 
110
130
  # Register regularization weights
111
131
  self.register_regularization_weights(
112
- embedding_attr='embedding',
113
- include_modules=['attention', 'mlp', 'candidate_attention_proj']
132
+ embedding_attr="embedding",
133
+ include_modules=["attention", "mlp", "candidate_attention_proj"],
114
134
  )
115
135
 
116
136
  self.compile(
@@ -123,61 +143,69 @@ class DIN(BaseModel):
123
143
  def forward(self, x):
124
144
  # Get candidate item embedding
125
145
  if self.candidate_feature:
126
- candidate_emb = self.embedding.embed_dict[self.candidate_feature.embedding_name](
146
+ candidate_emb = self.embedding.embed_dict[
147
+ self.candidate_feature.embedding_name
148
+ ](
127
149
  x[self.candidate_feature.name].long()
128
150
  ) # [B, emb_dim]
129
151
  else:
130
152
  candidate_emb = None
131
-
153
+
132
154
  # Get behavior sequence embedding
133
155
  behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
134
156
  behavior_emb = self.embedding.embed_dict[self.behavior_feature.embedding_name](
135
157
  behavior_seq
136
158
  ) # [B, seq_len, emb_dim]
137
-
159
+
138
160
  # Create mask for padding
139
161
  if self.behavior_feature.padding_idx is not None:
140
- mask = (behavior_seq != self.behavior_feature.padding_idx).unsqueeze(-1).float()
162
+ mask = (
163
+ (behavior_seq != self.behavior_feature.padding_idx)
164
+ .unsqueeze(-1)
165
+ .float()
166
+ )
141
167
  else:
142
168
  mask = (behavior_seq != 0).unsqueeze(-1).float()
143
-
169
+
144
170
  # Apply attention pooling
145
171
  if candidate_emb is not None:
146
172
  candidate_query = candidate_emb
147
173
  if self.candidate_attention_proj is not None:
148
174
  candidate_query = self.candidate_attention_proj(candidate_query)
149
175
  pooled_behavior = self.attention(
150
- query=candidate_query,
151
- keys=behavior_emb,
152
- mask=mask
176
+ query=candidate_query, keys=behavior_emb, mask=mask
153
177
  ) # [B, emb_dim]
154
178
  else:
155
179
  # If no candidate, use mean pooling
156
- pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (mask.sum(dim=1) + 1e-9)
157
-
180
+ pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (
181
+ mask.sum(dim=1) + 1e-9
182
+ )
183
+
158
184
  # Get other features
159
185
  other_embeddings = []
160
-
186
+
161
187
  if candidate_emb is not None:
162
188
  other_embeddings.append(candidate_emb)
163
-
189
+
164
190
  other_embeddings.append(pooled_behavior)
165
-
191
+
166
192
  # Other sparse features
167
193
  for feat in self.other_sparse_features:
168
- feat_emb = self.embedding.embed_dict[feat.embedding_name](x[feat.name].long())
194
+ feat_emb = self.embedding.embed_dict[feat.embedding_name](
195
+ x[feat.name].long()
196
+ )
169
197
  other_embeddings.append(feat_emb)
170
-
198
+
171
199
  # Dense features
172
200
  for feat in self.dense_features_list:
173
201
  val = x[feat.name].float()
174
202
  if val.dim() == 1:
175
203
  val = val.unsqueeze(1)
176
204
  other_embeddings.append(val)
177
-
205
+
178
206
  # Concatenate all features
179
207
  concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
180
-
208
+
181
209
  # MLP prediction
182
210
  y = self.mlp(concat_input) # [B, 1]
183
211
  return self.prediction_layer(y)
@@ -30,27 +30,29 @@ class FiBiNET(BaseModel):
30
30
  @property
31
31
  def default_task(self):
32
32
  return "binary"
33
-
34
- def __init__(self,
35
- dense_features: list[DenseFeature] | list = [],
36
- sparse_features: list[SparseFeature] | list = [],
37
- sequence_features: list[SequenceFeature] | list = [],
38
- mlp_params: dict = {},
39
- bilinear_type: str = "field_interaction",
40
- senet_reduction: int = 3,
41
- target: list[str] | list = [],
42
- task: str | list[str] | None = None,
43
- optimizer: str = "adam",
44
- optimizer_params: dict = {},
45
- loss: str | nn.Module | None = "bce",
46
- loss_params: dict | list[dict] | None = None,
47
- device: str = 'cpu',
48
- embedding_l1_reg=1e-6,
49
- dense_l1_reg=1e-5,
50
- embedding_l2_reg=1e-5,
51
- dense_l2_reg=1e-4,
52
- **kwargs):
53
-
33
+
34
+ def __init__(
35
+ self,
36
+ dense_features: list[DenseFeature] | list = [],
37
+ sparse_features: list[SparseFeature] | list = [],
38
+ sequence_features: list[SequenceFeature] | list = [],
39
+ mlp_params: dict = {},
40
+ bilinear_type: str = "field_interaction",
41
+ senet_reduction: int = 3,
42
+ target: list[str] | list = [],
43
+ task: str | list[str] | None = None,
44
+ optimizer: str = "adam",
45
+ optimizer_params: dict = {},
46
+ loss: str | nn.Module | None = "bce",
47
+ loss_params: dict | list[dict] | None = None,
48
+ device: str = "cpu",
49
+ embedding_l1_reg=1e-6,
50
+ dense_l1_reg=1e-5,
51
+ embedding_l2_reg=1e-5,
52
+ dense_l2_reg=1e-4,
53
+ **kwargs,
54
+ ):
55
+
54
56
  super(FiBiNET, self).__init__(
55
57
  dense_features=dense_features,
56
58
  sparse_features=sparse_features,
@@ -62,28 +64,36 @@ class FiBiNET(BaseModel):
62
64
  dense_l1_reg=dense_l1_reg,
63
65
  embedding_l2_reg=embedding_l2_reg,
64
66
  dense_l2_reg=dense_l2_reg,
65
- **kwargs
67
+ **kwargs,
66
68
  )
67
69
 
68
70
  self.loss = loss
69
71
  if self.loss is None:
70
72
  self.loss = "bce"
71
-
73
+
72
74
  self.linear_features = sparse_features + sequence_features
73
75
  self.deep_features = dense_features + sparse_features + sequence_features
74
76
  self.interaction_features = sparse_features + sequence_features
75
77
 
76
78
  if len(self.interaction_features) < 2:
77
- raise ValueError("FiBiNET requires at least two sparse/sequence features for interactions.")
79
+ raise ValueError(
80
+ "FiBiNET requires at least two sparse/sequence features for interactions."
81
+ )
78
82
 
79
83
  self.embedding = EmbeddingLayer(features=self.deep_features)
80
84
 
81
85
  self.num_fields = len(self.interaction_features)
82
86
  self.embedding_dim = self.interaction_features[0].embedding_dim
83
- if any(f.embedding_dim != self.embedding_dim for f in self.interaction_features):
84
- raise ValueError("All interaction features must share the same embedding_dim in FiBiNET.")
85
-
86
- self.senet = SENETLayer(num_fields=self.num_fields, reduction_ratio=senet_reduction)
87
+ if any(
88
+ f.embedding_dim != self.embedding_dim for f in self.interaction_features
89
+ ):
90
+ raise ValueError(
91
+ "All interaction features must share the same embedding_dim in FiBiNET."
92
+ )
93
+
94
+ self.senet = SENETLayer(
95
+ num_fields=self.num_fields, reduction_ratio=senet_reduction
96
+ )
87
97
  self.bilinear_standard = BiLinearInteractionLayer(
88
98
  input_dim=self.embedding_dim,
89
99
  num_fields=self.num_fields,
@@ -105,8 +115,14 @@ class FiBiNET(BaseModel):
105
115
 
106
116
  # Register regularization weights
107
117
  self.register_regularization_weights(
108
- embedding_attr='embedding',
109
- include_modules=['linear', 'senet', 'bilinear_standard', 'bilinear_senet', 'mlp']
118
+ embedding_attr="embedding",
119
+ include_modules=[
120
+ "linear",
121
+ "senet",
122
+ "bilinear_standard",
123
+ "bilinear_senet",
124
+ "mlp",
125
+ ],
110
126
  )
111
127
 
112
128
  self.compile(
@@ -117,10 +133,14 @@ class FiBiNET(BaseModel):
117
133
  )
118
134
 
119
135
  def forward(self, x):
120
- input_linear = self.embedding(x=x, features=self.linear_features, squeeze_dim=True)
136
+ input_linear = self.embedding(
137
+ x=x, features=self.linear_features, squeeze_dim=True
138
+ )
121
139
  y_linear = self.linear(input_linear)
122
140
 
123
- field_emb = self.embedding(x=x, features=self.interaction_features, squeeze_dim=False)
141
+ field_emb = self.embedding(
142
+ x=x, features=self.interaction_features, squeeze_dim=False
143
+ )
124
144
  senet_emb = self.senet(field_emb)
125
145
 
126
146
  bilinear_standard = self.bilinear_standard(field_emb).flatten(start_dim=1)
@@ -9,7 +9,12 @@ Reference:
9
9
  import torch.nn as nn
10
10
 
11
11
  from nextrec.basic.model import BaseModel
12
- from nextrec.basic.layers import EmbeddingLayer, FM as FMInteraction, LR, PredictionLayer
12
+ from nextrec.basic.layers import (
13
+ EmbeddingLayer,
14
+ FM as FMInteraction,
15
+ LR,
16
+ PredictionLayer,
17
+ )
13
18
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
14
19
 
15
20
 
@@ -21,24 +26,26 @@ class FM(BaseModel):
21
26
  @property
22
27
  def default_task(self):
23
28
  return "binary"
24
-
25
- def __init__(self,
26
- dense_features: list[DenseFeature] | list = [],
27
- sparse_features: list[SparseFeature] | list = [],
28
- sequence_features: list[SequenceFeature] | list = [],
29
- target: list[str] | list = [],
30
- task: str | list[str] | None = None,
31
- optimizer: str = "adam",
32
- optimizer_params: dict = {},
33
- loss: str | nn.Module | None = "bce",
34
- loss_params: dict | list[dict] | None = None,
35
- device: str = 'cpu',
36
- embedding_l1_reg=1e-6,
37
- dense_l1_reg=1e-5,
38
- embedding_l2_reg=1e-5,
39
- dense_l2_reg=1e-4,
40
- **kwargs):
41
-
29
+
30
+ def __init__(
31
+ self,
32
+ dense_features: list[DenseFeature] | list = [],
33
+ sparse_features: list[SparseFeature] | list = [],
34
+ sequence_features: list[SequenceFeature] | list = [],
35
+ target: list[str] | list = [],
36
+ task: str | list[str] | None = None,
37
+ optimizer: str = "adam",
38
+ optimizer_params: dict = {},
39
+ loss: str | nn.Module | None = "bce",
40
+ loss_params: dict | list[dict] | None = None,
41
+ device: str = "cpu",
42
+ embedding_l1_reg=1e-6,
43
+ dense_l1_reg=1e-5,
44
+ embedding_l2_reg=1e-5,
45
+ dense_l2_reg=1e-4,
46
+ **kwargs,
47
+ ):
48
+
42
49
  super(FM, self).__init__(
43
50
  dense_features=dense_features,
44
51
  sparse_features=sparse_features,
@@ -50,13 +57,13 @@ class FM(BaseModel):
50
57
  dense_l1_reg=dense_l1_reg,
51
58
  embedding_l2_reg=embedding_l2_reg,
52
59
  dense_l2_reg=dense_l2_reg,
53
- **kwargs
60
+ **kwargs,
54
61
  )
55
62
 
56
63
  self.loss = loss
57
64
  if self.loss is None:
58
65
  self.loss = "bce"
59
-
66
+
60
67
  self.fm_features = sparse_features + sequence_features
61
68
  if len(self.fm_features) == 0:
62
69
  raise ValueError("FM requires at least one sparse or sequence feature.")
@@ -70,8 +77,7 @@ class FM(BaseModel):
70
77
 
71
78
  # Register regularization weights
72
79
  self.register_regularization_weights(
73
- embedding_attr='embedding',
74
- include_modules=['linear']
80
+ embedding_attr="embedding", include_modules=["linear"]
75
81
  )
76
82
 
77
83
  self.compile(
@@ -69,12 +69,13 @@ class InstanceGuidedMask(nn.Module):
69
69
  self.fc2 = nn.Linear(hidden_dim, output_dim)
70
70
 
71
71
  def forward(self, v_emb_flat: torch.Tensor) -> torch.Tensor:
72
- # v_emb_flat: [batch, features count * embedding_dim]
72
+ # v_emb_flat: [batch, features count * embedding_dim]
73
73
  x = self.fc1(v_emb_flat)
74
74
  x = F.relu(x)
75
75
  v_mask = self.fc2(x)
76
76
  return v_mask
77
77
 
78
+
78
79
  class MaskBlockOnEmbedding(nn.Module):
79
80
  def __init__(
80
81
  self,
@@ -86,20 +87,28 @@ class MaskBlockOnEmbedding(nn.Module):
86
87
  super().__init__()
87
88
  self.num_fields = num_fields
88
89
  self.embedding_dim = embedding_dim
89
- self.input_dim = num_fields * embedding_dim # input_dim = features count * embedding_dim
90
+ self.input_dim = (
91
+ num_fields * embedding_dim
92
+ ) # input_dim = features count * embedding_dim
90
93
  self.ln_emb = nn.LayerNorm(embedding_dim)
91
- self.mask_gen = InstanceGuidedMask(input_dim=self.input_dim, hidden_dim=mask_hidden_dim, output_dim=self.input_dim,)
94
+ self.mask_gen = InstanceGuidedMask(
95
+ input_dim=self.input_dim,
96
+ hidden_dim=mask_hidden_dim,
97
+ output_dim=self.input_dim,
98
+ )
92
99
  self.ffn = nn.Linear(self.input_dim, hidden_dim)
93
100
  self.ln_hid = nn.LayerNorm(hidden_dim)
94
101
 
95
102
  # different from MaskBlockOnHidden: input is field embeddings
96
- def forward(self, field_emb: torch.Tensor, v_emb_flat: torch.Tensor) -> torch.Tensor:
103
+ def forward(
104
+ self, field_emb: torch.Tensor, v_emb_flat: torch.Tensor
105
+ ) -> torch.Tensor:
97
106
  B = field_emb.size(0)
98
- norm_emb = self.ln_emb(field_emb) # [B, features count, embedding_dim]
99
- norm_emb_flat = norm_emb.view(B, -1) # [B, features count * embedding_dim]
100
- v_mask = self.mask_gen(v_emb_flat) # [B, features count * embedding_dim]
101
- v_masked_emb = v_mask * norm_emb_flat # [B, features count * embedding_dim]
102
- hidden = self.ffn(v_masked_emb) # [B, hidden_dim]
107
+ norm_emb = self.ln_emb(field_emb) # [B, features count, embedding_dim]
108
+ norm_emb_flat = norm_emb.view(B, -1) # [B, features count * embedding_dim]
109
+ v_mask = self.mask_gen(v_emb_flat) # [B, features count * embedding_dim]
110
+ v_masked_emb = v_mask * norm_emb_flat # [B, features count * embedding_dim]
111
+ hidden = self.ffn(v_masked_emb) # [B, hidden_dim]
103
112
  hidden = self.ln_hid(hidden)
104
113
  hidden = F.relu(hidden)
105
114
 
@@ -123,15 +132,21 @@ class MaskBlockOnHidden(nn.Module):
123
132
  self.ln_input = nn.LayerNorm(hidden_dim)
124
133
  self.ln_output = nn.LayerNorm(hidden_dim)
125
134
 
126
- self.mask_gen = InstanceGuidedMask(input_dim=self.v_emb_dim, hidden_dim=mask_hidden_dim, output_dim=hidden_dim,)
135
+ self.mask_gen = InstanceGuidedMask(
136
+ input_dim=self.v_emb_dim,
137
+ hidden_dim=mask_hidden_dim,
138
+ output_dim=hidden_dim,
139
+ )
127
140
  self.ffn = nn.Linear(hidden_dim, hidden_dim)
128
141
 
129
142
  # different from MaskBlockOnEmbedding: input is hidden representation
130
- def forward(self, hidden_in: torch.Tensor, v_emb_flat: torch.Tensor) -> torch.Tensor:
131
- norm_hidden = self.ln_input(hidden_in)
143
+ def forward(
144
+ self, hidden_in: torch.Tensor, v_emb_flat: torch.Tensor
145
+ ) -> torch.Tensor:
146
+ norm_hidden = self.ln_input(hidden_in)
132
147
  v_mask = self.mask_gen(v_emb_flat)
133
- v_masked_hid = v_mask * norm_hidden
134
- out = self.ffn(v_masked_hid)
148
+ v_masked_hid = v_mask * norm_hidden
149
+ out = self.ffn(v_masked_hid)
135
150
  out = self.ln_output(out)
136
151
  out = F.relu(out)
137
152
  return out
@@ -151,7 +166,7 @@ class MaskNet(BaseModel):
151
166
  dense_features: list[DenseFeature] | None = None,
152
167
  sparse_features: list[SparseFeature] | None = None,
153
168
  sequence_features: list[SequenceFeature] | None = None,
154
- model_type: str = "parallel", # "serial" or "parallel"
169
+ model_type: str = "parallel", # "serial" or "parallel"
155
170
  num_blocks: int = 3,
156
171
  mask_hidden_dim: int = 64,
157
172
  block_hidden_dim: int = 256,
@@ -199,50 +214,97 @@ class MaskNet(BaseModel):
199
214
  self.sparse_features = sparse_features
200
215
  self.sequence_features = sequence_features
201
216
  self.mask_features = self.all_features # use all features for masking
202
- assert len(self.mask_features) > 0, "MaskNet requires at least one feature for masking."
217
+ assert (
218
+ len(self.mask_features) > 0
219
+ ), "MaskNet requires at least one feature for masking."
203
220
  self.embedding = EmbeddingLayer(features=self.mask_features)
204
221
  self.num_fields = len(self.mask_features)
205
222
  self.embedding_dim = getattr(self.mask_features[0], "embedding_dim", None)
206
- assert self.embedding_dim is not None, "MaskNet requires mask_features to have 'embedding_dim' defined."
223
+ assert (
224
+ self.embedding_dim is not None
225
+ ), "MaskNet requires mask_features to have 'embedding_dim' defined."
207
226
 
208
227
  for f in self.mask_features:
209
228
  edim = getattr(f, "embedding_dim", None)
210
229
  if edim is None or edim != self.embedding_dim:
211
- raise ValueError(f"MaskNet expects identical embedding_dim across all mask_features, but got {edim} for feature {getattr(f, 'name', type(f))}.")
230
+ raise ValueError(
231
+ f"MaskNet expects identical embedding_dim across all mask_features, but got {edim} for feature {getattr(f, 'name', type(f))}."
232
+ )
212
233
 
213
234
  self.v_emb_dim = self.num_fields * self.embedding_dim
214
235
  self.model_type = model_type.lower()
215
- assert self.model_type in ("serial", "parallel"), "model_type must be either 'serial' or 'parallel'."
236
+ assert self.model_type in (
237
+ "serial",
238
+ "parallel",
239
+ ), "model_type must be either 'serial' or 'parallel'."
216
240
 
217
241
  self.num_blocks = max(1, num_blocks)
218
242
  self.block_hidden_dim = block_hidden_dim
219
- self.block_dropout = nn.Dropout(block_dropout) if block_dropout > 0 else nn.Identity()
243
+ self.block_dropout = (
244
+ nn.Dropout(block_dropout) if block_dropout > 0 else nn.Identity()
245
+ )
220
246
 
221
247
  if self.model_type == "serial":
222
- self.first_block = MaskBlockOnEmbedding(num_fields=self.num_fields, embedding_dim=self.embedding_dim, mask_hidden_dim=mask_hidden_dim, hidden_dim=block_hidden_dim,)
248
+ self.first_block = MaskBlockOnEmbedding(
249
+ num_fields=self.num_fields,
250
+ embedding_dim=self.embedding_dim,
251
+ mask_hidden_dim=mask_hidden_dim,
252
+ hidden_dim=block_hidden_dim,
253
+ )
223
254
  self.hidden_blocks = nn.ModuleList(
224
- [MaskBlockOnHidden(num_fields=self.num_fields, embedding_dim=self.embedding_dim, mask_hidden_dim=mask_hidden_dim, hidden_dim=block_hidden_dim) for _ in range(self.num_blocks - 1)])
255
+ [
256
+ MaskBlockOnHidden(
257
+ num_fields=self.num_fields,
258
+ embedding_dim=self.embedding_dim,
259
+ mask_hidden_dim=mask_hidden_dim,
260
+ hidden_dim=block_hidden_dim,
261
+ )
262
+ for _ in range(self.num_blocks - 1)
263
+ ]
264
+ )
225
265
  self.mask_blocks = nn.ModuleList([self.first_block, *self.hidden_blocks])
226
266
  self.output_layer = nn.Linear(block_hidden_dim, 1)
227
267
  self.final_mlp = None
228
268
 
229
269
  else: # parallel
230
- self.mask_blocks = nn.ModuleList([MaskBlockOnEmbedding(num_fields=self.num_fields, embedding_dim=self.embedding_dim, mask_hidden_dim=mask_hidden_dim, hidden_dim=block_hidden_dim) for _ in range(self.num_blocks)])
231
- self.final_mlp = MLP(input_dim=self.num_blocks * block_hidden_dim, **mlp_params)
270
+ self.mask_blocks = nn.ModuleList(
271
+ [
272
+ MaskBlockOnEmbedding(
273
+ num_fields=self.num_fields,
274
+ embedding_dim=self.embedding_dim,
275
+ mask_hidden_dim=mask_hidden_dim,
276
+ hidden_dim=block_hidden_dim,
277
+ )
278
+ for _ in range(self.num_blocks)
279
+ ]
280
+ )
281
+ self.final_mlp = MLP(
282
+ input_dim=self.num_blocks * block_hidden_dim, **mlp_params
283
+ )
232
284
  self.output_layer = None
233
285
  self.prediction_layer = PredictionLayer(task_type=self.task)
234
286
 
235
287
  if self.model_type == "serial":
236
- self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "output_layer"],)
288
+ self.register_regularization_weights(
289
+ embedding_attr="embedding",
290
+ include_modules=["mask_blocks", "output_layer"],
291
+ )
237
292
  # serial
238
293
  else:
239
- self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"])
240
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
294
+ self.register_regularization_weights(
295
+ embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"]
296
+ )
297
+ self.compile(
298
+ optimizer=optimizer,
299
+ optimizer_params=optimizer_params,
300
+ loss=loss,
301
+ loss_params=loss_params,
302
+ )
241
303
 
242
304
  def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
243
305
  field_emb = self.embedding(x=x, features=self.mask_features, squeeze_dim=False)
244
306
  B = field_emb.size(0)
245
- v_emb_flat = field_emb.view(B, -1) # flattened embeddings
307
+ v_emb_flat = field_emb.view(B, -1) # flattened embeddings
246
308
 
247
309
  if self.model_type == "parallel":
248
310
  block_outputs = []
@@ -253,7 +315,7 @@ class MaskNet(BaseModel):
253
315
  concat_hidden = torch.cat(block_outputs, dim=-1)
254
316
  logit = self.final_mlp(concat_hidden) # [B, 1]
255
317
  # serial
256
- else:
318
+ else:
257
319
  hidden = self.first_block(field_emb, v_emb_flat)
258
320
  hidden = self.block_dropout(hidden)
259
321
  for block in self.hidden_blocks: