nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +220 -106
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1082 -400
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +272 -95
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +53 -37
  29. nextrec/models/multi_task/mmoe.py +64 -45
  30. nextrec/models/multi_task/ple.py +101 -48
  31. nextrec/models/multi_task/poso.py +113 -36
  32. nextrec/models/multi_task/share_bottom.py +48 -35
  33. nextrec/models/ranking/afm.py +72 -37
  34. nextrec/models/ranking/autoint.py +72 -55
  35. nextrec/models/ranking/dcn.py +55 -35
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +32 -22
  38. nextrec/models/ranking/dien.py +155 -99
  39. nextrec/models/ranking/din.py +85 -57
  40. nextrec/models/ranking/fibinet.py +52 -32
  41. nextrec/models/ranking/fm.py +29 -23
  42. nextrec/models/ranking/masknet.py +91 -29
  43. nextrec/models/ranking/pnn.py +31 -28
  44. nextrec/models/ranking/widedeep.py +34 -26
  45. nextrec/models/ranking/xdeepfm.py +60 -38
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +283 -165
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.4.1.dist-info/RECORD +0 -66
  61. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -52,87 +52,103 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
52
52
  class ESMM(BaseModel):
53
53
  """
54
54
  Entire Space Multi-Task Model
55
-
55
+
56
56
  ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
57
57
  - CTR task: P(click | impression)
58
58
  - CVR task: P(conversion | click)
59
59
  - CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
60
-
60
+
61
61
  This design addresses the sample selection bias and data sparsity issues in CVR modeling.
62
62
  """
63
-
63
+
64
64
  @property
65
65
  def model_name(self):
66
66
  return "ESMM"
67
-
67
+
68
68
  @property
69
69
  def default_task(self):
70
- return ['binary', 'binary']
71
-
72
- def __init__(self,
73
- dense_features: list[DenseFeature],
74
- sparse_features: list[SparseFeature],
75
- sequence_features: list[SequenceFeature],
76
- ctr_params: dict,
77
- cvr_params: dict,
78
- target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
79
- task: list[str] | None = None,
80
- optimizer: str = "adam",
81
- optimizer_params: dict = {},
82
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
83
- loss_params: dict | list[dict] | None = None,
84
- device: str = 'cpu',
85
- embedding_l1_reg=1e-6,
86
- dense_l1_reg=1e-5,
87
- embedding_l2_reg=1e-5,
88
- dense_l2_reg=1e-4,
89
- **kwargs):
90
-
70
+ return ["binary", "binary"]
71
+
72
+ def __init__(
73
+ self,
74
+ dense_features: list[DenseFeature],
75
+ sparse_features: list[SparseFeature],
76
+ sequence_features: list[SequenceFeature],
77
+ ctr_params: dict,
78
+ cvr_params: dict,
79
+ target: list[str] = ["ctr", "ctcvr"], # Note: ctcvr = ctr * cvr
80
+ task: list[str] | None = None,
81
+ optimizer: str = "adam",
82
+ optimizer_params: dict = {},
83
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
84
+ loss_params: dict | list[dict] | None = None,
85
+ device: str = "cpu",
86
+ embedding_l1_reg=1e-6,
87
+ dense_l1_reg=1e-5,
88
+ embedding_l2_reg=1e-5,
89
+ dense_l2_reg=1e-4,
90
+ **kwargs,
91
+ ):
92
+
91
93
  # ESMM requires exactly 2 targets: ctr and ctcvr
92
94
  if len(target) != 2:
