nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -52,89 +52,103 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
52
52
  class ESMM(BaseModel):
53
53
  """
54
54
  Entire Space Multi-Task Model
55
-
55
+
56
56
  ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
57
57
  - CTR task: P(click | impression)
58
58
  - CVR task: P(conversion | click)
59
59
  - CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
60
-
60
+
61
61
  This design addresses the sample selection bias and data sparsity issues in CVR modeling.
62
62
  """
63
-
63
+
64
64
  @property
65
65
  def model_name(self):
66
66
  return "ESMM"
67
67
 
68
68
  @property
69
- def task_type(self):
70
- # ESMM has fixed task types: CTR (binary) and CVR (binary)
71
- return ['binary', 'binary']
72
-
73
- def __init__(self,
74
- dense_features: list[DenseFeature],
75
- sparse_features: list[SparseFeature],
76
- sequence_features: list[SequenceFeature],
77
- ctr_params: dict,
78
- cvr_params: dict,
79
- target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
80
- task: list[str] = ['binary', 'binary'],
81
- optimizer: str = "adam",
82
- optimizer_params: dict = {},
83
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
84
- loss_params: dict | list[dict] | None = None,
85
- device: str = 'cpu',
86
- embedding_l1_reg=1e-6,
87
- dense_l1_reg=1e-5,
88
- embedding_l2_reg=1e-5,
89
- dense_l2_reg=1e-4,
90
- **kwargs):
91
-
69
+ def default_task(self):
70
+ return ["binary", "binary"]
71
+
72
+ def __init__(
73
+ self,
74
+ dense_features: list[DenseFeature],
75
+ sparse_features: list[SparseFeature],
76
+ sequence_features: list[SequenceFeature],
77
+ ctr_params: dict,
78
+ cvr_params: dict,
79
+ target: list[str] = ["ctr", "ctcvr"], # Note: ctcvr = ctr * cvr
80
+ task: list[str] | None = None,
81
+ optimizer: str = "adam",
82
+ optimizer_params: dict = {},
83
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
84
+ loss_params: dict | list[dict] | None = None,
85
+ device: str = "cpu",
86
+ embedding_l1_reg=1e-6,
87
+ dense_l1_reg=1e-5,
88
+ embedding_l2_reg=1e-5,
89
+ dense_l2_reg=1e-4,
90
+ **kwargs,
91
+ ):
92
+
92
93
  # ESMM requires exactly 2 targets: ctr and ctcvr
93
94
  if len(target) != 2:
94
- raise ValueError(f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}")
95
-
95
+ raise ValueError(
96
+ f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
97
+ )
98
+
96
99
  super(ESMM, self).__init__(
97
100
  dense_features=dense_features,
98
101
  sparse_features=sparse_features,
99
102
  sequence_features=sequence_features,
100
103
  target=target,
101
- task=task, # Both CTR and CTCVR are binary classification
104
+ task=task
105
+ or self.default_task, # Both CTR and CTCVR are binary classification
102
106
  device=device,
103
107
  embedding_l1_reg=embedding_l1_reg,
104
108
  dense_l1_reg=dense_l1_reg,
105
109
  embedding_l2_reg=embedding_l2_reg,
106
110
  dense_l2_reg=dense_l2_reg,
107
- early_stop_patience=20,
108
- **kwargs
111
+ **kwargs,
109
112
  )
110
113
 
111
114
  self.loss = loss
112
115
  if self.loss is None:
113
116
  self.loss = "bce"
114
-
117
+
115
118
  # All features
116
119
  self.all_features = dense_features + sparse_features + sequence_features
117
120
  # Shared embedding layer
118
121
  self.embedding = EmbeddingLayer(features=self.all_features)
119
- input_dim = self.embedding.input_dim # Calculate input dimension, better way than below
122
+ input_dim = (
123
+ self.embedding.input_dim
124
+ ) # Calculate input dimension, better way than below
120
125
  # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
121
126
  # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
122
127
  # input_dim = emb_dim_total + dense_input_dim
123
128
 
124
129
  # CTR tower
125
130
  self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
126
-
131
+
127
132
  # CVR tower
128
133
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
129
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
134
+ self.prediction_layer = PredictionLayer(
135
+ task_type=self.default_task, task_dims=[1, 1]
136
+ )
130
137
  # Register regularization weights
131
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
132
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
138
+ self.register_regularization_weights(
139
+ embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
140
+ )
141
+ self.compile(
142
+ optimizer=optimizer,
143
+ optimizer_params=optimizer_params,
144
+ loss=loss,
145
+ loss_params=loss_params,
146
+ )
133
147
 
