nextrec 0.2.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,44 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
+ Checkpoint: edit on 29/11/2025
3
4
  Author: Yang Zhou,zyaztec@gmail.com
4
- Reference: [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
5
+ Reference:
6
+ [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach
7
+ for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
8
+ (https://dl.acm.org/doi/10.1145/3209978.3210007)
9
+
10
+ Entire Space Multi-task Model (ESMM) targets CVR estimation by jointly optimizing
11
+ CTR and CTCVR on the full impression space, mitigating sample selection bias and
12
+ conversion sparsity. CTR predicts P(click | impression), CVR predicts P(conversion |
13
+ click), and their product forms CTCVR supervised on impression labels.
14
+
15
+ Workflow:
16
+ (1) Shared embeddings encode all features from impressions
17
+ (2) CTR tower outputs click probability conditioned on impression
18
+ (3) CVR tower outputs conversion probability conditioned on click
19
+ (4) CTCVR = CTR * CVR enables end-to-end training without filtering clicked data
20
+
21
+ Key Advantages:
22
+ - Trains on the entire impression space to remove selection bias
23
+ - Transfers rich click signals to sparse conversion prediction via shared embeddings
24
+ - Stable optimization by decomposing CTCVR into well-defined sub-tasks
25
+ - Simple architecture that can pair with other multi-task variants
26
+
27
+ ESMM(Entire Space Multi-task Model)用于 CVR 预估,通过在曝光全空间联合训练
28
+ CTR 与 CTCVR,缓解样本选择偏差和转化数据稀疏问题。CTR 预测 P(click|impression),
29
+ CVR 预测 P(conversion|click),二者相乘得到 CTCVR 并在曝光标签上直接监督。
30
+
31
+ 流程:
32
+ (1) 共享 embedding 统一处理曝光特征
33
+ (2) CTR 塔输出曝光下的点击概率
34
+ (3) CVR 塔输出点击后的转化概率
35
+ (4) CTR 与 CVR 相乘得到 CTCVR,无需只在点击子集上训练
36
+
37
+ 主要优点:
38
+ - 在曝光空间训练,避免样本选择偏差
39
+ - 通过共享表示将点击信号迁移到稀疏的转化任务
40
+ - 将 CTCVR 分解为子任务,优化稳定
41
+ - 结构简单,可与其它多任务方法组合使用
5
42
  """
6
43
 
7
44
  import torch
@@ -77,37 +114,22 @@ class ESMM(BaseModel):
77
114
 
78
115
  # All features
79
116
  self.all_features = dense_features + sparse_features + sequence_features
80
-
81
117
  # Shared embedding layer
82
118
  self.embedding = EmbeddingLayer(features=self.all_features)
119
+ input_dim = self.embedding.input_dim # Calculate input dimension, better way than below
120
+ # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
121
+ # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
122
+ # input_dim = emb_dim_total + dense_input_dim
83
123
 
84
- # Calculate input dimension
85
- emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
86
- dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
87
- input_dim = emb_dim_total + dense_input_dim
88
-
89
124
  # CTR tower
90
125
  self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
91
126
 
92
127
  # CVR tower
93
128
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
94
- self.prediction_layer = PredictionLayer(
95
- task_type=self.task_type,
96
- task_dims=[1, 1]
97
- )
98
-
129
+ self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
99
130
  # Register regularization weights
100
- self._register_regularization_weights(
101
- embedding_attr='embedding',
102
- include_modules=['ctr_tower', 'cvr_tower']
103
- )
104
-
105
- self.compile(
106
- optimizer=optimizer,
107
- optimizer_params=optimizer_params,
108
- loss=loss,
109
- loss_params=loss_params,
110
- )
131
+ self._register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
132
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
111
133
 
112
134
  def forward(self, x):
113
135
  # Get all embeddings and flatten
@@ -119,11 +141,8 @@ class ESMM(BaseModel):
119
141
  logits = torch.cat([ctr_logit, cvr_logit], dim=1)
120
142
  preds = self.prediction_layer(logits)
121
143
  ctr, cvr = preds.chunk(2, dim=1)
122
-
123
- # CTCVR prediction: P(click & conversion | impression) = P(click) * P(conversion | click)
124
144
  ctcvr = ctr * cvr # [B, 1]
125
145
 
126
- # Output: [CTR, CTCVR]
127
- # Note: We supervise CTR with click labels and CTCVR with conversion labels
146
+ # Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
128
147
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
129
148
  return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
@@ -1,7 +1,45 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
+ Checkpoint: edit on 29/11/2025
3
4
  Author: Yang Zhou,zyaztec@gmail.com
4
- Reference: [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
5
+ Reference:
6
+ [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with
7
+ multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
8
+ (https://dl.acm.org/doi/10.1145/3219819.3220007)
9
+
10
+ Multi-gate Mixture-of-Experts (MMoE) extends shared-bottom multi-task learning by
11
+ introducing multiple experts and task-specific softmax gates. Each task learns its
12
+ own routing weights over the expert pool, enabling both shared and task-specialized
13
+ representations while alleviating gradient conflicts across tasks.
14
+
15
+ In each forward pass:
16
+ (1) Shared embeddings encode all dense/sparse/sequence features
17
+ (2) Each expert processes the same input to produce candidate shared representations
18
+ (3) Every task gate outputs a simplex over experts to softly route information
19
+ (4) The task-specific weighted sum is passed into its tower and prediction head
20
+
21
+ Key Advantages:
22
+ - Soft parameter sharing reduces negative transfer between heterogeneous tasks
23
+ - Gates let tasks adaptively allocate expert capacity based on their difficulty
24
+ - Supports many tasks without duplicating full networks
25
+ - Works with mixed feature types via unified embeddings
26
+ - Simple to scale the number of experts or gates for capacity control
27
+
28
+ MMoE(Multi-gate Mixture-of-Experts)是多任务学习框架,通过多个专家网络与
29
+ 任务特定门控进行软路由,兼顾共享表示与任务特化,减轻梯度冲突问题。
30
+
31
+ 一次前向流程:
32
+ (1) 共享 embedding 统一编码稠密、稀疏与序列特征
33
+ (2) 每个专家对相同输入进行特征变换,得到候选共享表示
34
+ (3) 每个任务的门控产生对专家的概率分布,完成软选择与加权
35
+ (4) 加权结果输入到对应任务的塔网络与预测头
36
+
37
+ 主要优点:
38
+ - 软参数共享,缓解任务间负迁移
39
+ - 按任务难度自适应分配专家容量
40
+ - 便于扩展多任务,而无需复制完整网络
41
+ - 支持多种特征类型的统一建模
42
+ - 专家与门控数量可灵活调节以控制模型容量
5
43
  """
6
44
 
7
45
  import torch
@@ -75,18 +113,14 @@ class MMOE(BaseModel):
75
113
 
76
114
  if len(tower_params_list) != self.num_tasks:
77
115
  raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
78
-
79
- # All features
80
- self.all_features = dense_features + sparse_features + sequence_features
81
116
 
82
- # Embedding layer
117
+ self.all_features = dense_features + sparse_features + sequence_features
83
118
  self.embedding = EmbeddingLayer(features=self.all_features)
119
+ input_dim = self.embedding.input_dim
120
+ # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
121
+ # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
122
+ # input_dim = emb_dim_total + dense_input_dim
84
123
 
85
- # Calculate input dimension
86
- emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
87
- dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
88
- input_dim = emb_dim_total + dense_input_dim
89
-
90
124
  # Expert networks (shared by all tasks)
91
125
  self.experts = nn.ModuleList()
92
126
  for _ in range(num_experts):
@@ -102,10 +136,7 @@ class MMOE(BaseModel):
102
136
  # Task-specific gates
103
137
  self.gates = nn.ModuleList()
104
138
  for _ in range(self.num_tasks):
105
- gate = nn.Sequential(
106
- nn.Linear(input_dim, num_experts),
107
- nn.Softmax(dim=1)
108
- )
139
+ gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
109
140
  self.gates.append(gate)
110
141
 
111
142
  # Task-specific towers
@@ -113,23 +144,10 @@ class MMOE(BaseModel):
113
144
  for tower_params in tower_params_list:
114
145
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
115
146
  self.towers.append(tower)
116
- self.prediction_layer = PredictionLayer(
117
- task_type=self.task_type,
118
- task_dims=[1] * self.num_tasks
119
- )
120
-
147
+ self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
121
148
  # Register regularization weights
122
- self._register_regularization_weights(
123
- embedding_attr='embedding',
124
- include_modules=['experts', 'gates', 'towers']
125
- )
126
-
127
- self.compile(
128
- optimizer=optimizer,
129
- optimizer_params=optimizer_params,
130
- loss=loss,
131
- loss_params=loss_params,
132
- )
149
+ self._register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
150
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
133
151
 
134
152
  def forward(self, x):
135
153
  # Get all embeddings and flatten
@@ -1,7 +1,48 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
+ Checkpoint: edit on 29/11/2025
3
4
  Author: Yang Zhou,zyaztec@gmail.com
4
- Reference: [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//RecSys. 2020: 269-278.
5
+ Reference:
6
+ [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (PLE): A novel
7
+ multi-task learning (MTL) model for personalized recommendations[C]//RecSys. 2020: 269-278.
8
+ (https://dl.acm.org/doi/10.1145/3383313.3412236)
9
+
10
+ Progressive Layered Extraction (PLE) advances multi-task learning by stacking CGC
11
+ (Customized Gate Control) blocks that mix shared and task-specific experts. Each
12
+ layer routes information via task gates and a shared gate, then feeds the outputs
13
+ forward to deeper layers, progressively disentangling shared vs. task-specific
14
+ signals and mitigating gradient interference.
15
+
16
+ Layer workflow:
17
+ (1) Shared and per-task experts transform the same inputs
18
+ (2) Task gates select among shared + task-specific experts
19
+ (3) A shared gate aggregates all experts for the shared branch
20
+ (4) Outputs become inputs to the next CGC layer (progressive refinement)
21
+ (5) Final task towers operate on the last-layer task representations
22
+
23
+ Key Advantages:
24
+ - Progressive routing reduces negative transfer across layers
25
+ - Explicit shared/specific experts improve feature disentanglement
26
+ - Flexible depth and expert counts to match task complexity
27
+ - Works with heterogeneous features via unified embeddings
28
+ - Stable training by separating gates for shared and task branches
29
+
30
+ PLE(Progressive Layered Extraction)通过堆叠 CGC 模块,联合共享与任务特定专家,
31
+ 利用任务门与共享门逐层软路由,逐步分离共享与任务差异信息,缓解多任务间的梯度冲突。
32
+
33
+ 层内流程:
34
+ (1) 共享与任务专家对同一输入做特征变换
35
+ (2) 任务门在共享+任务专家上进行软选择
36
+ (3) 共享门汇总全部专家,更新共享分支
37
+ (4) 输出作为下一层输入,完成逐层细化
38
+ (5) 最后由任务塔完成各任务预测
39
+
40
+ 主要优点:
41
+ - 逐层路由降低负迁移
42
+ - 显式区分共享/特定专家,增强特征解耦
43
+ - 专家数量与层数可按任务复杂度灵活设置
44
+ - 统一 embedding 支持多种特征类型
45
+ - 共享与任务门分离,训练更稳定
5
46
  """
6
47
 
7
48
  import torch
@@ -11,6 +52,97 @@ from nextrec.basic.model import BaseModel
11
52
  from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
12
53
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
13
54
 
55
+ class CGCLayer(nn.Module):
56
+ """
57
+ CGC (Customized Gate Control) block used by PLE.
58
+ It routes shared and task-specific experts with task gates and a shared gate.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ input_dim: int,
64
+ num_tasks: int,
65
+ num_shared_experts: int,
66
+ num_specific_experts: int,
67
+ shared_expert_params: dict,
68
+ specific_expert_params: dict | list[dict],
69
+ ):
70
+ super().__init__()
71
+ if num_tasks < 1:
72
+ raise ValueError("num_tasks must be >= 1")
73
+
74
+ specific_params_list = self._normalize_specific_params(specific_expert_params, num_tasks)
75
+
76
+ self.output_dim = self._get_output_dim(shared_expert_params, input_dim)
77
+ specific_dims = [self._get_output_dim(params, input_dim) for params in specific_params_list]
78
+ dims_set = set(specific_dims + [self.output_dim])
79
+ if len(dims_set) != 1:
80
+ raise ValueError(f"Shared/specific expert output dims must match, got {dims_set}")
81
+
82
+ # experts
83
+ self.shared_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **shared_expert_params,) for _ in range(num_shared_experts)])
84
+ self.specific_experts = nn.ModuleList()
85
+ for params in specific_params_list:
86
+ task_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **params,) for _ in range(num_specific_experts)])
87
+ self.specific_experts.append(task_experts)
88
+
89
+ # gates
90
+ task_gate_expert_num = num_shared_experts + num_specific_experts
91
+ self.task_gates = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, task_gate_expert_num), nn.Softmax(dim=1),) for _ in range(num_tasks)])
92
+ shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
93
+ self.shared_gate = nn.Sequential(nn.Linear(input_dim, shared_gate_expert_num), nn.Softmax(dim=1),)
94
+
95
+ self.num_tasks = num_tasks
96
+
97
+ def forward(
98
+ self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
99
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
100
+ if len(task_inputs) != self.num_tasks:
101
+ raise ValueError(f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}")
102
+
103
+ shared_outputs = [expert(shared_input) for expert in self.shared_experts]
104
+ shared_stack = torch.stack(shared_outputs, dim=0) # [num_shared, B, D]
105
+
106
+ new_task_fea: list[torch.Tensor] = []
107
+ all_specific_for_shared: list[torch.Tensor] = []
108
+
109
+ for task_idx in range(self.num_tasks):
110
+ task_input = task_inputs[task_idx]
111
+ task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
112
+ all_specific_for_shared.extend(task_specific_outputs)
113
+ specific_stack = torch.stack(task_specific_outputs, dim=0)
114
+
115
+ all_experts = torch.cat([shared_stack, specific_stack], dim=0)
116
+ all_experts_t = all_experts.permute(1, 0, 2) # [B, num_expert, D]
117
+
118
+ gate_weights = self.task_gates[task_idx](task_input).unsqueeze(2)
119
+ gated_output = torch.sum(gate_weights * all_experts_t, dim=1)
120
+ new_task_fea.append(gated_output)
121
+
122
+ all_for_shared = all_specific_for_shared + shared_outputs
123
+ all_for_shared_tensor = torch.stack(all_for_shared, dim=1) # [B, num_all, D]
124
+ shared_gate_weights = self.shared_gate(shared_input).unsqueeze(1)
125
+ new_shared = torch.bmm(shared_gate_weights, all_for_shared_tensor).squeeze(1)
126
+
127
+ return new_task_fea, new_shared
128
+
129
+ @staticmethod
130
+ def _get_output_dim(params: dict, fallback: int) -> int:
131
+ dims = params.get("dims")
132
+ if dims:
133
+ return dims[-1]
134
+ return fallback
135
+
136
+ @staticmethod
137
+ def _normalize_specific_params(
138
+ params: dict | list[dict], num_tasks: int
139
+ ) -> list[dict]:
140
+ if isinstance(params, list):
141
+ if len(params) != num_tasks:
142
+ raise ValueError(f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks}).")
143
+ return [p.copy() for p in params]
144
+ return [params.copy() for _ in range(num_tasks)]
145
+
14
146
 