93
- raise ValueError(f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}")
94
-
95
+ raise ValueError(
96
+ f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
97
+ )
98
+
95
99
  super(ESMM, self).__init__(
96
100
  dense_features=dense_features,
97
101
  sparse_features=sparse_features,
98
102
  sequence_features=sequence_features,
99
103
  target=target,
100
- task=task or self.default_task, # Both CTR and CTCVR are binary classification
104
+ task=task
105
+ or self.default_task, # Both CTR and CTCVR are binary classification
101
106
  device=device,
102
107
  embedding_l1_reg=embedding_l1_reg,
103
108
  dense_l1_reg=dense_l1_reg,
104
109
  embedding_l2_reg=embedding_l2_reg,
105
110
  dense_l2_reg=dense_l2_reg,
106
- **kwargs
111
+ **kwargs,
107
112
  )
108
113
 
109
114
  self.loss = loss
110
115
  if self.loss is None:
111
116
  self.loss = "bce"
112
-
117
+
113
118
  # All features
114
119
  self.all_features = dense_features + sparse_features + sequence_features
115
120
  # Shared embedding layer
116
121
  self.embedding = EmbeddingLayer(features=self.all_features)
117
- input_dim = self.embedding.input_dim # Calculate input dimension, better way than below
122
+ input_dim = (
123
+ self.embedding.input_dim
124
+ ) # Calculate input dimension, better way than below
118
125
  # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
119
126
  # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
120
127
  # input_dim = emb_dim_total + dense_input_dim
121
128
 
122
129
  # CTR tower
123
130
  self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
124
-
131
+
125
132
  # CVR tower
126
133
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
127
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1, 1])
134
+ self.prediction_layer = PredictionLayer(
135
+ task_type=self.default_task, task_dims=[1, 1]
136
+ )
128
137
  # Register regularization weights
129
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
130
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
138
+ self.register_regularization_weights(
139
+ embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
140
+ )
141
+ self.compile(
142
+ optimizer=optimizer,
143
+ optimizer_params=optimizer_params,
144
+ loss=loss,
145
+ loss_params=loss_params,
146
+ )
131
147
 
132
148
  def forward(self, x):
133
149
  # Get all embeddings and flatten
134
150
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
135
-
151
+
136
152
  # CTR prediction: P(click | impression)
137
153
  ctr_logit = self.ctr_tower(input_flat) # [B, 1]
138
154
  cvr_logit = self.cvr_tower(input_flat) # [B, 1]
@@ -140,7 +156,7 @@ class ESMM(BaseModel):
140
156
  preds = self.prediction_layer(logits)
141
157
  ctr, cvr = preds.chunk(2, dim=1)
142
158
  ctcvr = ctr * cvr # [B, 1]
143
-
159
+
144
160
  # Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
145
161
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
146
162
  return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
@@ -53,13 +53,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
53
53
  class MMOE(BaseModel):
54
54
  """
55
55
  Multi-gate Mixture-of-Experts
56
-
56
+
57
57
  MMOE improves upon shared-bottom architecture by using multiple expert networks
58
58
  and task-specific gating networks. Each task has its own gate that learns to
59
59
  weight the contributions of different experts, allowing for both task-specific
60
60
  and shared representations.
61
61
  """
62
-
62
+
63
63
  @property
64
64
  def model_name(self):
65
65
  return "MMOE"
@@ -68,29 +68,31 @@ class MMOE(BaseModel):
68
68
  def default_task(self):
69
69
  num_tasks = getattr(self, "num_tasks", None)
70
70
  if num_tasks is not None and num_tasks > 0:
71
- return ['binary'] * num_tasks
72
- return ['binary']
73
-
74
- def __init__(self,
75
- dense_features: list[DenseFeature]=[],
76
- sparse_features: list[SparseFeature]=[],
77
- sequence_features: list[SequenceFeature]=[],
78
- expert_params: dict={},
79
- num_experts: int=3,
80
- tower_params_list: list[dict]=[],
81
- target: list[str]=[],
82
- task: str | list[str] | None = None,
83
- optimizer: str = "adam",
84
- optimizer_params: dict = {},
85
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
86
- loss_params: dict | list[dict] | None = None,
87
- device: str = 'cpu',
88
- embedding_l1_reg=1e-6,
89
- dense_l1_reg=1e-5,
90
- embedding_l2_reg=1e-5,
91
- dense_l2_reg=1e-4,
92
- **kwargs):
93
-
71
+ return ["binary"] * num_tasks
72
+ return ["binary"]
73
+
74
+ def __init__(
75
+ self,
76
+ dense_features: list[DenseFeature] = [],
77
+ sparse_features: list[SparseFeature] = [],
78
+ sequence_features: list[SequenceFeature] = [],
79
+ expert_params: dict = {},
80
+ num_experts: int = 3,
81
+ tower_params_list: list[dict] = [],
82
+ target: list[str] = [],
83
+ task: str | list[str] | None = None,
84
+ optimizer: str = "adam",
85
+ optimizer_params: dict = {},
86
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
87
+ loss_params: dict | list[dict] | None = None,
88
+ device: str = "cpu",
89
+ embedding_l1_reg=1e-6,
90
+ dense_l1_reg=1e-5,
91
+ embedding_l2_reg=1e-5,
92
+ dense_l2_reg=1e-4,
93
+ **kwargs,
94
+ ):
95
+
94
96
  self.num_tasks = len(target)
95
97
 
96
98
  super(MMOE, self).__init__(
@@ -104,19 +106,21 @@ class MMOE(BaseModel):
104
106
  dense_l1_reg=dense_l1_reg,
105
107
  embedding_l2_reg=embedding_l2_reg,
106
108
  dense_l2_reg=dense_l2_reg,
107
- **kwargs
109
+ **kwargs,
108
110
  )
109
111
 
110
112
  self.loss = loss
111
113
  if self.loss is None:
112
114
  self.loss = "bce"
113
-
115
+
114
116
  # Number of tasks and experts
115
117
  self.num_tasks = len(target)
116
118
  self.num_experts = num_experts
117
-
119
+
118
120
  if len(tower_params_list) != self.num_tasks:
119
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
121
+ raise ValueError(
122
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
123
+ )
120
124
 
121
125
  self.all_features = dense_features + sparse_features + sequence_features
122
126
  self.embedding = EmbeddingLayer(features=self.all_features)
@@ -130,54 +134,69 @@ class MMOE(BaseModel):
130
134
  for _ in range(num_experts):
131
135
  expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
132
136
  self.experts.append(expert)
133
-
137
+
134
138
  # Get expert output dimension
135
- if 'dims' in expert_params and len(expert_params['dims']) > 0:
136
- expert_output_dim = expert_params['dims'][-1]
139
+ if "dims" in expert_params and len(expert_params["dims"]) > 0:
140
+ expert_output_dim = expert_params["dims"][-1]
137
141
  else:
138
142
  expert_output_dim = input_dim
139
-
143
+
140
144
  # Task-specific gates
141
145
  self.gates = nn.ModuleList()
142
146
  for _ in range(self.num_tasks):
143
147
  gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
144
148
  self.gates.append(gate)
145
-
149
+
146
150
  # Task-specific towers
147
151
  self.towers = nn.ModuleList()
148
152
  for tower_params in tower_params_list:
149
153
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
150
154
  self.towers.append(tower)
151
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
155
+ self.prediction_layer = PredictionLayer(
156
+ task_type=self.default_task, task_dims=[1] * self.num_tasks
157
+ )
152
158
  # Register regularization weights
153
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
154
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
159
+ self.register_regularization_weights(
160
+ embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
161
+ )
162
+ self.compile(
163
+ optimizer=optimizer,
164
+ optimizer_params=optimizer_params,
165
+ loss=loss,
166
+ loss_params=loss_params,
167
+ )
155
168
 
156
169
  def forward(self, x):
157
170
  # Get all embeddings and flatten
158
171
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
159
-
172
+
160
173
  # Expert outputs: [num_experts, B, expert_dim]
161
174
  expert_outputs = [expert(input_flat) for expert in self.experts]