134
148
  def forward(self, x):
135
149
  # Get all embeddings and flatten
136
150
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
137
-
151
+
138
152
  # CTR prediction: P(click | impression)
139
153
  ctr_logit = self.ctr_tower(input_flat) # [B, 1]
140
154
  cvr_logit = self.cvr_tower(input_flat) # [B, 1]
@@ -142,7 +156,7 @@ class ESMM(BaseModel):
142
156
  preds = self.prediction_layer(logits)
143
157
  ctr, cvr = preds.chunk(2, dim=1)
144
158
  ctcvr = ctr * cvr # [B, 1]
145
-
159
+
146
160
  # Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
147
161
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
148
162
  return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
@@ -53,66 +53,74 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
53
53
  class MMOE(BaseModel):
54
54
  """
55
55
  Multi-gate Mixture-of-Experts
56
-
56
+
57
57
  MMOE improves upon shared-bottom architecture by using multiple expert networks
58
58
  and task-specific gating networks. Each task has its own gate that learns to
59
59
  weight the contributions of different experts, allowing for both task-specific
60
60
  and shared representations.
61
61
  """
62
-
62
+
63
63
  @property
64
64
  def model_name(self):
65
65
  return "MMOE"
66
66
 
67
67
  @property
68
- def task_type(self):
69
- return self.task if isinstance(self.task, list) else [self.task]
70
-
71
- def __init__(self,
72
- dense_features: list[DenseFeature]=[],
73
- sparse_features: list[SparseFeature]=[],
74
- sequence_features: list[SequenceFeature]=[],
75
- expert_params: dict={},
76
- num_experts: int=3,
77
- tower_params_list: list[dict]=[],
78
- target: list[str]=[],
79
- task: str | list[str] = 'binary',
80
- optimizer: str = "adam",
81
- optimizer_params: dict = {},
82
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
83
- loss_params: dict | list[dict] | None = None,
84
- device: str = 'cpu',
85
- embedding_l1_reg=1e-6,
86
- dense_l1_reg=1e-5,
87
- embedding_l2_reg=1e-5,
88
- dense_l2_reg=1e-4,
89
- **kwargs):
90
-
68
+ def default_task(self):
69
+ num_tasks = getattr(self, "num_tasks", None)
70
+ if num_tasks is not None and num_tasks > 0:
71
+ return ["binary"] * num_tasks
72
+ return ["binary"]
73
+
74
+ def __init__(
75
+ self,
76
+ dense_features: list[DenseFeature] = [],
77
+ sparse_features: list[SparseFeature] = [],
78
+ sequence_features: list[SequenceFeature] = [],
79
+ expert_params: dict = {},
80
+ num_experts: int = 3,
81
+ tower_params_list: list[dict] = [],
82
+ target: list[str] = [],
83
+ task: str | list[str] | None = None,
84
+ optimizer: str = "adam",
85
+ optimizer_params: dict = {},
86
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
87
+ loss_params: dict | list[dict] | None = None,
88
+ device: str = "cpu",
89
+ embedding_l1_reg=1e-6,
90
+ dense_l1_reg=1e-5,
91
+ embedding_l2_reg=1e-5,
92
+ dense_l2_reg=1e-4,
93
+ **kwargs,
94
+ ):
95
+
96
+ self.num_tasks = len(target)
97
+
91
98
  super(MMOE, self).__init__(
92
99
  dense_features=dense_features,
93
100
  sparse_features=sparse_features,
94
101
  sequence_features=sequence_features,
95
102
  target=target,
96
- task=task,
103
+ task=task or self.default_task,
97
104
  device=device,
98
105
  embedding_l1_reg=embedding_l1_reg,
99
106
  dense_l1_reg=dense_l1_reg,
100
107
  embedding_l2_reg=embedding_l2_reg,
101
108
  dense_l2_reg=dense_l2_reg,
102
- early_stop_patience=20,
103
- **kwargs
109
+ **kwargs,
104
110
  )
105
111
 
106
112
  self.loss = loss
107
113
  if self.loss is None:
108
114
  self.loss = "bce"
109
-
115
+
110
116
  # Number of tasks and experts
111
117
  self.num_tasks = len(target)
112
118
  self.num_experts = num_experts
113
-
119
+
114
120
  if len(tower_params_list) != self.num_tasks:
115
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
121
+ raise ValueError(
122
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
123
+ )
116
124
 
117
125
  self.all_features = dense_features + sparse_features + sequence_features
118
126
  self.embedding = EmbeddingLayer(features=self.all_features)
