orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from typing import Optional, List
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
from orbit.model import BaseBlock, register_model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@register_model()
|
|
11
|
+
class LowRankFusion(BaseBlock):
|
|
12
|
+
''' Low-rank Multimodal Fusion (LMF) 模块。
|
|
13
|
+
|
|
14
|
+
使用低秩分解近似多模态外积,通过将权重张量分解为模态特定因子的组合来实现高效融合。
|
|
15
|
+
'''
|
|
16
|
+
|
|
17
|
+
def __init__(self, in_features: List[int], out_features: int, rank: int, dropout: float = 0.0, channel_first: bool = False):
|
|
18
|
+
''' 初始化 LowRankFusion。
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
in_features (List[int]): 每个输入模态的特征维度列表。
|
|
22
|
+
out_features (int): 融合后的输出特征维度。
|
|
23
|
+
rank (int): 低秩分解的秩(rank)。
|
|
24
|
+
dropout (float, optional): Dropout 概率。默认为 0.0。
|
|
25
|
+
channel_first (bool, optional): 特征维度是否在第 1 维 (如 CNN [B, C, H, W])。
|
|
26
|
+
如果为 False,则假设特征在最后一维 (如 Transformer [B, L, C])。默认为 False。
|
|
27
|
+
'''
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
if not in_features:
|
|
31
|
+
raise ValueError("in_features cannot be empty")
|
|
32
|
+
|
|
33
|
+
self.in_features = in_features
|
|
34
|
+
self.out_features = out_features
|
|
35
|
+
self.rank = rank
|
|
36
|
+
self.channel_first = channel_first
|
|
37
|
+
|
|
38
|
+
self.modality_factors = nn.ModuleList([
|
|
39
|
+
nn.Linear(dim, rank, bias=True) for dim in in_features
|
|
40
|
+
])
|
|
41
|
+
|
|
42
|
+
self.fusion_weights = nn.Linear(rank, out_features, bias=True)
|
|
43
|
+
|
|
44
|
+
self.dropout = nn.Dropout(dropout)
|
|
45
|
+
|
|
46
|
+
self._init_weights(self)
|
|
47
|
+
|
|
48
|
+
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
|
49
|
+
''' 前向传播。
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
inputs (List[torch.Tensor]): 输入张量列表。所有张量在除最后一个维度外的其他维度应匹配。
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
torch.Tensor: 融合后的输出张量。
|
|
56
|
+
'''
|
|
57
|
+
if len(inputs) != len(self.in_features):
|
|
58
|
+
raise ValueError(f"Expected {len(self.in_features)} inputs, got {len(inputs)}")
|
|
59
|
+
|
|
60
|
+
if self.channel_first:
|
|
61
|
+
inputs = [x.movedim(1, -1) for x in inputs]
|
|
62
|
+
|
|
63
|
+
fusion_tensor: Optional[torch.Tensor] = None
|
|
64
|
+
|
|
65
|
+
for i, x in enumerate(inputs):
|
|
66
|
+
projected = self.modality_factors[i](x)
|
|
67
|
+
|
|
68
|
+
if fusion_tensor is None:
|
|
69
|
+
fusion_tensor = projected
|
|
70
|
+
else:
|
|
71
|
+
fusion_tensor = fusion_tensor * projected
|
|
72
|
+
|
|
73
|
+
if fusion_tensor is None:
|
|
74
|
+
raise ValueError("No inputs processed")
|
|
75
|
+
|
|
76
|
+
output = self.fusion_weights(self.dropout(fusion_tensor))
|
|
77
|
+
|
|
78
|
+
if self.channel_first:
|
|
79
|
+
output = output.movedim(-1, 1)
|
|
80
|
+
|
|
81
|
+
return output
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@register_model()
|
|
85
|
+
class GatedMultimodalUnit(BaseBlock):
|
|
86
|
+
''' Gated Multimodal Unit (GMU) 模块。
|
|
87
|
+
|
|
88
|
+
通过学习门控机制来控制每个模态对最终融合表示的贡献。
|
|
89
|
+
'''
|
|
90
|
+
|
|
91
|
+
def __init__(self, in_features: List[int], out_features: int, channel_first: bool = False):
|
|
92
|
+
''' 初始化 GatedMultimodalUnit。
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
in_features (List[int]): 每个输入模态的特征维度列表。
|
|
96
|
+
out_features (int): 隐藏层特征维度。
|
|
97
|
+
channel_first (bool, optional): 特征维度是否在第 1 维。默认为 False。
|
|
98
|
+
'''
|
|
99
|
+
super().__init__()
|
|
100
|
+
|
|
101
|
+
self.in_features = in_features
|
|
102
|
+
self.out_features = out_features
|
|
103
|
+
self.channel_first = channel_first
|
|
104
|
+
|
|
105
|
+
self.feature_transforms = nn.ModuleList([
|
|
106
|
+
nn.Linear(dim, out_features) for dim in in_features
|
|
107
|
+
])
|
|
108
|
+
|
|
109
|
+
total_in_features = sum(in_features)
|
|
110
|
+
self.gate_net = nn.Linear(total_in_features, len(in_features))
|
|
111
|
+
|
|
112
|
+
self._init_weights(self)
|
|
113
|
+
|
|
114
|
+
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
|
115
|
+
''' 前向传播。
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
inputs (List[torch.Tensor]): 输入张量列表。
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
torch.Tensor: 融合后的输出张量。
|
|
122
|
+
'''
|
|
123
|
+
if len(inputs) != len(self.in_features):
|
|
124
|
+
raise ValueError(f"Expected {len(self.in_features)} inputs, got {len(inputs)}")
|
|
125
|
+
|
|
126
|
+
processed_inputs = []
|
|
127
|
+
if self.channel_first:
|
|
128
|
+
processed_inputs = [x.movedim(1, -1) for x in inputs]
|
|
129
|
+
else:
|
|
130
|
+
processed_inputs = inputs
|
|
131
|
+
|
|
132
|
+
hidden_features = []
|
|
133
|
+
for i, x in enumerate(processed_inputs):
|
|
134
|
+
h = torch.tanh(self.feature_transforms[i](x))
|
|
135
|
+
hidden_features.append(h)
|
|
136
|
+
|
|
137
|
+
concatenated_input = torch.cat(processed_inputs, dim=-1)
|
|
138
|
+
gate_logits = self.gate_net(concatenated_input) # (B, ..., num_modalities)
|
|
139
|
+
gates = torch.softmax(gate_logits, dim=-1)
|
|
140
|
+
|
|
141
|
+
output = torch.zeros_like(hidden_features[0])
|
|
142
|
+
|
|
143
|
+
for i, h in enumerate(hidden_features):
|
|
144
|
+
g = gates[..., i:i+1] # (B, ..., 1)
|
|
145
|
+
output += g * h
|
|
146
|
+
|
|
147
|
+
if self.channel_first:
|
|
148
|
+
output = output.movedim(-1, 1)
|
|
149
|
+
|
|
150
|
+
return output
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@register_model()
|
|
154
|
+
class DiffusionMapsFusion(BaseBlock):
|
|
155
|
+
''' Diffusion Maps Fusion 模块。
|
|
156
|
+
|
|
157
|
+
基于流形学习中的扩散映射思想。
|
|
158
|
+
通过在特征通道间构建图拉普拉斯算子(或归一化亲和矩阵),
|
|
159
|
+
在流形空间进行特征对齐和交叉扩散 (Cross-diffusion)。
|
|
160
|
+
|
|
161
|
+
目前主要支持两个模态的融合。
|
|
162
|
+
'''
|
|
163
|
+
|
|
164
|
+
def __init__(self, in_features: List[int], out_features: int, sigma: float = 1.0, channel_first: bool = False):
|
|
165
|
+
''' 初始化 DiffusionMapsFusion。
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
in_features (List[int]): 两个输入模态的特征维度列表。
|
|
169
|
+
out_features (int): 输出特征维度。
|
|
170
|
+
sigma (float, optional): 高斯核的带宽参数。默认为 1.0。
|
|
171
|
+
channel_first (bool, optional): 特征维度是否在第 1 维。默认为 False。
|
|
172
|
+
'''
|
|
173
|
+
super().__init__()
|
|
174
|
+
|
|
175
|
+
if len(in_features) != 2:
|
|
176
|
+
raise ValueError("DiffusionMapsFusion currently only supports exactly 2 modalities.")
|
|
177
|
+
|
|
178
|
+
self.in_features = in_features
|
|
179
|
+
self.out_features = out_features
|
|
180
|
+
self.sigma = sigma
|
|
181
|
+
self.channel_first = channel_first
|
|
182
|
+
|
|
183
|
+
self.projections = nn.ModuleList([
|
|
184
|
+
nn.Linear(dim, out_features) for dim in in_features
|
|
185
|
+
])
|
|
186
|
+
|
|
187
|
+
self.output_proj = nn.Linear(out_features * 2, out_features)
|
|
188
|
+
|
|
189
|
+
self._init_weights(self)
|
|
190
|
+
|
|
191
|
+
def _compute_affinity(self, x: torch.Tensor) -> torch.Tensor:
|
|
192
|
+
''' 计算特征通道间的归一化亲和矩阵 (Diffusion Operator)。
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
x (torch.Tensor): 输入特征 (C, N)。C 是特征通道数,N 是样本数。
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
torch.Tensor: 归一化的亲和矩阵 (C, C)。
|
|
199
|
+
'''
|
|
200
|
+
# ||x_i - x_j||^2 = ||x_i||^2 + ||x_j||^2 - 2 <x_i, x_j>
|
|
201
|
+
sq_norm = (x ** 2).sum(1, keepdim=True)
|
|
202
|
+
dist_sq = sq_norm + sq_norm.t() - 2 * torch.mm(x, x.t())
|
|
203
|
+
|
|
204
|
+
W = torch.exp(-dist_sq / (2 * self.sigma ** 2))
|
|
205
|
+
|
|
206
|
+
D_inv = 1.0 / (W.sum(1, keepdim=True) + 1e-8)
|
|
207
|
+
P = D_inv * W
|
|
208
|
+
|
|
209
|
+
return P
|
|
210
|
+
|
|
211
|
+
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
|
212
|
+
''' 前向传播。
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
inputs (List[torch.Tensor]): 两个输入张量列表。
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
torch.Tensor: 融合后的输出张量。
|
|
219
|
+
'''
|
|
220
|
+
if len(inputs) != 2:
|
|
221
|
+
raise ValueError("Expected exactly 2 inputs")
|
|
222
|
+
|
|
223
|
+
processed_inputs = []
|
|
224
|
+
if self.channel_first:
|
|
225
|
+
processed_inputs = [x.movedim(1, -1) for x in inputs]
|
|
226
|
+
else:
|
|
227
|
+
processed_inputs = inputs
|
|
228
|
+
|
|
229
|
+
proj_feats = [self.projections[i](x) for i, x in enumerate(processed_inputs)]
|
|
230
|
+
xA, xB = proj_feats[0], proj_feats[1]
|
|
231
|
+
|
|
232
|
+
flat_xA = xA.reshape(-1, xA.shape[-1])
|
|
233
|
+
flat_xB = xB.reshape(-1, xB.shape[-1])
|
|
234
|
+
|
|
235
|
+
xA_T = flat_xA.t()
|
|
236
|
+
xB_T = flat_xB.t()
|
|
237
|
+
|
|
238
|
+
P_A = self._compute_affinity(xA_T) # (C, C)
|
|
239
|
+
P_B = self._compute_affinity(xB_T) # (C, C)
|
|
240
|
+
|
|
241
|
+
diffused_A_T = torch.mm(P_B, xA_T) # (C, N_samples)
|
|
242
|
+
diffused_B_T = torch.mm(P_A, xB_T) # (C, N_samples)
|
|
243
|
+
|
|
244
|
+
diffused_A = diffused_A_T.t().view(xA.shape)
|
|
245
|
+
diffused_B = diffused_B_T.t().view(xB.shape)
|
|
246
|
+
|
|
247
|
+
combined = torch.cat([diffused_A, diffused_B], dim=-1)
|
|
248
|
+
output = self.output_proj(combined)
|
|
249
|
+
|
|
250
|
+
if self.channel_first:
|
|
251
|
+
output = output.movedim(-1, 1)
|
|
252
|
+
|
|
253
|
+
return output
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@register_model()
|
|
257
|
+
class CompactMultimodalPooling(BaseBlock):
|
|
258
|
+
''' Compact Multimodal Pooling (MCB/CBP) 模块。
|
|
259
|
+
|
|
260
|
+
通过 Count Sketch 和 FFT 近似多模态特征的外积。
|
|
261
|
+
支持两个或多个模态的融合。
|
|
262
|
+
'''
|
|
263
|
+
|
|
264
|
+
def __init__(self, in_features: List[int], out_features: int, channel_first: bool = False):
|
|
265
|
+
''' 初始化 CompactMultimodalPooling。
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
in_features (List[int]): 每个输入模态的特征维度列表。
|
|
269
|
+
out_features (int): 输出特征维度。通常应该比输入维度高,以保持信息。
|
|
270
|
+
channel_first (bool, optional): 特征维度是否在第 1 维 (如 CNN [B, C, H, W])。
|
|
271
|
+
如果为 False,则假设特征在最后一维 (如 Transformer [B, L, C])。默认为 False。
|
|
272
|
+
'''
|
|
273
|
+
super().__init__()
|
|
274
|
+
|
|
275
|
+
self.in_features = in_features
|
|
276
|
+
self.out_features = out_features
|
|
277
|
+
self.channel_first = channel_first
|
|
278
|
+
|
|
279
|
+
for i, dim in enumerate(in_features):
|
|
280
|
+
self.register_buffer(f'h_{i}', torch.randint(0, out_features, (dim,)))
|
|
281
|
+
self.register_buffer(f's_{i}', torch.randint(0, 2, (dim,)) * 2 - 1) # Map {0, 1} to {-1, 1}
|
|
282
|
+
|
|
283
|
+
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
|
284
|
+
''' 前向传播。
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
inputs (List[torch.Tensor]): 输入张量列表。
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
torch.Tensor: 融合后的输出张量。
|
|
291
|
+
'''
|
|
292
|
+
if len(inputs) != len(self.in_features):
|
|
293
|
+
raise ValueError(f"Expected {len(self.in_features)} inputs, got {len(inputs)}")
|
|
294
|
+
|
|
295
|
+
if self.channel_first:
|
|
296
|
+
inputs = [x.movedim(1, -1) for x in inputs]
|
|
297
|
+
|
|
298
|
+
batch_size = inputs[0].size(0)
|
|
299
|
+
fft_product: Optional[torch.Tensor] = None
|
|
300
|
+
|
|
301
|
+
for i, x in enumerate(inputs):
|
|
302
|
+
h = getattr(self, f'h_{i}') # (dim,)
|
|
303
|
+
s = getattr(self, f's_{i}') # (dim,)
|
|
304
|
+
|
|
305
|
+
output_shape = list(x.shape)
|
|
306
|
+
output_shape[-1] = self.out_features
|
|
307
|
+
sketch = torch.zeros(output_shape, device=x.device, dtype=x.dtype)
|
|
308
|
+
|
|
309
|
+
weighted_x = x * s # (..., dim)
|
|
310
|
+
|
|
311
|
+
flat_x = weighted_x.reshape(-1, weighted_x.shape[-1]) # (N, dim)
|
|
312
|
+
flat_sketch = sketch.view(-1, self.out_features) # (N, out)
|
|
313
|
+
|
|
314
|
+
h_expanded = h.expand(flat_x.shape[0], -1)
|
|
315
|
+
|
|
316
|
+
flat_sketch.scatter_add_(1, h_expanded, flat_x)
|
|
317
|
+
|
|
318
|
+
sketch = flat_sketch.view(output_shape)
|
|
319
|
+
|
|
320
|
+
fft_x = torch.fft.rfft(sketch, dim=-1)
|
|
321
|
+
|
|
322
|
+
if fft_product is None:
|
|
323
|
+
fft_product = fft_x
|
|
324
|
+
else:
|
|
325
|
+
fft_product = fft_product * fft_x
|
|
326
|
+
|
|
327
|
+
if fft_product is None:
|
|
328
|
+
raise ValueError("No inputs processed")
|
|
329
|
+
|
|
330
|
+
output = torch.fft.irfft(fft_product, n=self.out_features, dim=-1)
|
|
331
|
+
|
|
332
|
+
if self.channel_first:
|
|
333
|
+
output = output.movedim(-1, 1)
|
|
334
|
+
|
|
335
|
+
return output
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from orbit.model import BaseBlock, register_model
|
|
7
|
+
from orbit.model.block.mlp import MLP
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseGate(BaseBlock):
|
|
11
|
+
''' 门控模块基类。
|
|
12
|
+
|
|
13
|
+
提供通用的 MLP/Linear 变换逻辑。
|
|
14
|
+
'''
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
in_features: int,
|
|
19
|
+
out_features: int,
|
|
20
|
+
bias: bool = True,
|
|
21
|
+
use_mlp: bool = False,
|
|
22
|
+
hidden_features: int = None,
|
|
23
|
+
override_repr: bool = True
|
|
24
|
+
):
|
|
25
|
+
''' 初始化 BaseGate。
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
in_features (int): 输入特征维度。
|
|
29
|
+
out_features (int): 输出特征维度。
|
|
30
|
+
bias (bool, optional): 是否使用偏置。默认为 True。
|
|
31
|
+
use_mlp (bool, optional): 是否使用 MLP 进行变换。默认为 False。
|
|
32
|
+
hidden_features (int, optional): MLP 的隐藏层维度。仅在 use_mlp=True 时有效。默认为 None。
|
|
33
|
+
'''
|
|
34
|
+
super(BaseGate, self).__init__()
|
|
35
|
+
self.in_features = in_features
|
|
36
|
+
self.out_features = out_features
|
|
37
|
+
self.bias = bias
|
|
38
|
+
self.use_mlp = use_mlp
|
|
39
|
+
self.hidden_features = hidden_features
|
|
40
|
+
self.override_repr = override_repr
|
|
41
|
+
|
|
42
|
+
if use_mlp:
|
|
43
|
+
hidden_features = hidden_features or in_features
|
|
44
|
+
self.mlp = MLP(
|
|
45
|
+
in_features=in_features,
|
|
46
|
+
hidden_features=hidden_features,
|
|
47
|
+
out_features=out_features,
|
|
48
|
+
use_gate=False,
|
|
49
|
+
dropout=0.0
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
53
|
+
|
|
54
|
+
def _transform(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
if self.use_mlp:
|
|
56
|
+
return self.mlp(x)
|
|
57
|
+
else:
|
|
58
|
+
return self.linear(x)
|
|
59
|
+
|
|
60
|
+
def _get_repr_args(self) -> list[str]:
|
|
61
|
+
args = [
|
|
62
|
+
f"in_features={self.in_features}",
|
|
63
|
+
f"out_features={self.out_features}",
|
|
64
|
+
f"bias={self.bias}",
|
|
65
|
+
f"use_mlp={self.use_mlp}"
|
|
66
|
+
]
|
|
67
|
+
if self.use_mlp and self.hidden_features is not None:
|
|
68
|
+
args.append(f"hidden_features={self.hidden_features}")
|
|
69
|
+
return args
|
|
70
|
+
|
|
71
|
+
def __repr__(self):
|
|
72
|
+
if self.override_repr:
|
|
73
|
+
return f"{self.__class__.__name__}({', '.join(self._get_repr_args())})"
|
|
74
|
+
return super().__repr__()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@register_model()
|
|
78
|
+
class SigmoidGate(BaseGate):
|
|
79
|
+
''' Sigmoid 门控模块。
|
|
80
|
+
|
|
81
|
+
实现: Sigmoid(Linear(x)) 或 Sigmoid(MLP(x))
|
|
82
|
+
用于生成 0 到 1 之间的门控值。
|
|
83
|
+
'''
|
|
84
|
+
|
|
85
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
''' 前向传播。
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
x (torch.Tensor): 输入张量。
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
torch.Tensor: 门控值,范围 [0, 1]。
|
|
93
|
+
'''
|
|
94
|
+
return torch.sigmoid(self._transform(x))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@register_model()
|
|
98
|
+
class TanhGate(BaseGate):
|
|
99
|
+
''' Tanh 门控模块。
|
|
100
|
+
|
|
101
|
+
实现: Tanh(Linear(x)) 或 Tanh(MLP(x))
|
|
102
|
+
用于生成 -1 到 1 之间的门控值。
|
|
103
|
+
'''
|
|
104
|
+
|
|
105
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
106
|
+
''' 前向传播。
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
x (torch.Tensor): 输入张量。
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
torch.Tensor: 门控值,范围 [-1, 1]。
|
|
113
|
+
'''
|
|
114
|
+
return torch.tanh(self._transform(x))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@register_model()
|
|
118
|
+
class SoftmaxGate(BaseGate):
|
|
119
|
+
''' Softmax 门控模块。
|
|
120
|
+
|
|
121
|
+
实现: Softmax(Linear(x)) 或 Softmax(MLP(x))
|
|
122
|
+
用于生成和为 1 的门控值。
|
|
123
|
+
'''
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
in_features: int,
|
|
128
|
+
out_features: int,
|
|
129
|
+
dim: int = -1,
|
|
130
|
+
temperature: float = 1.0,
|
|
131
|
+
bias: bool = True,
|
|
132
|
+
use_mlp: bool = False,
|
|
133
|
+
hidden_features: int = None,
|
|
134
|
+
override_repr: bool = True
|
|
135
|
+
):
|
|
136
|
+
''' 初始化 SoftmaxGate。
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
in_features (int): 输入特征维度。
|
|
140
|
+
out_features (int): 输出特征维度。
|
|
141
|
+
dim (int, optional): Softmax 操作的维度。默认为 -1。
|
|
142
|
+
temperature (float, optional): 温度系数,用于控制分布的平滑程度。默认为 1.0。
|
|
143
|
+
bias (bool, optional): 是否使用偏置。默认为 True。
|
|
144
|
+
use_mlp (bool, optional): 是否使用 MLP 进行变换。默认为 False。
|
|
145
|
+
hidden_features (int, optional): MLP 的隐藏层维度。仅在 use_mlp=True 时有效。默认为 None。
|
|
146
|
+
'''
|
|
147
|
+
super(SoftmaxGate, self).__init__(
|
|
148
|
+
in_features=in_features,
|
|
149
|
+
out_features=out_features,
|
|
150
|
+
bias=bias,
|
|
151
|
+
use_mlp=use_mlp,
|
|
152
|
+
hidden_features=hidden_features,
|
|
153
|
+
override_repr=override_repr
|
|
154
|
+
)
|
|
155
|
+
self.dim = dim
|
|
156
|
+
self.temperature = temperature
|
|
157
|
+
|
|
158
|
+
def _get_repr_args(self) -> list[str]:
|
|
159
|
+
args = super()._get_repr_args()
|
|
160
|
+
args.append(f"dim={self.dim}")
|
|
161
|
+
args.append(f"temperature={self.temperature}")
|
|
162
|
+
return args
|
|
163
|
+
|
|
164
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
165
|
+
''' 前向传播。
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
x (torch.Tensor): 输入张量。
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
torch.Tensor: 门控值,和为 1。
|
|
172
|
+
'''
|
|
173
|
+
x = self._transform(x)
|
|
174
|
+
return F.softmax(x / self.temperature, dim=self.dim)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@register_model()
|
|
178
|
+
class GLUGate(BaseBlock):
|
|
179
|
+
''' Gated Linear Unit (GLU) 门控模块。
|
|
180
|
+
|
|
181
|
+
支持多种激活函数 (Sigmoid, Tanh, ReLU, GELU, SiLU)。
|
|
182
|
+
实现: (x * W + b) * Activation(x * V + c)
|
|
183
|
+
'''
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
in_features: int,
|
|
188
|
+
hidden_features: int,
|
|
189
|
+
out_features: int,
|
|
190
|
+
activation: str = 'sigmoid',
|
|
191
|
+
bias: bool = True,
|
|
192
|
+
override_repr: bool = True
|
|
193
|
+
):
|
|
194
|
+
''' 初始化 GLUGate。
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
in_features (int): 输入维度。
|
|
198
|
+
hidden_features (int): 隐藏层维度 (投影维度)。
|
|
199
|
+
out_features (int): 输出维度。
|
|
200
|
+
activation (str, optional): 激活函数类型 ('sigmoid', 'tanh', 'relu', 'gelu', 'silu')。默认为 'sigmoid'。
|
|
201
|
+
bias (bool, optional): Linear 层是否使用偏置。默认为 True。
|
|
202
|
+
'''
|
|
203
|
+
super(GLUGate, self).__init__()
|
|
204
|
+
self.in_features = in_features
|
|
205
|
+
self.hidden_features = hidden_features
|
|
206
|
+
self.out_features = out_features
|
|
207
|
+
self.activation_name = activation
|
|
208
|
+
self.bias = bias
|
|
209
|
+
self.override_repr = override_repr
|
|
210
|
+
|
|
211
|
+
self.proj = nn.Linear(in_features, hidden_features * 2, bias=bias)
|
|
212
|
+
self.out_proj = nn.Linear(hidden_features, out_features, bias=bias)
|
|
213
|
+
|
|
214
|
+
if activation == 'sigmoid':
|
|
215
|
+
self.act = nn.Sigmoid()
|
|
216
|
+
elif activation == 'tanh':
|
|
217
|
+
self.act = nn.Tanh()
|
|
218
|
+
elif activation == 'relu':
|
|
219
|
+
self.act = nn.ReLU()
|
|
220
|
+
elif activation == 'gelu':
|
|
221
|
+
self.act = nn.GELU()
|
|
222
|
+
elif activation == 'silu':
|
|
223
|
+
self.act = nn.SiLU()
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"Unsupported activation: {activation}")
|
|
226
|
+
|
|
227
|
+
def __repr__(self):
|
|
228
|
+
if self.override_repr:
|
|
229
|
+
args = [
|
|
230
|
+
f"in_features={self.in_features}",
|
|
231
|
+
f"hidden_features={self.hidden_features}",
|
|
232
|
+
f"out_features={self.out_features}",
|
|
233
|
+
f"activation='{self.activation_name}'",
|
|
234
|
+
f"bias={self.bias}"
|
|
235
|
+
]
|
|
236
|
+
return f"{self.__class__.__name__}({', '.join(args)})"
|
|
237
|
+
return super().__repr__()
|
|
238
|
+
|
|
239
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
240
|
+
''' 前向传播。
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
x (torch.Tensor): 输入张量。
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
torch.Tensor: 输出张量。
|
|
247
|
+
'''
|
|
248
|
+
x = self.proj(x)
|
|
249
|
+
x, gate = x.chunk(2, dim=-1)
|
|
250
|
+
x = x * self.act(gate)
|
|
251
|
+
return self.out_proj(x)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@dataclass
|
|
255
|
+
class TopKGateOutput:
|
|
256
|
+
logits: torch.Tensor
|
|
257
|
+
indices: torch.Tensor
|
|
258
|
+
values: torch.Tensor
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def output(self) -> torch.Tensor:
|
|
262
|
+
output = torch.zeros_like(self.logits)
|
|
263
|
+
output.scatter_(-1, self.indices, self.values)
|
|
264
|
+
return output
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@register_model()
|
|
268
|
+
class TopKGate(BaseGate):
|
|
269
|
+
''' Top-K 门控模块。
|
|
270
|
+
|
|
271
|
+
只保留 Top-K 个最大的值,其余置零。
|
|
272
|
+
支持返回详细路由信息以供 MoE 使用。
|
|
273
|
+
'''
|
|
274
|
+
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
in_features: int,
|
|
278
|
+
out_features: int,
|
|
279
|
+
k: int = 1,
|
|
280
|
+
bias: bool = True,
|
|
281
|
+
use_mlp: bool = False,
|
|
282
|
+
hidden_features: int = None,
|
|
283
|
+
post_softmax: bool = False,
|
|
284
|
+
override_repr: bool = True
|
|
285
|
+
):
|
|
286
|
+
''' 初始化 TopKGate。
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
in_features (int): 输入特征维度。
|
|
290
|
+
out_features (int): 输出特征维度。
|
|
291
|
+
k (int, optional): 保留的 Top-K 值的数量。默认为 1。
|
|
292
|
+
bias (bool, optional): 是否使用偏置。默认为 True。
|
|
293
|
+
use_mlp (bool, optional): 是否使用 MLP 进行变换。默认为 False。
|
|
294
|
+
hidden_features (int, optional): MLP 的隐藏层维度。仅在 use_mlp=True 时有效。默认为 None。
|
|
295
|
+
post_softmax (bool, optional): 是否在 Top-K 选择后对值进行 Softmax 归一化。默认为 False。
|
|
296
|
+
'''
|
|
297
|
+
super(TopKGate, self).__init__(
|
|
298
|
+
in_features=in_features,
|
|
299
|
+
out_features=out_features,
|
|
300
|
+
bias=bias,
|
|
301
|
+
use_mlp=use_mlp,
|
|
302
|
+
hidden_features=hidden_features,
|
|
303
|
+
override_repr=override_repr
|
|
304
|
+
)
|
|
305
|
+
self.k = k
|
|
306
|
+
self.post_softmax = post_softmax
|
|
307
|
+
|
|
308
|
+
def _get_repr_args(self) -> list[str]:
|
|
309
|
+
args = super()._get_repr_args()
|
|
310
|
+
args.append(f"k={self.k}")
|
|
311
|
+
args.append(f"post_softmax={self.post_softmax}")
|
|
312
|
+
return args
|
|
313
|
+
|
|
314
|
+
def forward(self, x: torch.Tensor) -> TopKGateOutput:
|
|
315
|
+
''' 前向传播。
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
x (torch.Tensor): 输入张量。
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
TopKGateOutput: 包含 logits, indices, values 的数据类。
|
|
322
|
+
'''
|
|
323
|
+
logits = self._transform(x)
|
|
324
|
+
|
|
325
|
+
topk_values, topk_indices = torch.topk(logits, self.k, dim=-1)
|
|
326
|
+
|
|
327
|
+
if self.post_softmax:
|
|
328
|
+
topk_values = F.softmax(topk_values, dim=-1)
|
|
329
|
+
|
|
330
|
+
return TopKGateOutput(
|
|
331
|
+
logits=logits,
|
|
332
|
+
indices=topk_indices,
|
|
333
|
+
values=topk_values
|
|
334
|
+
)
|