nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +250 -112
  7. nextrec/basic/loggers.py +63 -44
  8. nextrec/basic/metrics.py +270 -120
  9. nextrec/basic/model.py +1084 -402
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +492 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +273 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +69 -46
  29. nextrec/models/multi_task/mmoe.py +91 -53
  30. nextrec/models/multi_task/ple.py +117 -58
  31. nextrec/models/multi_task/poso.py +163 -55
  32. nextrec/models/multi_task/share_bottom.py +63 -36
  33. nextrec/models/ranking/afm.py +80 -45
  34. nextrec/models/ranking/autoint.py +74 -57
  35. nextrec/models/ranking/dcn.py +110 -48
  36. nextrec/models/ranking/dcn_v2.py +265 -45
  37. nextrec/models/ranking/deepfm.py +39 -24
  38. nextrec/models/ranking/dien.py +335 -146
  39. nextrec/models/ranking/din.py +158 -92
  40. nextrec/models/ranking/fibinet.py +134 -52
  41. nextrec/models/ranking/fm.py +68 -26
  42. nextrec/models/ranking/masknet.py +95 -33
  43. nextrec/models/ranking/pnn.py +128 -58
  44. nextrec/models/ranking/widedeep.py +40 -28
  45. nextrec/models/ranking/xdeepfm.py +67 -40
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +496 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +33 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/model.py +22 -0
  55. nextrec/utils/optimizer.py +25 -9
  56. nextrec/utils/synthetic_data.py +283 -165
  57. nextrec/utils/tensor.py +24 -13
  58. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
  59. nextrec-0.4.3.dist-info/RECORD +69 -0
  60. nextrec-0.4.3.dist-info/entry_points.txt +2 -0
  61. nextrec-0.4.1.dist-info/RECORD +0 -66
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
  63. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,52 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 09/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