@@ -126,54 +134,69 @@ class MMOE(BaseModel):
126
134
  for _ in range(num_experts):
127
135
  expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
128
136
  self.experts.append(expert)
129
-
137
+
130
138
  # Get expert output dimension
131
- if 'dims' in expert_params and len(expert_params['dims']) > 0:
132
- expert_output_dim = expert_params['dims'][-1]
139
+ if "dims" in expert_params and len(expert_params["dims"]) > 0:
140
+ expert_output_dim = expert_params["dims"][-1]
133
141
  else:
134
142
  expert_output_dim = input_dim
135
-
143
+
136
144
  # Task-specific gates
137
145
  self.gates = nn.ModuleList()
138
146
  for _ in range(self.num_tasks):
139
147
  gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
140
148
  self.gates.append(gate)
141
-
149
+
142
150
  # Task-specific towers
143
151
  self.towers = nn.ModuleList()
144
152
  for tower_params in tower_params_list:
145
153
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
146
154
  self.towers.append(tower)
147
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
155
+ self.prediction_layer = PredictionLayer(
156
+ task_type=self.default_task, task_dims=[1] * self.num_tasks
157
+ )
148
158
  # Register regularization weights
149
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
150
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
159
+ self.register_regularization_weights(
160
+ embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
161
+ )
162
+ self.compile(
163
+ optimizer=optimizer,
164
+ optimizer_params=optimizer_params,
165
+ loss=loss,
166
+ loss_params=loss_params,
167
+ )
151
168
 
152
169
  def forward(self, x):
153
170
  # Get all embeddings and flatten
154
171
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
155
-
172
+
156
173
  # Expert outputs: [num_experts, B, expert_dim]
157
174
  expert_outputs = [expert(input_flat) for expert in self.experts]
158
- expert_outputs = torch.stack(expert_outputs, dim=0) # [num_experts, B, expert_dim]
159
-
175
+ expert_outputs = torch.stack(
176
+ expert_outputs, dim=0
177
+ ) # [num_experts, B, expert_dim]
178
+
160
179
  # Task-specific processing
161
180
  task_outputs = []
162
181
  for task_idx in range(self.num_tasks):
163
182
  # Gate weights for this task: [B, num_experts]
164
183
  gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
165
-
184
+
166
185
  # Weighted sum of expert outputs
167
186
  # gate_weights: [B, num_experts, 1]
168
187
  # expert_outputs: [num_experts, B, expert_dim]
169
188
  gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
170
- expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, num_experts, expert_dim]
171
- gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1) # [B, expert_dim]
172
-
189
+ expert_outputs_t = expert_outputs.permute(
190
+ 1, 0, 2
191
+ ) # [B, num_experts, expert_dim]
192
+ gated_output = torch.sum(
193
+ gate_weights * expert_outputs_t, dim=1
194
+ ) # [B, expert_dim]
195
+
173
196
  # Tower output
174
197
  tower_output = self.towers[task_idx](gated_output) # [B, 1]
175
198
  task_outputs.append(tower_output)
176
-
199
+
177
200
  # Stack outputs: [B, num_tasks]
178
201
  y = torch.cat(task_outputs, dim=1)
179
202
  return self.prediction_layer(y)
@@ -52,6 +52,7 @@ from nextrec.basic.model import BaseModel
52
52
  from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
53
53
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
54
54
 
55
+
55
56
  class CGCLayer(nn.Module):
