fusion-bench 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +6 -336
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/taskpool/__init__.py +2 -0
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +27 -14
- {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -13,348 +13,18 @@ from fusion_bench.method import BaseAlgorithm
|
|
|
13
13
|
from fusion_bench.method.simple_average import simple_average
|
|
14
14
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
15
15
|
from fusion_bench.modelpool import BaseModelPool
|
|
16
|
+
from fusion_bench.models.smile_moe.linear_from_module import (
|
|
17
|
+
ExpertNotTrainedError,
|
|
18
|
+
SmileCompressedLinear,
|
|
19
|
+
SmileGate,
|
|
20
|
+
SmileMoELinear,
|
|
21
|
+
)
|
|
16
22
|
from fusion_bench.models.utils import get_attr, set_attr
|
|
17
23
|
from fusion_bench.utils.parameters import print_parameters
|
|
18
24
|
|
|
19
25
|
log = logging.getLogger(__name__)
|
|
20
26
|
|
|
21
27
|
|
|
22
|
-
class ExpertNotTrainedError(Exception):
|
|
23
|
-
pass
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
|
|
27
|
-
"""
|
|
28
|
-
Check if a tensor or a list of tensors are all zeros.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
tensor (Tensor | List[Tensor]): A tensor or a list of tensors.
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
bool: True if all elements are zeros, False otherwise.
|
|
35
|
-
"""
|
|
36
|
-
if isinstance(tensor, Tensor):
|
|
37
|
-
return torch.allclose(tensor, torch.zeros_like(tensor))
|
|
38
|
-
else:
|
|
39
|
-
return all(_is_all_zeros(t) for t in tensor)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
43
|
-
"""
|
|
44
|
-
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
w (Tensor): The input tensor.
|
|
48
|
-
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
49
|
-
|
|
50
|
-
Returns:
|
|
51
|
-
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
52
|
-
"""
|
|
53
|
-
u, s, vh = torch.linalg.svd(
|
|
54
|
-
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
55
|
-
)
|
|
56
|
-
v = vh.T
|
|
57
|
-
return u, s, v
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def svd(
|
|
61
|
-
w: Tensor, full_matrices=True, accelerator=None
|
|
62
|
-
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
63
|
-
"""
|
|
64
|
-
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
w (Tensor): The input tensor.
|
|
68
|
-
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
69
|
-
accelerator (str): The device to perform the computation on.
|
|
70
|
-
|
|
71
|
-
Returns:
|
|
72
|
-
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
73
|
-
"""
|
|
74
|
-
if accelerator is None:
|
|
75
|
-
return _svd(w, full_matrices=full_matrices)
|
|
76
|
-
original_device = w.device
|
|
77
|
-
w = w.to(accelerator)
|
|
78
|
-
u, s, v = _svd(w)
|
|
79
|
-
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
class SmileGate(nn.Module):
|
|
83
|
-
def __init__(
|
|
84
|
-
self,
|
|
85
|
-
input_features: int,
|
|
86
|
-
w_diff_list: List[Tensor],
|
|
87
|
-
k: int,
|
|
88
|
-
svd_list=None, # cached `svd_list`, pass it to avoid recomputing
|
|
89
|
-
upscaling_accelerator=None,
|
|
90
|
-
):
|
|
91
|
-
"""
|
|
92
|
-
Initialize the SmileGate module.
|
|
93
|
-
|
|
94
|
-
Args:
|
|
95
|
-
input_features (int): The number of input features.
|
|
96
|
-
w_diff_list (List[Tensor]): A list of weight difference tensors.
|
|
97
|
-
k (int): The number of singular values to keep.
|
|
98
|
-
svd_list (List[Tuple[Tensor, Tensor, Tensor]]): Cached SVD results.
|
|
99
|
-
upscaling_accelerator (str): The device to perform the computation on.
|
|
100
|
-
"""
|
|
101
|
-
super().__init__()
|
|
102
|
-
self.input_features = input_features
|
|
103
|
-
self.num_experts = len(w_diff_list)
|
|
104
|
-
weights = []
|
|
105
|
-
for i, w_diff in enumerate(w_diff_list):
|
|
106
|
-
if svd_list is None:
|
|
107
|
-
u, s, v = svd(w_diff, accelerator=upscaling_accelerator)
|
|
108
|
-
else:
|
|
109
|
-
u, s, v = svd_list[i]
|
|
110
|
-
u = u[:, :k]
|
|
111
|
-
s = s[:k]
|
|
112
|
-
v = v[:, :k]
|
|
113
|
-
|
|
114
|
-
# weights.append((s * v).T)
|
|
115
|
-
weights.append(v.T)
|
|
116
|
-
self.k = s.size(0) # k is the actual k after truncation
|
|
117
|
-
|
|
118
|
-
weights = (
|
|
119
|
-
torch.stack(weights, dim=0)
|
|
120
|
-
.reshape(self.num_experts * self.k, -1)
|
|
121
|
-
.contiguous()
|
|
122
|
-
)
|
|
123
|
-
self.weights = nn.Parameter(
|
|
124
|
-
weights
|
|
125
|
-
) # weights should be a tensor of shape (num_experts * k, n)
|
|
126
|
-
|
|
127
|
-
def forward(self, x: Tensor):
|
|
128
|
-
"""
|
|
129
|
-
Forward pass of the SmileGate module.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
x (Tensor): The input tensor.
|
|
133
|
-
|
|
134
|
-
Returns:
|
|
135
|
-
Tensor: The routing weights.
|
|
136
|
-
"""
|
|
137
|
-
batch_size = x.size(0)
|
|
138
|
-
if self.num_experts == 1:
|
|
139
|
-
return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
|
|
140
|
-
|
|
141
|
-
routing_weights = F.linear(x, self.weights).view(
|
|
142
|
-
batch_size, self.num_experts, self.k
|
|
143
|
-
)
|
|
144
|
-
routing_weights = routing_weights.norm(p=2, dim=2)
|
|
145
|
-
return routing_weights
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
class SmileCompressedLinear(nn.Module):
|
|
149
|
-
def __init__(self, model: nn.Linear, k: int, svd_cache=None):
|
|
150
|
-
"""
|
|
151
|
-
Initialize the SmileCompressedLinear module.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
model (nn.Linear): The linear model to compress.
|
|
155
|
-
k (int): The number of singular values to keep.
|
|
156
|
-
svd_cache (Tuple[Tensor, Tensor, Tensor]): Cached SVD results.
|
|
157
|
-
"""
|
|
158
|
-
super().__init__()
|
|
159
|
-
if svd_cache is None:
|
|
160
|
-
u, s, v = svd(model.weight)
|
|
161
|
-
else:
|
|
162
|
-
u, s, v = svd_cache
|
|
163
|
-
if k > 0:
|
|
164
|
-
u = u[:, :k]
|
|
165
|
-
s = s[:k]
|
|
166
|
-
v = v[:, :k]
|
|
167
|
-
|
|
168
|
-
self.u = nn.Parameter(u)
|
|
169
|
-
self.svh = nn.Parameter((s * v).T)
|
|
170
|
-
|
|
171
|
-
if model.bias is not None:
|
|
172
|
-
self.bias = nn.Parameter(model.bias.data, requires_grad=True)
|
|
173
|
-
else:
|
|
174
|
-
self.register_parameter("bias", None)
|
|
175
|
-
|
|
176
|
-
def forward(self, x):
|
|
177
|
-
"""
|
|
178
|
-
Forward pass of the SmileCompressedLinear module.
|
|
179
|
-
|
|
180
|
-
Args:
|
|
181
|
-
x (Tensor): The input tensor.
|
|
182
|
-
|
|
183
|
-
Returns:
|
|
184
|
-
Tensor: The output tensor.
|
|
185
|
-
"""
|
|
186
|
-
x = F.linear(x, self.svh)
|
|
187
|
-
x = F.linear(x, self.u, self.bias)
|
|
188
|
-
return x
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
class SmileMoELinear(nn.Module):
|
|
192
|
-
@torch.no_grad()
|
|
193
|
-
def __init__(
|
|
194
|
-
self,
|
|
195
|
-
pretrained_model: nn.Linear,
|
|
196
|
-
finetuned_models: List[nn.Linear],
|
|
197
|
-
gate_k: int,
|
|
198
|
-
k: int,
|
|
199
|
-
top_k: int = 1,
|
|
200
|
-
full_matrices=True,
|
|
201
|
-
upscaling_accelerator=None,
|
|
202
|
-
routing_use_diff=True,
|
|
203
|
-
):
|
|
204
|
-
"""
|
|
205
|
-
Initialize the SmileMoELinear module.
|
|
206
|
-
|
|
207
|
-
Args:
|
|
208
|
-
pretrained_model (nn.Linear): The pretrained linear model.
|
|
209
|
-
finetuned_models (List[nn.Linear]): A list of fine-tuned linear models.
|
|
210
|
-
gate_k (int): The number of singular values to keep for the gate.
|
|
211
|
-
k (int): The number of singular values to keep for the experts.
|
|
212
|
-
top_k (int): The number of top experts to select.
|
|
213
|
-
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
214
|
-
upscaling_accelerator (str): The device to perform the computation on.
|
|
215
|
-
routing_use_diff (bool): Whether to use weight differences for routing.
|
|
216
|
-
"""
|
|
217
|
-
super().__init__()
|
|
218
|
-
self.num_experts = len(finetuned_models)
|
|
219
|
-
self.top_k = top_k
|
|
220
|
-
self.k = k
|
|
221
|
-
self.gate_k = gate_k
|
|
222
|
-
self.in_features = pretrained_model.in_features
|
|
223
|
-
self.out_features = pretrained_model.out_features
|
|
224
|
-
|
|
225
|
-
w_diff_list = [m.weight - pretrained_model.weight for m in finetuned_models]
|
|
226
|
-
if _is_all_zeros(w_diff_list):
|
|
227
|
-
# All fine-tuned models are identical to the pretrained model
|
|
228
|
-
raise ExpertNotTrainedError()
|
|
229
|
-
|
|
230
|
-
if routing_use_diff or k > 0:
|
|
231
|
-
svd_cache_list = [
|
|
232
|
-
svd(w, full_matrices=full_matrices, accelerator=upscaling_accelerator)
|
|
233
|
-
for w in w_diff_list
|
|
234
|
-
] # the svd cache list to avoid recomputing
|
|
235
|
-
|
|
236
|
-
# construct the gate network
|
|
237
|
-
if routing_use_diff:
|
|
238
|
-
self.gate = SmileGate(
|
|
239
|
-
input_features=self.in_features,
|
|
240
|
-
w_diff_list=w_diff_list,
|
|
241
|
-
k=gate_k,
|
|
242
|
-
svd_list=svd_cache_list,
|
|
243
|
-
upscaling_accelerator=upscaling_accelerator,
|
|
244
|
-
)
|
|
245
|
-
else:
|
|
246
|
-
self.gate = SmileGate(
|
|
247
|
-
input_features=self.in_features,
|
|
248
|
-
w_diff_list=[m.weight for m in finetuned_models],
|
|
249
|
-
k=gate_k,
|
|
250
|
-
svd_list=None,
|
|
251
|
-
upscaling_accelerator=upscaling_accelerator,
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# construct experts
|
|
255
|
-
for m, w_diff in zip(finetuned_models, w_diff_list):
|
|
256
|
-
m.weight.data = w_diff
|
|
257
|
-
if k > 0:
|
|
258
|
-
experts = [
|
|
259
|
-
SmileCompressedLinear(m, k, svd_cache=svd_cache)
|
|
260
|
-
for m, svd_cache in zip(finetuned_models, svd_cache_list)
|
|
261
|
-
]
|
|
262
|
-
else:
|
|
263
|
-
# if k is not set (<0), we use the full fine-tuned model
|
|
264
|
-
experts = finetuned_models
|
|
265
|
-
self.experts = nn.ModuleList(experts)
|
|
266
|
-
|
|
267
|
-
if pretrained_model.bias is not None:
|
|
268
|
-
for m in experts:
|
|
269
|
-
m.bias.data = m.bias.data - pretrained_model.bias
|
|
270
|
-
# assign the pretrained model (the shared part)
|
|
271
|
-
self.pretrained_model = pretrained_model
|
|
272
|
-
|
|
273
|
-
def forward(self, hidden_states: Tensor):
|
|
274
|
-
"""
|
|
275
|
-
Forward pass of the SmileMoELinear module.
|
|
276
|
-
|
|
277
|
-
Args:
|
|
278
|
-
hidden_states (Tensor): The input tensor.
|
|
279
|
-
|
|
280
|
-
Returns:
|
|
281
|
-
Tensor: The output tensor.
|
|
282
|
-
"""
|
|
283
|
-
pretrained_out = self.pretrained_model(hidden_states)
|
|
284
|
-
|
|
285
|
-
input_shape = hidden_states.size()
|
|
286
|
-
hidden_states = hidden_states.view(-1, self.in_features)
|
|
287
|
-
|
|
288
|
-
router_logits = self.gate(hidden_states)
|
|
289
|
-
routing_weights = F.softmax(router_logits, dim=1)
|
|
290
|
-
# sample the expert according to the routing weights
|
|
291
|
-
routing_weights, selected_experts = torch.topk(
|
|
292
|
-
routing_weights, self.top_k, dim=-1
|
|
293
|
-
)
|
|
294
|
-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
295
|
-
|
|
296
|
-
final_hidden_states = torch.zeros(
|
|
297
|
-
(hidden_states.size(0), self.out_features),
|
|
298
|
-
dtype=hidden_states.dtype,
|
|
299
|
-
device=hidden_states.device,
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
# One hot encode the selected experts to create an expert mask
|
|
303
|
-
# this will be used to easily index which expert is going to be sollicitated
|
|
304
|
-
expert_mask = torch.nn.functional.one_hot(
|
|
305
|
-
selected_experts, num_classes=self.num_experts
|
|
306
|
-
).permute(2, 1, 0)
|
|
307
|
-
|
|
308
|
-
# Loop over all available experts in the model and perform the computation on each expert
|
|
309
|
-
for expert_idx in range(self.num_experts):
|
|
310
|
-
expert_layer = self.experts[expert_idx]
|
|
311
|
-
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
312
|
-
|
|
313
|
-
# Index the correct hidden states and compute the expert hidden state for
|
|
314
|
-
# the current expert. We need to make sure to multiply the output hidden
|
|
315
|
-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
316
|
-
current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
|
|
317
|
-
if current_state.numel() == 0:
|
|
318
|
-
continue
|
|
319
|
-
current_hidden_states = (
|
|
320
|
-
expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
324
|
-
# the `top_x` tensor here.
|
|
325
|
-
final_hidden_states.index_add_(
|
|
326
|
-
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
327
|
-
)
|
|
328
|
-
final_hidden_states = final_hidden_states.reshape(
|
|
329
|
-
*input_shape[:-1], self.out_features
|
|
330
|
-
)
|
|
331
|
-
final_hidden_states = pretrained_out + final_hidden_states
|
|
332
|
-
return final_hidden_states
|
|
333
|
-
|
|
334
|
-
@property
|
|
335
|
-
def weight(self):
|
|
336
|
-
"""
|
|
337
|
-
Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
|
|
338
|
-
"""
|
|
339
|
-
return self.pretrained_model.weight
|
|
340
|
-
|
|
341
|
-
@property
|
|
342
|
-
def bias(self):
|
|
343
|
-
return self.pretrained_model.bias
|
|
344
|
-
|
|
345
|
-
def __repr__(self):
|
|
346
|
-
return (
|
|
347
|
-
f"SingularMoELinear("
|
|
348
|
-
f"in_features={self.pretrained_model.in_features}, "
|
|
349
|
-
f"out_features={self.pretrained_model.out_features}, "
|
|
350
|
-
f"num_experts={self.num_experts}, "
|
|
351
|
-
f"top_k={self.top_k}, "
|
|
352
|
-
f"gate_k={self.gate_k}, "
|
|
353
|
-
f"k={self.k}"
|
|
354
|
-
f")"
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
|
|
358
28
|
class SmileUpscalingAlgorithm(
|
|
359
29
|
SimpleProfilerMixin,
|
|
360
30
|
BaseAlgorithm,
|
|
@@ -27,6 +27,8 @@ from transformers.models.mistral.modeling_mistral import (
|
|
|
27
27
|
MistralRotaryEmbedding,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
|
+
from fusion_bench.models.smile_moe.linear_from_hf_config import SmileLinear
|
|
31
|
+
|
|
30
32
|
from .configuration_smile_mistral import SmileMistralConfig
|
|
31
33
|
|
|
32
34
|
logger = logging.getLogger(__name__)
|
|
@@ -80,209 +82,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
80
82
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
81
83
|
|
|
82
84
|
|
|
83
|
-
class SmileGate(nn.Module):
|
|
84
|
-
__constants__ = ["in_features", "num_experts", "k"]
|
|
85
|
-
in_features: int
|
|
86
|
-
num_experts: int
|
|
87
|
-
k: int
|
|
88
|
-
weight: Tensor
|
|
89
|
-
|
|
90
|
-
def __init__(
|
|
91
|
-
self,
|
|
92
|
-
in_features: int,
|
|
93
|
-
num_experts: int,
|
|
94
|
-
k: int,
|
|
95
|
-
device=None,
|
|
96
|
-
dtype=None,
|
|
97
|
-
):
|
|
98
|
-
factory_kwargs = {"device": device, "dtype": dtype}
|
|
99
|
-
super().__init__()
|
|
100
|
-
self.input_features = in_features
|
|
101
|
-
self.num_experts = num_experts
|
|
102
|
-
self.k = k
|
|
103
|
-
|
|
104
|
-
self.weight = nn.Parameter(
|
|
105
|
-
torch.empty(num_experts * k, in_features, **factory_kwargs)
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
def forward(self, x: Tensor):
|
|
109
|
-
batch_size = x.size(0)
|
|
110
|
-
if self.num_experts == 1:
|
|
111
|
-
return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
|
|
112
|
-
|
|
113
|
-
routing_weights = F.linear(x, self.weight).view(
|
|
114
|
-
batch_size, self.num_experts, self.k
|
|
115
|
-
)
|
|
116
|
-
routing_weights = routing_weights.norm(p=2, dim=2)
|
|
117
|
-
return routing_weights
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
class SmileLinearExpert(nn.Module):
|
|
121
|
-
__constants__ = ["in_features", "out_features", "k"]
|
|
122
|
-
in_features: int
|
|
123
|
-
out_features: int
|
|
124
|
-
k: int
|
|
125
|
-
|
|
126
|
-
def __init__(
|
|
127
|
-
self,
|
|
128
|
-
in_features,
|
|
129
|
-
out_features,
|
|
130
|
-
k: int,
|
|
131
|
-
bias: bool,
|
|
132
|
-
device=None,
|
|
133
|
-
dtype=None,
|
|
134
|
-
):
|
|
135
|
-
factory_kwargs = {"device": device, "dtype": dtype}
|
|
136
|
-
super().__init__()
|
|
137
|
-
self.in_features = in_features
|
|
138
|
-
self.out_features = out_features
|
|
139
|
-
self.k = k
|
|
140
|
-
|
|
141
|
-
self.u = nn.Parameter(torch.empty(out_features, k, **factory_kwargs))
|
|
142
|
-
self.svh = nn.Parameter(torch.empty(k, in_features, **factory_kwargs))
|
|
143
|
-
|
|
144
|
-
if bias:
|
|
145
|
-
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
|
146
|
-
else:
|
|
147
|
-
self.register_parameter("bias", None)
|
|
148
|
-
|
|
149
|
-
def forward(self, x):
|
|
150
|
-
x = F.linear(x, self.svh)
|
|
151
|
-
x = F.linear(x, self.u, self.bias)
|
|
152
|
-
return x
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
class SmileLinear(nn.Module):
|
|
156
|
-
@torch.no_grad()
|
|
157
|
-
def __init__(
|
|
158
|
-
self,
|
|
159
|
-
config: SmileMistralConfig,
|
|
160
|
-
in_features,
|
|
161
|
-
out_features,
|
|
162
|
-
bias: bool,
|
|
163
|
-
device=None,
|
|
164
|
-
dtype=None,
|
|
165
|
-
):
|
|
166
|
-
factory_kwargs = {"device": device, "dtype": dtype}
|
|
167
|
-
super().__init__()
|
|
168
|
-
self.num_local_experts = config.num_local_experts
|
|
169
|
-
self.num_experts_per_tok = config.num_experts_per_tok
|
|
170
|
-
self.rank_of_expert = config.rank_of_expert
|
|
171
|
-
self.rank_of_router = config.rank_of_router
|
|
172
|
-
self.in_features = in_features
|
|
173
|
-
self.out_features = out_features
|
|
174
|
-
|
|
175
|
-
# construct the gate network
|
|
176
|
-
self.gate = SmileGate(
|
|
177
|
-
in_features=in_features,
|
|
178
|
-
num_experts=self.num_local_experts,
|
|
179
|
-
k=self.rank_of_router,
|
|
180
|
-
**factory_kwargs,
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
# the shared linear
|
|
184
|
-
self.shared_linear = nn.Linear(
|
|
185
|
-
in_features, out_features, bias=bias, **factory_kwargs
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
# construct experts
|
|
189
|
-
if self.rank_of_expert > 0:
|
|
190
|
-
self.experts = nn.ModuleList(
|
|
191
|
-
[
|
|
192
|
-
SmileLinearExpert(
|
|
193
|
-
in_features=in_features,
|
|
194
|
-
out_features=out_features,
|
|
195
|
-
bias=bias,
|
|
196
|
-
k=self.rank_of_expert,
|
|
197
|
-
**factory_kwargs,
|
|
198
|
-
)
|
|
199
|
-
for _ in range(self.num_local_experts)
|
|
200
|
-
]
|
|
201
|
-
)
|
|
202
|
-
else:
|
|
203
|
-
self.experts = nn.ModuleList(
|
|
204
|
-
[
|
|
205
|
-
nn.Linear(in_features, out_features, bias=bias, **factory_kwargs)
|
|
206
|
-
for _ in range(self.num_local_experts)
|
|
207
|
-
]
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
def forward(self, hidden_states: Tensor):
|
|
211
|
-
pretrained_out = self.shared_linear(hidden_states)
|
|
212
|
-
|
|
213
|
-
input_shape = hidden_states.size()
|
|
214
|
-
hidden_states = hidden_states.view(-1, self.in_features)
|
|
215
|
-
|
|
216
|
-
router_logits = self.gate(hidden_states)
|
|
217
|
-
routing_weights = F.softmax(router_logits, dim=1)
|
|
218
|
-
# sample the expert according to the routing weights
|
|
219
|
-
routing_weights, selected_experts = torch.topk(
|
|
220
|
-
routing_weights, self.num_experts_per_tok, dim=-1
|
|
221
|
-
)
|
|
222
|
-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
223
|
-
|
|
224
|
-
final_hidden_states = torch.zeros(
|
|
225
|
-
(hidden_states.size(0), self.out_features),
|
|
226
|
-
dtype=hidden_states.dtype,
|
|
227
|
-
device=hidden_states.device,
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
# One hot encode the selected experts to create an expert mask
|
|
231
|
-
# this will be used to easily index which expert is going to be sollicitated
|
|
232
|
-
expert_mask = torch.nn.functional.one_hot(
|
|
233
|
-
selected_experts, num_classes=self.num_local_experts
|
|
234
|
-
).permute(2, 1, 0)
|
|
235
|
-
|
|
236
|
-
# Loop over all available experts in the model and perform the computation on each expert
|
|
237
|
-
for expert_idx in range(self.num_local_experts):
|
|
238
|
-
expert_layer = self.experts[expert_idx]
|
|
239
|
-
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
240
|
-
|
|
241
|
-
# Index the correct hidden states and compute the expert hidden state for
|
|
242
|
-
# the current expert. We need to make sure to multiply the output hidden
|
|
243
|
-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
244
|
-
current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
|
|
245
|
-
if current_state.numel() == 0:
|
|
246
|
-
continue
|
|
247
|
-
current_hidden_states = (
|
|
248
|
-
expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
252
|
-
# the `top_x` tensor here.
|
|
253
|
-
final_hidden_states.index_add_(
|
|
254
|
-
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
255
|
-
)
|
|
256
|
-
final_hidden_states = final_hidden_states.reshape(
|
|
257
|
-
*input_shape[:-1], self.out_features
|
|
258
|
-
)
|
|
259
|
-
final_hidden_states = pretrained_out + final_hidden_states
|
|
260
|
-
return final_hidden_states
|
|
261
|
-
|
|
262
|
-
@property
|
|
263
|
-
def weight(self):
|
|
264
|
-
"""
|
|
265
|
-
Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
|
|
266
|
-
"""
|
|
267
|
-
return self.shared_linear.weight
|
|
268
|
-
|
|
269
|
-
@property
|
|
270
|
-
def bias(self):
|
|
271
|
-
return self.shared_linear.bias
|
|
272
|
-
|
|
273
|
-
def __repr__(self):
|
|
274
|
-
return (
|
|
275
|
-
f"SingularMoELinear("
|
|
276
|
-
f"in_features={self.shared_linear.in_features}, "
|
|
277
|
-
f"out_features={self.shared_linear.out_features}, "
|
|
278
|
-
f"num_local_experts={self.num_local_experts}, "
|
|
279
|
-
f"num_experts_per_tok={self.num_experts_per_tok}, "
|
|
280
|
-
f"rank_of_router={self.rank_of_router}, "
|
|
281
|
-
f"rank_of_expert={self.rank_of_expert}"
|
|
282
|
-
f")"
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
|
|
286
85
|
class SmileMistralAttention(nn.Module):
|
|
287
86
|
"""
|
|
288
87
|
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from transformers import PretrainedConfig
|
|
2
|
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class SmileQwen2Config(Qwen2Config):
|
|
6
|
+
model_type = "smile_qwen2"
|
|
7
|
+
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
num_experts_per_tok: int = 1,
|
|
11
|
+
rank_of_router: int = None,
|
|
12
|
+
rank_of_expert: int = None,
|
|
13
|
+
num_local_experts: int = None,
|
|
14
|
+
**kwargs,
|
|
15
|
+
):
|
|
16
|
+
self.num_experts_per_tok = num_experts_per_tok
|
|
17
|
+
self.rank_of_router = rank_of_router
|
|
18
|
+
self.rank_of_expert = rank_of_expert
|
|
19
|
+
self.num_local_experts = num_local_experts
|
|
20
|
+
|
|
21
|
+
super().__init__(**kwargs)
|