nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +220 -106
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1082 -400
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +272 -95
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +53 -37
  29. nextrec/models/multi_task/mmoe.py +64 -45
  30. nextrec/models/multi_task/ple.py +101 -48
  31. nextrec/models/multi_task/poso.py +113 -36
  32. nextrec/models/multi_task/share_bottom.py +48 -35
  33. nextrec/models/ranking/afm.py +72 -37
  34. nextrec/models/ranking/autoint.py +72 -55
  35. nextrec/models/ranking/dcn.py +55 -35
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +32 -22
  38. nextrec/models/ranking/dien.py +155 -99
  39. nextrec/models/ranking/din.py +85 -57
  40. nextrec/models/ranking/fibinet.py +52 -32
  41. nextrec/models/ranking/fm.py +29 -23
  42. nextrec/models/ranking/masknet.py +91 -29
  43. nextrec/models/ranking/pnn.py +31 -28
  44. nextrec/models/ranking/widedeep.py +34 -26
  45. nextrec/models/ranking/xdeepfm.py +60 -38
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +283 -165
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.4.1.dist-info/RECORD +0 -66
  61. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -79,7 +79,7 @@ class POSOGate(nn.Module):
79
79
  h = self.act(self.fc1(pc))
80
80
  g = torch.sigmoid(self.fc2(h)) # (B, out_dim) in (0,1)
81
81
  return self.scale_factor * g
82
-
82
+
83
83
 
84
84
  class POSOFC(nn.Module):
85
85
  """
@@ -116,10 +116,10 @@ class POSOFC(nn.Module):
116
116
  pc: (B, pc_dim)
117
117
  return: (B, out_dim)
118
118
  """
119
- h = self.act(self.linear(x)) # Standard FC with activation
120
- g = self.gate(pc) # (B, out_dim)
121
- return g * h # Element-wise gating
122
-
119
+ h = self.act(self.linear(x)) # Standard FC with activation
120
+ g = self.gate(pc) # (B, out_dim)
121
+ return g * h # Element-wise gating
122
+
123
123
 
124
124
  class POSOMLP(nn.Module):
125
125
  """
@@ -173,7 +173,7 @@ class POSOMLP(nn.Module):
173
173
  if self.dropout is not None:
174
174
  h = self.dropout(h)
175
175
  return h
176
-
176
+
177
177
 
178
178
  class POSOMMoE(nn.Module):
179
179
  """
@@ -183,7 +183,7 @@ class POSOMMoE(nn.Module):
183
183
  - Task gates aggregate the PC-masked expert outputs
184
184
 
185
185
  Concretely:
186
- h_e = expert_e(x) # (B, D)
186
+ h_e = expert_e(x) # (B, D)
187
187
  g_e = POSOGate(pc) in (0, C)^{D} # (B, D)
188
188
  h_e_tilde = g_e ⊙ h_e # (B, D)
189
189
  z_t = Σ_e gate_t,e(x) * h_e_tilde
@@ -192,14 +192,14 @@ class POSOMMoE(nn.Module):
192
192
  def __init__(
193
193
  self,
194
194
  input_dim: int,
195
- pc_dim: int, # for poso feature dimension
195
+ pc_dim: int, # for poso feature dimension
196
196
  num_experts: int,
197
197
  expert_hidden_dims: list[int],
198
198
  num_tasks: int,
199
199
  activation: str = "relu",
200
200
  expert_dropout: float = 0.0,
201
- gate_hidden_dim: int = 32, # for poso gate hidden dimension
202
- scale_factor: float = 2.0, # for poso gate scale factor
201
+ gate_hidden_dim: int = 32, # for poso gate hidden dimension
202
+ scale_factor: float = 2.0, # for poso gate scale factor
203
203
  gate_use_softmax: bool = True,
204
204
  ) -> None:
205
205
  super().__init__()
@@ -207,15 +207,41 @@ class POSOMMoE(nn.Module):
207
207
  self.num_tasks = num_tasks
208
208
 
209
209
  # Experts built with framework MLP, same as standard MMoE
