nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +250 -112
  7. nextrec/basic/loggers.py +63 -44
  8. nextrec/basic/metrics.py +270 -120
  9. nextrec/basic/model.py +1084 -402
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +492 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +273 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +69 -46
  29. nextrec/models/multi_task/mmoe.py +91 -53
  30. nextrec/models/multi_task/ple.py +117 -58
  31. nextrec/models/multi_task/poso.py +163 -55
  32. nextrec/models/multi_task/share_bottom.py +63 -36
  33. nextrec/models/ranking/afm.py +80 -45
  34. nextrec/models/ranking/autoint.py +74 -57
  35. nextrec/models/ranking/dcn.py +110 -48
  36. nextrec/models/ranking/dcn_v2.py +265 -45
  37. nextrec/models/ranking/deepfm.py +39 -24
  38. nextrec/models/ranking/dien.py +335 -146
  39. nextrec/models/ranking/din.py +158 -92
  40. nextrec/models/ranking/fibinet.py +134 -52
  41. nextrec/models/ranking/fm.py +68 -26
  42. nextrec/models/ranking/masknet.py +95 -33
  43. nextrec/models/ranking/pnn.py +128 -58
  44. nextrec/models/ranking/widedeep.py +40 -28
  45. nextrec/models/ranking/xdeepfm.py +67 -40
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +496 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +33 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/model.py +22 -0
  55. nextrec/utils/optimizer.py +25 -9
  56. nextrec/utils/synthetic_data.py +283 -165
  57. nextrec/utils/tensor.py +24 -13
  58. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
  59. nextrec-0.4.3.dist-info/RECORD +69 -0
  60. nextrec-0.4.3.dist-info/entry_points.txt +2 -0
  61. nextrec-0.4.1.dist-info/RECORD +0 -66
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
  63. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,18 +1,62 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 09/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
