nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +250 -112
  7. nextrec/basic/loggers.py +63 -44
  8. nextrec/basic/metrics.py +270 -120
  9. nextrec/basic/model.py +1084 -402
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +492 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +273 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +69 -46
  29. nextrec/models/multi_task/mmoe.py +91 -53
  30. nextrec/models/multi_task/ple.py +117 -58
  31. nextrec/models/multi_task/poso.py +163 -55
  32. nextrec/models/multi_task/share_bottom.py +63 -36
  33. nextrec/models/ranking/afm.py +80 -45
  34. nextrec/models/ranking/autoint.py +74 -57
  35. nextrec/models/ranking/dcn.py +110 -48
  36. nextrec/models/ranking/dcn_v2.py +265 -45
  37. nextrec/models/ranking/deepfm.py +39 -24
  38. nextrec/models/ranking/dien.py +335 -146
  39. nextrec/models/ranking/din.py +158 -92
  40. nextrec/models/ranking/fibinet.py +134 -52
  41. nextrec/models/ranking/fm.py +68 -26
  42. nextrec/models/ranking/masknet.py +95 -33
  43. nextrec/models/ranking/pnn.py +128 -58
  44. nextrec/models/ranking/widedeep.py +40 -28
  45. nextrec/models/ranking/xdeepfm.py +67 -40
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +496 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +33 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/model.py +22 -0
  55. nextrec/utils/optimizer.py +25 -9
  56. nextrec/utils/synthetic_data.py +283 -165
  57. nextrec/utils/tensor.py +24 -13
  58. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
  59. nextrec-0.4.3.dist-info/RECORD +69 -0
  60. nextrec-0.4.3.dist-info/entry_points.txt +2 -0
  61. nextrec-0.4.1.dist-info/RECORD +0 -66
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
  63. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -52,87 +52,110 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
52
52
  class ESMM(BaseModel):
53
53
  """
54
54
  Entire Space Multi-Task Model
55
-
55
+
56
56
  ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
57
57
  - CTR task: P(click | impression)
58
58
  - CVR task: P(conversion | click)
59
59
  - CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
60
-
60
+
61
61
  This design addresses the sample selection bias and data sparsity issues in CVR modeling.
62
62
  """
63
-
63
+
64
64
  @property
65
65
  def model_name(self):
66
66
  return "ESMM"
67
-
67
+
68
68
  @property
69
69
  def default_task(self):
70
- return ['binary', 'binary']
71
-
72
- def __init__(self,
73
- dense_features: list[DenseFeature],
74
- sparse_features: list[SparseFeature],
75
- sequence_features: list[SequenceFeature],
76
- ctr_params: dict,
77
- cvr_params: dict,
78
- target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
79
- task: list[str] | None = None,
80
- optimizer: str = "adam",
81
- optimizer_params: dict = {},
82
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
83
- loss_params: dict | list[dict] | None = None,
84
- device: str = 'cpu',
85
- embedding_l1_reg=1e-6,
86
- dense_l1_reg=1e-5,
87
- embedding_l2_reg=1e-5,
88
- dense_l2_reg=1e-4,
89
- **kwargs):
90
-
91
- # ESMM requires exactly 2 targets: ctr and ctcvr
70
+ return ["binary", "binary"]
71
+
72
+ def __init__(
73
+ self,
74
+ dense_features: list[DenseFeature],
75
+ sparse_features: list[SparseFeature],
76
+ sequence_features: list[SequenceFeature],
77
+ ctr_params: dict,
78
+ cvr_params: dict,
79
+ target: list[str] | None = None, # Note: ctcvr = ctr * cvr
80
+ task: list[str] | None = None,
81
+ optimizer: str = "adam",
82
+ optimizer_params: dict | None = None,
83
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
84
+ loss_params: dict | list[dict] | None = None,
85
+ device: str = "cpu",
86
+ embedding_l1_reg=1e-6,
87
+ dense_l1_reg=1e-5,
88
+ embedding_l2_reg=1e-5,
89
+ dense_l2_reg=1e-4,
90
+ **kwargs,
91
+ ):
92
+
93
+ target = target or ["ctr", "ctcvr"]
94
+ optimizer_params = optimizer_params or {}
95
+ if loss is None:
96
+ loss = "bce"
97
+
92
98
  if len(target) != 2:
93
- raise ValueError(f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}")
94
-
99
+ raise ValueError(
100
+ f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
101
+ )
102
+
103
+ self.num_tasks = len(target)
104
+ resolved_task = task
105
+ if resolved_task is None:
106
+ resolved_task = self.default_task
107
+ elif isinstance(resolved_task, str):
108
+ resolved_task = [resolved_task] * self.num_tasks
109
+ elif len(resolved_task) == 1 and self.num_tasks > 1:
110
+ resolved_task = resolved_task * self.num_tasks
111
+ elif len(resolved_task) != self.num_tasks:
112
+ raise ValueError(
113
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
114
+ )
115
+ # resolved_task is now guaranteed to be a list[str]
116
+
95
117
  super(ESMM, self).__init__(
96
118
  dense_features=dense_features,
97
119
  sparse_features=sparse_features,
98
120
  sequence_features=sequence_features,
99
121
  target=target,
100
- task=task or self.default_task, # Both CTR and CTCVR are binary classification
122
+ task=resolved_task, # Both CTR and CTCVR are binary classification
101
123
  device=device,
102
124
  embedding_l1_reg=embedding_l1_reg,
103
125
  dense_l1_reg=dense_l1_reg,
104
126
  embedding_l2_reg=embedding_l2_reg,
105
127
  dense_l2_reg=dense_l2_reg,
106
- **kwargs
128
+ **kwargs,
107
129
  )
108
130
 
109
131
  self.loss = loss
110
- if self.loss is None:
111
- self.loss = "bce"
112
-
113
- # All features
114
- self.all_features = dense_features + sparse_features + sequence_features
115
- # Shared embedding layer
132
+
116
133
  self.embedding = EmbeddingLayer(features=self.all_features)
117
- input_dim = self.embedding.input_dim # Calculate input dimension, better way than below
118
- # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
119
- # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
120
- # input_dim = emb_dim_total + dense_input_dim
134
+ input_dim = self.embedding.input_dim
121
135
 
122
136
  # CTR tower
123
137
  self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
124
-
138
+
125
139
  # CVR tower
126
140
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
127
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1, 1])
141
+ self.prediction_layer = PredictionLayer(
142
+ task_type=self.default_task, task_dims=[1, 1]
143
+ )
128
144
  # Register regularization weights
129
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
130
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
145
+ self.register_regularization_weights(
146
+ embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
147
+ )
148
+ self.compile(
149
+ optimizer=optimizer,
150
+ optimizer_params=optimizer_params,
151
+ loss=loss,
152
+ loss_params=loss_params,
153
+ )
131
154
 
132
155
  def forward(self, x):
133
156
  # Get all embeddings and flatten
134
157
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
135
-
158
+
136
159
  # CTR prediction: P(click | impression)
137
160
  ctr_logit = self.ctr_tower(input_flat) # [B, 1]
138
161
  cvr_logit = self.cvr_tower(input_flat) # [B, 1]
@@ -140,7 +163,7 @@ class ESMM(BaseModel):
140
163
  preds = self.prediction_layer(logits)
141
164
  ctr, cvr = preds.chunk(2, dim=1)
142
165
  ctcvr = ctr * cvr # [B, 1]
143
-
166
+
144
167
  # Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
145
168
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
146
169
  return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
@@ -53,13 +53,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
53
53
  class MMOE(BaseModel):
