nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
  45. nextrec-0.1.2.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -17,13 +17,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
17
  class PLE(BaseModel):
18
18
  """
19
19
  Progressive Layered Extraction
20
-
20
+
21
21
  PLE is an advanced multi-task learning model that extends MMOE by introducing
22
22
  both task-specific experts and shared experts at each level. It uses a progressive
23
23
  routing mechanism where experts from level k feed into gates at level k+1.
24
24
  This design better captures task-specific and shared information progressively.
25
25
  """
26
-
26
+
27
27
  @property
28
28
  def model_name(self):
29
29
  return "PLE"
@@ -31,29 +31,31 @@ class PLE(BaseModel):
31
31
  @property
32
32
  def task_type(self):
33
33
  return self.task if isinstance(self.task, list) else [self.task]
34
-
35
- def __init__(self,
36
- dense_features: list[DenseFeature],
37
- sparse_features: list[SparseFeature],
38
- sequence_features: list[SequenceFeature],
39
- shared_expert_params: dict,
40
- specific_expert_params: dict,
41
- num_shared_experts: int,
42
- num_specific_experts: int,
43
- num_levels: int,
44
- tower_params_list: list[dict],
45
- target: list[str],
46
- task: str | list[str] = 'binary',
47
- optimizer: str = "adam",
48
- optimizer_params: dict = {},
49
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
50
- device: str = 'cpu',
51
- model_id: str = "baseline",
52
- embedding_l1_reg=1e-6,
53
- dense_l1_reg=1e-5,
54
- embedding_l2_reg=1e-5,
55
- dense_l2_reg=1e-4):
56
-
34
+
35
+ def __init__(
36
+ self,
37
+ dense_features: list[DenseFeature],
38
+ sparse_features: list[SparseFeature],
39
+ sequence_features: list[SequenceFeature],
40
+ shared_expert_params: dict,
41
+ specific_expert_params: dict,
42
+ num_shared_experts: int,
43
+ num_specific_experts: int,
44
+ num_levels: int,
45
+ tower_params_list: list[dict],
46
+ target: list[str],
47
+ task: str | list[str] = "binary",
48
+ optimizer: str = "adam",
49
+ optimizer_params: dict = {},
50
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
51
+ device: str = "cpu",
52
+ model_id: str = "baseline",
53
+ embedding_l1_reg=1e-6,
54
+ dense_l1_reg=1e-5,
55
+ embedding_l2_reg=1e-5,
56
+ dense_l2_reg=1e-4,
57
+ ):
58
+
57
59
  super(PLE, self).__init__(
58
60
  dense_features=dense_features,
59
61
  sparse_features=sparse_features,
@@ -66,13 +68,13 @@ class PLE(BaseModel):
66
68
  embedding_l2_reg=embedding_l2_reg,
67
69
  dense_l2_reg=dense_l2_reg,
68
70
  early_stop_patience=20,
69
- model_id=model_id
71
+ model_id=model_id,
70
72
  )
71
73
 
72
74
  self.loss = loss
73
75
  if self.loss is None:
74
76
  self.loss = "bce"
75
-
77
+
76
78
  # Number of tasks, experts, and levels
77
79
  self.num_tasks = len(target)
78
80
  self.num_shared_experts = num_shared_experts
@@ -80,10 +82,12 @@ class PLE(BaseModel):
80
82
  self.num_levels = num_levels
81
83
  if optimizer_params is None:
82
84
  optimizer_params = {}
83
-
85
+
84
86
  if len(tower_params_list) != self.num_tasks:
85
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
86
-
87
+ raise ValueError(
88
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
89
+ )
90
+
87
91
  # All features
88
92
  self.all_features = dense_features + sparse_features + sequence_features
89
93
 
@@ -91,42 +95,60 @@ class PLE(BaseModel):
91
95
  self.embedding = EmbeddingLayer(features=self.all_features)
92
96
 
93
97
  # Calculate input dimension
