nextrec 0.2.7__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +4 -8
- nextrec/basic/callback.py +1 -1
- nextrec/basic/features.py +33 -25
- nextrec/basic/layers.py +164 -601
- nextrec/basic/loggers.py +4 -5
- nextrec/basic/metrics.py +39 -115
- nextrec/basic/model.py +257 -177
- nextrec/basic/session.py +1 -5
- nextrec/data/__init__.py +12 -0
- nextrec/data/data_utils.py +3 -27
- nextrec/data/dataloader.py +26 -34
- nextrec/data/preprocessor.py +2 -1
- nextrec/loss/listwise.py +6 -4
- nextrec/loss/loss_utils.py +10 -6
- nextrec/loss/pairwise.py +5 -3
- nextrec/loss/pointwise.py +7 -13
- nextrec/models/generative/__init__.py +5 -0
- nextrec/models/generative/hstu.py +399 -0
- nextrec/models/match/mind.py +110 -1
- nextrec/models/multi_task/esmm.py +46 -27
- nextrec/models/multi_task/mmoe.py +48 -30
- nextrec/models/multi_task/ple.py +156 -141
- nextrec/models/multi_task/poso.py +413 -0
- nextrec/models/multi_task/share_bottom.py +43 -26
- nextrec/models/ranking/__init__.py +2 -0
- nextrec/models/ranking/dcn.py +20 -1
- nextrec/models/ranking/dcn_v2.py +84 -0
- nextrec/models/ranking/deepfm.py +44 -18
- nextrec/models/ranking/dien.py +130 -27
- nextrec/models/ranking/masknet.py +13 -67
- nextrec/models/ranking/widedeep.py +39 -18
- nextrec/models/ranking/xdeepfm.py +34 -1
- nextrec/utils/common.py +26 -1
- nextrec/utils/optimizer.py +7 -3
- nextrec-0.3.2.dist-info/METADATA +312 -0
- nextrec-0.3.2.dist-info/RECORD +57 -0
- nextrec-0.2.7.dist-info/METADATA +0 -281
- nextrec-0.2.7.dist-info/RECORD +0 -54
- {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/WHEEL +0 -0
- {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,45 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
+
Checkpoint: edit on 29/11/2025
|
|
3
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
4
|
-
Reference:
|
|
5
|
+
Reference:
|
|
6
|
+
[1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with
|
|
7
|
+
multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
|
|
8
|
+
(https://dl.acm.org/doi/10.1145/3219819.3220007)
|
|
9
|
+
|
|
10
|
+
Multi-gate Mixture-of-Experts (MMoE) extends shared-bottom multi-task learning by
|
|
11
|
+
introducing multiple experts and task-specific softmax gates. Each task learns its
|
|
12
|
+
own routing weights over the expert pool, enabling both shared and task-specialized
|
|
13
|
+
representations while alleviating gradient conflicts across tasks.
|
|
14
|
+
|
|
15
|
+
In each forward pass:
|
|
16
|
+
(1) Shared embeddings encode all dense/sparse/sequence features
|
|
17
|
+
(2) Each expert processes the same input to produce candidate shared representations
|
|
18
|
+
(3) Every task gate outputs a simplex over experts to softly route information
|
|
19
|
+
(4) The task-specific weighted sum is passed into its tower and prediction head
|
|
20
|
+
|
|
21
|
+
Key Advantages:
|
|
22
|
+
- Soft parameter sharing reduces negative transfer between heterogeneous tasks
|
|
23
|
+
- Gates let tasks adaptively allocate expert capacity based on their difficulty
|
|
24
|
+
- Supports many tasks without duplicating full networks
|
|
25
|
+
- Works with mixed feature types via unified embeddings
|
|
26
|
+
- Simple to scale the number of experts or gates for capacity control
|
|
27
|
+
|
|
28
|
+
MMoE(Multi-gate Mixture-of-Experts)是多任务学习框架,通过多个专家网络与
|
|
29
|
+
任务特定门控进行软路由,兼顾共享表示与任务特化,减轻梯度冲突问题。
|
|
30
|
+
|
|
31
|
+
一次前向流程:
|
|
32
|
+
(1) 共享 embedding 统一编码稠密、稀疏与序列特征
|
|
33
|
+
(2) 每个专家对相同输入进行特征变换,得到候选共享表示
|
|
34
|
+
(3) 每个任务的门控产生对专家的概率分布,完成软选择与加权
|
|
35
|
+
(4) 加权结果输入到对应任务的塔网络与预测头
|
|
36
|
+
|
|
37
|
+
主要优点:
|
|
38
|
+
- 软参数共享,缓解任务间负迁移
|
|
39
|
+
- 按任务难度自适应分配专家容量
|
|
40
|
+
- 便于扩展多任务,而无需复制完整网络
|
|
41
|
+
- 支持多种特征类型的统一建模
|
|
42
|
+
- 专家与门控数量可灵活调节以控制模型容量
|
|
5
43
|
"""
|
|
6
44
|
|
|
7
45
|
import torch
|
|
@@ -75,18 +113,14 @@ class MMOE(BaseModel):
|
|
|
75
113
|
|
|
76
114
|
if len(tower_params_list) != self.num_tasks:
|
|
77
115
|
raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
|
|
78
|
-
|
|
79
|
-
# All features
|
|
80
|
-
self.all_features = dense_features + sparse_features + sequence_features
|
|
81
116
|
|
|
82
|
-
|
|
117
|
+
self.all_features = dense_features + sparse_features + sequence_features
|
|
83
118
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
119
|
+
input_dim = self.embedding.input_dim
|
|
120
|
+
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
121
|
+
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
122
|
+
# input_dim = emb_dim_total + dense_input_dim
|
|
84
123
|
|
|
85
|
-
# Calculate input dimension
|
|
86
|
-
emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
87
|
-
dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
88
|
-
input_dim = emb_dim_total + dense_input_dim
|
|
89
|
-
|
|
90
124
|
# Expert networks (shared by all tasks)
|
|
91
125
|
self.experts = nn.ModuleList()
|
|
92
126
|
for _ in range(num_experts):
|
|
@@ -102,10 +136,7 @@ class MMOE(BaseModel):
|
|
|
102
136
|
# Task-specific gates
|
|
103
137
|
self.gates = nn.ModuleList()
|
|
104
138
|
for _ in range(self.num_tasks):
|
|
105
|
-
gate = nn.Sequential(
|
|
106
|
-
nn.Linear(input_dim, num_experts),
|
|
107
|
-
nn.Softmax(dim=1)
|
|
108
|
-
)
|
|
139
|
+
gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
|
|
109
140
|
self.gates.append(gate)
|
|
110
141
|
|
|
111
142
|
# Task-specific towers
|
|
@@ -113,23 +144,10 @@ class MMOE(BaseModel):
|
|
|
113
144
|
for tower_params in tower_params_list:
|
|
114
145
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
115
146
|
self.towers.append(tower)
|
|
116
|
-
self.prediction_layer = PredictionLayer(
|
|
117
|
-
task_type=self.task_type,
|
|
118
|
-
task_dims=[1] * self.num_tasks
|
|
119
|
-
)
|
|
120
|
-
|
|
147
|
+
self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
|
|
121
148
|
# Register regularization weights
|
|
122
|
-
self._register_regularization_weights(
|
|
123
|
-
|
|
124
|
-
include_modules=['experts', 'gates', 'towers']
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
self.compile(
|
|
128
|
-
optimizer=optimizer,
|
|
129
|
-
optimizer_params=optimizer_params,
|
|
130
|
-
loss=loss,
|
|
131
|
-
loss_params=loss_params,
|
|
132
|
-
)
|
|
149
|
+
self._register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
|
|
150
|
+
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
|
|
133
151
|
|
|
134
152
|
def forward(self, x):
|
|
135
153
|
# Get all embeddings and flatten
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -1,7 +1,48 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
+
Checkpoint: edit on 29/11/2025
|
|
3
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
4
|
-
Reference:
|
|
5
|
+
Reference:
|
|
6
|
+
[1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (PLE): A novel
|
|
7
|
+
multi-task learning (MTL) model for personalized recommendations[C]//RecSys. 2020: 269-278.
|
|
8
|
+
(https://dl.acm.org/doi/10.1145/3383313.3412236)
|
|
9
|
+
|
|
10
|
+
Progressive Layered Extraction (PLE) advances multi-task learning by stacking CGC
|
|
11
|
+
(Customized Gate Control) blocks that mix shared and task-specific experts. Each
|
|
12
|
+
layer routes information via task gates and a shared gate, then feeds the outputs
|
|
13
|
+
forward to deeper layers, progressively disentangling shared vs. task-specific
|
|
14
|
+
signals and mitigating gradient interference.
|
|
15
|
+
|
|
16
|
+
Layer workflow:
|
|
17
|
+
(1) Shared and per-task experts transform the same inputs
|
|
18
|
+
(2) Task gates select among shared + task-specific experts
|
|
19
|
+
(3) A shared gate aggregates all experts for the shared branch
|
|
20
|
+
(4) Outputs become inputs to the next CGC layer (progressive refinement)
|
|
21
|
+
(5) Final task towers operate on the last-layer task representations
|
|
22
|
+
|
|
23
|
+
Key Advantages:
|
|
24
|
+
- Progressive routing reduces negative transfer across layers
|
|
25
|
+
- Explicit shared/specific experts improve feature disentanglement
|
|
26
|
+
- Flexible depth and expert counts to match task complexity
|
|
27
|
+
- Works with heterogeneous features via unified embeddings
|
|
28
|
+
- Stable training by separating gates for shared and task branches
|
|
29
|
+
|
|
30
|
+
PLE(Progressive Layered Extraction)通过堆叠 CGC 模块,联合共享与任务特定专家,
|
|
31
|
+
利用任务门与共享门逐层软路由,逐步分离共享与任务差异信息,缓解多任务间的梯度冲突。
|
|
32
|
+
|
|
33
|
+
层内流程:
|
|
34
|
+
(1) 共享与任务专家对同一输入做特征变换
|
|
35
|
+
(2) 任务门在共享+任务专家上进行软选择
|
|
36
|
+
(3) 共享门汇总全部专家,更新共享分支
|
|
37
|
+
(4) 输出作为下一层输入,完成逐层细化
|
|
38
|
+
(5) 最后由任务塔完成各任务预测
|
|
39
|
+
|
|
40
|
+
主要优点:
|
|
41
|
+
- 逐层路由降低负迁移
|
|
42
|
+
- 显式区分共享/特定专家,增强特征解耦
|
|
43
|
+
- 专家数量与层数可按任务复杂度灵活设置
|
|
44
|
+
- 统一 embedding 支持多种特征类型
|
|
45
|
+
- 共享与任务门分离,训练更稳定
|
|
5
46
|
"""
|
|
6
47
|
|
|
7
48
|
import torch
|
|
@@ -11,6 +52,97 @@ from nextrec.basic.model import BaseModel
|
|
|
11
52
|
from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
|
|
12
53
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
13
54
|
|
|
55
|
+
class CGCLayer(nn.Module):
|
|
56
|
+
"""
|
|
57
|
+
CGC (Customized Gate Control) block used by PLE.
|
|
58
|
+
It routes shared and task-specific experts with task gates and a shared gate.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
input_dim: int,
|
|
64
|
+
num_tasks: int,
|
|
65
|
+
num_shared_experts: int,
|
|
66
|
+
num_specific_experts: int,
|
|
67
|
+
shared_expert_params: dict,
|
|
68
|
+
specific_expert_params: dict | list[dict],
|
|
69
|
+
):
|
|
70
|
+
super().__init__()
|
|
71
|
+
if num_tasks < 1:
|
|
72
|
+
raise ValueError("num_tasks must be >= 1")
|
|
73
|
+
|
|
74
|
+
specific_params_list = self._normalize_specific_params(specific_expert_params, num_tasks)
|
|
75
|
+
|
|
76
|
+
self.output_dim = self._get_output_dim(shared_expert_params, input_dim)
|
|
77
|
+
specific_dims = [self._get_output_dim(params, input_dim) for params in specific_params_list]
|
|
78
|
+
dims_set = set(specific_dims + [self.output_dim])
|
|
79
|
+
if len(dims_set) != 1:
|
|
80
|
+
raise ValueError(f"Shared/specific expert output dims must match, got {dims_set}")
|
|
81
|
+
|
|
82
|
+
# experts
|
|
83
|
+
self.shared_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **shared_expert_params,) for _ in range(num_shared_experts)])
|
|
84
|
+
self.specific_experts = nn.ModuleList()
|
|
85
|
+
for params in specific_params_list:
|
|
86
|
+
task_experts = nn.ModuleList([MLP(input_dim=input_dim, output_layer=False, **params,) for _ in range(num_specific_experts)])
|
|
87
|
+
self.specific_experts.append(task_experts)
|
|
88
|
+
|
|
89
|
+
# gates
|
|
90
|
+
task_gate_expert_num = num_shared_experts + num_specific_experts
|
|
91
|
+
self.task_gates = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, task_gate_expert_num), nn.Softmax(dim=1),) for _ in range(num_tasks)])
|
|
92
|
+
shared_gate_expert_num = num_shared_experts + num_specific_experts * num_tasks
|
|
93
|
+
self.shared_gate = nn.Sequential(nn.Linear(input_dim, shared_gate_expert_num), nn.Softmax(dim=1),)
|
|
94
|
+
|
|
95
|
+
self.num_tasks = num_tasks
|
|
96
|
+
|
|
97
|
+
def forward(
|
|
98
|
+
self, task_inputs: list[torch.Tensor], shared_input: torch.Tensor
|
|
99
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
100
|
+
if len(task_inputs) != self.num_tasks:
|
|
101
|
+
raise ValueError(f"Expected {self.num_tasks} task inputs, got {len(task_inputs)}")
|
|
102
|
+
|
|
103
|
+
shared_outputs = [expert(shared_input) for expert in self.shared_experts]
|
|
104
|
+
shared_stack = torch.stack(shared_outputs, dim=0) # [num_shared, B, D]
|
|
105
|
+
|
|
106
|
+
new_task_fea: list[torch.Tensor] = []
|
|
107
|
+
all_specific_for_shared: list[torch.Tensor] = []
|
|
108
|
+
|
|
109
|
+
for task_idx in range(self.num_tasks):
|
|
110
|
+
task_input = task_inputs[task_idx]
|
|
111
|
+
task_specific_outputs = [expert(task_input) for expert in self.specific_experts[task_idx]] # type: ignore
|
|
112
|
+
all_specific_for_shared.extend(task_specific_outputs)
|
|
113
|
+
specific_stack = torch.stack(task_specific_outputs, dim=0)
|
|
114
|
+
|
|
115
|
+
all_experts = torch.cat([shared_stack, specific_stack], dim=0)
|
|
116
|
+
all_experts_t = all_experts.permute(1, 0, 2) # [B, num_expert, D]
|
|
117
|
+
|
|
118
|
+
gate_weights = self.task_gates[task_idx](task_input).unsqueeze(2)
|
|
119
|
+
gated_output = torch.sum(gate_weights * all_experts_t, dim=1)
|
|
120
|
+
new_task_fea.append(gated_output)
|
|
121
|
+
|
|
122
|
+
all_for_shared = all_specific_for_shared + shared_outputs
|
|
123
|
+
all_for_shared_tensor = torch.stack(all_for_shared, dim=1) # [B, num_all, D]
|
|
124
|
+
shared_gate_weights = self.shared_gate(shared_input).unsqueeze(1)
|
|
125
|
+
new_shared = torch.bmm(shared_gate_weights, all_for_shared_tensor).squeeze(1)
|
|
126
|
+
|
|
127
|
+
return new_task_fea, new_shared
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def _get_output_dim(params: dict, fallback: int) -> int:
|
|
131
|
+
dims = params.get("dims")
|
|
132
|
+
if dims:
|
|
133
|
+
return dims[-1]
|
|
134
|
+
return fallback
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _normalize_specific_params(
|
|
138
|
+
params: dict | list[dict], num_tasks: int
|
|
139
|
+
) -> list[dict]:
|
|
140
|
+
if isinstance(params, list):
|
|
141
|
+
if len(params) != num_tasks:
|
|
142
|
+
raise ValueError(f"Length of specific_expert_params ({len(params)}) must match num_tasks ({num_tasks}).")
|
|
143
|
+
return [p.copy() for p in params]
|
|
144
|
+
return [params.copy() for _ in range(num_tasks)]
|
|
145
|
+
|
|
14
146
|
|
|
15
147
|
class PLE(BaseModel):
|
|
16
148
|
"""
|
|
@@ -35,7 +167,7 @@ class PLE(BaseModel):
|
|
|
35
167
|
sparse_features: list[SparseFeature],
|
|
36
168
|
sequence_features: list[SequenceFeature],
|
|
37
169
|
shared_expert_params: dict,
|
|
38
|
-
specific_expert_params: dict,
|
|
170
|
+
specific_expert_params: dict | list[dict],
|
|
39
171
|
num_shared_experts: int,
|
|
40
172
|
num_specific_experts: int,
|
|
41
173
|
num_levels: int,
|
|
@@ -43,7 +175,7 @@ class PLE(BaseModel):
|
|
|
43
175
|
target: list[str],
|
|
44
176
|
task: str | list[str] = 'binary',
|
|
45
177
|
optimizer: str = "adam",
|
|
46
|
-
optimizer_params: dict =
|
|
178
|
+
optimizer_params: dict | None = None,
|
|
47
179
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
48
180
|
loss_params: dict | list[dict] | None = None,
|
|
49
181
|
device: str = 'cpu',
|
|
@@ -71,7 +203,6 @@ class PLE(BaseModel):
|
|
|
71
203
|
self.loss = loss
|
|
72
204
|
if self.loss is None:
|
|
73
205
|
self.loss = "bce"
|
|
74
|
-
|
|
75
206
|
# Number of tasks, experts, and levels
|
|
76
207
|
self.num_tasks = len(target)
|
|
77
208
|
self.num_shared_experts = num_shared_experts
|
|
@@ -79,20 +210,16 @@ class PLE(BaseModel):
|
|
|
79
210
|
self.num_levels = num_levels
|
|
80
211
|
if optimizer_params is None:
|
|
81
212
|
optimizer_params = {}
|
|
82
|
-
|
|
83
213
|
if len(tower_params_list) != self.num_tasks:
|
|
84
214
|
raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
|
|
85
|
-
|
|
86
|
-
# All features
|
|
87
|
-
self.all_features = dense_features + sparse_features + sequence_features
|
|
88
|
-
|
|
89
215
|
# Embedding layer
|
|
90
216
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
91
217
|
|
|
92
218
|
# Calculate input dimension
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
219
|
+
input_dim = self.embedding.input_dim
|
|
220
|
+
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
221
|
+
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
222
|
+
# input_dim = emb_dim_total + dense_input_dim
|
|
96
223
|
|
|
97
224
|
# Get expert output dimension
|
|
98
225
|
if 'dims' in shared_expert_params and len(shared_expert_params['dims']) > 0:
|
|
@@ -100,74 +227,30 @@ class PLE(BaseModel):
|
|
|
100
227
|
else:
|
|
101
228
|
expert_output_dim = input_dim
|
|
102
229
|
|
|
103
|
-
# Build
|
|
104
|
-
self.
|
|
105
|
-
self.specific_experts_layers = nn.ModuleList() # [num_levels, num_tasks]
|
|
106
|
-
self.gates_layers = nn.ModuleList() # [num_levels, num_tasks + 1] (+1 for shared gate)
|
|
107
|
-
|
|
230
|
+
# Build CGC layers
|
|
231
|
+
self.cgc_layers = nn.ModuleList()
|
|
108
232
|
for level in range(num_levels):
|
|
109
|
-
# Input dimension for this level
|
|
110
233
|
level_input_dim = input_dim if level == 0 else expert_output_dim
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
# Task-specific experts for this level
|
|
120
|
-
specific_experts_for_tasks = nn.ModuleList()
|
|
121
|
-
for _ in range(self.num_tasks):
|
|
122
|
-
task_experts = nn.ModuleList()
|
|
123
|
-
for _ in range(num_specific_experts):
|
|
124
|
-
expert = MLP(input_dim=level_input_dim, output_layer=False, **specific_expert_params)
|
|
125
|
-
task_experts.append(expert)
|
|
126
|
-
specific_experts_for_tasks.append(task_experts)
|
|
127
|
-
self.specific_experts_layers.append(specific_experts_for_tasks)
|
|
128
|
-
|
|
129
|
-
# Gates for this level (num_tasks task gates + 1 shared gate)
|
|
130
|
-
gates = nn.ModuleList()
|
|
131
|
-
# Task-specific gates
|
|
132
|
-
num_experts_for_task_gate = num_shared_experts + num_specific_experts
|
|
133
|
-
for _ in range(self.num_tasks):
|
|
134
|
-
gate = nn.Sequential(
|
|
135
|
-
nn.Linear(level_input_dim, num_experts_for_task_gate),
|
|
136
|
-
nn.Softmax(dim=1)
|
|
137
|
-
)
|
|
138
|
-
gates.append(gate)
|
|
139
|
-
# Shared gate: contains all tasks' specific experts + shared experts
|
|
140
|
-
# expert counts = num_shared_experts + num_specific_experts * num_tasks
|
|
141
|
-
num_experts_for_shared_gate = num_shared_experts + num_specific_experts * self.num_tasks
|
|
142
|
-
shared_gate = nn.Sequential(
|
|
143
|
-
nn.Linear(level_input_dim, num_experts_for_shared_gate),
|
|
144
|
-
nn.Softmax(dim=1)
|
|
234
|
+
cgc_layer = CGCLayer(
|
|
235
|
+
input_dim=level_input_dim,
|
|
236
|
+
num_tasks=self.num_tasks,
|
|
237
|
+
num_shared_experts=num_shared_experts,
|
|
238
|
+
num_specific_experts=num_specific_experts,
|
|
239
|
+
shared_expert_params=shared_expert_params,
|
|
240
|
+
specific_expert_params=specific_expert_params,
|
|
145
241
|
)
|
|
146
|
-
|
|
147
|
-
|
|
242
|
+
self.cgc_layers.append(cgc_layer)
|
|
243
|
+
expert_output_dim = cgc_layer.output_dim
|
|
148
244
|
|
|
149
245
|
# Task-specific towers
|
|
150
246
|
self.towers = nn.ModuleList()
|
|
151
247
|
for tower_params in tower_params_list:
|
|
152
248
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
153
249
|
self.towers.append(tower)
|
|
154
|
-
self.prediction_layer = PredictionLayer(
|
|
155
|
-
task_type=self.task_type,
|
|
156
|
-
task_dims=[1] * self.num_tasks
|
|
157
|
-
)
|
|
158
|
-
|
|
250
|
+
self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
|
|
159
251
|
# Register regularization weights
|
|
160
|
-
self._register_regularization_weights(
|
|
161
|
-
|
|
162
|
-
include_modules=['shared_experts_layers', 'specific_experts_layers', 'gates_layers', 'towers']
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
self.compile(
|
|
166
|
-
optimizer=optimizer,
|
|
167
|
-
optimizer_params=optimizer_params,
|
|
168
|
-
loss=loss,
|
|
169
|
-
loss_params=loss_params,
|
|
170
|
-
)
|
|
252
|
+
self._register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
|
|
253
|
+
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
|
|
171
254
|
|
|
172
255
|
def forward(self, x):
|
|
173
256
|
# Get all embeddings and flatten
|
|
@@ -178,76 +261,8 @@ class PLE(BaseModel):
|
|
|
178
261
|
shared_fea = input_flat
|
|
179
262
|
|
|
180
263
|
# Progressive Layered Extraction: CGC
|
|
181
|
-
for
|
|
182
|
-
|
|
183
|
-
specific_experts = self.specific_experts_layers[level] # ModuleList[num_tasks][num_specific_experts]
|
|
184
|
-
gates = self.gates_layers[level] # ModuleList[num_tasks + 1]
|
|
185
|
-
|
|
186
|
-
# Compute shared experts output for this level
|
|
187
|
-
# shared_expert_list: List[Tensor[B, expert_dim]]
|
|
188
|
-
shared_expert_list = [expert(shared_fea) for expert in shared_experts] # type: ignore[list-item]
|
|
189
|
-
# [num_shared_experts, B, expert_dim]
|
|
190
|
-
shared_expert_outputs = torch.stack(shared_expert_list, dim=0)
|
|
191
|
-
|
|
192
|
-
all_specific_outputs_for_shared = []
|
|
193
|
-
|
|
194
|
-
# Compute task's gated output and specific outputs
|
|
195
|
-
new_task_fea = []
|
|
196
|
-
for task_idx in range(self.num_tasks):
|
|
197
|
-
# Current input for this task at this level
|
|
198
|
-
current_task_in = task_fea[task_idx]
|
|
199
|
-
|
|
200
|
-
# Specific task experts for this task
|
|
201
|
-
task_expert_modules = specific_experts[task_idx] # type: ignore
|
|
202
|
-
|
|
203
|
-
# Specific task expert output list List[Tensor[B, expert_dim]]
|
|
204
|
-
task_specific_list = []
|
|
205
|
-
for expert in task_expert_modules:
|
|
206
|
-
out = expert(current_task_in)
|
|
207
|
-
task_specific_list.append(out)
|
|
208
|
-
# All specific task experts are candidates for the shared gate
|
|
209
|
-
all_specific_outputs_for_shared.append(out)
|
|
210
|
-
|
|
211
|
-
# [num_specific_taskexperts, B, expert_dim]
|
|
212
|
-
task_specific_outputs = torch.stack(task_specific_list, dim=0)
|
|
213
|
-
|
|
214
|
-
# Input for gate: shared_experts + own specific task experts
|
|
215
|
-
# [num_shared + num_specific, B, expert_dim]
|
|
216
|
-
all_expert_outputs = torch.cat(
|
|
217
|
-
[shared_expert_outputs, task_specific_outputs],
|
|
218
|
-
dim=0
|
|
219
|
-
)
|
|
220
|
-
# [B, num_experts, expert_dim]
|
|
221
|
-
all_expert_outputs_t = all_expert_outputs.permute(1, 0, 2)
|
|
222
|
-
|
|
223
|
-
# Gate for task (gates[task_idx])
|
|
224
|
-
# Output shape: [B, num_shared + num_specific]
|
|
225
|
-
gate_weights = gates[task_idx](current_task_in)
|
|
226
|
-
# [B, num_experts, 1]
|
|
227
|
-
gate_weights = gate_weights.unsqueeze(2)
|
|
228
|
-
|
|
229
|
-
# Weighted sum to get this task's features at this level: [B, expert_dim]
|
|
230
|
-
gated_output = torch.sum(gate_weights * all_expert_outputs_t, dim=1)
|
|
231
|
-
new_task_fea.append(gated_output)
|
|
232
|
-
|
|
233
|
-
# compute shared gate output
|
|
234
|
-
# Input for shared gate: specific task experts + shared experts
|
|
235
|
-
# all_specific_outputs_for_shared: List[Tensor[B, expert_dim]]
|
|
236
|
-
# shared_expert_list: List[Tensor[B, expert_dim]]
|
|
237
|
-
all_for_shared_list = all_specific_outputs_for_shared + shared_expert_list
|
|
238
|
-
# [B, num_all_experts, expert_dim]
|
|
239
|
-
all_for_shared = torch.stack(all_for_shared_list, dim=1)
|
|
240
|
-
|
|
241
|
-
# [B, num_all_experts]
|
|
242
|
-
shared_gate_weights = gates[self.num_tasks](shared_fea) # type: ignore
|
|
243
|
-
# [B, 1, num_all_experts]
|
|
244
|
-
shared_gate_weights = shared_gate_weights.unsqueeze(1)
|
|
245
|
-
|
|
246
|
-
# weighted sum: [B, 1, expert_dim] → [B, expert_dim]
|
|
247
|
-
new_shared_fea = torch.bmm(shared_gate_weights, all_for_shared).squeeze(1)
|
|
248
|
-
|
|
249
|
-
task_fea = new_task_fea
|
|
250
|
-
shared_fea = new_shared_fea
|
|
264
|
+
for layer in self.cgc_layers:
|
|
265
|
+
task_fea, shared_fea = layer(task_fea, shared_fea)
|
|
251
266
|
|
|
252
267
|
# task tower
|
|
253
268
|
task_outputs = []
|
|
@@ -257,4 +272,4 @@ class PLE(BaseModel):
|
|
|
257
272
|
|
|
258
273
|
# [B, num_tasks]
|
|
259
274
|
y = torch.cat(task_outputs, dim=1)
|
|
260
|
-
return self.prediction_layer(y)
|
|
275
|
+
return self.prediction_layer(y)
|