15
147
  class PLE(BaseModel):
16
148
  """
@@ -35,7 +167,7 @@ class PLE(BaseModel):
35
167
  sparse_features: list[SparseFeature],
36
168
  sequence_features: list[SequenceFeature],
37
169
  shared_expert_params: dict,
38
- specific_expert_params: dict,
170
+ specific_expert_params: dict | list[dict],
39
171
  num_shared_experts: int,
40
172
  num_specific_experts: int,
41
173
  num_levels: int,
@@ -43,7 +175,7 @@ class PLE(BaseModel):
43
175
  target: list[str],
44
176
  task: str | list[str] = 'binary',
45
177
  optimizer: str = "adam",
46
- optimizer_params: dict = {},
178
+ optimizer_params: dict | None = None,
47
179
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
48
180
  loss_params: dict | list[dict] | None = None,
49
181
  device: str = 'cpu',
@@ -71,7 +203,6 @@ class PLE(BaseModel):
71
203
  self.loss = loss
72
204
  if self.loss is None:
73
205
  self.loss = "bce"
74
-
75
206
  # Number of tasks, experts, and levels
76
207
  self.num_tasks = len(target)
77
208
  self.num_shared_experts = num_shared_experts
@@ -79,20 +210,16 @@ class PLE(BaseModel):
79
210
  self.num_levels = num_levels
80
211
  if optimizer_params is None:
81
212
  optimizer_params = {}
82
-
83
213
  if len(tower_params_list) != self.num_tasks:
84
214
  raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
85
-
86
- # All features
87
- self.all_features = dense_features + sparse_features + sequence_features
88
-
89
215
  # Embedding layer
90
216
  self.embedding = EmbeddingLayer(features=self.all_features)
91
217
 
92
218
  # Calculate input dimension
93
- emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
94
- dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
95
- input_dim = emb_dim_total + dense_input_dim
219
+ input_dim = self.embedding.input_dim
220
+ # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
221
+ # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
222
+ # input_dim = emb_dim_total + dense_input_dim
96
223
 
97
224
  # Get expert output dimension
98
225
  if 'dims' in shared_expert_params and len(shared_expert_params['dims']) > 0:
@@ -100,74 +227,30 @@ class PLE(BaseModel):
100
227
  else:
101
228
  expert_output_dim = input_dim
102
229
 
103
- # Build extraction layers (CGC layers)
104
- self.shared_experts_layers = nn.ModuleList() # [num_levels]
105
- self.specific_experts_layers = nn.ModuleList() # [num_levels, num_tasks]
106
- self.gates_layers = nn.ModuleList() # [num_levels, num_tasks + 1] (+1 for shared gate)
107
-
230
+ # Build CGC layers
231
+ self.cgc_layers = nn.ModuleList()
108
232
  for level in range(num_levels):
109
- # Input dimension for this level
110
233
  level_input_dim = input_dim if level == 0 else expert_output_dim
111
-
112
- # Shared experts for this level
113
- shared_experts = nn.ModuleList()
114
- for _ in range(num_shared_experts):
115
- expert = MLP(input_dim=level_input_dim, output_layer=False, **shared_expert_params)
116
- shared_experts.append(expert)
117
- self.shared_experts_layers.append(shared_experts)
118
-
119
- # Task-specific experts for this level
120
- specific_experts_for_tasks = nn.ModuleList()
121
- for _ in range(self.num_tasks):
122
- task_experts = nn.ModuleList()
123
- for _ in range(num_specific_experts):
124
- expert = MLP(input_dim=level_input_dim, output_layer=False, **specific_expert_params)
125
- task_experts.append(expert)
126
- specific_experts_for_tasks.append(task_experts)
127
- self.specific_experts_layers.append(specific_experts_for_tasks)
128
-
129
- # Gates for this level (num_tasks task gates + 1 shared gate)
130
- gates = nn.ModuleList()
131
- # Task-specific gates
132
- num_experts_for_task_gate = num_shared_experts + num_specific_experts
133
- for _ in range(self.num_tasks):
134
- gate = nn.Sequential(
135
- nn.Linear(level_input_dim, num_experts_for_task_gate),
136
- nn.Softmax(dim=1)
137
- )
138
- gates.append(gate)
139
- # Shared gate: contains all tasks' specific experts + shared experts
140
- # expert counts = num_shared_experts + num_specific_experts * num_tasks
141
- num_experts_for_shared_gate = num_shared_experts + num_specific_experts * self.num_tasks
142
- shared_gate = nn.Sequential(
143
- nn.Linear(level_input_dim, num_experts_for_shared_gate),
144
- nn.Softmax(dim=1)
234
+ cgc_layer = CGCLayer(
235
+ input_dim=level_input_dim,
236
+ num_tasks=self.num_tasks,
237
+ num_shared_experts=num_shared_experts,
238
+ num_specific_experts=num_specific_experts,
239
+ shared_expert_params=shared_expert_params,
240
+ specific_expert_params=specific_expert_params,
145
241
  )
146
- gates.append(shared_gate)
147
- self.gates_layers.append(gates)
242
+ self.cgc_layers.append(cgc_layer)
243
+ expert_output_dim = cgc_layer.output_dim
148
244
 
149
245
  # Task-specific towers
150
246
  self.towers = nn.ModuleList()
151
247
  for tower_params in tower_params_list:
152
248
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
153
249
  self.towers.append(tower)
154
- self.prediction_layer = PredictionLayer(
155
- task_type=self.task_type,
156
- task_dims=[1] * self.num_tasks
157
- )
158
-
250
+ self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
159
251
  # Register regularization weights
160
- self._register_regularization_weights(
161
- embedding_attr='embedding',
162
- include_modules=['shared_experts_layers', 'specific_experts_layers', 'gates_layers', 'towers']
163
- )
164
-
165
- self.compile(
166
- optimizer=optimizer,
167
- optimizer_params=optimizer_params,
168
- loss=loss,
169
- loss_params=loss_params,
170
- )
252
+ self._register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
253
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
171
254
 
172
255
  def forward(self, x):
173
256
  # Get all embeddings and flatten
@@ -178,76 +261,8 @@ class PLE(BaseModel):
178
261
  shared_fea = input_flat
179
262
 
180
263
  # Progressive Layered Extraction: CGC
181
- for level in range(self.num_levels):
182
- shared_experts = self.shared_experts_layers[level] # ModuleList[num_shared_experts]
183
- specific_experts = self.specific_experts_layers[level] # ModuleList[num_tasks][num_specific_experts]
184
- gates = self.gates_layers[level] # ModuleList[num_tasks + 1]
185
-
186
- # Compute shared experts output for this level
187
- # shared_expert_list: List[Tensor[B, expert_dim]]
188
- shared_expert_list = [expert(shared_fea) for expert in shared_experts] # type: ignore[list-item]
189
- # [num_shared_experts, B, expert_dim]
190
- shared_expert_outputs = torch.stack(shared_expert_list, dim=0)
191
-
192
- all_specific_outputs_for_shared = []
193
-
194
- # Compute task's gated output and specific outputs
195
- new_task_fea = []
196
- for task_idx in range(self.num_tasks):
197
- # Current input for this task at this level
198
- current_task_in = task_fea[task_idx]
199
-
200
- # Specific task experts for this task
201
- task_expert_modules = specific_experts[task_idx] # type: ignore
202
-
203
- # Specific task expert output list List[Tensor[B, expert_dim]]
204
- task_specific_list = []
205
- for expert in task_expert_modules:
206
- out = expert(current_task_in)
207
- task_specific_list.append(out)
208
- # All specific task experts are candidates for the shared gate
209
- all_specific_outputs_for_shared.append(out)
210
-
211
- # [num_specific_taskexperts, B, expert_dim]
212
- task_specific_outputs = torch.stack(task_specific_list, dim=0)
213
-
214
- # Input for gate: shared_experts + own specific task experts
215
- # [num_shared + num_specific, B, expert_dim]
216
- all_expert_outputs = torch.cat(
217
- [shared_expert_outputs, task_specific_outputs],
218
- dim=0
219
- )
220
- # [B, num_experts, expert_dim]
221
- all_expert_outputs_t = all_expert_outputs.permute(1, 0, 2)
222
-
223
- # Gate for task (gates[task_idx])
224
- # Output shape: [B, num_shared + num_specific]
225
- gate_weights = gates[task_idx](current_task_in)
226
- # [B, num_experts, 1]
227
- gate_weights = gate_weights.unsqueeze(2)
228
-
229
- # Weighted sum to get this task's features at this level: [B, expert_dim]
230
- gated_output = torch.sum(gate_weights * all_expert_outputs_t, dim=1)
231
- new_task_fea.append(gated_output)
232
-
233
- # compute shared gate output
234
- # Input for shared gate: specific task experts + shared experts
235
- # all_specific_outputs_for_shared: List[Tensor[B, expert_dim]]
236
- # shared_expert_list: List[Tensor[B, expert_dim]]
237
- all_for_shared_list = all_specific_outputs_for_shared + shared_expert_list
238
- # [B, num_all_experts, expert_dim]
239
- all_for_shared = torch.stack(all_for_shared_list, dim=1)
240
-
241
- # [B, num_all_experts]
242
- shared_gate_weights = gates[self.num_tasks](shared_fea) # type: ignore
243
- # [B, 1, num_all_experts]
244
- shared_gate_weights = shared_gate_weights.unsqueeze(1)
245
-
246
- # weighted sum: [B, 1, expert_dim] → [B, expert_dim]
247
- new_shared_fea = torch.bmm(shared_gate_weights, all_for_shared).squeeze(1)
248
-
249
- task_fea = new_task_fea
250
- shared_fea = new_shared_fea
264
+ for layer in self.cgc_layers:
265
+ task_fea, shared_fea = layer(task_fea, shared_fea)
251
266
 
252
267
  # task tower
253
268
  task_outputs = []
@@ -257,4 +272,4 @@ class PLE(BaseModel):
257
272
 
258
273
  # [B, num_tasks]
259
274
  y = torch.cat(task_outputs, dim=1)
260
- return self.prediction_layer(y)
275
+ return self.prediction_layer(y)