- [1] Rendle S. Factorization machines[C]//ICDM. 2010: 995-1000.
6
+ [1] Rendle S. Factorization machines[C]//ICDM. 2010: 995-1000.
7
+
8
+ Factorization Machines (FM) capture second-order feature interactions with
9
+ linear complexity by factorizing the pairwise interaction matrix. Each field
10
+ is embedded into a latent vector; FM models the dot product of every pair of
11
+ embeddings and sums them along with a linear term, enabling strong performance
12
+ with sparse high-dimensional data and minimal feature engineering.
13
+
14
+ Pipeline:
15
+ (1) Embed sparse and sequence fields into low-dimensional vectors
16
+ (2) Compute linear logit over concatenated embeddings
17
+ (3) Compute pairwise interaction logit via factorized dot products
18
+ (4) Sum linear + interaction terms and apply prediction layer
19
+
20
+ Key Advantages:
21
+ - Models pairwise interactions efficiently (O(nk) vs. O(n^2))
22
+ - Works well on sparse inputs without handcrafted crosses
23
+ - Simple architecture with strong baseline performance
24
+
25
+ FM 是一种通过分解二阶特征交互矩阵、以线性复杂度建模特征对的 CTR 模型。
26
+ 每个特征映射为低维向量,FM 对任意特征对进行内积求和并叠加线性项,
27
+ 无需复杂特征工程即可在稀疏高维场景取得稳健效果。
28
+
29
+ 处理流程:
30
+ (1) 对稀疏/序列特征做 embedding
31
+ (2) 计算线性部分的 logit
32
+ (3) 计算嵌入对之间的二阶交互 logit
33
+ (4) 线性项与交互项求和,再通过预测层输出
34
+
35
+ 主要优点:
36
+ - 线性复杂度建模二阶交互,效率高
37
+ - 对稀疏特征友好,减少人工特征交叉
38
+ - 结构简单、表现强健,常作 CTR 基线
7
39
  """
8
40
 
9
41
  import torch.nn as nn
10
42
 
11
43
  from nextrec.basic.model import BaseModel
12
- from nextrec.basic.layers import EmbeddingLayer, FM as FMInteraction, LR, PredictionLayer
44
+ from nextrec.basic.layers import (
45
+ EmbeddingLayer,
46
+ FM as FMInteraction,
47
+ LR,
48
+ PredictionLayer,
49
+ )
13
50
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
14
51
 
15
52
 
@@ -21,24 +58,30 @@ class FM(BaseModel):
21
58
  @property
22
59
  def default_task(self):
23
60
  return "binary"
24
-
25
- def __init__(self,
26
- dense_features: list[DenseFeature] | list = [],
27
- sparse_features: list[SparseFeature] | list = [],
28
- sequence_features: list[SequenceFeature] | list = [],
29
- target: list[str] | list = [],
30
- task: str | list[str] | None = None,
31
- optimizer: str = "adam",
32
- optimizer_params: dict = {},
33
- loss: str | nn.Module | None = "bce",
34
- loss_params: dict | list[dict] | None = None,
35
- device: str = 'cpu',
36
- embedding_l1_reg=1e-6,
37
- dense_l1_reg=1e-5,
38
- embedding_l2_reg=1e-5,
39
- dense_l2_reg=1e-4,
40
- **kwargs):
41
-
61
+
62
+ def __init__(
63
+ self,
64
+ dense_features: list[DenseFeature] | None = None,
65
+ sparse_features: list[SparseFeature] | None = None,
66
+ sequence_features: list[SequenceFeature] | None = None,
67
+ target: list[str] | str | None = None,
68
+ task: str | list[str] | None = None,
69
+ optimizer: str = "adam",
70
+ optimizer_params: dict | None = None,
71
+ loss: str | nn.Module | None = "bce",
72
+ loss_params: dict | list[dict] | None = None,
73
+ device: str = "cpu",
74
+ embedding_l1_reg=1e-6,
75
+ dense_l1_reg=1e-5,
76
+ embedding_l2_reg=1e-5,
77
+ dense_l2_reg=1e-4,
78
+ **kwargs,
79
+ ):
80
+
81
+ dense_features = dense_features or []
82
+ sparse_features = sparse_features or []
83
+ sequence_features = sequence_features or []
84
+
42
85
  super(FM, self).__init__(
43
86
  dense_features=dense_features,
44
87
  sparse_features=sparse_features,
@@ -50,13 +93,13 @@ class FM(BaseModel):
50
93
  dense_l1_reg=dense_l1_reg,
51
94
  embedding_l2_reg=embedding_l2_reg,
52
95
  dense_l2_reg=dense_l2_reg,
53
- **kwargs
96
+ **kwargs,
54
97
  )
55
98
 
56
99
  self.loss = loss
57
100
  if self.loss is None:
58
101
  self.loss = "bce"
59
-
102
+
60
103
  self.fm_features = sparse_features + sequence_features
61
104
  if len(self.fm_features) == 0:
62
105
  raise ValueError("FM requires at least one sparse or sequence feature.")
@@ -70,8 +113,7 @@ class FM(BaseModel):
70
113
 
71
114
  # Register regularization weights
72
115
  self.register_regularization_weights(
73
- embedding_attr='embedding',
74
- include_modules=['linear']
116
+ embedding_attr="embedding", include_modules=["linear"]
75
117
  )
76
118
 
77
119
  self.compile(
@@ -69,12 +69,13 @@ class InstanceGuidedMask(nn.Module):
69
69
  self.fc2 = nn.Linear(hidden_dim, output_dim)
70
70
 
71
71
  def forward(self, v_emb_flat: torch.Tensor) -> torch.Tensor:
72
- # v_emb_flat: [batch, features count * embedding_dim]
72
+ # v_emb_flat: [batch, features count * embedding_dim]
73
73
  x = self.fc1(v_emb_flat)
74
74
  x = F.relu(x)
75
75
  v_mask = self.fc2(x)
76
76
  return v_mask
77
77
 
78
+
78
79
  class MaskBlockOnEmbedding(nn.Module):
79
80
  def __init__(
80
81
  self,
@@ -86,20 +87,28 @@ class MaskBlockOnEmbedding(nn.Module):
86
87
  super().__init__()
87
88
  self.num_fields = num_fields
88
89
  self.embedding_dim = embedding_dim
89
- self.input_dim = num_fields * embedding_dim # input_dim = features count * embedding_dim
90
+ self.input_dim = (
91
+ num_fields * embedding_dim
92
+ ) # input_dim = features count * embedding_dim
90
93
  self.ln_emb = nn.LayerNorm(embedding_dim)
91
- self.mask_gen = InstanceGuidedMask(input_dim=self.input_dim, hidden_dim=mask_hidden_dim, output_dim=self.input_dim,)
94
+ self.mask_gen = InstanceGuidedMask(
95
+ input_dim=self.input_dim,
96
+ hidden_dim=mask_hidden_dim,
97
+ output_dim=self.input_dim,
98
+ )
92
99
  self.ffn = nn.Linear(self.input_dim, hidden_dim)
93
100
  self.ln_hid = nn.LayerNorm(hidden_dim)
94
101
 
95
102
  # different from MaskBlockOnHidden: input is field embeddings
96
- def forward(self, field_emb: torch.Tensor, v_emb_flat: torch.Tensor) -> torch.Tensor:
103
+ def forward(
104
+ self, field_emb: torch.Tensor, v_emb_flat: torch.Tensor
105
+ ) -> torch.Tensor:
97
106
  B = field_emb.size(0)
98
- norm_emb = self.ln_emb(field_emb) # [B, features count, embedding_dim]
99
- norm_emb_flat = norm_emb.view(B, -1) # [B, features count * embedding_dim]
100
- v_mask = self.mask_gen(v_emb_flat) # [B, features count * embedding_dim]
101
- v_masked_emb = v_mask * norm_emb_flat # [B, features count * embedding_dim]
102
- hidden = self.ffn(v_masked_emb) # [B, hidden_dim]
107
+ norm_emb = self.ln_emb(field_emb) # [B, features count, embedding_dim]
108
+ norm_emb_flat = norm_emb.view(B, -1) # [B, features count * embedding_dim]
109
+ v_mask = self.mask_gen(v_emb_flat) # [B, features count * embedding_dim]
110
+ v_masked_emb = v_mask * norm_emb_flat # [B, features count * embedding_dim]
111
+ hidden = self.ffn(v_masked_emb) # [B, hidden_dim]
103
112
  hidden = self.ln_hid(hidden)
104
113
  hidden = F.relu(hidden)
105
114
 
@@ -123,15 +132,21 @@ class MaskBlockOnHidden(nn.Module):
123
132
  self.ln_input = nn.LayerNorm(hidden_dim)
124
133
  self.ln_output = nn.LayerNorm(hidden_dim)
125
134
 
126
- self.mask_gen = InstanceGuidedMask(input_dim=self.v_emb_dim, hidden_dim=mask_hidden_dim, output_dim=hidden_dim,)
135
+ self.mask_gen = InstanceGuidedMask(
136
+ input_dim=self.v_emb_dim,
137
+ hidden_dim=mask_hidden_dim,
138
+ output_dim=hidden_dim,
139
+ )
127
140
  self.ffn = nn.Linear(hidden_dim, hidden_dim)
128
141
 
129
142
  # different from MaskBlockOnEmbedding: input is hidden representation
130
- def forward(self, hidden_in: torch.Tensor, v_emb_flat: torch.Tensor) -> torch.Tensor:
131
- norm_hidden = self.ln_input(hidden_in)
143
+ def forward(
144
+ self, hidden_in: torch.Tensor, v_emb_flat: torch.Tensor
145
+ ) -> torch.Tensor:
146
+ norm_hidden = self.ln_input(hidden_in)
132
147
  v_mask = self.mask_gen(v_emb_flat)
133
- v_masked_hid = v_mask * norm_hidden
134
- out = self.ffn(v_masked_hid)
148
+ v_masked_hid = v_mask * norm_hidden
149
+ out = self.ffn(v_masked_hid)
135
150
  out = self.ln_output(out)
136
151
  out = F.relu(out)
137
152
  return out
@@ -151,7 +166,7 @@ class MaskNet(BaseModel):
151
166
  dense_features: list[DenseFeature] | None = None,
152
167
  sparse_features: list[SparseFeature] | None = None,
153
168
  sequence_features: list[SequenceFeature] | None = None,
154
- model_type: str = "parallel", # "serial" or "parallel"
169
+ architecture: str = "parallel", # "serial" or "parallel"
155
170
  num_blocks: int = 3,
156
171
  mask_hidden_dim: int = 64,
157
172
  block_hidden_dim: int = 256,
@@ -199,52 +214,99 @@ class MaskNet(BaseModel):
199
214
  self.sparse_features = sparse_features
200
215
  self.sequence_features = sequence_features
201
216
  self.mask_features = self.all_features # use all features for masking
202
- assert len(self.mask_features) > 0, "MaskNet requires at least one feature for masking."
217
+ assert (
218
+ len(self.mask_features) > 0
219
+ ), "MaskNet requires at least one feature for masking."
203
220
  self.embedding = EmbeddingLayer(features=self.mask_features)
204
221
  self.num_fields = len(self.mask_features)
205
222
  self.embedding_dim = getattr(self.mask_features[0], "embedding_dim", None)
206
- assert self.embedding_dim is not None, "MaskNet requires mask_features to have 'embedding_dim' defined."
223
+ assert (
224
+ self.embedding_dim is not None
225
+ ), "MaskNet requires mask_features to have 'embedding_dim' defined."
207
226
 
208
227
  for f in self.mask_features:
209
228
  edim = getattr(f, "embedding_dim", None)
210
229
  if edim is None or edim != self.embedding_dim:
211
- raise ValueError(f"MaskNet expects identical embedding_dim across all mask_features, but got {edim} for feature {getattr(f, 'name', type(f))}.")
230
+ raise ValueError(
231
+ f"MaskNet expects identical embedding_dim across all mask_features, but got {edim} for feature {getattr(f, 'name', type(f))}."
232
+ )
212
233
 
213
234
  self.v_emb_dim = self.num_fields * self.embedding_dim
214
- self.model_type = model_type.lower()
215
- assert self.model_type in ("serial", "parallel"), "model_type must be either 'serial' or 'parallel'."
235
+ self.architecture = architecture.lower()
236
+ assert self.architecture in (
237
+ "serial",
238
+ "parallel",
239
+ ), "architecture must be either 'serial' or 'parallel'."
216
240
 
217
241
  self.num_blocks = max(1, num_blocks)
218
242
  self.block_hidden_dim = block_hidden_dim
219
- self.block_dropout = nn.Dropout(block_dropout) if block_dropout > 0 else nn.Identity()
243
+ self.block_dropout = (
244
+ nn.Dropout(block_dropout) if block_dropout > 0 else nn.Identity()
245
+ )
220
246
 
221
- if self.model_type == "serial":
222
- self.first_block = MaskBlockOnEmbedding(num_fields=self.num_fields, embedding_dim=self.embedding_dim, mask_hidden_dim=mask_hidden_dim, hidden_dim=block_hidden_dim,)
247
+ if self.architecture == "serial":
248
+ self.first_block = MaskBlockOnEmbedding(
249
+ num_fields=self.num_fields,
250
+ embedding_dim=self.embedding_dim,
251
+ mask_hidden_dim=mask_hidden_dim,
252
+ hidden_dim=block_hidden_dim,
253
+ )
223
254
  self.hidden_blocks = nn.ModuleList(
224
- [MaskBlockOnHidden(num_fields=self.num_fields, embedding_dim=self.embedding_dim, mask_hidden_dim=mask_hidden_dim, hidden_dim=block_hidden_dim) for _ in range(self.num_blocks - 1)])
255
+ [
256
+ MaskBlockOnHidden(
257
+ num_fields=self.num_fields,
258
+ embedding_dim=self.embedding_dim,
259
+ mask_hidden_dim=mask_hidden_dim,
260
+ hidden_dim=block_hidden_dim,
261
+ )
262
+ for _ in range(self.num_blocks - 1)
263
+ ]
264
+ )
225
265
  self.mask_blocks = nn.ModuleList([self.first_block, *self.hidden_blocks])
226
266
  self.output_layer = nn.Linear(block_hidden_dim, 1)
227
267
  self.final_mlp = None
228
268
 
229
269
  else: # parallel
230
- self.mask_blocks = nn.ModuleList([MaskBlockOnEmbedding(num_fields=self.num_fields, embedding_dim=self.embedding_dim, mask_hidden_dim=mask_hidden_dim, hidden_dim=block_hidden_dim) for _ in range(self.num_blocks)])
231
- self.final_mlp = MLP(input_dim=self.num_blocks * block_hidden_dim, **mlp_params)
270
+ self.mask_blocks = nn.ModuleList(
271
+ [
272
+ MaskBlockOnEmbedding(
273
+ num_fields=self.num_fields,
274
+ embedding_dim=self.embedding_dim,
275
+ mask_hidden_dim=mask_hidden_dim,
276
+ hidden_dim=block_hidden_dim,
277
+ )
278
+ for _ in range(self.num_blocks)
279
+ ]
280
+ )
281
+ self.final_mlp = MLP(
282
+ input_dim=self.num_blocks * block_hidden_dim, **mlp_params
283
+ )
232
284
  self.output_layer = None
233
285
  self.prediction_layer = PredictionLayer(task_type=self.task)
234
286
 
235
- if self.model_type == "serial":
236
- self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "output_layer"],)
287
+ if self.architecture == "serial":
288
+ self.register_regularization_weights(
289
+ embedding_attr="embedding",
290
+ include_modules=["mask_blocks", "output_layer"],
291
+ )
237
292
  # serial
238
293
  else:
239
- self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"])
240
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
294
+ self.register_regularization_weights(
295
+ embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"]
296
+ )
297
+ self.compile(
298
+ optimizer=optimizer,
299
+ optimizer_params=optimizer_params,
300
+ loss=loss,
301
+ loss_params=loss_params,
302
+ )
241
303
 
242
304
  def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
243
305
  field_emb = self.embedding(x=x, features=self.mask_features, squeeze_dim=False)
244
306
  B = field_emb.size(0)
245
- v_emb_flat = field_emb.view(B, -1) # flattened embeddings
307
+ v_emb_flat = field_emb.view(B, -1) # flattened embeddings
246
308
 
247
- if self.model_type == "parallel":
309
+ if self.architecture == "parallel":
248
310
  block_outputs = []
249
311
  for block in self.mask_blocks:
250
312
  h = block(field_emb, v_emb_flat) # [B, block_hidden_dim]
@@ -253,7 +315,7 @@ class MaskNet(BaseModel):
253
315
  concat_hidden = torch.cat(block_outputs, dim=-1)
254
316
  logit = self.final_mlp(concat_hidden) # [B, 1]
255
317
  # serial
256
- else:
318
+ else:
257
319
  hidden = self.first_block(field_emb, v_emb_flat)
258
320
  hidden = self.block_dropout(hidden)
259
321
  for block in self.hidden_blocks:
@@ -3,7 +3,35 @@ Date: create on 09/11/2025
3
3
  Author:
4
4
  Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
- [1] Qu Y, Cai H, Ren K, et al. Product-based neural networks for user response prediction[C]//ICDM. 2016: 1149-1154.
6
+ [1] Qu Y, Cai H, Ren K, et al. Product-based neural networks for user response
7
+ prediction[C]//ICDM. 2016: 1149-1154. (https://arxiv.org/abs/1611.00144)
8
+
9
+ Product-based Neural Networks (PNN) are CTR prediction models that explicitly
10
+ encode feature interactions by combining:
11
+ (1) A linear signal from concatenated field embeddings
12
+ (2) A product signal capturing pairwise feature interactions (inner or outer)
13
+ The product layer augments the linear input to an MLP, enabling the network to
14
+ model both first-order and high-order feature interactions in a structured way.
15
+
16
+ Computation workflow:
17
+ - Embed each categorical/sequence field with a shared embedding dimension
18
+ - Linear signal: flatten and concatenate all field embeddings
19
+ - Product signal:
20
+ * Inner product: dot products over all field pairs
21
+ * Outer product: project embeddings then compute element-wise products
22
+ - Concatenate linear and product signals; feed into MLP for prediction
23
+
24
+ Key Advantages:
25
+ - Explicit pairwise interaction modeling without heavy feature engineering
26
+ - Flexible choice between inner/outer products to trade off capacity vs. cost
27
+ - Combines linear context with interaction signal for stronger expressiveness
28
+ - Simple architecture that integrates cleanly with standard MLP pipelines
29
+
30
+ PNN 是一种 CTR 预估模型,通过将线性信号与乘积信号结合,显式建模特征交互:
31
+ - 线性信号:将各字段的 embedding 拼接,用于保留一阶信息
32
+ - 乘积信号:对所有字段对做内积或外积,捕捉二阶及更高阶交互
33
+ 随后将两类信号拼接送入 MLP,实现对用户响应的预测。内积版本计算量更低,
34
+ 外积版本表达力更强,可根据场景取舍。
7
35
  """