- [1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]
7
- //Proceedings of the 24th ACM SIGKDD international conference on knowledge discovery & data mining. 2018: 1059-1068.
8
- (https://arxiv.org/abs/1706.06978)
6
+ [1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate
7
+ prediction[C] //Proceedings of the 24th ACM SIGKDD international conference on
8
+ knowledge discovery & data mining. 2018: 1059-1068.
9
+ (https://arxiv.org/abs/1706.06978)
10
+
11
+ Deep Interest Network (DIN) is a CTR model that builds a target-aware user
12
+ representation by attending over the historical behavior sequence. Instead of
13
+ compressing all behaviors into one static vector, DIN highlights the behaviors
14
+ most relevant to the current candidate item, enabling adaptive interest
15
+ modeling for each request.
16
+
17
+ Pipeline:
18
+ (1) Embed candidate item, user behavior sequence, and other sparse/dense fields
19
+ (2) Use a small attention MLP to score each historical behavior against the
20
+ candidate embedding
21
+ (3) Apply masked weighted pooling to obtain a target-specific interest vector
22
+ (4) Concatenate candidate, interest vector, other sparse embeddings, and dense
23
+ features
24
+ (5) Feed the combined representation into an MLP for final prediction
25
+
26
+ Key Advantages:
27
+ - Target-aware attention captures fine-grained interests per candidate item
28
+ - Adaptive pooling handles diverse behavior patterns without heavy feature crafting
29
+ - Masked weighting reduces noise from padded sequence positions
30
+ - Easily incorporates additional sparse/dense context features alongside behavior
31
+
32
+ DIN 是一个 CTR 预估模型,通过对用户历史行为序列进行目标感知的注意力加权,
33
+ 构建针对当前候选物品的兴趣表示。它不是将全部行为压缩为固定向量,而是突出
34
+ 与候选物品最相关的行为,实现请求级的自适应兴趣建模。
35
+
36
+ 处理流程:
37
+ (1) 对候选物品、用户行为序列及其他稀疏/稠密特征做 embedding
38
+ (2) 使用小型注意力 MLP 计算每个历史行为与候选 embedding 的相关性
39
+ (3) 通过掩码加权池化得到目标特定的兴趣向量
40
+ (4) 拼接候选、兴趣向量、其他稀疏 embedding 与稠密特征
41
+ (5) 输入 MLP 完成最终点击率预测
42
+
43
+ 主要优点:
44
+ - 目标感知注意力捕捉候选级的细粒度兴趣
45
+ - 自适应池化应对多样化行为模式,减少手工特征工程
46
+ - 掩码加权降低序列填充位置的噪声
47
+ - 便捷融合行为与额外稀疏/稠密上下文信息
9
48
  """
10
49
 
11
50
  import torch
12
51
  import torch.nn as nn
13
52
 
14
53
  from nextrec.basic.model import BaseModel
15
- from nextrec.basic.layers import EmbeddingLayer, MLP, AttentionPoolingLayer, PredictionLayer
54
+ from nextrec.basic.layers import (
55
+ EmbeddingLayer,
56
+ MLP,
57
+ AttentionPoolingLayer,
58
+ PredictionLayer,
59
+ )
16
60
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
61
 
18
62
 
@@ -24,28 +68,41 @@ class DIN(BaseModel):
24
68
  @property
25
69
  def default_task(self):
26
70
  return "binary"
27
-
28
- def __init__(self,
29
- dense_features: list[DenseFeature],
30
- sparse_features: list[SparseFeature],
31
- sequence_features: list[SequenceFeature],
32
- mlp_params: dict,
33
- attention_hidden_units: list[int] = [80, 40],
34
- attention_activation: str = 'sigmoid',
35
- attention_use_softmax: bool = True,
36
- target: list[str] = [],
37
- task: str | list[str] | None = None,
38
- optimizer: str = "adam",
39
- optimizer_params: dict = {},
40
- loss: str | nn.Module | None = "bce",
41
- loss_params: dict | list[dict] | None = None,
42
- device: str = 'cpu',
43
- embedding_l1_reg=1e-6,
44
- dense_l1_reg=1e-5,
45
- embedding_l2_reg=1e-5,
46
- dense_l2_reg=1e-4,
47
- **kwargs):
48
-
71
+
72
+ def __init__(
73
+ self,
74
+ dense_features: list[DenseFeature] | None = None,
75
+ sparse_features: list[SparseFeature] | None = None,
76
+ sequence_features: list[SequenceFeature] | None = None,
77
+ behavior_feature_name: str | None = None,
78
+ candidate_feature_name: str | None = None,
79
+ mlp_params: dict | None = None,
80
+ attention_hidden_units: list[int] | None = None,
81
+ attention_activation: str = "dice",
82
+ attention_use_softmax: bool = True,
83
+ target: list[str] | str | None = None,
84
+ task: str | list[str] | None = None,
85
+ optimizer: str = "adam",
86
+ optimizer_params: dict | None = None,
87
+ loss: str | nn.Module | None = "bce",
88
+ loss_params: dict | list[dict] | None = None,
89
+ device: str = "cpu",
90
+ embedding_l1_reg=1e-6,
91
+ dense_l1_reg=1e-5,
92
+ embedding_l2_reg=1e-5,
93
+ dense_l2_reg=1e-4,
94
+ **kwargs,
95
+ ):
96
+
97
+ dense_features = dense_features or []
98
+ sparse_features = sparse_features or []
99
+ sequence_features = sequence_features or []
100
+ mlp_params = mlp_params or {}
101
+ attention_hidden_units = attention_hidden_units or [80, 40]
102
+ optimizer_params = optimizer_params or {}
103
+ if loss is None:
104
+ loss = "bce"
105
+
49
106
  super(DIN, self).__init__(
50
107
  dense_features=dense_features,
51
108
  sparse_features=sparse_features,
@@ -57,43 +114,52 @@ class DIN(BaseModel):
57
114
  dense_l1_reg=dense_l1_reg,
58
115
  embedding_l2_reg=embedding_l2_reg,
59
116
  dense_l2_reg=dense_l2_reg,
60
- **kwargs
117
+ **kwargs,
61
118
  )
62
119
 
63
- self.loss = loss
64
- if self.loss is None:
65
- self.loss = "bce"
66
-
67
- # Features classification
68
- # DIN requires: candidate item + user behavior sequence + other features
120
+ # DIN requires: user behavior sequence + candidate item + other features
69
121
  if len(sequence_features) == 0:
70
- raise ValueError("DIN requires at least one sequence feature for user behavior history")
71
-
72
- self.behavior_feature = sequence_features[0] # User behavior sequence
73
- self.candidate_feature = sparse_features[-1] if sparse_features else None # Candidate item
74
-
75
- # Other features (excluding behavior sequence in final concatenation)
76
- self.other_sparse_features = sparse_features[:-1] if self.candidate_feature else sparse_features
77
- self.dense_features_list = dense_features
78
-
79
- # All features for embedding
80
- self.all_features = dense_features + sparse_features + sequence_features
122
+ raise ValueError(
123
+ "DIN requires at least one sequence feature for user behavior history"
124
+ )
125
+ if behavior_feature_name is None:
126
+ raise ValueError("DIN requires an explicit behavior_feature_name")
127
+
128
+ if candidate_feature_name is None:
129
+ raise ValueError("DIN requires an explicit candidate_feature_name")
130
+
131
+ self.behavior_feature = [
132
+ f for f in sequence_features if f.name == behavior_feature_name
133
+ ][0]
134
+ self.candidate_feature = [
135
+ f for f in sparse_features if f.name == candidate_feature_name
136
+ ][0]
137
+
138
+ # Other sparse features
139
+ self.other_sparse_features = [
140
+ f for f in sparse_features if f.name != self.candidate_feature.name
141
+ ]
81
142
 
82
143
  # Embedding layer
83
144
  self.embedding = EmbeddingLayer(features=self.all_features)
84
-
145
+
85
146
  # Attention layer for behavior sequence
86
147
  behavior_emb_dim = self.behavior_feature.embedding_dim
87
148
  self.candidate_attention_proj = None
88
- if self.candidate_feature is not None and self.candidate_feature.embedding_dim != behavior_emb_dim:
89
- self.candidate_attention_proj = nn.Linear(self.candidate_feature.embedding_dim, behavior_emb_dim)
149
+ if (
150
+ self.candidate_feature is not None
151
+ and self.candidate_feature.embedding_dim != behavior_emb_dim
152
+ ):
153
+ self.candidate_attention_proj = nn.Linear(
154
+ self.candidate_feature.embedding_dim, behavior_emb_dim
155
+ )
90
156
  self.attention = AttentionPoolingLayer(
91
157
  embedding_dim=behavior_emb_dim,
92
158
  hidden_units=attention_hidden_units,
93
159
  activation=attention_activation,
94
- use_softmax=attention_use_softmax
160
+ use_softmax=attention_use_softmax,
95
161
  )
96
-
162
+
97
163
  # Calculate MLP input dimension
98
164
  # candidate + attention_pooled_behavior + other_sparse + dense
99
165
  mlp_input_dim = 0
@@ -101,16 +167,18 @@ class DIN(BaseModel):
101
167
  mlp_input_dim += self.candidate_feature.embedding_dim
102
168
  mlp_input_dim += behavior_emb_dim # attention pooled
103
169
  mlp_input_dim += sum([f.embedding_dim for f in self.other_sparse_features])
104
- mlp_input_dim += sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
105
-
170
+ mlp_input_dim += sum(
171
+ [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
172
+ )
173
+
106
174
  # MLP for final prediction
107
175
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
108
176
  self.prediction_layer = PredictionLayer(task_type=self.task)
109
177
 
110
178
  # Register regularization weights
111
179
  self.register_regularization_weights(
112
- embedding_attr='embedding',
113
- include_modules=['attention', 'mlp', 'candidate_attention_proj']
180
+ embedding_attr="embedding",
181
+ include_modules=["attention", "mlp", "candidate_attention_proj"],
114
182
  )
115
183
 
116
184
  self.compile(
@@ -122,62 +190,60 @@ class DIN(BaseModel):
122
190
 
123
191
  def forward(self, x):
124
192
  # Get candidate item embedding
125
- if self.candidate_feature:
126
- candidate_emb = self.embedding.embed_dict[self.candidate_feature.embedding_name](
127
- x[self.candidate_feature.name].long()
128
- ) # [B, emb_dim]
129
- else:
130
- candidate_emb = None
131
-
193
+ if self.candidate_feature is None:
194
+ raise ValueError("DIN requires a candidate item feature")
195
+ candidate_emb = self.embedding.embed_dict[
196
+ self.candidate_feature.embedding_name
197
+ ](
198
+ x[self.candidate_feature.name].long()
199
+ ) # [B, emb_dim]
200
+
132
201
  # Get behavior sequence embedding
133
202
  behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
134
203
  behavior_emb = self.embedding.embed_dict[self.behavior_feature.embedding_name](
135
204
  behavior_seq
136
205
  ) # [B, seq_len, emb_dim]
137
-
206
+
138
207
  # Create mask for padding
139
208
  if self.behavior_feature.padding_idx is not None:
140
- mask = (behavior_seq != self.behavior_feature.padding_idx).unsqueeze(-1).float()
209
+ mask = (
210
+ (behavior_seq != self.behavior_feature.padding_idx)
211
+ .unsqueeze(-1)
212
+ .float()
213
+ )
141
214
  else:
142
215
  mask = (behavior_seq != 0).unsqueeze(-1).float()
143
-
216
+
144
217
  # Apply attention pooling
145
- if candidate_emb is not None:
146
- candidate_query = candidate_emb
147
- if self.candidate_attention_proj is not None:
148
- candidate_query = self.candidate_attention_proj(candidate_query)
149
- pooled_behavior = self.attention(
150
- query=candidate_query,
151
- keys=behavior_emb,
152
- mask=mask
153
- ) # [B, emb_dim]
154
- else:
155
- # If no candidate, use mean pooling
156
- pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (mask.sum(dim=1) + 1e-9)
157
-
218
+ candidate_query = candidate_emb
219
+ if self.candidate_attention_proj is not None:
220
+ candidate_query = self.candidate_attention_proj(candidate_query)
221
+ pooled_behavior = self.attention(
222
+ query=candidate_query, keys=behavior_emb, mask=mask
223
+ ) # [B, emb_dim]
224
+
158
225
  # Get other features
159
226
  other_embeddings = []
160
-
161
- if candidate_emb is not None:
162
- other_embeddings.append(candidate_emb)
163
-
227
+
228
+ other_embeddings.append(candidate_emb)
229
+
164
230
  other_embeddings.append(pooled_behavior)
165
-
231
+
166
232
  # Other sparse features
167
233
  for feat in self.other_sparse_features:
168
- feat_emb = self.embedding.embed_dict[feat.embedding_name](x[feat.name].long())
234
+ feat_emb = self.embedding.embed_dict[feat.embedding_name](
235
+ x[feat.name].long()
236
+ )
169
237
  other_embeddings.append(feat_emb)
170
-
238
+
171
239
  # Dense features
172
- for feat in self.dense_features_list:
173
- val = x[feat.name].float()
174
- if val.dim() == 1:
175
- val = val.unsqueeze(1)
176
- other_embeddings.append(val)
177
-
240
+ for feat in self.dense_features:
241
+ dense_val = self.embedding.project_dense(feat, x)
242
+ other_embeddings.append(dense_val)
243
+
178
244
  # Concatenate all features
179
245
  concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
180
-
246
+
181
247
  # MLP prediction
182
248
  y = self.mlp(concat_input) # [B, 1]
183
249
  return self.prediction_layer(y)
@@ -1,10 +1,43 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 09/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
- [1] Huang T, Zhang Z, Zhang B, et al. FiBiNET: Combining feature importance and bilinear feature interaction
7
- for click-through rate prediction[C]//RecSys. 2019: 169-177.
6
+ [1] Huang T, Zhang Z, Zhang B, et al. FiBiNET: Combining feature importance and bilinear
7
+ feature interaction for click-through rate prediction[C]//RecSys. 2019: 169-177.
8
+ (https://arxiv.org/abs/1905.09433)
9
+
10
+ FiBiNET (Feature Importance and Bilinear Interaction Network) is a CTR model that
11
+ jointly learns which fields matter most and how they interact. It first uses SENET
12
+ to produce field-wise importance weights and recalibrate embeddings, then applies a
13
+ bilinear interaction layer to capture pairwise feature relationships with enhanced
14
+ expressiveness.
15
+
16
+ Pipeline:
17
+ (1) Embed sparse and sequence features that share a common embedding dimension
18
+ (2) SENET squeezes and excites across fields to generate importance scores
19
+ (3) Reweight embeddings with SENET scores to highlight informative fields
20
+ (4) Compute bilinear interactions on both the original and SENET-reweighted
21
+ embeddings to model pairwise relations
22
+ (5) Concatenate interaction outputs and feed them into an MLP alongside a linear
23
+ term for final prediction
24
+
25
+ Key Advantages:
26
+ - SENET recalibration emphasizes the most informative feature fields
27
+ - Bilinear interactions explicitly model pairwise relationships beyond simple dot
28
+ products
29
+ - Dual-path (standard + SENET-reweighted) interactions enrich representation power
30
+ - Combines linear and deep components for both memorization and generalization
31
+
32
+ FiBiNET 是一个 CTR 预估模型,通过 SENET 重新分配特征字段的重要性,再用双线性
33
+ 交互层捕捉成对特征关系。模型先对稀疏/序列特征做 embedding,SENET 生成字段权重并
34
+ 重标定 embedding,随后在原始和重标定的 embedding 上分别计算双线性交互,最后将
35
+ 交互结果与线性部分一起输入 MLP 得到预测。
36
+ 主要优点:
37
+ - SENET 让模型聚焦最重要的特征字段
38
+ - 双线性交互显式建模特征对关系,表达力强于简单点积
39
+ - 标准与重标定两路交互结合,丰富特征表示
40
+ - 线性与深层结构并行,兼顾记忆与泛化
8
41
  """
9
42
 
10
43
  import torch
@@ -13,6 +46,7 @@ import torch.nn as nn
13
46
  from nextrec.basic.model import BaseModel
14
47
  from nextrec.basic.layers import (
15
48
  BiLinearInteractionLayer,
49
+ HadamardInteractionLayer,
16
50
  EmbeddingLayer,
17
51
  LR,
18
52
  MLP,
@@ -30,27 +64,38 @@ class FiBiNET(BaseModel):
30
64
  @property
31
65
  def default_task(self):
32
66
  return "binary"
33
-
34
- def __init__(self,
35
- dense_features: list[DenseFeature] | list = [],
36
- sparse_features: list[SparseFeature] | list = [],
37
- sequence_features: list[SequenceFeature] | list = [],
38
- mlp_params: dict = {},
39
- bilinear_type: str = "field_interaction",
40
- senet_reduction: int = 3,
41
- target: list[str] | list = [],
42
- task: str | list[str] | None = None,
43
- optimizer: str = "adam",
44
- optimizer_params: dict = {},
45
- loss: str | nn.Module | None = "bce",
46
- loss_params: dict | list[dict] | None = None,
47
- device: str = 'cpu',
48
- embedding_l1_reg=1e-6,
49
- dense_l1_reg=1e-5,
50
- embedding_l2_reg=1e-5,
51
- dense_l2_reg=1e-4,
52
- **kwargs):
53
-
67
+
68
+ def __init__(
69
+ self,
70
+ dense_features: list[DenseFeature] | None = None,
71
+ sparse_features: list[SparseFeature] | None = None,
72
+ sequence_features: list[SequenceFeature] | None = None,
73
+ mlp_params: dict | None = None,
74
+ interaction_combo: str = "11", # "0": Hadamard, "1": Bilinear
75
+ bilinear_type: str = "field_interaction",
76
+ senet_reduction: int = 3,
77
+ target: list[str] | str | None = None,
78
+ task: str | list[str] | None = None,
79
+ optimizer: str = "adam",
80
+ optimizer_params: dict | None = None,
81
+ loss: str | nn.Module | None = "bce",
82
+ loss_params: dict | list[dict] | None = None,
83
+ device: str = "cpu",
84
+ embedding_l1_reg=1e-6,
85
+ dense_l1_reg=1e-5,
86
+ embedding_l2_reg=1e-5,
87
+ dense_l2_reg=1e-4,
88
+ **kwargs,
89
+ ):
90
+
91
+ dense_features = dense_features or []
92
+ sparse_features = sparse_features or []
93
+ sequence_features = sequence_features or []
94
+ mlp_params = mlp_params or {}
95
+ optimizer_params = optimizer_params or {}
96
+ if loss is None:
97
+ loss = "bce"
98
+
54
99
  super(FiBiNET, self).__init__(
55
100
  dense_features=dense_features,
56
101
  sparse_features=sparse_features,
@@ -62,39 +107,61 @@ class FiBiNET(BaseModel):
62
107
  dense_l1_reg=dense_l1_reg,
63
108
  embedding_l2_reg=embedding_l2_reg,
64
109
  dense_l2_reg=dense_l2_reg,
65
- **kwargs
110
+ **kwargs,
66
111
  )
67
112
 
68
113
  self.loss = loss
69
- if self.loss is None:
70
- self.loss = "bce"
71
-
72
114
  self.linear_features = sparse_features + sequence_features
73
- self.deep_features = dense_features + sparse_features + sequence_features
74
115
  self.interaction_features = sparse_features + sequence_features
75
116
 
76
117
  if len(self.interaction_features) < 2:
77
- raise ValueError("FiBiNET requires at least two sparse/sequence features for interactions.")
118
+ raise ValueError(
119
+ "FiBiNET requires at least two sparse/sequence features for interactions."
120
+ )
78
121
 
79
- self.embedding = EmbeddingLayer(features=self.deep_features)
122
+ self.embedding = EmbeddingLayer(features=self.all_features)
80
123
 
81
124
  self.num_fields = len(self.interaction_features)
82
125
  self.embedding_dim = self.interaction_features[0].embedding_dim
83
- if any(f.embedding_dim != self.embedding_dim for f in self.interaction_features):
84
- raise ValueError("All interaction features must share the same embedding_dim in FiBiNET.")
85
-
86
- self.senet = SENETLayer(num_fields=self.num_fields, reduction_ratio=senet_reduction)
87
- self.bilinear_standard = BiLinearInteractionLayer(
88
- input_dim=self.embedding_dim,
89
- num_fields=self.num_fields,
90
- bilinear_type=bilinear_type,
91
- )
92
- self.bilinear_senet = BiLinearInteractionLayer(
93
- input_dim=self.embedding_dim,
94
- num_fields=self.num_fields,
95
- bilinear_type=bilinear_type,
126
+ if any(
127
+ f.embedding_dim != self.embedding_dim for f in self.interaction_features
128
+ ):
129
+ raise ValueError(
130
+ "All interaction features must share the same embedding_dim in FiBiNET."
131
+ )
132
+
133
+ self.senet = SENETLayer(
134
+ num_fields=self.num_fields, reduction_ratio=senet_reduction
96
135
  )
97
136
 
137
+ self.interaction_combo = interaction_combo
138
+
139
+ # E interaction layers: original embeddings
140
+ if interaction_combo[0] == "0": # Hadamard
141
+ self.interaction_E = HadamardInteractionLayer(
142
+ num_fields=self.num_fields
143
+ ) # [B, num_pairs, D]
144
+ elif interaction_combo[0] == "1": # Bilinear
145
+ self.interaction_E = BiLinearInteractionLayer(
146
+ input_dim=self.embedding_dim,
147
+ num_fields=self.num_fields,
148
+ bilinear_type=bilinear_type,
149
+ ) # [B, num_pairs, D]
150
+ else:
151
+ raise ValueError("interaction_combo must be '01' or '11'")
152
+
153
+ # V interaction layers: SENET reweighted embeddings
154
+ if interaction_combo[1] == "0":
155
+ self.interaction_V = HadamardInteractionLayer(num_fields=self.num_fields)
156
+ elif interaction_combo[1] == "1":
157
+ self.interaction_V = BiLinearInteractionLayer(
158
+ input_dim=self.embedding_dim,
159
+ num_fields=self.num_fields,
160
+ bilinear_type=bilinear_type,
161
+ )
162
+ else:
163
+ raise ValueError("Deep-FiBiNET SENET side must be '01' or '11'")
164
+
98
165
  linear_dim = sum([f.embedding_dim for f in self.linear_features])
99
166
  self.linear = LR(linear_dim)
100
167
 
@@ -105,8 +172,14 @@ class FiBiNET(BaseModel):
105
172
 
106
173
  # Register regularization weights
107
174
  self.register_regularization_weights(
108
- embedding_attr='embedding',
109
- include_modules=['linear', 'senet', 'bilinear_standard', 'bilinear_senet', 'mlp']
175
+ embedding_attr="embedding",
176
+ include_modules=[
177
+ "linear",
178
+ "senet",
179
+ "mlp",
180
+ "interaction_E",
181
+ "interaction_V",
182
+ ],
110
183
  )
111
184
 
112
185
  self.compile(
@@ -117,15 +190,24 @@ class FiBiNET(BaseModel):
117
190
  )
118
191
 
119
192
  def forward(self, x):
120
- input_linear = self.embedding(x=x, features=self.linear_features, squeeze_dim=True)
193
+ input_linear = self.embedding(
194
+ x=x, features=self.linear_features, squeeze_dim=True
195
+ )
121
196
  y_linear = self.linear(input_linear)
122
197
 
123
- field_emb = self.embedding(x=x, features=self.interaction_features, squeeze_dim=False)
198
+ field_emb = self.embedding(
199
+ x=x, features=self.interaction_features, squeeze_dim=False
200
+ )
124
201
  senet_emb = self.senet(field_emb)
125
202
 
126
- bilinear_standard = self.bilinear_standard(field_emb).flatten(start_dim=1)
127
- bilinear_senet = self.bilinear_senet(senet_emb).flatten(start_dim=1)
128
- deep_input = torch.cat([bilinear_standard, bilinear_senet], dim=1)
203
+ out_E = self.interaction_E(field_emb) # [B, num_pairs, D]
204
+
205
+ out_V = self.interaction_V(senet_emb) # [B, num_pairs, D]
206
+
207
+ deep_input = torch.cat(
208
+ [out_E.flatten(start_dim=1), out_V.flatten(start_dim=1)], dim=1
209
+ )
210
+
129
211
  y_deep = self.mlp(deep_input)
130
212
 
131
213
  y = y_linear + y_deep