210
- self.experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, dims=expert_hidden_dims, activation=activation, dropout=expert_dropout,) for _ in range(num_experts)])
211
- self.expert_output_dim = expert_hidden_dims[-1] if expert_hidden_dims else input_dim
210
+ self.experts = nn.ModuleList(
211
+ [
212
+ MLP(
213
+ input_dim=input_dim,
214
+ output_layer=False,
215
+ dims=expert_hidden_dims,
216
+ activation=activation,
217
+ dropout=expert_dropout,
218
+ )
219
+ for _ in range(num_experts)
220
+ ]
221
+ )
222
+ self.expert_output_dim = (
223
+ expert_hidden_dims[-1] if expert_hidden_dims else input_dim
224
+ )
212
225
 
213
226
  # Task-specific gates: gate_t(x) over experts
214
- self.gates = nn.ModuleList([nn.Linear(input_dim, num_experts) for _ in range(num_tasks)])
227
+ self.gates = nn.ModuleList(
228
+ [nn.Linear(input_dim, num_experts) for _ in range(num_tasks)]
229
+ )
215
230
  self.gate_use_softmax = gate_use_softmax
216
231
 
217
232
  # PC gate per expert: g_e(pc) ∈ R^D
218
- self.expert_pc_gates = nn.ModuleList([POSOGate(pc_dim=pc_dim, out_dim=self.expert_output_dim, hidden_dim=gate_hidden_dim, scale_factor=scale_factor, activation=activation,) for _ in range(num_experts)])
233
+ self.expert_pc_gates = nn.ModuleList(
234
+ [
235
+ POSOGate(
236
+ pc_dim=pc_dim,
237
+ out_dim=self.expert_output_dim,
238
+ hidden_dim=gate_hidden_dim,
239
+ scale_factor=scale_factor,
240
+ activation=activation,
241
+ )
242
+ for _ in range(num_experts)
243
+ ]
244
+ )
219
245
 
220
246
  def forward(self, x: torch.Tensor, pc: torch.Tensor) -> list[torch.Tensor]:
221
247
  """
@@ -226,9 +252,9 @@ class POSOMMoE(nn.Module):
226
252
  # 1) Expert outputs with POSO PC gate
227
253
  masked_expert_outputs = []
228
254
  for e, expert in enumerate(self.experts):
229
- h_e = expert(x) # (B, D)
230
- g_e = self.expert_pc_gates[e](pc) # (B, D)
231
- h_e_tilde = g_e * h_e # (B, D)
255
+ h_e = expert(x) # (B, D)
256
+ g_e = self.expert_pc_gates[e](pc) # (B, D)
257
+ h_e_tilde = g_e * h_e # (B, D)
232
258
  masked_expert_outputs.append(h_e_tilde)
233
259
 
234
260
  masked_expert_outputs = torch.stack(masked_expert_outputs, dim=1) # (B, E, D)
@@ -236,13 +262,13 @@ class POSOMMoE(nn.Module):
236
262
  # 2) Task gates depend on x as in standard MMoE
237
263
  task_outputs: list[torch.Tensor] = []
238
264
  for t in range(self.num_tasks):
239
- logits = self.gates[t](x) # (B, E)
265
+ logits = self.gates[t](x) # (B, E)
240
266
  if self.gate_use_softmax:
241
267
  gate = F.softmax(logits, dim=1)
242
268
  else:
243
269
  gate = logits
244
270
 
245
- gate = gate.unsqueeze(-1) # (B, E, 1)
271
+ gate = gate.unsqueeze(-1) # (B, E, 1)
246
272
  z_t = torch.sum(gate * masked_expert_outputs, dim=1) # (B, D)
247
273
  task_outputs.append(z_t)
248
274
 
@@ -312,12 +338,24 @@ class POSO(BaseModel):
312
338
  self.pc_sequence_features = list(pc_sequence_features or [])
313
339
  self.num_tasks = len(target)
314
340
 
315
- if not self.pc_dense_features and not self.pc_sparse_features and not self.pc_sequence_features:
316
- raise ValueError("POSO requires at least one PC feature for personalization.")
341
+ if (
342
+ not self.pc_dense_features
343
+ and not self.pc_sparse_features
344
+ and not self.pc_sequence_features
345
+ ):
346
+ raise ValueError(
347
+ "POSO requires at least one PC feature for personalization."
348
+ )
317
349
 
318
- dense_features = merge_features(self.main_dense_features, self.pc_dense_features)
319
- sparse_features = merge_features(self.main_sparse_features, self.pc_sparse_features)
320
- sequence_features = merge_features(self.main_sequence_features, self.pc_sequence_features)
350
+ dense_features = merge_features(
351
+ self.main_dense_features, self.pc_dense_features
352
+ )
353
+ sparse_features = merge_features(
354
+ self.main_sparse_features, self.pc_sparse_features
355
+ )
356
+ sequence_features = merge_features(
357
+ self.main_sequence_features, self.pc_sequence_features
358
+ )
321
359
 
322
360
  super().__init__(
323
361
  dense_features=dense_features,
@@ -338,10 +376,18 @@ class POSO(BaseModel):
338
376
 
339
377
  self.num_tasks = len(target)
340
378
  if len(tower_params_list) != self.num_tasks:
341
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
379
+ raise ValueError(
380
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
381
+ )
342
382
 
343
- self.main_features = self.main_dense_features + self.main_sparse_features + self.main_sequence_features
344
- self.pc_features = self.pc_dense_features + self.pc_sparse_features + self.pc_sequence_features
383
+ self.main_features = (
384
+ self.main_dense_features
385
+ + self.main_sparse_features
386
+ + self.main_sequence_features
387
+ )
388
+ self.pc_features = (
389
+ self.pc_dense_features + self.pc_sparse_features + self.pc_sequence_features
390
+ )
345
391
 
346
392
  self.embedding = EmbeddingLayer(features=self.all_features)
347
393
  self.main_input_dim = self.embedding.get_input_dim(self.main_features)
@@ -349,7 +395,9 @@ class POSO(BaseModel):
349
395
 
350
396
  self.architecture = architecture.lower()
351
397
  if self.architecture not in {"mlp", "mmoe"}:
352
- raise ValueError(f"Unsupported architecture '{architecture}', choose from ['mlp', 'mmoe'].")
398
+ raise ValueError(
399
+ f"Unsupported architecture '{architecture}', choose from ['mlp', 'mmoe']."
400
+ )
353
401
 
354
402
  # Build backbones
355
403
  if self.architecture == "mlp":
@@ -358,13 +406,17 @@ class POSO(BaseModel):
358
406
  for tower_params in tower_params_list:
359
407
  dims = tower_params.get("dims")
360
408
  if not dims:
361
- raise ValueError("tower_params must include a non-empty 'dims' list for POSO-MLP towers.")
409
+ raise ValueError(
410
+ "tower_params must include a non-empty 'dims' list for POSO-MLP towers."
411
+ )
362
412
  dropout = tower_params.get("dropout", 0.0)
363
413
  tower = POSOMLP(
364
414
  input_dim=self.main_input_dim,
365
415
  pc_dim=self.pc_input_dim,
366
416
  dims=dims,
367
- gate_hidden_dim=tower_params.get("gate_hidden_dim", gate_hidden_dim),
417
+ gate_hidden_dim=tower_params.get(
418
+ "gate_hidden_dim", gate_hidden_dim
419
+ ),
368
420
  scale_factor=tower_params.get("scale_factor", gate_scale_factor),
369
421
  activation=tower_params.get("activation", gate_activation),
370
422
  use_bias=tower_params.get("use_bias", gate_use_bias),
@@ -375,7 +427,9 @@ class POSO(BaseModel):
375
427
  self.tower_heads.append(nn.Linear(tower_output_dim, 1))
376
428
  else:
377
429
  if expert_hidden_dims is None or not expert_hidden_dims:
378
- raise ValueError("expert_hidden_dims must be provided for MMoE architecture.")
430
+ raise ValueError(
431
+ "expert_hidden_dims must be provided for MMoE architecture."
432
+ )
379
433
  self.mmoe = POSOMMoE(
380
434
  input_dim=self.main_input_dim,
381
435
  pc_dim=self.pc_input_dim,
@@ -388,12 +442,35 @@ class POSO(BaseModel):
388
442
  scale_factor=expert_gate_scale_factor,
389
443
  gate_use_softmax=gate_use_softmax,
390
444
  )
391
- self.towers = nn.ModuleList([MLP(input_dim=self.mmoe.expert_output_dim, output_layer=True, **tower_params,) for tower_params in tower_params_list])
445
+ self.towers = nn.ModuleList(
446
+ [
447
+ MLP(
448
+ input_dim=self.mmoe.expert_output_dim,
449
+ output_layer=True,
450
+ **tower_params,
451
+ )
452
+ for tower_params in tower_params_list
453
+ ]
454
+ )
392
455
  self.tower_heads = None
393
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks,)
394
- include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
395
- self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
396
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
456
+ self.prediction_layer = PredictionLayer(
457
+ task_type=self.default_task,
458
+ task_dims=[1] * self.num_tasks,
459
+ )
460
+ include_modules = (
461
+ ["towers", "tower_heads"]
462
+ if self.architecture == "mlp"
463
+ else ["mmoe", "towers"]
464
+ )
465
+ self.register_regularization_weights(
466
+ embedding_attr="embedding", include_modules=include_modules
467
+ )
468
+ self.compile(
469
+ optimizer=optimizer,
470
+ optimizer_params=optimizer_params,
471
+ loss=loss,
472
+ loss_params=loss_params,
473
+ )
397
474
 
398
475
  def forward(self, x):
399
476
  # Embed main and PC features separately so PC can gate hidden units
@@ -56,28 +56,30 @@ class ShareBottom(BaseModel):
56
56
  def default_task(self):
57
57
  num_tasks = getattr(self, "num_tasks", None)
58
58
  if num_tasks is not None and num_tasks > 0:
59
- return ['binary'] * num_tasks
60
- return ['binary']
61
-
62
- def __init__(self,
63
- dense_features: list[DenseFeature],
64
- sparse_features: list[SparseFeature],
65
- sequence_features: list[SequenceFeature],
66
- bottom_params: dict,
67
- tower_params_list: list[dict],
68
- target: list[str],
69
- task: str | list[str] | None = None,
70
- optimizer: str = "adam",
71
- optimizer_params: dict = {},
72
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
73
- loss_params: dict | list[dict] | None = None,
74
- device: str = 'cpu',
75
- embedding_l1_reg=1e-6,
76
- dense_l1_reg=1e-5,
77
- embedding_l2_reg=1e-5,
78
- dense_l2_reg=1e-4,
79
- **kwargs):
80
-
59
+ return ["binary"] * num_tasks
60
+ return ["binary"]
61
+
62
+ def __init__(
63
+ self,
64
+ dense_features: list[DenseFeature],
65
+ sparse_features: list[SparseFeature],
66
+ sequence_features: list[SequenceFeature],
67
+ bottom_params: dict,
68
+ tower_params_list: list[dict],
69
+ target: list[str],
70
+ task: str | list[str] | None = None,
71
+ optimizer: str = "adam",
72
+ optimizer_params: dict = {},
73
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
74
+ loss_params: dict | list[dict] | None = None,
75
+ device: str = "cpu",
76
+ embedding_l1_reg=1e-6,
77
+ dense_l1_reg=1e-5,
78
+ embedding_l2_reg=1e-5,
79
+ dense_l2_reg=1e-4,
80
+ **kwargs,
81
+ ):
82
+
81
83
  self.num_tasks = len(target)
82
84
 
83
85
  super(ShareBottom, self).__init__(
@@ -91,7 +93,7 @@ class ShareBottom(BaseModel):
91
93
  dense_l1_reg=dense_l1_reg,
92
94
  embedding_l2_reg=embedding_l2_reg,
93
95
  dense_l2_reg=dense_l2_reg,
94
- **kwargs
96
+ **kwargs,
95
97
  )
96
98
 
97
99
  self.loss = loss
@@ -100,7 +102,9 @@ class ShareBottom(BaseModel):
100
102
  # Number of tasks
101
103
  self.num_tasks = len(target)
102
104
  if len(tower_params_list) != self.num_tasks:
103
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
105
+ raise ValueError(
106
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
107
+ )
104
108
  # Embedding layer
105
109
  self.embedding = EmbeddingLayer(features=self.all_features)
106
110
  # Calculate input dimension
@@ -108,39 +112,48 @@ class ShareBottom(BaseModel):
108
112
  # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
109
113
  # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
110
114
  # input_dim = emb_dim_total + dense_input_dim
111
-
115
+
112
116
  # Shared bottom network
113
117
  self.bottom = MLP(input_dim=input_dim, output_layer=False, **bottom_params)
114
-
118
+
115
119
  # Get bottom output dimension
116
- if 'dims' in bottom_params and len(bottom_params['dims']) > 0:
117
- bottom_output_dim = bottom_params['dims'][-1]
120
+ if "dims" in bottom_params and len(bottom_params["dims"]) > 0:
121
+ bottom_output_dim = bottom_params["dims"][-1]
118
122
  else:
119
123
  bottom_output_dim = input_dim
120
-
124
+
121
125
  # Task-specific towers
122
126
  self.towers = nn.ModuleList()
123
127
  for tower_params in tower_params_list:
124
128
  tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
125
129
  self.towers.append(tower)
126
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
130
+ self.prediction_layer = PredictionLayer(
131
+ task_type=self.default_task, task_dims=[1] * self.num_tasks
132
+ )
127
133
  # Register regularization weights
128
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
129
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
134
+ self.register_regularization_weights(
135
+ embedding_attr="embedding", include_modules=["bottom", "towers"]
136
+ )
137
+ self.compile(
138
+ optimizer=optimizer,
139
+ optimizer_params=optimizer_params,
140
+ loss=loss,
141
+ loss_params=loss_params,
142
+ )
130
143
 
131
144
  def forward(self, x):
132
145
  # Get all embeddings and flatten
133
146
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
134
-
147
+
135
148
  # Shared bottom
136
149
  bottom_output = self.bottom(input_flat) # [B, bottom_dim]
137
-
150
+
138
151
  # Task-specific towers
139
152
  task_outputs = []
140
153
  for tower in self.towers:
141
154
  tower_output = tower(bottom_output) # [B, 1]
142
155
  task_outputs.append(tower_output)
143
-
156
+
144
157
  # Stack outputs: [B, num_tasks]
145
158
  y = torch.cat(task_outputs, dim=1)
146
159
  return self.prediction_layer(y)
@@ -40,7 +40,7 @@ import torch
40
40
  import torch.nn as nn
41
41
 
42
42
  from nextrec.basic.model import BaseModel
43
- from nextrec.basic.layers import EmbeddingLayer, LR, PredictionLayer, InputMask
43
+ from nextrec.basic.layers import EmbeddingLayer, PredictionLayer, InputMask
44
44
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
45
45
 
46
46
 
@@ -52,25 +52,28 @@ class AFM(BaseModel):
52
52
  @property
53
53
  def default_task(self):
54
54
  return "binary"
55
-
56
- def __init__(self,
57
- dense_features: list[DenseFeature] | list = [],
58
- sparse_features: list[SparseFeature] | list = [],
59
- sequence_features: list[SequenceFeature] | list = [],
60
- attention_dim: int = 32,
61
- attention_dropout: float = 0.0,
62
- target: list[str] | list = [],
63
- task: str | list[str] | None = None,
64
- optimizer: str = "adam",
65
- optimizer_params: dict = {},
66
- loss: str | nn.Module | None = "bce",
67
- loss_params: dict | list[dict] | None = None,
68
- device: str = 'cpu',
69
- embedding_l1_reg=1e-6,
70
- dense_l1_reg=1e-5,
71
- embedding_l2_reg=1e-5,
72
- dense_l2_reg=1e-4, **kwargs):
73
-
55
+
56
+ def __init__(
57
+ self,
58
+ dense_features: list[DenseFeature] | list = [],
59
+ sparse_features: list[SparseFeature] | list = [],
60
+ sequence_features: list[SequenceFeature] | list = [],
61
+ attention_dim: int = 32,
62
+ attention_dropout: float = 0.0,
63
+ target: list[str] | list = [],
64
+ task: str | list[str] | None = None,
65
+ optimizer: str = "adam",
66
+ optimizer_params: dict = {},
67
+ loss: str | nn.Module | None = "bce",
68
+ loss_params: dict | list[dict] | None = None,
69
+ device: str = "cpu",
70
+ embedding_l1_reg=1e-6,
71
+ dense_l1_reg=1e-5,
72
+ embedding_l2_reg=1e-5,
73
+ dense_l2_reg=1e-4,
74
+ **kwargs,
75
+ ):
76
+
74
77
  super(AFM, self).__init__(
75
78
  dense_features=dense_features,
76
79
  sparse_features=sparse_features,
@@ -82,7 +85,7 @@ class AFM(BaseModel):
82
85
  dense_l1_reg=dense_l1_reg,
83
86
  embedding_l2_reg=embedding_l2_reg,
84
87
  dense_l2_reg=dense_l2_reg,
85
- **kwargs
88
+ **kwargs,
86
89
  )
87
90
 
88
91
  if target is None:
@@ -91,22 +94,30 @@ class AFM(BaseModel):
91
94
  optimizer_params = {}
92
95
  if loss is None:
93
96
  loss = "bce"
94
-
97
+
95
98
  self.fm_features = sparse_features + sequence_features
96
99
  if len(self.fm_features) < 2:
97
- raise ValueError("AFM requires at least two sparse/sequence features to build pairwise interactions.")
100
+ raise ValueError(
101
+ "AFM requires at least two sparse/sequence features to build pairwise interactions."
102
+ )
98
103
 
99
104
  # make sure all embedding dimension are the same for FM features
100
105
  self.embedding_dim = self.fm_features[0].embedding_dim
101
106
  if any(f.embedding_dim != self.embedding_dim for f in self.fm_features):
102
- raise ValueError("All FM features must share the same embedding_dim for AFM.")
107
+ raise ValueError(
108
+ "All FM features must share the same embedding_dim for AFM."
109
+ )
103
110
 
104
- self.embedding = EmbeddingLayer(features=self.fm_features) # [Batch, Field, Dim ]
111
+ self.embedding = EmbeddingLayer(
112
+ features=self.fm_features
113
+ ) # [Batch, Field, Dim ]
105
114
 
106
115
  # First-order terms: dense linear + one hot embeddings
107
116
  self.dense_features = list(dense_features)
108
117
  dense_input_dim = sum([f.input_dim for f in self.dense_features])
109
- self.linear_dense = nn.Linear(dense_input_dim, 1, bias=True) if dense_input_dim > 0 else None
118
+ self.linear_dense = (
119
+ nn.Linear(dense_input_dim, 1, bias=True) if dense_input_dim > 0 else None
120
+ )
110
121
 
111
122
  # First-order term: sparse/sequence features one-hot
112
123
  # **INFO**: source paper does not contain sequence features in experiments,
@@ -114,9 +125,15 @@ class AFM(BaseModel):
114
125
  # remove sequence features from fm_features.
115
126
  self.first_order_embeddings = nn.ModuleDict()
116
127
  for feature in self.fm_features:
117
- if feature.embedding_name in self.first_order_embeddings: # shared embedding
128
+ if (
129
+ feature.embedding_name in self.first_order_embeddings
130
+ ): # shared embedding
118
131
  continue
119
- emb = nn.Embedding(num_embeddings=feature.vocab_size, embedding_dim=1, padding_idx=feature.padding_idx) # equal to one-hot encoding weight
132
+ emb = nn.Embedding(
133
+ num_embeddings=feature.vocab_size,
134
+ embedding_dim=1,
135
+ padding_idx=feature.padding_idx,
136
+ ) # equal to one-hot encoding weight
120
137
  # nn.init.zeros_(emb.weight)
121
138
  self.first_order_embeddings[feature.embedding_name] = emb
122
139
 
@@ -129,11 +146,18 @@ class AFM(BaseModel):
129
146
 
130
147
  # Register regularization weights
131
148
  self.register_regularization_weights(
132
- embedding_attr='embedding',
133
- include_modules=['linear_dense', 'attention_linear', 'attention_p', 'output_projection']
149
+ embedding_attr="embedding",
150
+ include_modules=[
151
+ "linear_dense",
152
+ "attention_linear",
153
+ "attention_p",
154
+ "output_projection",
155
+ ],
134
156
  )
135
157
  # add first-order embeddings to embedding regularization list
136
- self.embedding_params.extend(emb.weight for emb in self.first_order_embeddings.values())
158
+ self.embedding_params.extend(
159
+ emb.weight for emb in self.first_order_embeddings.values()
160
+ )
137
161
 
138
162
  self.compile(
139
163
  optimizer=optimizer,
@@ -143,13 +167,17 @@ class AFM(BaseModel):
143
167
  )
144
168
 
145
169
  def forward(self, x):
146
- field_emb = self.embedding(x=x, features=self.fm_features, squeeze_dim=False) # [B, F, D]
170
+ field_emb = self.embedding(
171
+ x=x, features=self.fm_features, squeeze_dim=False
172
+ ) # [B, F, D]
147
173
  batch_size = field_emb.size(0)
148
174
  y_linear = torch.zeros(batch_size, 1, device=field_emb.device)
149
175
 
150
176
  # First-order dense part
151
177
  if self.linear_dense is not None:
152
- dense_inputs = [x[f.name].float().view(batch_size, -1) for f in self.dense_features]
178
+ dense_inputs = [
179
+ x[f.name].float().view(batch_size, -1) for f in self.dense_features
180
+ ]
153
181
  dense_stack = torch.cat(dense_inputs, dim=1) if dense_inputs else None
154
182
  if dense_stack is not None:
155
183
  y_linear = y_linear + self.linear_dense(dense_stack)
@@ -161,7 +189,7 @@ class AFM(BaseModel):
161
189
  if isinstance(feature, SparseFeature):
162
190
  term = emb(x[feature.name].long()) # [B, 1]
163
191
  else: # SequenceFeature
164
- seq_input = x[feature.name].long() # [B, 1]
192
+ seq_input = x[feature.name].long() # [B, 1]
165
193
  if feature.max_len is not None and seq_input.size(1) > feature.max_len:
166
194
  seq_input = seq_input[:, -feature.max_len :]
167
195
  mask = self.input_mask(x, feature, seq_input).squeeze(1) # [B, 1]
@@ -169,7 +197,9 @@ class AFM(BaseModel):
169
197
  term = (seq_weight * mask).sum(dim=1, keepdim=True) # [B, 1]
170
198
  first_order_terms.append(term)
171
199
  if first_order_terms:
172
- y_linear = y_linear + torch.sum(torch.cat(first_order_terms, dim=1), dim=1, keepdim=True)
200
+ y_linear = y_linear + torch.sum(
201
+ torch.cat(first_order_terms, dim=1), dim=1, keepdim=True
202
+ )
173
203
 
174
204
  interactions = []
175
205
  feature_values = []
@@ -182,13 +212,18 @@ class AFM(BaseModel):
182
212
  else:
183
213
  if isinstance(feature, SequenceFeature):
184
214
  seq_input = x[feature.name].long()
185
- if feature.max_len is not None and seq_input.size(1) > feature.max_len:
215
+ if (
216
+ feature.max_len is not None
217
+ and seq_input.size(1) > feature.max_len
218
+ ):
186
219
  seq_input = seq_input[:, -feature.max_len :]
187
220
  value = self.input_mask(x, feature, seq_input).sum(dim=2) # [B, 1]
188
221
  else:
189
222
  value = torch.ones(batch_size, 1, device=field_emb.device)
190
223
  feature_values.append(value)
191
- feature_values_tensor = torch.cat(feature_values, dim=1).unsqueeze(-1) # [B, F, 1]
224
+ feature_values_tensor = torch.cat(feature_values, dim=1).unsqueeze(
225
+ -1
226
+ ) # [B, F, 1]
192
227
  field_emb = field_emb * feature_values_tensor
193
228
 
194
229
  num_fields = field_emb.shape[1]