8
36
 
9
37
  import torch
@@ -15,6 +43,7 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
15
43
 
16
44
 
17
45
  class PNN(BaseModel):
46
+
18
47
  @property
19
48
  def model_name(self):
20
49
  return "PNN"
@@ -22,27 +51,39 @@ class PNN(BaseModel):
22
51
  @property
23
52
  def default_task(self):
24
53
  return "binary"
25
-
26
- def __init__(self,
27
- dense_features: list[DenseFeature] | list = [],
28
- sparse_features: list[SparseFeature] | list = [],
29
- sequence_features: list[SequenceFeature] | list = [],
30
- mlp_params: dict = {},
31
- product_type: str = "inner",
32
- outer_product_dim: int | None = None,
33
- target: list[str] | list = [],
34
- task: str | list[str] | None = None,
35
- optimizer: str = "adam",
36
- optimizer_params: dict = {},
37
- loss: str | nn.Module | None = "bce",
38
- loss_params: dict | list[dict] | None = None,
39
- device: str = 'cpu',
40
- embedding_l1_reg=1e-6,
41
- dense_l1_reg=1e-5,
42
- embedding_l2_reg=1e-5,
43
- dense_l2_reg=1e-4,
44
- **kwargs):
45
-
54
+
55
+ def __init__(
56
+ self,
57
+ dense_features: list[DenseFeature] | None = None,
58
+ sparse_features: list[SparseFeature] | None = None,
59
+ sequence_features: list[SequenceFeature] | None = None,
60
+ mlp_params: dict | None = None,
61
+ product_type: str = "inner", # "inner" (IPNN), "outer" (OPNN), "both" (PNN*)
62
+ outer_product_dim: int | None = None,
63
+ target: list[str] | str | None = None,
64
+ task: str | list[str] | None = None,
65
+ optimizer: str = "adam",
66
+ optimizer_params: dict | None = None,
67
+ loss: str | nn.Module | None = "bce",
68
+ loss_params: dict | list[dict] | None = None,
69
+ device: str = "cpu",
70
+ embedding_l1_reg=1e-6,
71
+ dense_l1_reg=1e-5,
72
+ embedding_l2_reg=1e-5,
73
+ dense_l2_reg=1e-4,
74
+ **kwargs,
75
+ ):
76
+
77
+ dense_features = dense_features or []
78
+ sparse_features = sparse_features or []
79
+ sequence_features = sequence_features or []
80
+ mlp_params = mlp_params or {}
81
+ if outer_product_dim is not None and outer_product_dim <= 0:
82
+ raise ValueError("outer_product_dim must be a positive integer.")
83
+ optimizer_params = optimizer_params or {}
84
+ if loss is None:
85
+ loss = "bce"
86
+
46
87
  super(PNN, self).__init__(
47
88
  dense_features=dense_features,
48
89
  sparse_features=sparse_features,
@@ -54,46 +95,54 @@ class PNN(BaseModel):
54
95
  dense_l1_reg=dense_l1_reg,
55
96
  embedding_l2_reg=embedding_l2_reg,
56
97
  dense_l2_reg=dense_l2_reg,
57
- **kwargs
98
+ **kwargs,
58
99
  )