162
- expert_outputs = torch.stack(expert_outputs, dim=0) # [num_experts, B, expert_dim]
163
-
175
+ expert_outputs = torch.stack(
176
+ expert_outputs, dim=0
177
+ ) # [num_experts, B, expert_dim]
178
+
164
179
  # Task-specific processing
165
180
  task_outputs = []
166
181
  for task_idx in range(self.num_tasks):
167
182
  # Gate weights for this task: [B, num_experts]
168
183
  gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
169
-
184
+
170
185
  # Weighted sum of expert outputs
171
186
  # gate_weights: [B, num_experts, 1]
172
187
  # expert_outputs: [num_experts, B, expert_dim]
173
188
  gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
174
- expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, num_experts, expert_dim]
175
- gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1) # [B, expert_dim]
176
-
189
+ expert_outputs_t = expert_outputs.permute(
190
+ 1, 0, 2
191
+ ) # [B, num_experts, expert_dim]
192
+ gated_output = torch.sum(
193
+ gate_weights * expert_outputs_t, dim=1
194
+ ) # [B, expert_dim]
195
+
177
196
  # Tower output
178
197
  tower_output = self.towers[task_idx](gated_output) # [B, 1]
179
198
  task_outputs.append(tower_output)
180
-
199
+
181
200
  # Stack outputs: [B, num_tasks]
182
201
  y = torch.cat(task_outputs, dim=1)
183
202
  return self.prediction_layer(y)
@@ -52,6 +52,7 @@ from nextrec.basic.model import BaseModel
52
52
  from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
53
53
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
54
54
 
55
+
55
56
  class CGCLayer(nn.Module):
