nextrec 0.4.16__py3-none-any.whl → 0.4.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/heads.py +99 -0
  3. nextrec/basic/loggers.py +5 -5
  4. nextrec/basic/model.py +217 -88
  5. nextrec/cli.py +1 -1
  6. nextrec/data/dataloader.py +93 -95
  7. nextrec/data/preprocessor.py +108 -46
  8. nextrec/loss/grad_norm.py +13 -13
  9. nextrec/models/multi_task/esmm.py +10 -11
  10. nextrec/models/multi_task/mmoe.py +20 -19
  11. nextrec/models/multi_task/ple.py +35 -34
  12. nextrec/models/multi_task/poso.py +23 -21
  13. nextrec/models/multi_task/share_bottom.py +18 -17
  14. nextrec/models/ranking/afm.py +4 -3
  15. nextrec/models/ranking/autoint.py +4 -3
  16. nextrec/models/ranking/dcn.py +4 -3
  17. nextrec/models/ranking/dcn_v2.py +4 -3
  18. nextrec/models/ranking/deepfm.py +4 -3
  19. nextrec/models/ranking/dien.py +2 -2
  20. nextrec/models/ranking/din.py +2 -2
  21. nextrec/models/ranking/eulernet.py +4 -3
  22. nextrec/models/ranking/ffm.py +4 -3
  23. nextrec/models/ranking/fibinet.py +2 -2
  24. nextrec/models/ranking/fm.py +4 -3
  25. nextrec/models/ranking/lr.py +4 -3
  26. nextrec/models/ranking/masknet.py +4 -5
  27. nextrec/models/ranking/pnn.py +5 -4
  28. nextrec/models/ranking/widedeep.py +8 -8
  29. nextrec/models/ranking/xdeepfm.py +5 -4
  30. nextrec/utils/console.py +20 -6
  31. nextrec/utils/data.py +154 -32
  32. nextrec/utils/model.py +86 -1
  33. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/METADATA +5 -6
  34. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/RECORD +37 -36
  35. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/WHEEL +0 -0
  36. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/entry_points.txt +0 -0
  37. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/grad_norm.py CHANGED
@@ -2,7 +2,7 @@
2
2
  GradNorm loss weighting for multi-task learning.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 24/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
 
8
8
  Reference:
@@ -45,7 +45,7 @@ class GradNormLossWeighting:
45
45
  Adaptive multi-task loss weighting with GradNorm.
46
46
 
47
47
  Args:
48
- num_tasks: Number of tasks.
48
+ nums_task: Number of tasks.
49
49
  alpha: GradNorm balancing strength.
50
50
  lr: Learning rate for the weight optimizer.
51
51
  init_weights: Optional initial weights per task.
@@ -58,7 +58,7 @@ class GradNormLossWeighting:
58
58
 
59
59
  def __init__(
60
60
  self,
61
- num_tasks: int,
61
+ nums_task: int,
62
62
  alpha: float = 1.5,
63
63
  lr: float = 0.025,
64
64
  init_weights: Iterable[float] | None = None,
@@ -68,9 +68,9 @@ class GradNormLossWeighting:
68
68
  init_ema_decay: float = 0.9,
69
69
  eps: float = 1e-8,
70
70
  ) -> None:
71
- if num_tasks <= 1:
72
- raise ValueError("GradNorm requires num_tasks > 1.")
73
- self.num_tasks = num_tasks
71
+ if nums_task <= 1:
72
+ raise ValueError("GradNorm requires nums_task > 1.")
73
+ self.nums_task = nums_task
74
74
  self.alpha = alpha
75
75
  self.eps = eps
76
76
  if ema_decay is not None:
@@ -87,12 +87,12 @@ class GradNormLossWeighting:
87
87
  self.init_ema_count = 0
88
88
 
89
89
  if init_weights is None:
90
- weights = torch.ones(self.num_tasks, dtype=torch.float32)
90
+ weights = torch.ones(self.nums_task, dtype=torch.float32)
91
91
  else:
92
92
  weights = torch.tensor(list(init_weights), dtype=torch.float32)
93
- if weights.numel() != self.num_tasks:
93
+ if weights.numel() != self.nums_task:
94
94
  raise ValueError(
95
- "init_weights length must match num_tasks for GradNorm."
95
+ "init_weights length must match nums_task for GradNorm."
96
96
  )