59
100
 
60
- self.loss = loss
61
- if self.loss is None:
62
- self.loss = "bce"
63
-
64
- self.field_features = sparse_features + sequence_features
101
+ self.field_features = dense_features + sparse_features + sequence_features
65
102
  if len(self.field_features) < 2:
66
103
  raise ValueError("PNN requires at least two sparse/sequence features.")
67
104
 
68
105
  self.embedding = EmbeddingLayer(features=self.field_features)
69
106
  self.num_fields = len(self.field_features)
107
+
70
108
  self.embedding_dim = self.field_features[0].embedding_dim
71
109
  if any(f.embedding_dim != self.embedding_dim for f in self.field_features):
72
- raise ValueError("All field features must share the same embedding_dim for PNN.")
110
+ raise ValueError(
111
+ "All field features must share the same embedding_dim for PNN."
112
+ )
73
113
 
74
114
  self.product_type = product_type.lower()
75
- if self.product_type not in {"inner", "outer"}:
76
- raise ValueError("product_type must be 'inner' or 'outer'.")
115
+ if self.product_type not in {"inner", "outer", "both"}:
116
+ raise ValueError("product_type must be 'inner', 'outer', or 'both'.")
77
117
 
78
118
  self.num_pairs = self.num_fields * (self.num_fields - 1) // 2