94
- emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
95
- dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
98
+ emb_dim_total = sum(
99
+ [
100
+ f.embedding_dim
101
+ for f in self.all_features
102
+ if not isinstance(f, DenseFeature)
103
+ ]
104
+ )
105
+ dense_input_dim = sum(
106
+ [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
107
+ )
96
108
  input_dim = emb_dim_total + dense_input_dim
97
-
109
+
98
110
  # Get expert output dimension
99
- if 'dims' in shared_expert_params and len(shared_expert_params['dims']) > 0:
100
- expert_output_dim = shared_expert_params['dims'][-1]
111
+ if "dims" in shared_expert_params and len(shared_expert_params["dims"]) > 0:
112
+ expert_output_dim = shared_expert_params["dims"][-1]
101
113
  else:
102
114
  expert_output_dim = input_dim
103
-
115
+
104
116
  # Build extraction layers (CGC layers)
105
117
  self.shared_experts_layers = nn.ModuleList() # [num_levels]
106
118
  self.specific_experts_layers = nn.ModuleList() # [num_levels, num_tasks]
107
- self.gates_layers = nn.ModuleList() # [num_levels, num_tasks + 1] (+1 for shared gate)
108
-
119
+ self.gates_layers = (
120
+ nn.ModuleList()
121
+ ) # [num_levels, num_tasks + 1] (+1 for shared gate)
122
+
109
123
  for level in range(num_levels):
110
124
  # Input dimension for this level
111
125
  level_input_dim = input_dim if level == 0 else expert_output_dim
112
-
126
+
113
127
  # Shared experts for this level
114
128
  shared_experts = nn.ModuleList()
115
129
  for _ in range(num_shared_experts):
116
- expert = MLP(input_dim=level_input_dim, output_layer=False, **shared_expert_params)
130
+ expert = MLP(
131
+ input_dim=level_input_dim,
132
+ output_layer=False,
133
+ **shared_expert_params,
134
+ )
117
135
  shared_experts.append(expert)
118
136
  self.shared_experts_layers.append(shared_experts)
119
-
137
+
120
138
  # Task-specific experts for this level
121
139
  specific_experts_for_tasks = nn.ModuleList()
122
140
  for _ in range(self.num_tasks):
123
141
  task_experts = nn.ModuleList()
124
142
  for _ in range(num_specific_experts):
125
- expert = MLP(input_dim=level_input_dim, output_layer=False, **specific_expert_params)
143
+ expert = MLP(
144
+ input_dim=level_input_dim,
145
+ output_layer=False,
146
+ **specific_expert_params,
147
+ )
126
148
  task_experts.append(expert)
127
149
  specific_experts_for_tasks.append(task_experts)
128
150
  self.specific_experts_layers.append(specific_experts_for_tasks)
129
-
151
+
130
152
  # Gates for this level (num_tasks task gates + 1 shared gate)
131
153
  gates = nn.ModuleList()
132
154
  # Task-specific gates
@@ -134,40 +156,42 @@ class PLE(BaseModel):
134
156
  for _ in range(self.num_tasks):
135
157
  gate = nn.Sequential(
136
158
  nn.Linear(level_input_dim, num_experts_for_task_gate),
137
- nn.Softmax(dim=1)
159
+ nn.Softmax(dim=1),
138
160
  )
139
161
  gates.append(gate)
140
162
  # Shared gate: contains all tasks' specific experts + shared experts
141
163
  # expert counts = num_shared_experts + num_specific_experts * num_tasks
142
- num_experts_for_shared_gate = num_shared_experts + num_specific_experts * self.num_tasks
164
+ num_experts_for_shared_gate = (
165
+ num_shared_experts + num_specific_experts * self.num_tasks
166
+ )
143
167
  shared_gate = nn.Sequential(
144
168
  nn.Linear(level_input_dim, num_experts_for_shared_gate),
145
- nn.Softmax(dim=1)
169
+ nn.Softmax(dim=1),
146
170
  )
147
171
  gates.append(shared_gate)
148
172
  self.gates_layers.append(gates)
149
-
173
+
150
174
  # Task-specific towers
151
175
  self.towers = nn.ModuleList()
152
176
  for tower_params in tower_params_list:
153
177
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
154
178
  self.towers.append(tower)
155
179
  self.prediction_layer = PredictionLayer(
156
- task_type=self.task_type,
157
- task_dims=[1] * self.num_tasks
180
+ task_type=self.task_type, task_dims=[1] * self.num_tasks
158
181
  )
159
182
 
160
183
  # Register regularization weights
161
184
  self._register_regularization_weights(
162
- embedding_attr='embedding',
163
- include_modules=['shared_experts_layers', 'specific_experts_layers', 'gates_layers', 'towers']
185
+ embedding_attr="embedding",
186
+ include_modules=[
187
+ "shared_experts_layers",
188
+ "specific_experts_layers",
189
+ "gates_layers",
190
+ "towers",
191
+ ],
164
192
  )
165
193
 
166
- self.compile(
167
- optimizer=optimizer,
168
- optimizer_params=optimizer_params,
169
- loss=loss
170
- )
194
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
171
195
 
172
196
  def forward(self, x):
173
197
  # Get all embeddings and flatten
@@ -179,13 +203,17 @@ class PLE(BaseModel):
179
203
 
180
204
  # Progressive Layered Extraction: CGC
181
205
  for level in range(self.num_levels):
182
- shared_experts = self.shared_experts_layers[level] # ModuleList[num_shared_experts]
183
- specific_experts = self.specific_experts_layers[level] # ModuleList[num_tasks][num_specific_experts]
184
- gates = self.gates_layers[level] # ModuleList[num_tasks + 1]
206
+ shared_experts = self.shared_experts_layers[
207
+ level
208
+ ] # ModuleList[num_shared_experts]
209
+ specific_experts = self.specific_experts_layers[
210
+ level
211
+ ] # ModuleList[num_tasks][num_specific_experts]
212
+ gates = self.gates_layers[level] # ModuleList[num_tasks + 1]
185
213
 
186
214
  # Compute shared experts output for this level
187
215
  # shared_expert_list: List[Tensor[B, expert_dim]]
188
- shared_expert_list = [expert(shared_fea) for expert in shared_experts] # type: ignore[list-item]
216
+ shared_expert_list = [expert(shared_fea) for expert in shared_experts] # type: ignore[list-item]
189
217
  # [num_shared_experts, B, expert_dim]
190
218
  shared_expert_outputs = torch.stack(shared_expert_list, dim=0)
191
219
 
@@ -198,7 +226,7 @@ class PLE(BaseModel):
198
226
  current_task_in = task_fea[task_idx]
199
227
 
200
228
  # Specific task experts for this task
201
- task_expert_modules = specific_experts[task_idx] # type: ignore
229
+ task_expert_modules = specific_experts[task_idx] # type: ignore
202
230
 
203
231
  # Specific task expert output list List[Tensor[B, expert_dim]]
204
232
  task_specific_list = []
@@ -214,8 +242,7 @@ class PLE(BaseModel):
214
242
  # Input for gate: shared_experts + own specific task experts
215
243
  # [num_shared + num_specific, B, expert_dim]
216
244
  all_expert_outputs = torch.cat(
217
- [shared_expert_outputs, task_specific_outputs],
218
- dim=0
245
+ [shared_expert_outputs, task_specific_outputs], dim=0
219
246
  )
220
247
  # [B, num_experts, expert_dim]
221
248
  all_expert_outputs_t = all_expert_outputs.permute(1, 0, 2)
@@ -239,7 +266,7 @@ class PLE(BaseModel):
239
266
  all_for_shared = torch.stack(all_for_shared_list, dim=1)
240
267
 
241
268
  # [B, num_all_experts]
242
- shared_gate_weights = gates[self.num_tasks](shared_fea) # type: ignore
269
+ shared_gate_weights = gates[self.num_tasks](shared_fea) # type: ignore
243
270
  # [B, 1, num_all_experts]
244
271
  shared_gate_weights = shared_gate_weights.unsqueeze(1)
245
272
 
@@ -257,4 +284,4 @@ class PLE(BaseModel):
257
284
 
258
285
  # [B, num_tasks]
259
286
  y = torch.cat(task_outputs, dim=1)
260
- return self.prediction_layer(y)
287
+ return self.prediction_layer(y)
@@ -23,25 +23,27 @@ class ShareBottom(BaseModel):
23
23
  def task_type(self):
24
24
  # Multi-task model, return list of task types
25
25
  return self.task if isinstance(self.task, list) else [self.task]
26
-
27
- def __init__(self,
28
- dense_features: list[DenseFeature],
29
- sparse_features: list[SparseFeature],
30
- sequence_features: list[SequenceFeature],
31
- bottom_params: dict,
32
- tower_params_list: list[dict],
33
- target: list[str],
34
- task: str | list[str] = 'binary',
35
- optimizer: str = "adam",
36
- optimizer_params: dict = {},
37
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
38
- device: str = 'cpu',
39
- model_id: str = "baseline",
40
- embedding_l1_reg=1e-6,
41
- dense_l1_reg=1e-5,
42
- embedding_l2_reg=1e-5,
43
- dense_l2_reg=1e-4):
44
-
26
+
27
+ def __init__(
28
+ self,
29
+ dense_features: list[DenseFeature],
30
+ sparse_features: list[SparseFeature],
31
+ sequence_features: list[SequenceFeature],
32
+ bottom_params: dict,
33
+ tower_params_list: list[dict],
34
+ target: list[str],
35
+ task: str | list[str] = "binary",
36
+ optimizer: str = "adam",
37
+ optimizer_params: dict = {},
38
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
39
+ device: str = "cpu",
40
+ model_id: str = "baseline",
41
+ embedding_l1_reg=1e-6,
42
+ dense_l1_reg=1e-5,
43
+ embedding_l2_reg=1e-5,
44
+ dense_l2_reg=1e-4,
45
+ ):
46
+
45
47
  super(ShareBottom, self).__init__(
46
48
  dense_features=dense_features,
47
49
  sparse_features=sparse_features,
@@ -54,18 +56,20 @@ class ShareBottom(BaseModel):
54
56
  embedding_l2_reg=embedding_l2_reg,
55
57
  dense_l2_reg=dense_l2_reg,
56
58
  early_stop_patience=20,
57
- model_id=model_id
59
+ model_id=model_id,
58
60
  )
59
61
 
60
62
  self.loss = loss
61
63
  if self.loss is None:
62
64
  self.loss = "bce"
63
-
65
+
64
66
  # Number of tasks
65
67
  self.num_tasks = len(target)
66
68
  if len(tower_params_list) != self.num_tasks:
67
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
68
-
69
+ raise ValueError(
70
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
71
+ )
72
+
69
73
  # All features
70
74
  self.all_features = dense_features + sparse_features + sequence_features
71
75
 
@@ -73,54 +77,56 @@ class ShareBottom(BaseModel):
73
77
  self.embedding = EmbeddingLayer(features=self.all_features)
74
78
 
75
79
  # Calculate input dimension
76
- emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
77
- dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
80
+ emb_dim_total = sum(
81
+ [
82
+ f.embedding_dim
83
+ for f in self.all_features
84
+ if not isinstance(f, DenseFeature)
85
+ ]
86
+ )
87
+ dense_input_dim = sum(
88
+ [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
89
+ )
78
90
  input_dim = emb_dim_total + dense_input_dim
79
-
91
+
80
92
  # Shared bottom network
81
93
  self.bottom = MLP(input_dim=input_dim, output_layer=False, **bottom_params)
82
-
94
+
83
95
  # Get bottom output dimension
84
- if 'dims' in bottom_params and len(bottom_params['dims']) > 0:
85
- bottom_output_dim = bottom_params['dims'][-1]
96
+ if "dims" in bottom_params and len(bottom_params["dims"]) > 0:
97
+ bottom_output_dim = bottom_params["dims"][-1]
86
98
  else:
87
99
  bottom_output_dim = input_dim
88
-
100
+
89
101
  # Task-specific towers
90
102
  self.towers = nn.ModuleList()
91
103
  for tower_params in tower_params_list:
92
104
  tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
93
105
  self.towers.append(tower)
94
106
  self.prediction_layer = PredictionLayer(
95
- task_type=self.task_type,
96
- task_dims=[1] * self.num_tasks
107
+ task_type=self.task_type, task_dims=[1] * self.num_tasks
97
108
  )
98
109
 
99
110
  # Register regularization weights
100
111
  self._register_regularization_weights(
101
- embedding_attr='embedding',
102
- include_modules=['bottom', 'towers']
112
+ embedding_attr="embedding", include_modules=["bottom", "towers"]
103
113
  )
104
114
 
105
- self.compile(
106
- optimizer=optimizer,
107
- optimizer_params=optimizer_params,
108
- loss=loss
109
- )
115
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
110
116
 
111
117
  def forward(self, x):
112
118
  # Get all embeddings and flatten
113
119
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
114
-
120
+
115
121
  # Shared bottom
116
122
  bottom_output = self.bottom(input_flat) # [B, bottom_dim]
117
-
123
+
118
124
  # Task-specific towers
119
125
  task_outputs = []
120
126
  for tower in self.towers:
121
127
  tower_output = tower(bottom_output) # [B, 1]
122
128
  task_outputs.append(tower_output)
123
-
129
+
124
130
  # Stack outputs: [B, num_tasks]
125
131
  y = torch.cat(task_outputs, dim=1)
126
132
  return self.prediction_layer(y)
@@ -7,11 +7,11 @@ from .din import DIN
7
7
  from .dien import DIEN
8
8
 
9
9
  __all__ = [
10
- 'DeepFM',
11
- 'AutoInt',
12
- 'WideDeep',
13
- 'xDeepFM',
14
- 'DCN',
15
- 'DIN',
16
- 'DIEN',
10
+ "DeepFM",
11
+ "AutoInt",
12
+ "WideDeep",
13
+ "xDeepFM",
14
+ "DCN",
15
+ "DIN",
16
+ "DIEN",
17
17
  ]
@@ -23,24 +23,26 @@ class AFM(BaseModel):
23
23
  @property
24
24
  def task_type(self):
25
25
  return "binary"
26
-
27
- def __init__(self,
28
- dense_features: list[DenseFeature] | list = [],
29
- sparse_features: list[SparseFeature] | list = [],
30
- sequence_features: list[SequenceFeature] | list = [],
31
- attention_dim: int = 32,
32
- attention_dropout: float = 0.0,
33
- target: list[str] | list = [],
34
- optimizer: str = "adam",
35
- optimizer_params: dict = {},
36
- loss: str | nn.Module | None = "bce",
37
- device: str = 'cpu',
38
- model_id: str = "baseline",
39
- embedding_l1_reg=1e-6,
40
- dense_l1_reg=1e-5,
41
- embedding_l2_reg=1e-5,
42
- dense_l2_reg=1e-4):
43
-
26
+
27
+ def __init__(
28
+ self,
29
+ dense_features: list[DenseFeature] | list = [],
30
+ sparse_features: list[SparseFeature] | list = [],
31
+ sequence_features: list[SequenceFeature] | list = [],
32
+ attention_dim: int = 32,
33
+ attention_dropout: float = 0.0,
34
+ target: list[str] | list = [],
35
+ optimizer: str = "adam",
36
+ optimizer_params: dict = {},
37
+ loss: str | nn.Module | None = "bce",
38
+ device: str = "cpu",
39
+ model_id: str = "baseline",
40
+ embedding_l1_reg=1e-6,
41
+ dense_l1_reg=1e-5,
42
+ embedding_l2_reg=1e-5,
43
+ dense_l2_reg=1e-4,
44
+ ):
45
+
44
46
  super(AFM, self).__init__(
45
47
  dense_features=dense_features,
46
48
  sparse_features=sparse_features,
@@ -53,21 +55,25 @@ class AFM(BaseModel):
53
55
  embedding_l2_reg=embedding_l2_reg,
54
56
  dense_l2_reg=dense_l2_reg,
55
57
  early_stop_patience=20,
56
- model_id=model_id
58
+ model_id=model_id,
57
59
  )
58
60
 
59
61
  self.loss = loss
60
62
  if self.loss is None:
61
63
  self.loss = "bce"
62
-
64
+
63
65
  self.fm_features = sparse_features + sequence_features
64
66
  if len(self.fm_features) < 2:
65
- raise ValueError("AFM requires at least two sparse/sequence features to build pairwise interactions.")
67
+ raise ValueError(
68
+ "AFM requires at least two sparse/sequence features to build pairwise interactions."
69
+ )
66
70
 
67
71
  # Assume uniform embedding dimension across FM fields
68
72
  self.embedding_dim = self.fm_features[0].embedding_dim
69
73
  if any(f.embedding_dim != self.embedding_dim for f in self.fm_features):
70
- raise ValueError("All FM features must share the same embedding_dim for AFM.")
74
+ raise ValueError(
75
+ "All FM features must share the same embedding_dim for AFM."
76
+ )
71
77
 
72
78
  self.embedding = EmbeddingLayer(features=self.fm_features)
73
79
 
@@ -82,18 +88,21 @@ class AFM(BaseModel):
82
88
 
83
89
  # Register regularization weights
84
90
  self._register_regularization_weights(
85
- embedding_attr='embedding',
86
- include_modules=['linear', 'attention_linear', 'attention_p', 'output_projection']
91
+ embedding_attr="embedding",
92
+ include_modules=[
93
+ "linear",
94
+ "attention_linear",
95
+ "attention_p",
96
+ "output_projection",
97
+ ],
87
98
  )
88
99
 
89
- self.compile(
90
- optimizer=optimizer,
91
- optimizer_params=optimizer_params,
92
- loss=loss
93
- )
100
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
94
101
 
95
102
  def forward(self, x):
96
- field_emb = self.embedding(x=x, features=self.fm_features, squeeze_dim=False) # [B, F, D]
103
+ field_emb = self.embedding(
104
+ x=x, features=self.fm_features, squeeze_dim=False
105
+ ) # [B, F, D]
97
106
  input_linear = field_emb.flatten(start_dim=1)
98
107
  y_linear = self.linear(input_linear)
99
108