97
97
  if device is not None:
98
98
  weights = weights.to(device)
@@ -123,9 +123,9 @@ class GradNormLossWeighting:
123
123
  """
124
124
  Return weighted total loss and update task weights with GradNorm.
125
125
  """
126
- if len(task_losses) != self.num_tasks:
126
+ if len(task_losses) != self.nums_task:
127
127
  raise ValueError(
128
- f"Expected {self.num_tasks} task losses, got {len(task_losses)}."
128
+ f"Expected {self.nums_task} task losses, got {len(task_losses)}."
129
129
  )
130
130
  shared_params = [p for p in shared_params if p.requires_grad]
131
131
  if not shared_params:
@@ -152,7 +152,7 @@ class GradNormLossWeighting:
152
152
 
153
153
  weights_detached = self.weights.detach()
154
154
  weighted_losses = [
155
- weights_detached[i] * task_losses[i] for i in range(self.num_tasks)
155
+ weights_detached[i] * task_losses[i] for i in range(self.nums_task)
156
156
  ]
157
157
  total_loss = torch.stack(weighted_losses).sum()
158
158
 
@@ -226,7 +226,7 @@ class GradNormLossWeighting:
226
226
 
227
227
  with torch.no_grad():
228
228
  w = self.weights.clamp(min=self.eps)
229
- w = w * self.num_tasks / (w.sum() + self.eps)
229
+ w = w * self.nums_task / (w.sum() + self.eps)
230
230
  self.weights.copy_(w)
231
231
 
232
232
  self.pending_grad = None
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 29/11/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach
@@ -45,7 +45,8 @@ import torch
45
45
  import torch.nn as nn
46
46
 
47
47
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
48
+ from nextrec.basic.layers import MLP, EmbeddingLayer
49
+ from nextrec.basic.heads import TaskHead
49
50
  from nextrec.basic.model import BaseModel
50
51
 
51
52
 
@@ -100,17 +101,17 @@ class ESMM(BaseModel):
100
101
  f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
101
102
  )
102
103
 
103
- self.num_tasks = len(target)
104
+ self.nums_task = len(target)
104
105
  resolved_task = task
105
106
  if resolved_task is None:
106
107
  resolved_task = self.default_task
107
108
  elif isinstance(resolved_task, str):
108
- resolved_task = [resolved_task] * self.num_tasks
109
- elif len(resolved_task) == 1 and self.num_tasks > 1:
110
- resolved_task = resolved_task * self.num_tasks
111
- elif len(resolved_task) != self.num_tasks:
109
+ resolved_task = [resolved_task] * self.nums_task
110
+ elif len(resolved_task) == 1 and self.nums_task > 1:
111
+ resolved_task = resolved_task * self.nums_task
112
+ elif len(resolved_task) != self.nums_task:
112
113
  raise ValueError(
113
- f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
114
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
114
115
  )
115
116
  # resolved_task is now guaranteed to be a list[str]
116
117
 
@@ -139,9 +140,7 @@ class ESMM(BaseModel):
139
140
  # CVR tower
140
141
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
141
142
  self.grad_norm_shared_modules = ["embedding"]
142
- self.prediction_layer = PredictionLayer(
143
- task_type=self.default_task, task_dims=[1, 1]
144
- )
143
+ self.prediction_layer = TaskHead(task_type=self.default_task, task_dims=[1, 1])
145
144
  # Register regularization weights