79
- if self.product_type == "outer":
80
- self.outer_dim = outer_product_dim or self.embedding_dim
81
- self.kernel = nn.Linear(self.embedding_dim, self.outer_dim, bias=False)
82
- product_dim = self.num_pairs * self.outer_dim
119
+ self.outer_product_dim = outer_product_dim or self.embedding_dim
120
+
121
+ if self.product_type in {"outer", "both"}:
122
+ self.kernel = nn.Parameter(
123
+ torch.randn(self.embedding_dim, self.outer_product_dim)
124
+ )
125
+ nn.init.xavier_uniform_(self.kernel)
83
126
  else:
84
- self.outer_dim = None
85
- product_dim = self.num_pairs
127
+ self.kernel = None
86
128
 
87
129
  linear_dim = self.num_fields * self.embedding_dim
130
+
131
+ if self.product_type == "inner":
132
+ product_dim = self.num_pairs
133
+ elif self.product_type == "outer":
134
+ product_dim = self.num_pairs
135
+ else:
136
+ product_dim = 2 * self.num_pairs
137
+
88
138
  self.mlp = MLP(input_dim=linear_dim + product_dim, **mlp_params)
89
139
  self.prediction_layer = PredictionLayer(task_type=self.task)
90
140
 
91
- modules = ['mlp']
92
- if self.product_type == "outer":
93
- modules.append('kernel')
141
+ modules = ["mlp"]
142
+ if self.kernel is not None:
143
+ modules.append("kernel")
94
144
  self.register_regularization_weights(
95
- embedding_attr='embedding',
96
- include_modules=modules
145
+ embedding_attr="embedding", include_modules=modules
97
146
  )