56
57
  """
57
58
  CGC (Customized Gate Control) block used by PLE.
@@ -71,26 +72,61 @@ class CGCLayer(nn.Module):
71
72
  if num_tasks < 1:
72
73
  raise ValueError("num_tasks must be >= 1")
73
74
 
74
- specific_params_list = self._normalize_specific_params(specific_expert_params, num_tasks)
75
+ specific_params_list = self._normalize_specific_params(
76
+ specific_expert_params, num_tasks
77
+ )
75
78
 
76
79
  self.output_dim = self._get_output_dim(shared_expert_params, input_dim)
77
- specific_dims = [self._get_output_dim(params, input_dim) for params in specific_params_list]
80
+ specific_dims = [
81
+ self._get_output_dim(params, input_dim) for params in specific_params_list
82
+ ]
78
83
  dims_set = set(specific_dims + [self.output_dim])
79
84
  if len(dims_set) != 1:
80
- raise ValueError(f"Shared/specific expert output dims must match, got {dims_set}")
85
+ raise ValueError(
86
+ f"Shared/specific expert output dims must match, got {dims_set}"
87
+ )
81
88
 
82
89
  # experts
83
- self.shared_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **shared_expert_params,) for _ in range(num_shared_experts)])
90
+ self.shared_experts = nn.ModuleList(
91
+ [
92
+ MLP(
93
+ input_dim=input_dim,
94
+ output_layer=False,
95
+ **shared_expert_params,
96
+ )
97
+ for _ in range(num_shared_experts)
98
+ ]
99
+ )
84
100
  self.specific_experts = nn.ModuleList()
85
101
  for params in specific_params_list:
86
- task_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **params,) for _ in range(num_specific_experts)])
102
+ task_experts = nn.ModuleList(
103
+ [
104
+ MLP(
105
+ input_dim=input_dim,
106
+ output_layer=False,
107
+ **params,
108
+ )
109
+ for _ in range(num_specific_experts)
110
+ ]
111
+ )
87
112
  self.specific_experts.append(task_experts)
88
113
 
89
114
  # gates
90
115
  task_gate_expert_num = num_shared_experts + num_specific_experts
91
- self.task_gates = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, task_gate_expert_num), nn.Softmax(dim=1),) for _ in range(num_tasks)])
116
+ self.task_gates = nn.ModuleList(
117
+ [
118
+ nn.Sequential(
119
+ nn.Linear(input_dim, task_gate_expert_num),
120
+ nn.Softmax(dim=1),
121
+ )
122
+ for _ in range(num_tasks)
123
+ ]
124
+ )
92
125
  shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
93
- self.shared_gate = nn.Sequential(nn.Linear(input_dim, shared_gate_expert_num), nn.Softmax(dim=1),)
126
+ self.shared_gate = nn.Sequential(
127
+ nn.Linear(input_dim, shared_gate_expert_num),
128
+ nn.Softmax(dim=1),
129
+ )
94
130
 
95
131
  self.num_tasks = num_tasks
96
132
 
@@ -98,7 +134,9 @@ class CGCLayer(nn.Module):
98
134
  self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
99
135
  ) -> tuple[list[torch.Tensor], torch.Tensor]:
100
136
  if len(task_inputs) != self.num_tasks:
101
- raise ValueError(f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}")
137
+ raise ValueError(
138
+ f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}"
139
+ )
102
140
 
103
141
  shared_outputs = [expert(shared_input) for expert in self.shared_experts]
104
142
  shared_stack = torch.stack(shared_outputs, dim=0) # [num_shared, B, D]
@@ -108,7 +146,7 @@ class CGCLayer(nn.Module):
108
146
 
109
147
  for task_idx in range(self.num_tasks):
110
148
  task_input = task_inputs[task_idx]
111
- task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
149
+ task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
112
150
  all_specific_for_shared.extend(task_specific_outputs)
113
151
  specific_stack = torch.stack(task_specific_outputs, dim=0)
114
152
 
@@ -139,7 +177,9 @@ class CGCLayer(nn.Module):
139
177
  ) -> list[dict]:
140
178
  if isinstance(params, list):
141
179
  if len(params) != num_tasks:
142
- raise ValueError(f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks}).")
180
+ raise ValueError(
181
+ f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks})."
182
+ )
143
183
  return [p.copy() for p in params]
144
184
  return [params.copy() for _ in range(num_tasks)]
145
185
 
@@ -147,13 +187,13 @@ class CGCLayer(nn.Module):
147
187
  class PLE(BaseModel):
148
188
  """
149
189
  Progressive Layered Extraction
150
-
190
+
151
191
  PLE is an advanced multi-task learning model that extends MMOE by introducing
152
192
  both task-specific experts and shared experts at each level. It uses a progressive
153
193
  routing mechanism where experts from level k feed into gates at level k+1.
154
194
  This design better captures task-specific and shared information progressively.