146
145
  self.register_regularization_weights(
147
146
  embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 29/11/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with
@@ -46,7 +46,8 @@ import torch
46
46
  import torch.nn as nn
47
47
 
48
48
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
49
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
49
+ from nextrec.basic.layers import MLP, EmbeddingLayer
50
+ from nextrec.basic.heads import TaskHead
50
51
  from nextrec.basic.model import BaseModel
51
52
 
52
53
 
@@ -66,9 +67,9 @@ class MMOE(BaseModel):
66
67
 
67
68
  @property
68
69
  def default_task(self):
69
- num_tasks = getattr(self, "num_tasks", None)
70
- if num_tasks is not None and num_tasks > 0:
71
- return ["binary"] * num_tasks
70
+ nums_task = getattr(self, "nums_task", None)
71
+ if nums_task is not None and nums_task > 0:
72
+ return ["binary"] * nums_task
72
73
  return ["binary"]
73
74
 
74
75
  def __init__(
@@ -106,18 +107,18 @@ class MMOE(BaseModel):
106
107
  elif isinstance(target, str):
107
108
  target = [target]
108
109
 
109
- self.num_tasks = len(target) if target else 1
110
+ self.nums_task = len(target) if target else 1
110
111
 
111
112
  resolved_task = task
112
113
  if resolved_task is None:
113
114
  resolved_task = self.default_task
114
115
  elif isinstance(resolved_task, str):
115
- resolved_task = [resolved_task] * self.num_tasks
116
- elif len(resolved_task) == 1 and self.num_tasks > 1:
117
- resolved_task = resolved_task * self.num_tasks
118
- elif len(resolved_task) != self.num_tasks:
116
+ resolved_task = [resolved_task] * self.nums_task
117
+ elif len(resolved_task) == 1 and self.nums_task > 1:
118
+ resolved_task = resolved_task * self.nums_task
119
+ elif len(resolved_task) != self.nums_task:
119
120
  raise ValueError(
120
- f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
121
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
121
122
  )
122
123
 
123
124
  super(MMOE, self).__init__(
@@ -137,12 +138,12 @@ class MMOE(BaseModel):
137
138
  self.loss = loss
138
139
 
139
140
  # Number of tasks and experts
140
- self.num_tasks = len(target)
141
+ self.nums_task = len(target)
141
142
  self.num_experts = num_experts
142
143
 
143
- if len(tower_params_list) != self.num_tasks:
144
+ if len(tower_params_list) != self.nums_task:
144
145
  raise ValueError(
145
- f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
146
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
146
147
  )
147
148
 
148
149
  self.embedding = EmbeddingLayer(features=self.all_features)
@@ -162,7 +163,7 @@ class MMOE(BaseModel):
162
163
 
163
164
  # Task-specific gates
164
165
  self.gates = nn.ModuleList()
165
- for _ in range(self.num_tasks):
166
+ for _ in range(self.nums_task):
166
167
  gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
167
168
  self.gates.append(gate)
168
169
  self.grad_norm_shared_modules = ["embedding", "experts", "gates"]
@@ -172,8 +173,8 @@ class MMOE(BaseModel):
172
173
  for tower_params in tower_params_list:
173
174
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
174
175
  self.towers.append(tower)
175
- self.prediction_layer = PredictionLayer(
176
- task_type=self.default_task, task_dims=[1] * self.num_tasks
176
+ self.prediction_layer = TaskHead(
177
+ task_type=self.default_task, task_dims=[1] * self.nums_task
177
178
  )
178
179
  # Register regularization weights
179
180
  self.register_regularization_weights(
@@ -198,7 +199,7 @@ class MMOE(BaseModel):
198
199
 
199
200
  # Task-specific processing
200
201
  task_outputs = []
201
- for task_idx in range(self.num_tasks):
202
+ for task_idx in range(self.nums_task):
202
203
  # Gate weights for this task: [B, num_experts]
203
204
  gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
204
205
 
@@ -217,6 +218,6 @@ class MMOE(BaseModel):
217
218
  tower_output = self.towers[task_idx](gated_output) # [B, 1]
218
219
  task_outputs.append(tower_output)
219
220
 
220
- # Stack outputs: [B, num_tasks]
221
+ # Stack outputs: [B, nums_task]
221
222
  y = torch.cat(task_outputs, dim=1)
222
223
  return self.prediction_layer(y)
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 29/11/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (PLE): A novel
@@ -49,7 +49,8 @@ import torch
49
49
  import torch.nn as nn
50
50
 
51
51
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
52
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
52
+ from nextrec.basic.layers import MLP, EmbeddingLayer
53
+ from nextrec.basic.heads import TaskHead
53
54
  from nextrec.basic.model import BaseModel
54
55
  from nextrec.utils.model import get_mlp_output_dim
55
56
 
@@ -63,18 +64,18 @@ class CGCLayer(nn.Module):
63
64
  def __init__(
64
65
  self,
65
66
  input_dim: int,
66
- num_tasks: int,
67
+ nums_task: int,
67
68
  num_shared_experts: int,
68
69
  num_specific_experts: int,
69
70
  shared_expert_params: dict,
70
71
  specific_expert_params: dict | list[dict],
71
72
  ):
72
73
  super().__init__()
73
- if num_tasks < 1:
74
- raise ValueError("num_tasks must be >= 1")
74
+ if nums_task < 1:
75
+ raise ValueError("nums_task must be >= 1")
75
76
 
76
77
  specific_params_list = self.normalize_specific_params(
77
- specific_expert_params, num_tasks
78
+ specific_expert_params, nums_task
78
79
  )
79
80
 
80
81
  self.output_dim = get_mlp_output_dim(shared_expert_params, input_dim)
@@ -120,23 +121,23 @@ class CGCLayer(nn.Module):
120
121
  nn.Linear(input_dim, task_gate_expert_num),
121
122
  nn.Softmax(dim=1),
122
123
  )
123
- for _ in range(num_tasks)
124
+ for _ in range(nums_task)
124
125
  ]
125
126
  )
126
- shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
127
+ shared_gate_expert_num = num_shared_experts + num_specific_experts * nums_task
127
128
  self.shared_gate = nn.Sequential(
128
129
  nn.Linear(input_dim, shared_gate_expert_num),
129
130
  nn.Softmax(dim=1),
130
131
  )
131
132
 
132
- self.num_tasks = num_tasks
133
+ self.nums_task = nums_task
133
134
 
134
135
  def forward(
135
136
  self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
136
137
  ) -> tuple[list[torch.Tensor], torch.Tensor]:
137
- if len(task_inputs) != self.num_tasks:
138
+ if len(task_inputs) != self.nums_task:
138
139
  raise ValueError(
139
- f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}"
140
+ f"Expected {self.nums_task} task inputs, got {len(task_inputs)}"
140
141
  )
141
142
 
142
143
  shared_outputs = [expert(shared_input) for expert in self.shared_experts]
@@ -145,7 +146,7 @@ class CGCLayer(nn.Module):
145
146
  new_task_fea: list[torch.Tensor] = []
146
147
  all_specific_for_shared: list[torch.Tensor] = []
147
148
 
148
- for task_idx in range(self.num_tasks):
149
+ for task_idx in range(self.nums_task):
149
150
  task_input = task_inputs[task_idx]
150
151
  task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
151
152
  all_specific_for_shared.extend(task_specific_outputs)
@@ -167,15 +168,15 @@ class CGCLayer(nn.Module):
167
168
 
168
169
  @staticmethod
169
170
  def normalize_specific_params(
170
- params: dict | list[dict], num_tasks: int
171
+ params: dict | list[dict], nums_task: int
171
172
  ) -> list[dict]:
172
173
  if isinstance(params, list):
173
- if len(params) != num_tasks:
174
+ if len(params) != nums_task:
174
175
  raise ValueError(
175
- f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks})."
176
+ f"Length of specific_expert_params ({len(params)}) must match nums_task ({nums_task})."
176
177
  )