98
147
 
99
148
  self.compile(
@@ -103,27 +152,48 @@ class PNN(BaseModel):
103
152
  loss_params=loss_params,
104
153
  )
105
154
 
155
+ def compute_inner_products(self, field_emb: torch.Tensor) -> torch.Tensor:
156
+ interactions = []
157
+ for i in range(self.num_fields - 1):
158
+ vi = field_emb[:, i, :] # [B, D]
159
+ for j in range(i + 1, self.num_fields):
160
+ vj = field_emb[:, j, :] # [B, D]
161
+ # <v_i, v_j> = sum_k v_i,k * v_j,k
162
+ pij = torch.sum(vi * vj, dim=1, keepdim=True) # [B, 1]
163
+ interactions.append(pij)
164
+ return torch.cat(interactions, dim=1) # [B, num_pairs]
165
+
166
+ def compute_outer_kernel_products(self, field_emb: torch.Tensor) -> torch.Tensor:
167
+ if self.kernel is None:
168
+ raise RuntimeError("kernel is not initialized for outer product.")
169
+
170
+ interactions = []
171
+ for i in range(self.num_fields - 1):
172
+ vi = field_emb[:, i, :] # [B, D]
173
+ # Project vi with kernel -> [B, K]
174
+ vi_proj = torch.matmul(vi, self.kernel) # [B, K]
175
+ for j in range(i + 1, self.num_fields):
176
+ vj = field_emb[:, j, :] # [B, D]
177
+ vj_proj = torch.matmul(vj, self.kernel) # [B, K]
178
+ # g(vi, vj) = (v_i^T W) * (v_j^T W) summed over projection dim
179
+ pij = torch.sum(vi_proj * vj_proj, dim=1, keepdim=True) # [B, 1]
180
+ interactions.append(pij)
181
+ return torch.cat(interactions, dim=1) # [B, num_pairs]
182
+
106
183
  def forward(self, x):