54
54
  """
55
55
  Multi-gate Mixture-of-Experts
56
-
56
+
57
57
  MMOE improves upon shared-bottom architecture by using multiple expert networks
58
58
  and task-specific gating networks. Each task has its own gate that learns to
59
59
  weight the contributions of different experts, allowing for both task-specific
60
60
  and shared representations.
61
61
  """
62
-
62
+
63
63
  @property
64
64
  def model_name(self):
65
65
  return "MMOE"
@@ -68,116 +68,154 @@ class MMOE(BaseModel):
68
68
  def default_task(self):
69
69
  num_tasks = getattr(self, "num_tasks", None)
70
70
  if num_tasks is not None and num_tasks > 0:
71
- return ['binary'] * num_tasks
72
- return ['binary']
73
-
74
- def __init__(self,
75
- dense_features: list[DenseFeature]=[],
76
- sparse_features: list[SparseFeature]=[],
77
- sequence_features: list[SequenceFeature]=[],
78
- expert_params: dict={},
79
- num_experts: int=3,
80
- tower_params_list: list[dict]=[],
81
- target: list[str]=[],
82
- task: str | list[str] | None = None,
83
- optimizer: str = "adam",
84
- optimizer_params: dict = {},
85
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
86
- loss_params: dict | list[dict] | None = None,
87
- device: str = 'cpu',
88
- embedding_l1_reg=1e-6,
89
- dense_l1_reg=1e-5,
90
- embedding_l2_reg=1e-5,
91
- dense_l2_reg=1e-4,
92
- **kwargs):
93
-
94
- self.num_tasks = len(target)
71
+ return ["binary"] * num_tasks
72
+ return ["binary"]
73
+
74
+ def __init__(
75
+ self,
76
+ dense_features: list[DenseFeature] | None = None,
77
+ sparse_features: list[SparseFeature] | None = None,
78
+ sequence_features: list[SequenceFeature] | None = None,
79
+ expert_params: dict | None = None,
80
+ num_experts: int = 3,
81
+ tower_params_list: list[dict] | None = None,
82
+ target: list[str] | str | None = None,
83
+ task: str | list[str] = "binary",
84
+ optimizer: str = "adam",
85
+ optimizer_params: dict | None = None,
86
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
87
+ loss_params: dict | list[dict] | None = None,
88
+ device: str = "cpu",
89
+ embedding_l1_reg=1e-6,
90
+ dense_l1_reg=1e-5,
91
+ embedding_l2_reg=1e-5,
92
+ dense_l2_reg=1e-4,
93
+ **kwargs,
94
+ ):
95
+
96
+ dense_features = dense_features or []
97
+ sparse_features = sparse_features or []
98
+ sequence_features = sequence_features or []
99
+ expert_params = expert_params or {}
100
+ tower_params_list = tower_params_list or []
101
+ optimizer_params = optimizer_params or {}
102
+ if loss is None:
103
+ loss = "bce"
104
+ if target is None:
105
+ target = []
106
+ elif isinstance(target, str):
107
+ target = [target]
108
+
109
+ self.num_tasks = len(target) if target else 1
110
+
111
+ resolved_task = task
112
+ if resolved_task is None:
113
+ resolved_task = self.default_task
114
+ elif isinstance(resolved_task, str):
115
+ resolved_task = [resolved_task] * self.num_tasks
116
+ elif len(resolved_task) == 1 and self.num_tasks > 1:
117
+ resolved_task = resolved_task * self.num_tasks
118
+ elif len(resolved_task) != self.num_tasks:
119
+ raise ValueError(
120
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
121
+ )
95
122
 
96
123
  super(MMOE, self).__init__(
97
124
  dense_features=dense_features,
98
125
  sparse_features=sparse_features,
99
126
  sequence_features=sequence_features,
100
127
  target=target,
101
- task=task or self.default_task,
128
+ task=resolved_task,
102
129
  device=device,
103
130
  embedding_l1_reg=embedding_l1_reg,
104
131
  dense_l1_reg=dense_l1_reg,
105
132
  embedding_l2_reg=embedding_l2_reg,
106
133
  dense_l2_reg=dense_l2_reg,
107
- **kwargs
134
+ **kwargs,
108
135
  )