56
57
  """
57
58
  CGC (Customized Gate Control) block used by PLE.
@@ -71,26 +72,61 @@ class CGCLayer(nn.Module):
71
72
  if num_tasks < 1:
72
73
  raise ValueError("num_tasks must be >= 1")
73
74
 
74
- specific_params_list = self._normalize_specific_params(specific_expert_params, num_tasks)
75
+ specific_params_list = self._normalize_specific_params(
76
+ specific_expert_params, num_tasks
77
+ )
75
78
 
76
79
  self.output_dim = self._get_output_dim(shared_expert_params, input_dim)
77
- specific_dims = [self._get_output_dim(params, input_dim) for params in specific_params_list]
80
+ specific_dims = [
81
+ self._get_output_dim(params, input_dim) for params in specific_params_list
82
+ ]
78
83
  dims_set = set(specific_dims + [self.output_dim])
79
84
  if len(dims_set) != 1:
80
- raise ValueError(f"Shared/specific expert output dims must match, got {dims_set}")
85
+ raise ValueError(
86
+ f"Shared/specific expert output dims must match, got {dims_set}"
87
+ )
81
88
 
82
89
  # experts
83
- self.shared_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **shared_expert_params,) for _ in range(num_shared_experts)])
90
+ self.shared_experts = nn.ModuleList(
91
+ [
92
+ MLP(
93
+ input_dim=input_dim,
94
+ output_layer=False,
95
+ **shared_expert_params,
96
+ )
97
+ for _ in range(num_shared_experts)
98
+ ]
99
+ )
84
100
  self.specific_experts = nn.ModuleList()
85
101
  for params in specific_params_list:
86
- task_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **params,) for _ in range(num_specific_experts)])
102
+ task_experts = nn.ModuleList(
103
+ [
104
+ MLP(
105
+ input_dim=input_dim,
106
+ output_layer=False,
107
+ **params,
108
+ )
109
+ for _ in range(num_specific_experts)
110
+ ]
111
+ )
87
112
  self.specific_experts.append(task_experts)
88
113
 
89
114
  # gates
90
115
  task_gate_expert_num = num_shared_experts + num_specific_experts
91
- self.task_gates = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, task_gate_expert_num), nn.Softmax(dim=1),) for _ in range(num_tasks)])
116
+ self.task_gates = nn.ModuleList(
117
+ [
118
+ nn.Sequential(
119
+ nn.Linear(input_dim, task_gate_expert_num),
120
+ nn.Softmax(dim=1),
121
+ )
122
+ for _ in range(num_tasks)
123
+ ]
124
+ )
92
125
  shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
93
- self.shared_gate = nn.Sequential(nn.Linear(input_dim, shared_gate_expert_num), nn.Softmax(dim=1),)
126
+ self.shared_gate = nn.Sequential(
127
+ nn.Linear(input_dim, shared_gate_expert_num),
128
+ nn.Softmax(dim=1),
129
+ )
94
130
 
95
131
  self.num_tasks = num_tasks
96
132
 
@@ -98,7 +134,9 @@ class CGCLayer(nn.Module):
98
134
  self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
99
135
  ) -> tuple[list[torch.Tensor], torch.Tensor]:
100
136
  if len(task_inputs) != self.num_tasks:
101
- raise ValueError(f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}")
137
+ raise ValueError(
138
+ f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}"
139
+ )
102
140
 
103
141
  shared_outputs = [expert(shared_input) for expert in self.shared_experts]
104
142
  shared_stack = torch.stack(shared_outputs, dim=0) # [num_shared, B, D]
@@ -108,7 +146,7 @@ class CGCLayer(nn.Module):
108
146
 
109
147
  for task_idx in range(self.num_tasks):
110
148
  task_input = task_inputs[task_idx]
111
- task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
149
+ task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
112
150
  all_specific_for_shared.extend(task_specific_outputs)
113
151
  specific_stack = torch.stack(task_specific_outputs, dim=0)
114
152
 
@@ -139,7 +177,9 @@ class CGCLayer(nn.Module):
139
177
  ) -> list[dict]:
140
178
  if isinstance(params, list):
141
179
  if len(params) != num_tasks:
142
- raise ValueError(f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks}).")
180
+ raise ValueError(
181
+ f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks})."
182
+ )
143
183
  return [p.copy() for p in params]
144
184
  return [params.copy() for _ in range(num_tasks)]
145
185
 
@@ -147,57 +187,63 @@ class CGCLayer(nn.Module):
147
187
  class PLE(BaseModel):
148
188
  """
149
189
  Progressive Layered Extraction
150
-
190
+
151
191
  PLE is an advanced multi-task learning model that extends MMOE by introducing
152
192
  both task-specific experts and shared experts at each level. It uses a progressive
153
193
  routing mechanism where experts from level k feed into gates at level k+1.
154
194
  This design better captures task-specific and shared information progressively.