177
178
  return [p.copy() for p in params]
178
- return [params.copy() for _ in range(num_tasks)]
179
+ return [params.copy() for _ in range(nums_task)]
179
180
 
180
181
 
181
182
  class PLE(BaseModel):
@@ -194,9 +195,9 @@ class PLE(BaseModel):
194
195
 
195
196
  @property
196
197
  def default_task(self):
197
- num_tasks = getattr(self, "num_tasks", None)
198
- if num_tasks is not None and num_tasks > 0:
199
- return ["binary"] * num_tasks
198
+ nums_task = getattr(self, "nums_task", None)
199
+ if nums_task is not None and nums_task > 0:
200
+ return ["binary"] * nums_task
200
201
  return ["binary"]
201
202
 
202
203
  def __init__(
@@ -224,18 +225,18 @@ class PLE(BaseModel):
224
225
  **kwargs,
225
226
  ):
226
227
 
227
- self.num_tasks = len(target)
228
+ self.nums_task = len(target)
228
229
 
229
230
  resolved_task = task
230
231
  if resolved_task is None:
231
232
  resolved_task = self.default_task
232
233
  elif isinstance(resolved_task, str):
233
- resolved_task = [resolved_task] * self.num_tasks
234
- elif len(resolved_task) == 1 and self.num_tasks > 1:
235
- resolved_task = resolved_task * self.num_tasks
236
- elif len(resolved_task) != self.num_tasks:
234
+ resolved_task = [resolved_task] * self.nums_task
235
+ elif len(resolved_task) == 1 and self.nums_task > 1:
236
+ resolved_task = resolved_task * self.nums_task
237
+ elif len(resolved_task) != self.nums_task:
237
238
  raise ValueError(
238
- f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
239
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
239
240
  )