109
136
 
110
137
  self.loss = loss
111
- if self.loss is None:
112
- self.loss = "bce"
113
-
138
+
114
139
  # Number of tasks and experts
115
140
  self.num_tasks = len(target)
116
141
  self.num_experts = num_experts
117
-
142
+
118
143
  if len(tower_params_list) != self.num_tasks:
119
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
144
+ raise ValueError(
145
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
146
+ )
120
147
 
121
- self.all_features = dense_features + sparse_features + sequence_features
122
148
  self.embedding = EmbeddingLayer(features=self.all_features)
123
149
  input_dim = self.embedding.input_dim
124
- # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
125
- # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
126
- # input_dim = emb_dim_total + dense_input_dim
127
150
 
128
151
  # Expert networks (shared by all tasks)
129
152
  self.experts = nn.ModuleList()
130
153
  for _ in range(num_experts):
131
154
  expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
132
155
  self.experts.append(expert)
133
-
156
+
134
157
  # Get expert output dimension
135
- if 'dims' in expert_params and len(expert_params['dims']) > 0:
136
- expert_output_dim = expert_params['dims'][-1]
158
+ if "dims" in expert_params and len(expert_params["dims"]) > 0:
159
+ expert_output_dim = expert_params["dims"][-1]
137
160
  else:
138
161
  expert_output_dim = input_dim
139
-
162
+
140
163
  # Task-specific gates
141
164
  self.gates = nn.ModuleList()
142
165
  for _ in range(self.num_tasks):
143
166
  gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
144
167
  self.gates.append(gate)
145
-
168
+
146
169
  # Task-specific towers
147
170
  self.towers = nn.ModuleList()
148
171
  for tower_params in tower_params_list:
149
172
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
150
173
  self.towers.append(tower)
151
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
174
+ self.prediction_layer = PredictionLayer(
175
+ task_type=self.default_task, task_dims=[1] * self.num_tasks
176
+ )
152
177
  # Register regularization weights
153
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
154
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
178
+ self.register_regularization_weights(
179
+ embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
180
+ )
181
+ self.compile(
182
+ optimizer=optimizer,
183
+ optimizer_params=optimizer_params,
184
+ loss=self.loss,
185
+ loss_params=loss_params,
186
+ )
155
187
 
156
188
  def forward(self, x):
157
189
  # Get all embeddings and flatten
158
190
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
159
-
191
+
160
192
  # Expert outputs: [num_experts, B, expert_dim]
161
193
  expert_outputs = [expert(input_flat) for expert in self.experts]
162
- expert_outputs = torch.stack(expert_outputs, dim=0) # [num_experts, B, expert_dim]
163
-
194
+ expert_outputs = torch.stack(
195
+ expert_outputs, dim=0
196
+ ) # [num_experts, B, expert_dim]
197
+
164
198
  # Task-specific processing
165
199
  task_outputs = []
166
200
  for task_idx in range(self.num_tasks):
167
201
  # Gate weights for this task: [B, num_experts]
168
202
  gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
169
-
203
+
170
204
  # Weighted sum of expert outputs
171
205
  # gate_weights: [B, num_experts, 1]
172
206
  # expert_outputs: [num_experts, B, expert_dim]
173
207
  gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
174
- expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, num_experts, expert_dim]
175
- gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1) # [B, expert_dim]
176
-
208
+ expert_outputs_t = expert_outputs.permute(
209
+ 1, 0, 2
210
+ ) # [B, num_experts, expert_dim]
211
+ gated_output = torch.sum(
212
+ gate_weights * expert_outputs_t, dim=1
213
+ ) # [B, expert_dim]
214
+
177
215
  # Tower output
178
216
  tower_output = self.towers[task_idx](gated_output) # [B, 1]