155
195
  """
156
-
196
+
157
197
  @property
158
198
  def model_name(self):
159
199
  return "PLE"
160
200
 
161
201
  @property
162
- def task_type(self):
163
- return self.task if isinstance(self.task, list) else [self.task]
164
-
165
- def __init__(self,
166
- dense_features: list[DenseFeature],
167
- sparse_features: list[SparseFeature],
168
- sequence_features: list[SequenceFeature],
169
- shared_expert_params: dict,
170
- specific_expert_params: dict | list[dict],
171
- num_shared_experts: int,
172
- num_specific_experts: int,
173
- num_levels: int,
174
- tower_params_list: list[dict],
175
- target: list[str],
176
- task: str | list[str] = 'binary',
177
- optimizer: str = "adam",
178
- optimizer_params: dict | None = None,
179
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
180
- loss_params: dict | list[dict] | None = None,
181
- device: str = 'cpu',
182
- embedding_l1_reg=1e-6,
183
- dense_l1_reg=1e-5,
184
- embedding_l2_reg=1e-5,
185
- dense_l2_reg=1e-4,
186
- **kwargs):
187
-
202
+ def default_task(self):
203
+ num_tasks = getattr(self, "num_tasks", None)
204
+ if num_tasks is not None and num_tasks > 0:
205
+ return ["binary"] * num_tasks
206
+ return ["binary"]
207
+
208
+ def __init__(
209
+ self,
210
+ dense_features: list[DenseFeature],
211
+ sparse_features: list[SparseFeature],
212
+ sequence_features: list[SequenceFeature],
213
+ shared_expert_params: dict,
214
+ specific_expert_params: dict | list[dict],
215
+ num_shared_experts: int,
216
+ num_specific_experts: int,
217
+ num_levels: int,
218
+ tower_params_list: list[dict],
219
+ target: list[str],
220
+ task: str | list[str] | None = None,
221
+ optimizer: str = "adam",
222
+ optimizer_params: dict | None = None,
223
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
224
+ loss_params: dict | list[dict] | None = None,
225
+ device: str = "cpu",
226
+ embedding_l1_reg=1e-6,
227
+ dense_l1_reg=1e-5,
228
+ embedding_l2_reg=1e-5,
229
+ dense_l2_reg=1e-4,
230
+ **kwargs,
231
+ ):
232
+
233
+ self.num_tasks = len(target)
234
+
188
235
  super(PLE, self).__init__(
189
236
  dense_features=dense_features,
190
237
  sparse_features=sparse_features,
191
238
  sequence_features=sequence_features,
192
239
  target=target,
193
- task=task,
240
+ task=task or self.default_task,
194
241
  device=device,
195
242
  embedding_l1_reg=embedding_l1_reg,
196
243
  dense_l1_reg=dense_l1_reg,
197
244
  embedding_l2_reg=embedding_l2_reg,
198
245
  dense_l2_reg=dense_l2_reg,
199
- early_stop_patience=20,
200
- **kwargs
246
+ **kwargs,
201
247
  )
202
248
 
203
249
  self.loss = loss
@@ -211,7 +257,9 @@ class PLE(BaseModel):
211
257
  if optimizer_params is None:
212
258
  optimizer_params = {}
213
259
  if len(tower_params_list) != self.num_tasks:
214
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
260
+ raise ValueError(
261
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
262
+ )
215
263
  # Embedding layer
216
264
  self.embedding = EmbeddingLayer(features=self.all_features)
217
265
 
@@ -220,13 +268,13 @@ class PLE(BaseModel):
220
268
  # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
221
269
  # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
222
270
  # input_dim = emb_dim_total + dense_input_dim
223
-
271
+
224
272
  # Get expert output dimension
225
- if 'dims' in shared_expert_params and len(shared_expert_params['dims']) > 0:
226
- expert_output_dim = shared_expert_params['dims'][-1]
273
+ if "dims" in shared_expert_params and len(shared_expert_params["dims"]) > 0:
274
+ expert_output_dim = shared_expert_params["dims"][-1]
227
275
  else:
228
276
  expert_output_dim = input_dim
229
-
277
+
230
278
  # Build CGC layers
231
279
  self.cgc_layers = nn.ModuleList()
232
280
  for level in range(num_levels):
@@ -241,16 +289,25 @@ class PLE(BaseModel):
241
289
  )
242
290
  self.cgc_layers.append(cgc_layer)
243
291
  expert_output_dim = cgc_layer.output_dim
244
-
292
+
245
293
  # Task-specific towers
246
294
  self.towers = nn.ModuleList()
247
295
  for tower_params in tower_params_list:
248
296
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
249
297
  self.towers.append(tower)
250
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
298
+ self.prediction_layer = PredictionLayer(
299
+ task_type=self.default_task, task_dims=[1] * self.num_tasks
300
+ )
251
301
  # Register regularization weights
252
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
253
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
302
+ self.register_regularization_weights(
303
+ embedding_attr="embedding", include_modules=["cgc_layers", "towers"]
304
+ )
305
+ self.compile(
306
+ optimizer=optimizer,
307
+ optimizer_params=optimizer_params,
308
+ loss=self.loss,
309
+ loss_params=loss_params,
310
+ )
254
311
 
255
312
  def forward(self, x):
256
313
  # Get all embeddings and flatten