155
195
  """
156
-
196
+
157
197
  @property
158
198
  def model_name(self):
159
199
  return "PLE"
@@ -162,32 +202,34 @@ class PLE(BaseModel):
162
202
  def default_task(self):
163
203
  num_tasks = getattr(self, "num_tasks", None)
164
204
  if num_tasks is not None and num_tasks > 0:
165
- return ['binary'] * num_tasks
166
- return ['binary']
167
-
168
- def __init__(self,
169
- dense_features: list[DenseFeature],
170
- sparse_features: list[SparseFeature],
171
- sequence_features: list[SequenceFeature],
172
- shared_expert_params: dict,
173
- specific_expert_params: dict | list[dict],
174
- num_shared_experts: int,
175
- num_specific_experts: int,
176
- num_levels: int,
177
- tower_params_list: list[dict],
178
- target: list[str],
179
- task: str | list[str] | None = None,
180
- optimizer: str = "adam",
181
- optimizer_params: dict | None = None,
182
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
183
- loss_params: dict | list[dict] | None = None,
184
- device: str = 'cpu',
185
- embedding_l1_reg=1e-6,
186
- dense_l1_reg=1e-5,
187
- embedding_l2_reg=1e-5,
188
- dense_l2_reg=1e-4,
189
- **kwargs):
190
-
205
+ return ["binary"] * num_tasks
206
+ return ["binary"]
207
+
208
+ def __init__(
209
+ self,
210
+ dense_features: list[DenseFeature],
211
+ sparse_features: list[SparseFeature],
212
+ sequence_features: list[SequenceFeature],
213
+ shared_expert_params: dict,
214
+ specific_expert_params: dict | list[dict],
215
+ num_shared_experts: int,
216
+ num_specific_experts: int,
217
+ num_levels: int,
218
+ tower_params_list: list[dict],
219
+ target: list[str],
220
+ task: str | list[str] | None = None,
221
+ optimizer: str = "adam",
222
+ optimizer_params: dict | None = None,
223
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
224
+ loss_params: dict | list[dict] | None = None,
225
+ device: str = "cpu",
226
+ embedding_l1_reg=1e-6,
227
+ dense_l1_reg=1e-5,
228
+ embedding_l2_reg=1e-5,
229
+ dense_l2_reg=1e-4,
230
+ **kwargs,
231
+ ):
232
+
191
233
  self.num_tasks = len(target)
192
234
 
193
235
  super(PLE, self).__init__(
@@ -201,7 +243,7 @@ class PLE(BaseModel):
201
243
  dense_l1_reg=dense_l1_reg,
202
244
  embedding_l2_reg=embedding_l2_reg,
203
245
  dense_l2_reg=dense_l2_reg,
204
- **kwargs
246
+ **kwargs,
205
247
  )
206
248
 
207
249
  self.loss = loss
@@ -215,7 +257,9 @@ class PLE(BaseModel):
215
257
  if optimizer_params is None:
216
258
  optimizer_params = {}
217
259
  if len(tower_params_list) != self.num_tasks:
218
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
260
+ raise ValueError(
261
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
262
+ )
219
263
  # Embedding layer
220
264
  self.embedding = EmbeddingLayer(features=self.all_features)
221
265
 
@@ -224,13 +268,13 @@ class PLE(BaseModel):
224
268
  # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
225
269
  # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
226
270
  # input_dim = emb_dim_total + dense_input_dim
227
-
271
+
228
272
  # Get expert output dimension
229
- if 'dims' in shared_expert_params and len(shared_expert_params['dims']) > 0:
230
- expert_output_dim = shared_expert_params['dims'][-1]
273
+ if "dims" in shared_expert_params and len(shared_expert_params["dims"]) > 0:
274
+ expert_output_dim = shared_expert_params["dims"][-1]
231
275
  else:
232
276
  expert_output_dim = input_dim
233
-
277
+
234
278
  # Build CGC layers
235
279
  self.cgc_layers = nn.ModuleList()
236
280
  for level in range(num_levels):
@@ -245,16 +289,25 @@ class PLE(BaseModel):
245
289
  )
246
290
  self.cgc_layers.append(cgc_layer)
247
291
  expert_output_dim = cgc_layer.output_dim
248
-
292
+
249
293
  # Task-specific towers
250
294
  self.towers = nn.ModuleList()
251
295
  for tower_params in tower_params_list:
252
296
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
253
297
  self.towers.append(tower)
254
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
298
+ self.prediction_layer = PredictionLayer(
299
+ task_type=self.default_task, task_dims=[1] * self.num_tasks
300
+ )
255
301
  # Register regularization weights
256
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
257
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
302
+ self.register_regularization_weights(
303
+ embedding_attr="embedding", include_modules=["cgc_layers", "towers"]
304
+ )
305
+ self.compile(
306
+ optimizer=optimizer,
307
+ optimizer_params=optimizer_params,
308
+ loss=self.loss,
309
+ loss_params=loss_params,
310
+ )
258
311
 
259
312
  def forward(self, x):
260
313
  # Get all embeddings and flatten