179
217
  task_outputs.append(tower_output)
180
-
218
+
181
219
  # Stack outputs: [B, num_tasks]
182
220
  y = torch.cat(task_outputs, dim=1)
183
221
  return self.prediction_layer(y)
@@ -51,6 +51,8 @@ import torch.nn as nn
51
51
  from nextrec.basic.model import BaseModel
52
52
  from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
53
53
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
54
+ from nextrec.utils.model import get_mlp_output_dim
55
+
54
56
 
55
57
  class CGCLayer(nn.Module):
56
58
  """
@@ -71,26 +73,61 @@ class CGCLayer(nn.Module):
71
73
  if num_tasks < 1:
72
74
  raise ValueError("num_tasks must be >= 1")
73
75
 
74
- specific_params_list = self._normalize_specific_params(specific_expert_params, num_tasks)
76
+ specific_params_list = self.normalize_specific_params(
77
+ specific_expert_params, num_tasks
78
+ )
75
79
 
76
- self.output_dim = self._get_output_dim(shared_expert_params, input_dim)
77
- specific_dims = [self._get_output_dim(params, input_dim) for params in specific_params_list]
80
+ self.output_dim = get_mlp_output_dim(shared_expert_params, input_dim)
81
+ specific_dims = [
82
+ get_mlp_output_dim(params, input_dim) for params in specific_params_list
83
+ ]
78
84
  dims_set = set(specific_dims + [self.output_dim])
79
85
  if len(dims_set) != 1:
80
- raise ValueError(f"Shared/specific expert output dims must match, got {dims_set}")
86
+ raise ValueError(
87
+ f"Shared/specific expert output dims must match, got {dims_set}"
88
+ )
81
89
 
82
90
  # experts
83
- self.shared_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **shared_expert_params,) for _ in range(num_shared_experts)])
91
+ self.shared_experts = nn.ModuleList(
92
+ [
93
+ MLP(
94
+ input_dim=input_dim,
95
+ output_layer=False,
96
+ **shared_expert_params,
97
+ )
98
+ for _ in range(num_shared_experts)
99
+ ]
100
+ )
84
101
  self.specific_experts = nn.ModuleList()
85
102
  for params in specific_params_list:
86
- task_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **params,) for _ in range(num_specific_experts)])
103
+ task_experts = nn.ModuleList(
104
+ [
105
+ MLP(
106
+ input_dim=input_dim,
107
+ output_layer=False,
108
+ **params,
109
+ )
110
+ for _ in range(num_specific_experts)
111
+ ]
112
+ )
87
113
  self.specific_experts.append(task_experts)
88
114
 
89
115
  # gates
90
116
  task_gate_expert_num = num_shared_experts + num_specific_experts
91
- self.task_gates = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, task_gate_expert_num), nn.Softmax(dim=1),) for _ in range(num_tasks)])
117
+ self.task_gates = nn.ModuleList(
118
+ [
119
+ nn.Sequential(
120
+ nn.Linear(input_dim, task_gate_expert_num),
121
+ nn.Softmax(dim=1),
122
+ )
123
+ for _ in range(num_tasks)
124
+ ]
125
+ )
92
126
  shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
93
- self.shared_gate = nn.Sequential(nn.Linear(input_dim, shared_gate_expert_num), nn.Softmax(dim=1),)
127
+ self.shared_gate = nn.Sequential(
128
+ nn.Linear(input_dim, shared_gate_expert_num),
129
+ nn.Softmax(dim=1),
130
+ )
94
131
 
95
132
  self.num_tasks = num_tasks
96
133
 
@@ -98,7 +135,9 @@ class CGCLayer(nn.Module):
98
135
  self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
99
136
  ) -> tuple[list[torch.Tensor], torch.Tensor]:
100
137
  if len(task_inputs) != self.num_tasks:
101
- raise ValueError(f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}")
138
+ raise ValueError(
139
+ f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}"
140
+ )
102
141
 
103
142
  shared_outputs = [expert(shared_input) for expert in self.shared_experts]
104
143
  shared_stack = torch.stack(shared_outputs, dim=0) # [num_shared, B, D]
@@ -108,7 +147,7 @@ class CGCLayer(nn.Module):
108
147
 
109
148
  for task_idx in range(self.num_tasks):
110
149
  task_input = task_inputs[task_idx]
111
- task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
150
+ task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
112
151
  all_specific_for_shared.extend(task_specific_outputs)
113
152
  specific_stack = torch.stack(task_specific_outputs, dim=0)
114
153
 
@@ -127,19 +166,14 @@ class CGCLayer(nn.Module):
127
166
  return new_task_fea, new_shared
128
167
 
129
168
  @staticmethod
130
- def _get_output_dim(params: dict, fallback: int) -> int:
131
- dims = params.get("dims")
132
- if dims:
133
- return dims[-1]
134
- return fallback
135
-
136
- @staticmethod
137
- def _normalize_specific_params(
169
+ def normalize_specific_params(
138
170
  params: dict | list[dict], num_tasks: int
139
171
  ) -> list[dict]:
140
172
  if isinstance(params, list):
141
173
  if len(params) != num_tasks:
142
- raise ValueError(f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks}).")
174
+ raise ValueError(
175
+ f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks})."
176
+ )
143
177
  return [p.copy() for p in params]
144
178
  return [params.copy() for _ in range(num_tasks)]
145
179
 
@@ -147,13 +181,13 @@ class CGCLayer(nn.Module):
147
181
  class PLE(BaseModel):
148
182
  """
