orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,335 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Optional, List
5
+ from dataclasses import dataclass
6
+
7
+ from orbit.model import BaseBlock, register_model
8
+
9
+
10
+ @register_model()
11
+ class LowRankFusion(BaseBlock):
12
+ ''' Low-rank Multimodal Fusion (LMF) 模块。
13
+
14
+ 使用低秩分解近似多模态外积,通过将权重张量分解为模态特定因子的组合来实现高效融合。
15
+ '''
16
+
17
+ def __init__(self, in_features: List[int], out_features: int, rank: int, dropout: float = 0.0, channel_first: bool = False):
18
+ ''' 初始化 LowRankFusion。
19
+
20
+ Args:
21
+ in_features (List[int]): 每个输入模态的特征维度列表。
22
+ out_features (int): 融合后的输出特征维度。
23
+ rank (int): 低秩分解的秩(rank)。
24
+ dropout (float, optional): Dropout 概率。默认为 0.0。
25
+ channel_first (bool, optional): 特征维度是否在第 1 维 (如 CNN [B, C, H, W])。
26
+ 如果为 False,则假设特征在最后一维 (如 Transformer [B, L, C])。默认为 False。
27
+ '''
28
+ super().__init__()
29
+
30
+ if not in_features:
31
+ raise ValueError("in_features cannot be empty")
32
+
33
+ self.in_features = in_features
34
+ self.out_features = out_features
35
+ self.rank = rank
36
+ self.channel_first = channel_first
37
+
38
+ self.modality_factors = nn.ModuleList([
39
+ nn.Linear(dim, rank, bias=True) for dim in in_features
40
+ ])
41
+
42
+ self.fusion_weights = nn.Linear(rank, out_features, bias=True)
43
+
44
+ self.dropout = nn.Dropout(dropout)
45
+
46
+ self._init_weights(self)
47
+
48
+ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
49
+ ''' 前向传播。
50
+
51
+ Args:
52
+ inputs (List[torch.Tensor]): 输入张量列表。所有张量在除最后一个维度外的其他维度应匹配。
53
+
54
+ Returns:
55
+ torch.Tensor: 融合后的输出张量。
56
+ '''
57
+ if len(inputs) != len(self.in_features):
58
+ raise ValueError(f"Expected {len(self.in_features)} inputs, got {len(inputs)}")
59
+
60
+ if self.channel_first:
61
+ inputs = [x.movedim(1, -1) for x in inputs]
62
+
63
+ fusion_tensor: Optional[torch.Tensor] = None
64
+
65
+ for i, x in enumerate(inputs):
66
+ projected = self.modality_factors[i](x)
67
+
68
+ if fusion_tensor is None:
69
+ fusion_tensor = projected
70
+ else:
71
+ fusion_tensor = fusion_tensor * projected
72
+
73
+ if fusion_tensor is None:
74
+ raise ValueError("No inputs processed")
75
+
76
+ output = self.fusion_weights(self.dropout(fusion_tensor))
77
+
78
+ if self.channel_first:
79
+ output = output.movedim(-1, 1)
80
+
81
+ return output
82
+
83
+
84
+ @register_model()
85
+ class GatedMultimodalUnit(BaseBlock):
86
+ ''' Gated Multimodal Unit (GMU) 模块。
87
+
88
+ 通过学习门控机制来控制每个模态对最终融合表示的贡献。
89
+ '''
90
+
91
+ def __init__(self, in_features: List[int], out_features: int, channel_first: bool = False):
92
+ ''' 初始化 GatedMultimodalUnit。
93
+
94
+ Args:
95
+ in_features (List[int]): 每个输入模态的特征维度列表。
96
+ out_features (int): 隐藏层特征维度。
97
+ channel_first (bool, optional): 特征维度是否在第 1 维。默认为 False。
98
+ '''
99
+ super().__init__()
100
+
101
+ self.in_features = in_features
102
+ self.out_features = out_features
103
+ self.channel_first = channel_first
104
+
105
+ self.feature_transforms = nn.ModuleList([
106
+ nn.Linear(dim, out_features) for dim in in_features
107
+ ])
108
+
109
+ total_in_features = sum(in_features)
110
+ self.gate_net = nn.Linear(total_in_features, len(in_features))
111
+
112
+ self._init_weights(self)
113
+
114
+ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
115
+ ''' 前向传播。
116
+
117
+ Args:
118
+ inputs (List[torch.Tensor]): 输入张量列表。
119
+
120
+ Returns:
121
+ torch.Tensor: 融合后的输出张量。
122
+ '''
123
+ if len(inputs) != len(self.in_features):
124
+ raise ValueError(f"Expected {len(self.in_features)} inputs, got {len(inputs)}")
125
+
126
+ processed_inputs = []
127
+ if self.channel_first:
128
+ processed_inputs = [x.movedim(1, -1) for x in inputs]
129
+ else:
130
+ processed_inputs = inputs
131
+
132
+ hidden_features = []
133
+ for i, x in enumerate(processed_inputs):
134
+ h = torch.tanh(self.feature_transforms[i](x))
135
+ hidden_features.append(h)
136
+
137
+ concatenated_input = torch.cat(processed_inputs, dim=-1)
138
+ gate_logits = self.gate_net(concatenated_input) # (B, ..., num_modalities)
139
+ gates = torch.softmax(gate_logits, dim=-1)
140
+
141
+ output = torch.zeros_like(hidden_features[0])
142
+
143
+ for i, h in enumerate(hidden_features):
144
+ g = gates[..., i:i+1] # (B, ..., 1)
145
+ output += g * h
146
+
147
+ if self.channel_first:
148
+ output = output.movedim(-1, 1)
149
+
150
+ return output
151
+
152
+
153
+ @register_model()
154
+ class DiffusionMapsFusion(BaseBlock):
155
+ ''' Diffusion Maps Fusion 模块。
156
+
157
+ 基于流形学习中的扩散映射思想。
158
+ 通过在特征通道间构建图拉普拉斯算子(或归一化亲和矩阵),
159
+ 在流形空间进行特征对齐和交叉扩散 (Cross-diffusion)。
160
+
161
+ 目前主要支持两个模态的融合。
162
+ '''
163
+
164
+ def __init__(self, in_features: List[int], out_features: int, sigma: float = 1.0, channel_first: bool = False):
165
+ ''' 初始化 DiffusionMapsFusion。
166
+
167
+ Args:
168
+ in_features (List[int]): 两个输入模态的特征维度列表。
169
+ out_features (int): 输出特征维度。
170
+ sigma (float, optional): 高斯核的带宽参数。默认为 1.0。
171
+ channel_first (bool, optional): 特征维度是否在第 1 维。默认为 False。
172
+ '''
173
+ super().__init__()
174
+
175
+ if len(in_features) != 2:
176
+ raise ValueError("DiffusionMapsFusion currently only supports exactly 2 modalities.")
177
+
178
+ self.in_features = in_features
179
+ self.out_features = out_features
180
+ self.sigma = sigma
181
+ self.channel_first = channel_first
182
+
183
+ self.projections = nn.ModuleList([
184
+ nn.Linear(dim, out_features) for dim in in_features
185
+ ])
186
+
187
+ self.output_proj = nn.Linear(out_features * 2, out_features)
188
+
189
+ self._init_weights(self)
190
+
191
+ def _compute_affinity(self, x: torch.Tensor) -> torch.Tensor:
192
+ ''' 计算特征通道间的归一化亲和矩阵 (Diffusion Operator)。
193
+
194
+ Args:
195
+ x (torch.Tensor): 输入特征 (C, N)。C 是特征通道数,N 是样本数。
196
+
197
+ Returns:
198
+ torch.Tensor: 归一化的亲和矩阵 (C, C)。
199
+ '''
200
+ # ||x_i - x_j||^2 = ||x_i||^2 + ||x_j||^2 - 2 <x_i, x_j>
201
+ sq_norm = (x ** 2).sum(1, keepdim=True)
202
+ dist_sq = sq_norm + sq_norm.t() - 2 * torch.mm(x, x.t())
203
+
204
+ W = torch.exp(-dist_sq / (2 * self.sigma ** 2))
205
+
206
+ D_inv = 1.0 / (W.sum(1, keepdim=True) + 1e-8)
207
+ P = D_inv * W
208
+
209
+ return P
210
+
211
+ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
212
+ ''' 前向传播。
213
+
214
+ Args:
215
+ inputs (List[torch.Tensor]): 两个输入张量列表。
216
+
217
+ Returns:
218
+ torch.Tensor: 融合后的输出张量。
219
+ '''
220
+ if len(inputs) != 2:
221
+ raise ValueError("Expected exactly 2 inputs")
222
+
223
+ processed_inputs = []
224
+ if self.channel_first:
225
+ processed_inputs = [x.movedim(1, -1) for x in inputs]
226
+ else:
227
+ processed_inputs = inputs
228
+
229
+ proj_feats = [self.projections[i](x) for i, x in enumerate(processed_inputs)]
230
+ xA, xB = proj_feats[0], proj_feats[1]
231
+
232
+ flat_xA = xA.reshape(-1, xA.shape[-1])
233
+ flat_xB = xB.reshape(-1, xB.shape[-1])
234
+
235
+ xA_T = flat_xA.t()
236
+ xB_T = flat_xB.t()
237
+
238
+ P_A = self._compute_affinity(xA_T) # (C, C)
239
+ P_B = self._compute_affinity(xB_T) # (C, C)
240
+
241
+ diffused_A_T = torch.mm(P_B, xA_T) # (C, N_samples)
242
+ diffused_B_T = torch.mm(P_A, xB_T) # (C, N_samples)
243
+
244
+ diffused_A = diffused_A_T.t().view(xA.shape)
245
+ diffused_B = diffused_B_T.t().view(xB.shape)
246
+
247
+ combined = torch.cat([diffused_A, diffused_B], dim=-1)
248
+ output = self.output_proj(combined)
249
+
250
+ if self.channel_first:
251
+ output = output.movedim(-1, 1)
252
+
253
+ return output
254
+
255
+
256
+ @register_model()
257
+ class CompactMultimodalPooling(BaseBlock):
258
+ ''' Compact Multimodal Pooling (MCB/CBP) 模块。
259
+
260
+ 通过 Count Sketch 和 FFT 近似多模态特征的外积。
261
+ 支持两个或多个模态的融合。
262
+ '''
263
+
264
+ def __init__(self, in_features: List[int], out_features: int, channel_first: bool = False):
265
+ ''' 初始化 CompactMultimodalPooling。
266
+
267
+ Args:
268
+ in_features (List[int]): 每个输入模态的特征维度列表。
269
+ out_features (int): 输出特征维度。通常应该比输入维度高,以保持信息。
270
+ channel_first (bool, optional): 特征维度是否在第 1 维 (如 CNN [B, C, H, W])。
271
+ 如果为 False,则假设特征在最后一维 (如 Transformer [B, L, C])。默认为 False。
272
+ '''
273
+ super().__init__()
274
+
275
+ self.in_features = in_features
276
+ self.out_features = out_features
277
+ self.channel_first = channel_first
278
+
279
+ for i, dim in enumerate(in_features):
280
+ self.register_buffer(f'h_{i}', torch.randint(0, out_features, (dim,)))
281
+ self.register_buffer(f's_{i}', torch.randint(0, 2, (dim,)) * 2 - 1) # Map {0, 1} to {-1, 1}
282
+
283
+ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
284
+ ''' 前向传播。
285
+
286
+ Args:
287
+ inputs (List[torch.Tensor]): 输入张量列表。
288
+
289
+ Returns:
290
+ torch.Tensor: 融合后的输出张量。
291
+ '''
292
+ if len(inputs) != len(self.in_features):
293
+ raise ValueError(f"Expected {len(self.in_features)} inputs, got {len(inputs)}")
294
+
295
+ if self.channel_first:
296
+ inputs = [x.movedim(1, -1) for x in inputs]
297
+
298
+ batch_size = inputs[0].size(0)
299
+ fft_product: Optional[torch.Tensor] = None
300
+
301
+ for i, x in enumerate(inputs):
302
+ h = getattr(self, f'h_{i}') # (dim,)
303
+ s = getattr(self, f's_{i}') # (dim,)
304
+
305
+ output_shape = list(x.shape)
306
+ output_shape[-1] = self.out_features
307
+ sketch = torch.zeros(output_shape, device=x.device, dtype=x.dtype)
308
+
309
+ weighted_x = x * s # (..., dim)
310
+
311
+ flat_x = weighted_x.reshape(-1, weighted_x.shape[-1]) # (N, dim)
312
+ flat_sketch = sketch.view(-1, self.out_features) # (N, out)
313
+
314
+ h_expanded = h.expand(flat_x.shape[0], -1)
315
+
316
+ flat_sketch.scatter_add_(1, h_expanded, flat_x)
317
+
318
+ sketch = flat_sketch.view(output_shape)
319
+
320
+ fft_x = torch.fft.rfft(sketch, dim=-1)
321
+
322
+ if fft_product is None:
323
+ fft_product = fft_x
324
+ else:
325
+ fft_product = fft_product * fft_x
326
+
327
+ if fft_product is None:
328
+ raise ValueError("No inputs processed")
329
+
330
+ output = torch.fft.irfft(fft_product, n=self.out_features, dim=-1)
331
+
332
+ if self.channel_first:
333
+ output = output.movedim(-1, 1)
334
+
335
+ return output
@@ -0,0 +1,334 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+
6
+ from orbit.model import BaseBlock, register_model
7
+ from orbit.model.block.mlp import MLP
8
+
9
+
10
+ class BaseGate(BaseBlock):
11
+ ''' 门控模块基类。
12
+
13
+ 提供通用的 MLP/Linear 变换逻辑。
14
+ '''
15
+
16
+ def __init__(
17
+ self,
18
+ in_features: int,
19
+ out_features: int,
20
+ bias: bool = True,
21
+ use_mlp: bool = False,
22
+ hidden_features: int = None,
23
+ override_repr: bool = True
24
+ ):
25
+ ''' 初始化 BaseGate。
26
+
27
+ Args:
28
+ in_features (int): 输入特征维度。
29
+ out_features (int): 输出特征维度。
30
+ bias (bool, optional): 是否使用偏置。默认为 True。
31
+ use_mlp (bool, optional): 是否使用 MLP 进行变换。默认为 False。
32
+ hidden_features (int, optional): MLP 的隐藏层维度。仅在 use_mlp=True 时有效。默认为 None。
33
+ '''
34
+ super(BaseGate, self).__init__()
35
+ self.in_features = in_features
36
+ self.out_features = out_features
37
+ self.bias = bias
38
+ self.use_mlp = use_mlp
39
+ self.hidden_features = hidden_features
40
+ self.override_repr = override_repr
41
+
42
+ if use_mlp:
43
+ hidden_features = hidden_features or in_features
44
+ self.mlp = MLP(
45
+ in_features=in_features,
46
+ hidden_features=hidden_features,
47
+ out_features=out_features,
48
+ use_gate=False,
49
+ dropout=0.0
50
+ )
51
+ else:
52
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
53
+
54
+ def _transform(self, x: torch.Tensor) -> torch.Tensor:
55
+ if self.use_mlp:
56
+ return self.mlp(x)
57
+ else:
58
+ return self.linear(x)
59
+
60
+ def _get_repr_args(self) -> list[str]:
61
+ args = [
62
+ f"in_features={self.in_features}",
63
+ f"out_features={self.out_features}",
64
+ f"bias={self.bias}",
65
+ f"use_mlp={self.use_mlp}"
66
+ ]
67
+ if self.use_mlp and self.hidden_features is not None:
68
+ args.append(f"hidden_features={self.hidden_features}")
69
+ return args
70
+
71
+ def __repr__(self):
72
+ if self.override_repr:
73
+ return f"{self.__class__.__name__}({', '.join(self._get_repr_args())})"
74
+ return super().__repr__()
75
+
76
+
77
+ @register_model()
78
+ class SigmoidGate(BaseGate):
79
+ ''' Sigmoid 门控模块。
80
+
81
+ 实现: Sigmoid(Linear(x)) 或 Sigmoid(MLP(x))
82
+ 用于生成 0 到 1 之间的门控值。
83
+ '''
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ ''' 前向传播。
87
+
88
+ Args:
89
+ x (torch.Tensor): 输入张量。
90
+
91
+ Returns:
92
+ torch.Tensor: 门控值,范围 [0, 1]。
93
+ '''
94
+ return torch.sigmoid(self._transform(x))
95
+
96
+
97
+ @register_model()
98
+ class TanhGate(BaseGate):
99
+ ''' Tanh 门控模块。
100
+
101
+ 实现: Tanh(Linear(x)) 或 Tanh(MLP(x))
102
+ 用于生成 -1 到 1 之间的门控值。
103
+ '''
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ ''' 前向传播。
107
+
108
+ Args:
109
+ x (torch.Tensor): 输入张量。
110
+
111
+ Returns:
112
+ torch.Tensor: 门控值,范围 [-1, 1]。
113
+ '''
114
+ return torch.tanh(self._transform(x))
115
+
116
+
117
+ @register_model()
118
+ class SoftmaxGate(BaseGate):
119
+ ''' Softmax 门控模块。
120
+
121
+ 实现: Softmax(Linear(x)) 或 Softmax(MLP(x))
122
+ 用于生成和为 1 的门控值。
123
+ '''
124
+
125
+ def __init__(
126
+ self,
127
+ in_features: int,
128
+ out_features: int,
129
+ dim: int = -1,
130
+ temperature: float = 1.0,
131
+ bias: bool = True,
132
+ use_mlp: bool = False,
133
+ hidden_features: int = None,
134
+ override_repr: bool = True
135
+ ):
136
+ ''' 初始化 SoftmaxGate。
137
+
138
+ Args:
139
+ in_features (int): 输入特征维度。
140
+ out_features (int): 输出特征维度。
141
+ dim (int, optional): Softmax 操作的维度。默认为 -1。
142
+ temperature (float, optional): 温度系数,用于控制分布的平滑程度。默认为 1.0。
143
+ bias (bool, optional): 是否使用偏置。默认为 True。
144
+ use_mlp (bool, optional): 是否使用 MLP 进行变换。默认为 False。
145
+ hidden_features (int, optional): MLP 的隐藏层维度。仅在 use_mlp=True 时有效。默认为 None。
146
+ '''
147
+ super(SoftmaxGate, self).__init__(
148
+ in_features=in_features,
149
+ out_features=out_features,
150
+ bias=bias,
151
+ use_mlp=use_mlp,
152
+ hidden_features=hidden_features,
153
+ override_repr=override_repr
154
+ )
155
+ self.dim = dim
156
+ self.temperature = temperature
157
+
158
+ def _get_repr_args(self) -> list[str]:
159
+ args = super()._get_repr_args()
160
+ args.append(f"dim={self.dim}")
161
+ args.append(f"temperature={self.temperature}")
162
+ return args
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ ''' 前向传播。
166
+
167
+ Args:
168
+ x (torch.Tensor): 输入张量。
169
+
170
+ Returns:
171
+ torch.Tensor: 门控值,和为 1。
172
+ '''
173
+ x = self._transform(x)
174
+ return F.softmax(x / self.temperature, dim=self.dim)
175
+
176
+
177
+ @register_model()
178
+ class GLUGate(BaseBlock):
179
+ ''' Gated Linear Unit (GLU) 门控模块。
180
+
181
+ 支持多种激活函数 (Sigmoid, Tanh, ReLU, GELU, SiLU)。
182
+ 实现: (x * W + b) * Activation(x * V + c)
183
+ '''
184
+
185
+ def __init__(
186
+ self,
187
+ in_features: int,
188
+ hidden_features: int,
189
+ out_features: int,
190
+ activation: str = 'sigmoid',
191
+ bias: bool = True,
192
+ override_repr: bool = True
193
+ ):
194
+ ''' 初始化 GLUGate。
195
+
196
+ Args:
197
+ in_features (int): 输入维度。
198
+ hidden_features (int): 隐藏层维度 (投影维度)。
199
+ out_features (int): 输出维度。
200
+ activation (str, optional): 激活函数类型 ('sigmoid', 'tanh', 'relu', 'gelu', 'silu')。默认为 'sigmoid'。
201
+ bias (bool, optional): Linear 层是否使用偏置。默认为 True。
202
+ '''
203
+ super(GLUGate, self).__init__()
204
+ self.in_features = in_features
205
+ self.hidden_features = hidden_features
206
+ self.out_features = out_features
207
+ self.activation_name = activation
208
+ self.bias = bias
209
+ self.override_repr = override_repr
210
+
211
+ self.proj = nn.Linear(in_features, hidden_features * 2, bias=bias)
212
+ self.out_proj = nn.Linear(hidden_features, out_features, bias=bias)
213
+
214
+ if activation == 'sigmoid':
215
+ self.act = nn.Sigmoid()
216
+ elif activation == 'tanh':
217
+ self.act = nn.Tanh()
218
+ elif activation == 'relu':
219
+ self.act = nn.ReLU()
220
+ elif activation == 'gelu':
221
+ self.act = nn.GELU()
222
+ elif activation == 'silu':
223
+ self.act = nn.SiLU()
224
+ else:
225
+ raise ValueError(f"Unsupported activation: {activation}")
226
+
227
+ def __repr__(self):
228
+ if self.override_repr:
229
+ args = [
230
+ f"in_features={self.in_features}",
231
+ f"hidden_features={self.hidden_features}",
232
+ f"out_features={self.out_features}",
233
+ f"activation='{self.activation_name}'",
234
+ f"bias={self.bias}"
235
+ ]
236
+ return f"{self.__class__.__name__}({', '.join(args)})"
237
+ return super().__repr__()
238
+
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ ''' 前向传播。
241
+
242
+ Args:
243
+ x (torch.Tensor): 输入张量。
244
+
245
+ Returns:
246
+ torch.Tensor: 输出张量。
247
+ '''
248
+ x = self.proj(x)
249
+ x, gate = x.chunk(2, dim=-1)
250
+ x = x * self.act(gate)
251
+ return self.out_proj(x)
252
+
253
+
254
+ @dataclass
255
+ class TopKGateOutput:
256
+ logits: torch.Tensor
257
+ indices: torch.Tensor
258
+ values: torch.Tensor
259
+
260
+ @property
261
+ def output(self) -> torch.Tensor:
262
+ output = torch.zeros_like(self.logits)
263
+ output.scatter_(-1, self.indices, self.values)
264
+ return output
265
+
266
+
267
+ @register_model()
268
+ class TopKGate(BaseGate):
269
+ ''' Top-K 门控模块。
270
+
271
+ 只保留 Top-K 个最大的值,其余置零。
272
+ 支持返回详细路由信息以供 MoE 使用。
273
+ '''
274
+
275
+ def __init__(
276
+ self,
277
+ in_features: int,
278
+ out_features: int,
279
+ k: int = 1,
280
+ bias: bool = True,
281
+ use_mlp: bool = False,
282
+ hidden_features: int = None,
283
+ post_softmax: bool = False,
284
+ override_repr: bool = True
285
+ ):
286
+ ''' 初始化 TopKGate。
287
+
288
+ Args:
289
+ in_features (int): 输入特征维度。
290
+ out_features (int): 输出特征维度。
291
+ k (int, optional): 保留的 Top-K 值的数量。默认为 1。
292
+ bias (bool, optional): 是否使用偏置。默认为 True。
293
+ use_mlp (bool, optional): 是否使用 MLP 进行变换。默认为 False。
294
+ hidden_features (int, optional): MLP 的隐藏层维度。仅在 use_mlp=True 时有效。默认为 None。
295
+ post_softmax (bool, optional): 是否在 Top-K 选择后对值进行 Softmax 归一化。默认为 False。
296
+ '''
297
+ super(TopKGate, self).__init__(
298
+ in_features=in_features,
299
+ out_features=out_features,
300
+ bias=bias,
301
+ use_mlp=use_mlp,
302
+ hidden_features=hidden_features,
303
+ override_repr=override_repr
304
+ )
305
+ self.k = k
306
+ self.post_softmax = post_softmax
307
+
308
+ def _get_repr_args(self) -> list[str]:
309
+ args = super()._get_repr_args()
310
+ args.append(f"k={self.k}")
311
+ args.append(f"post_softmax={self.post_softmax}")
312
+ return args
313
+
314
+ def forward(self, x: torch.Tensor) -> TopKGateOutput:
315
+ ''' 前向传播。
316
+
317
+ Args:
318
+ x (torch.Tensor): 输入张量。
319
+
320
+ Returns:
321
+ TopKGateOutput: 包含 logits, indices, values 的数据类。
322
+ '''
323
+ logits = self._transform(x)
324
+
325
+ topk_values, topk_indices = torch.topk(logits, self.k, dim=-1)
326
+
327
+ if self.post_softmax:
328
+ topk_values = F.softmax(topk_values, dim=-1)
329
+
330
+ return TopKGateOutput(
331
+ logits=logits,
332
+ indices=topk_indices,
333
+ values=topk_values
334
+ )