240
241
 
241
242
  super(PLE, self).__init__(
@@ -256,15 +257,15 @@ class PLE(BaseModel):
256
257
  if self.loss is None:
257
258
  self.loss = "bce"
258
259
  # Number of tasks, experts, and levels
259
- self.num_tasks = len(target)
260
+ self.nums_task = len(target)
260
261
  self.num_shared_experts = num_shared_experts
261
262
  self.num_specific_experts = num_specific_experts
262
263
  self.num_levels = num_levels
263
264
  if optimizer_params is None:
264
265
  optimizer_params = {}
265
- if len(tower_params_list) != self.num_tasks:
266
+ if len(tower_params_list) != self.nums_task:
266
267
  raise ValueError(
267
- f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
268
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
268
269
  )
269
270
  # Embedding layer
270
271
  self.embedding = EmbeddingLayer(features=self.all_features)
@@ -287,7 +288,7 @@ class PLE(BaseModel):
287
288
  level_input_dim = input_dim if level == 0 else expert_output_dim
288
289
  cgc_layer = CGCLayer(
289
290
  input_dim=level_input_dim,
290
- num_tasks=self.num_tasks,
291
+ nums_task=self.nums_task,
291
292
  num_shared_experts=num_shared_experts,
292
293
  num_specific_experts=num_specific_experts,
293
294
  shared_expert_params=shared_expert_params,
@@ -302,8 +303,8 @@ class PLE(BaseModel):
302
303
  for tower_params in tower_params_list:
303
304
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
304
305
  self.towers.append(tower)
305
- self.prediction_layer = PredictionLayer(
306
- task_type=self.default_task, task_dims=[1] * self.num_tasks
306
+ self.prediction_layer = TaskHead(
307
+ task_type=self.default_task, task_dims=[1] * self.nums_task
307
308
  )
308
309
  # Register regularization weights
309
310
  self.register_regularization_weights(
@@ -321,7 +322,7 @@ class PLE(BaseModel):
321
322
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
322
323
 
323
324
  # Initial features for each task and shared
324
- task_fea = [input_flat for _ in range(self.num_tasks)]
325
+ task_fea = [input_flat for _ in range(self.nums_task)]
325
326
  shared_fea = input_flat
326
327
 
327
328
  # Progressive Layered Extraction: CGC
@@ -330,10 +331,10 @@ class PLE(BaseModel):
330
331
 
331
332
  # task tower
332
333
  task_outputs = []
333
- for task_idx in range(self.num_tasks):
334
+ for task_idx in range(self.nums_task):
334
335
  tower_output = self.towers[task_idx](task_fea[task_idx]) # [B, 1]
335
336
  task_outputs.append(tower_output)
336
337
 
337
- # [B, num_tasks]
338
+ # [B, nums_task]
338
339
  y = torch.cat(task_outputs, dim=1)
339
340
  return self.prediction_layer(y)
@@ -1,5 +1,6 @@
1
1
  """
2
2
  Date: create on 28/11/2025
3
+ Checkpoint: edit on 23/12/2025
3
4
  Author: Yang Zhou,zyaztec@gmail.com
4
5
  Reference:
5
6
  [1] Wang et al. "POSO: Personalized Cold Start Modules for Large-scale Recommender Systems", 2021.
@@ -44,7 +45,8 @@ import torch.nn.functional as F
44
45
 
45
46
  from nextrec.basic.activation import activation_layer
46
47
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
47
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
48
+ from nextrec.basic.layers import MLP, EmbeddingLayer
49
+ from nextrec.basic.heads import TaskHead
48
50
  from nextrec.basic.model import BaseModel
49
51
  from nextrec.utils.model import select_features
50
52
 
@@ -195,7 +197,7 @@ class POSOMMoE(nn.Module):
195
197
  pc_dim: int, # for poso feature dimension
196
198
  num_experts: int,
197
199
  expert_hidden_dims: list[int],
198
- num_tasks: int,
200
+ nums_task: int,
199
201
  activation: str = "relu",
200
202
  expert_dropout: float = 0.0,
201
203
  gate_hidden_dim: int = 32, # for poso gate hidden dimension
@@ -204,7 +206,7 @@ class POSOMMoE(nn.Module):
204
206
  ) -> None:
205
207
  super().__init__()
206
208
  self.num_experts = num_experts
207
- self.num_tasks = num_tasks
209
+ self.nums_task = nums_task
208
210
 
209
211
  # Experts built with framework MLP, same as standard MMoE
210
212
  self.experts = nn.ModuleList(
@@ -225,7 +227,7 @@ class POSOMMoE(nn.Module):
225
227
 
226
228
  # Task-specific gates: gate_t(x) over experts
227
229
  self.gates = nn.ModuleList(
228
- [nn.Linear(input_dim, num_experts) for _ in range(num_tasks)]
230
+ [nn.Linear(input_dim, num_experts) for _ in range(nums_task)]
229
231
  )
230
232
  self.gate_use_softmax = gate_use_softmax
231
233
 
@@ -247,7 +249,7 @@ class POSOMMoE(nn.Module):
247
249
  """
248
250
  x: (B, input_dim)
249
251
  pc: (B, pc_dim)
250
- return: list of task outputs z_t with length num_tasks, each (B, D)
252
+ return: list of task outputs z_t with length nums_task, each (B, D)
251
253
  """
252
254
  # 1) Expert outputs with POSO PC gate
253
255
  masked_expert_outputs = []
@@ -261,7 +263,7 @@ class POSOMMoE(nn.Module):
261
263
 
262
264
  # 2) Task gates depend on x as in standard MMoE
263
265
  task_outputs: list[torch.Tensor] = []
264
- for t in range(self.num_tasks):
266
+ for t in range(self.nums_task):
265
267
  logits = self.gates[t](x) # (B, E)
266
268
  if self.gate_use_softmax:
267
269
  gate = F.softmax(logits, dim=1)
@@ -288,9 +290,9 @@ class POSO(BaseModel):
288
290
 
289
291
  @property
290
292
  def default_task(self) -> list[str]:
291
- num_tasks = getattr(self, "num_tasks", None)
292
- if num_tasks is not None and num_tasks > 0:
293
- return ["binary"] * num_tasks
293
+ nums_task = getattr(self, "nums_task", None)
294
+ if nums_task is not None and nums_task > 0:
295
+ return ["binary"] * nums_task
294
296
  return ["binary"]
295
297
 
296
298
  def __init__(
@@ -332,24 +334,24 @@ class POSO(BaseModel):
332
334
  dense_l2_reg: float = 1e-4,
333
335
  **kwargs,
334
336
  ):
335
- self.num_tasks = len(target)
337
+ self.nums_task = len(target)
336
338
 
337
- # Normalize task to match num_tasks
339
+ # Normalize task to match nums_task
338
340
  resolved_task = task
339
341
  if resolved_task is None:
340
342
  resolved_task = self.default_task
341
343
  elif isinstance(resolved_task, str):
342
- resolved_task = [resolved_task] * self.num_tasks
343
- elif len(resolved_task) == 1 and self.num_tasks > 1:
344
- resolved_task = resolved_task * self.num_tasks
345
- elif len(resolved_task) != self.num_tasks:
344
+ resolved_task = [resolved_task] * self.nums_task
345
+ elif len(resolved_task) == 1 and self.nums_task > 1:
346
+ resolved_task = resolved_task * self.nums_task
347
+ elif len(resolved_task) != self.nums_task:
346
348
  raise ValueError(
347
- f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
349
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
348
350
  )
349
351
 
350
- if len(tower_params_list) != self.num_tasks:
352
+ if len(tower_params_list) != self.nums_task:
351
353
  raise ValueError(
352
- f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
354
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
353
355
  )
354
356
 
355
357
  super().__init__(
@@ -465,7 +467,7 @@ class POSO(BaseModel):
465
467
  pc_dim=self.pc_input_dim,
466
468
  num_experts=num_experts,
467
469
  expert_hidden_dims=expert_hidden_dims,
468
- num_tasks=self.num_tasks,
470
+ nums_task=self.nums_task,
469
471
  activation=expert_activation,
470
472
  expert_dropout=expert_dropout,
471
473
  gate_hidden_dim=expert_gate_hidden_dim,
@@ -487,9 +489,9 @@ class POSO(BaseModel):
487
489
  self.grad_norm_shared_modules = ["embedding"]
488
490
  else:
489
491
  self.grad_norm_shared_modules = ["embedding", "mmoe"]
490
- self.prediction_layer = PredictionLayer(
492
+ self.prediction_layer = TaskHead(
491
493
  task_type=self.default_task,
492
- task_dims=[1] * self.num_tasks,
494
+ task_dims=[1] * self.nums_task,
493
495
  )
494
496
  include_modules = (
495
497
  ["towers", "tower_heads"]
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 24/11/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Caruana R. Multitask learning[J]. Machine Learning, 1997, 28: 41-75.
@@ -43,7 +43,8 @@ import torch
43
43
  import torch.nn as nn
44
44
 
45
45
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
46
+ from nextrec.basic.layers import MLP, EmbeddingLayer
47
+ from nextrec.basic.heads import TaskHead
47
48
  from nextrec.basic.model import BaseModel
48
49
 
49
50
 
@@ -54,9 +55,9 @@ class ShareBottom(BaseModel):
54
55
 
55
56
  @property
56
57
  def default_task(self):
57
- num_tasks = getattr(self, "num_tasks", None)
58
- if num_tasks is not None and num_tasks > 0:
59
- return ["binary"] * num_tasks
58
+ nums_task = getattr(self, "nums_task", None)
59
+ if nums_task is not None and nums_task > 0:
60
+ return ["binary"] * nums_task
60
61
  return ["binary"]
61
62
 
62
63
  def __init__(
@@ -82,18 +83,18 @@ class ShareBottom(BaseModel):
82
83
 
83
84
  optimizer_params = optimizer_params or {}
84
85
 
85
- self.num_tasks = len(target)
86
+ self.nums_task = len(target)
86
87
 
87
88
  resolved_task = task
88
89
  if resolved_task is None:
89
90
  resolved_task = self.default_task
90
91
  elif isinstance(resolved_task, str):
91
- resolved_task = [resolved_task] * self.num_tasks
92
- elif len(resolved_task) == 1 and self.num_tasks > 1:
93
- resolved_task = resolved_task * self.num_tasks
94
- elif len(resolved_task) != self.num_tasks:
92
+ resolved_task = [resolved_task] * self.nums_task
93
+ elif len(resolved_task) == 1 and self.nums_task > 1:
94
+ resolved_task = resolved_task * self.nums_task
95
+ elif len(resolved_task) != self.nums_task:
95
96
  raise ValueError(
96
- f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
97
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.nums_task})."
97
98
  )
98
99
 
99
100
  super(ShareBottom, self).__init__(
@@ -114,10 +115,10 @@ class ShareBottom(BaseModel):
114
115
  if self.loss is None:
115
116
  self.loss = "bce"
116
117
  # Number of tasks
117
- self.num_tasks = len(target)
118
- if len(tower_params_list) != self.num_tasks:
118
+ self.nums_task = len(target)
119
+ if len(tower_params_list) != self.nums_task:
119
120
  raise ValueError(
120
- f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
121
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
121
122
  )
122
123
  # Embedding layer
123
124
  self.embedding = EmbeddingLayer(features=self.all_features)
@@ -142,8 +143,8 @@ class ShareBottom(BaseModel):
142
143
  for tower_params in tower_params_list:
143
144
  tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
144
145
  self.towers.append(tower)
145
- self.prediction_layer = PredictionLayer(
146
- task_type=self.default_task, task_dims=[1] * self.num_tasks
146
+ self.prediction_layer = TaskHead(
147
+ task_type=self.default_task, task_dims=[1] * self.nums_task
147
148
  )
148
149
  # Register regularization weights
149
150
  self.register_regularization_weights(
@@ -169,6 +170,6 @@ class ShareBottom(BaseModel):
169
170
  tower_output = tower(bottom_output) # [B, 1]
170
171
  task_outputs.append(tower_output)
171
172
 
172
- # Stack outputs: [B, num_tasks]
173
+ # Stack outputs: [B, nums_task]
173
174
  y = torch.cat(task_outputs, dim=1)
174
175
  return self.prediction_layer(y)