149
183
  Progressive Layered Extraction
150
-
184
+
151
185
  PLE is an advanced multi-task learning model that extends MMOE by introducing
152
186
  both task-specific experts and shared experts at each level. It uses a progressive
153
187
  routing mechanism where experts from level k feed into gates at level k+1.
154
188
  This design better captures task-specific and shared information progressively.
155
189
  """
156
-
190
+
157
191
  @property
158
192
  def model_name(self):
159
193
  return "PLE"
@@ -162,46 +196,60 @@ class PLE(BaseModel):
162
196
  def default_task(self):
163
197
  num_tasks = getattr(self, "num_tasks", None)
164
198
  if num_tasks is not None and num_tasks > 0:
165
- return ['binary'] * num_tasks
166
- return ['binary']
167
-
168
- def __init__(self,
169
- dense_features: list[DenseFeature],
170
- sparse_features: list[SparseFeature],
171
- sequence_features: list[SequenceFeature],
172
- shared_expert_params: dict,
173
- specific_expert_params: dict | list[dict],
174
- num_shared_experts: int,
175
- num_specific_experts: int,
176
- num_levels: int,
177
- tower_params_list: list[dict],
178
- target: list[str],
179
- task: str | list[str] | None = None,
180
- optimizer: str = "adam",
181
- optimizer_params: dict | None = None,
182
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
183
- loss_params: dict | list[dict] | None = None,
184
- device: str = 'cpu',
185
- embedding_l1_reg=1e-6,
186
- dense_l1_reg=1e-5,
187
- embedding_l2_reg=1e-5,
188
- dense_l2_reg=1e-4,
189
- **kwargs):
190
-
199
+ return ["binary"] * num_tasks
200
+ return ["binary"]
201
+
202
+ def __init__(
203
+ self,
204
+ dense_features: list[DenseFeature],
205
+ sparse_features: list[SparseFeature],
206
+ sequence_features: list[SequenceFeature],
207
+ shared_expert_params: dict,
208
+ specific_expert_params: dict | list[dict],
209
+ num_shared_experts: int,
210
+ num_specific_experts: int,
211
+ num_levels: int,
212
+ tower_params_list: list[dict],
213
+ target: list[str],
214
+ task: str | list[str] | None = None,
215
+ optimizer: str = "adam",
216
+ optimizer_params: dict | None = None,
217
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
218
+ loss_params: dict | list[dict] | None = None,
219
+ device: str = "cpu",
220
+ embedding_l1_reg=1e-6,
221
+ dense_l1_reg=1e-5,
222
+ embedding_l2_reg=1e-5,
223
+ dense_l2_reg=1e-4,
224
+ **kwargs,
225
+ ):
226
+
191
227
  self.num_tasks = len(target)
192
228
 
229
+ resolved_task = task
230
+ if resolved_task is None:
231
+ resolved_task = self.default_task
232
+ elif isinstance(resolved_task, str):
233
+ resolved_task = [resolved_task] * self.num_tasks
234
+ elif len(resolved_task) == 1 and self.num_tasks > 1:
235
+ resolved_task = resolved_task * self.num_tasks
236
+ elif len(resolved_task) != self.num_tasks:
237
+ raise ValueError(
238
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
239
+ )
240
+
193
241
  super(PLE, self).__init__(
194
242
  dense_features=dense_features,
195
243
  sparse_features=sparse_features,
196
244
  sequence_features=sequence_features,
197
245
  target=target,
198
- task=task or self.default_task,
246
+ task=resolved_task,
199
247
  device=device,
200
248
  embedding_l1_reg=embedding_l1_reg,
201
249
  dense_l1_reg=dense_l1_reg,
202
250
  embedding_l2_reg=embedding_l2_reg,
203
251
  dense_l2_reg=dense_l2_reg,
204
- **kwargs
252
+ **kwargs,
205
253
  )
206
254
 
207
255
  self.loss = loss
@@ -215,7 +263,9 @@ class PLE(BaseModel):
215
263
  if optimizer_params is None:
216
264
  optimizer_params = {}
217
265
  if len(tower_params_list) != self.num_tasks:
218
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
266
+ raise ValueError(
267
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
268
+ )
219
269
  # Embedding layer
220
270
  self.embedding = EmbeddingLayer(features=self.all_features)
221
271
 
@@ -224,13 +274,13 @@ class PLE(BaseModel):
224
274
  # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
225
275
  # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
226
276
  # input_dim = emb_dim_total + dense_input_dim
227
-
277
+
228
278
  # Get expert output dimension
229
- if 'dims' in shared_expert_params and len(shared_expert_params['dims']) > 0:
230
- expert_output_dim = shared_expert_params['dims'][-1]
279
+ if "dims" in shared_expert_params and len(shared_expert_params["dims"]) > 0:
280
+ expert_output_dim = shared_expert_params["dims"][-1]
231
281
  else:
232
282
  expert_output_dim = input_dim
233
-
283
+
234
284
  # Build CGC layers
235
285
  self.cgc_layers = nn.ModuleList()
236
286
  for level in range(num_levels):
@@ -245,16 +295,25 @@ class PLE(BaseModel):
245
295
  )
246
296
  self.cgc_layers.append(cgc_layer)
247
297
  expert_output_dim = cgc_layer.output_dim
248
-
298
+
249
299
  # Task-specific towers
250
300
  self.towers = nn.ModuleList()
251
301
  for tower_params in tower_params_list:
252
302
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
253
303
  self.towers.append(tower)
254
- self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
304
+ self.prediction_layer = PredictionLayer(
305
+ task_type=self.default_task, task_dims=[1] * self.num_tasks
306
+ )
255
307
  # Register regularization weights
256
- self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
257
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
308
+ self.register_regularization_weights(
309
+ embedding_attr="embedding", include_modules=["cgc_layers", "towers"]
310
+ )
311
+ self.compile(
312
+ optimizer=optimizer,
313
+ optimizer_params=optimizer_params,
314
+ loss=self.loss,
315
+ loss_params=loss_params,
316
+ )
258
317
 
259
318
  def forward(self, x):
260
319
  # Get all embeddings and flatten