184
+ # field_emb: [B, F, D]
107
185
  field_emb = self.embedding(x=x, features=self.field_features, squeeze_dim=False)
108
- linear_signal = field_emb.flatten(start_dim=1)
186
+ # Z = [v_1; v_2; ...; v_F]
187
+ linear_signal = field_emb.flatten(start_dim=1) # [B, F*D]
109
188
 
110
189
  if self.product_type == "inner":
111
- interactions = []
112
- for i in range(self.num_fields - 1):
113
- vi = field_emb[:, i, :]
114
- for j in range(i + 1, self.num_fields):
115
- vj = field_emb[:, j, :]
116
- interactions.append(torch.sum(vi * vj, dim=1, keepdim=True))
117
- product_signal = torch.cat(interactions, dim=1)
190
+ product_signal = self.compute_inner_products(field_emb)
191
+ elif self.product_type == "outer":
192
+ product_signal = self.compute_outer_kernel_products(field_emb)
118
193
  else:
119
- transformed = self.kernel(field_emb) # [B, F, outer_dim]
120
- interactions = []
121
- for i in range(self.num_fields - 1):
122
- vi = transformed[:, i, :]
123
- for j in range(i + 1, self.num_fields):
124
- vj = transformed[:, j, :]
125
- interactions.append(vi * vj)
126
- product_signal = torch.stack(interactions, dim=1).flatten(start_dim=1)
194
+ inner_p = self.compute_inner_products(field_emb)
195
+ outer_p = self.compute_outer_kernel_products(field_emb)
196
+ product_signal = torch.cat([inner_p, outer_p], dim=1)
127
197
 
128
198
  deep_input = torch.cat([linear_signal, product_signal], dim=1)
129
199
  y = self.mlp(deep_input)