sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +52 -21
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +34 -15
- sglang/srt/managers/scheduler.py +273 -67
- sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
- sglang/srt/managers/tp_worker.py +52 -17
- sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +123 -58
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +49 -8
- sglang/srt/openai_api/protocol.py +13 -1
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +83 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +91 -9
- sglang/test/runners.py +4 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +67 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
177
177
|
)
|
178
178
|
|
179
179
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
180
|
-
|
181
|
-
num_nextn_layers = self.config.num_nextn_predict_layers
|
182
|
-
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
183
|
-
assert num_nextn_layers == self.config.num_hidden_layers
|
184
|
-
else:
|
185
|
-
raise ValueError("num_nextn_predict_layers is not in the config")
|
186
|
-
|
187
|
-
stacked_params_mapping = [
|
188
|
-
# (param_name, shard_name, shard_id)
|
189
|
-
("gate_up_proj", "gate_proj", 0),
|
190
|
-
("gate_up_proj", "up_proj", 1),
|
191
|
-
]
|
192
|
-
if self.n_share_experts_fusion > 0:
|
193
|
-
logger.info(
|
194
|
-
f"Cloning {self.n_share_experts_fusion} "
|
195
|
-
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
|
196
|
-
)
|
197
|
-
weights_list = list(weights)
|
198
|
-
weights_dict = dict(weights_list)
|
199
|
-
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
200
|
-
suffix_list = [
|
201
|
-
"down_proj.weight",
|
202
|
-
"down_proj.weight_scale",
|
203
|
-
"gate_proj.weight",
|
204
|
-
"gate_proj.weight_scale",
|
205
|
-
"up_proj.weight",
|
206
|
-
"up_proj.weight_scale",
|
207
|
-
]
|
208
|
-
else:
|
209
|
-
suffix_list = [
|
210
|
-
"down_proj.weight",
|
211
|
-
"down_proj.weight_scale_inv",
|
212
|
-
"gate_proj.weight",
|
213
|
-
"gate_proj.weight_scale_inv",
|
214
|
-
"up_proj.weight",
|
215
|
-
"up_proj.weight_scale_inv",
|
216
|
-
]
|
217
|
-
names_to_remove = []
|
218
|
-
for suffix in suffix_list:
|
219
|
-
shared_expert_weight_name = (
|
220
|
-
f"model.layers.0.mlp.shared_experts.{suffix}"
|
221
|
-
)
|
222
|
-
for num_repeat in range(self.n_share_experts_fusion):
|
223
|
-
weights_list.append(
|
224
|
-
(
|
225
|
-
f"model.layers.0."
|
226
|
-
f"mlp.experts."
|
227
|
-
f"{self.config.n_routed_experts + num_repeat}"
|
228
|
-
f".{suffix}",
|
229
|
-
weights_dict[shared_expert_weight_name],
|
230
|
-
)
|
231
|
-
)
|
232
|
-
names_to_remove += [shared_expert_weight_name]
|
233
|
-
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
234
|
-
|
235
|
-
# Params for weights, fp8 weight scales, fp8 activation scales
|
236
|
-
# (param_name, weight_name, expert_id, shard_id)
|
237
|
-
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
238
|
-
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
239
|
-
ckpt_gate_proj_name="gate_proj",
|
240
|
-
ckpt_down_proj_name="down_proj",
|
241
|
-
ckpt_up_proj_name="up_proj",
|
242
|
-
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
243
|
-
)
|
244
|
-
|
245
|
-
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
246
|
-
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
247
|
-
self.config.q_lora_rank is not None
|
248
|
-
)
|
249
|
-
cached_a_proj = {} if fuse_qkv_a_proj else None
|
250
|
-
|
251
|
-
nextn_layer_prefix = "model.layers.0"
|
252
|
-
nextn_spec_weight_names = [
|
253
|
-
"shared_head.norm",
|
254
|
-
"eh_proj",
|
255
|
-
"enorm",
|
256
|
-
"hnorm",
|
257
|
-
]
|
258
|
-
|
259
|
-
params_dict = dict(self.named_parameters())
|
260
|
-
for name, loaded_weight in weights:
|
261
|
-
if not name.startswith(nextn_layer_prefix):
|
262
|
-
continue
|
263
|
-
|
264
|
-
# Use shared head and embed weights from target model
|
265
|
-
if "shared_head.head" in name or "embed_tokens" in name:
|
266
|
-
continue
|
267
|
-
|
268
|
-
is_decoder = True
|
269
|
-
# For nextn specific weights
|
270
|
-
for weight_name in nextn_spec_weight_names:
|
271
|
-
if weight_name in name:
|
272
|
-
name = name.replace(nextn_layer_prefix, "model")
|
273
|
-
is_decoder = False
|
274
|
-
break
|
275
|
-
# For decoder layer weights
|
276
|
-
if is_decoder:
|
277
|
-
name = name.replace(nextn_layer_prefix, "model.decoder")
|
278
|
-
|
279
|
-
if "rotary_emb.inv_freq" in name:
|
280
|
-
continue
|
281
|
-
for param_name, weight_name, shard_id in stacked_params_mapping:
|
282
|
-
# Skip non-stacked layers and experts (experts handled below).
|
283
|
-
if weight_name not in name:
|
284
|
-
continue
|
285
|
-
# We have mlp.experts[0].gate_proj in the checkpoint.
|
286
|
-
# Since we handle the experts below in expert_params_mapping,
|
287
|
-
# we need to skip here BEFORE we update the name, otherwise
|
288
|
-
# name will be updated to mlp.experts[0].gate_up_proj, which
|
289
|
-
# will then be updated below in expert_params_mapping
|
290
|
-
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
291
|
-
if ("mlp.experts." in name) and name not in params_dict:
|
292
|
-
continue
|
293
|
-
name = name.replace(weight_name, param_name)
|
294
|
-
# Skip loading extra bias for GPTQ models.
|
295
|
-
if name.endswith(".bias") and name not in params_dict:
|
296
|
-
continue
|
297
|
-
param = params_dict[name]
|
298
|
-
weight_loader = param.weight_loader
|
299
|
-
weight_loader(param, loaded_weight, shard_id)
|
300
|
-
break
|
301
|
-
else:
|
302
|
-
for mapping in expert_params_mapping:
|
303
|
-
param_name, weight_name, expert_id, shard_id = mapping
|
304
|
-
if weight_name not in name:
|
305
|
-
continue
|
306
|
-
name = name.replace(weight_name, param_name)
|
307
|
-
param = params_dict[name]
|
308
|
-
weight_loader = param.weight_loader
|
309
|
-
weight_loader(
|
310
|
-
param,
|
311
|
-
loaded_weight,
|
312
|
-
name,
|
313
|
-
shard_id=shard_id,
|
314
|
-
expert_id=expert_id,
|
315
|
-
)
|
316
|
-
break
|
317
|
-
else:
|
318
|
-
# Skip loading extra bias for GPTQ models.
|
319
|
-
if name.endswith(".bias") and name not in params_dict:
|
320
|
-
continue
|
321
|
-
|
322
|
-
# Handle fused_qkv_a_proj
|
323
|
-
if fuse_qkv_a_proj and (
|
324
|
-
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
325
|
-
):
|
326
|
-
cached_a_proj[name] = loaded_weight
|
327
|
-
q_a_proj_name = (
|
328
|
-
name
|
329
|
-
if "q_a_proj" in name
|
330
|
-
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
331
|
-
)
|
332
|
-
kv_a_proj_name = (
|
333
|
-
name
|
334
|
-
if "kv_a_proj_with_mqa" in name
|
335
|
-
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
336
|
-
)
|
337
|
-
|
338
|
-
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
339
|
-
if (
|
340
|
-
q_a_proj_name in cached_a_proj
|
341
|
-
and kv_a_proj_name in cached_a_proj
|
342
|
-
):
|
343
|
-
|
344
|
-
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
345
|
-
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
346
|
-
fused_weight = torch.cat(
|
347
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
348
|
-
)
|
349
|
-
|
350
|
-
param_name = name.replace(
|
351
|
-
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
352
|
-
)
|
353
|
-
param = params_dict[param_name]
|
354
|
-
|
355
|
-
weight_loader = getattr(
|
356
|
-
param, "weight_loader", default_weight_loader
|
357
|
-
)
|
358
|
-
weight_loader(param, fused_weight)
|
359
|
-
cached_a_proj.pop(q_a_proj_name)
|
360
|
-
cached_a_proj.pop(kv_a_proj_name)
|
361
|
-
else:
|
362
|
-
param = params_dict[name]
|
363
|
-
weight_loader = getattr(
|
364
|
-
param, "weight_loader", default_weight_loader
|
365
|
-
)
|
366
|
-
weight_loader(param, loaded_weight)
|
367
|
-
|
368
|
-
self_attn = self.model.decoder.self_attn
|
369
|
-
if hasattr(self_attn.kv_b_proj, "qweight"):
|
370
|
-
# AWQ compatible
|
371
|
-
if _is_cuda:
|
372
|
-
w = awq_dequantize(
|
373
|
-
self_attn.kv_b_proj.qweight,
|
374
|
-
self_attn.kv_b_proj.scales,
|
375
|
-
self_attn.kv_b_proj.qzeros,
|
376
|
-
).T
|
377
|
-
else:
|
378
|
-
w = awq_dequantize(
|
379
|
-
self_attn.kv_b_proj.qweight,
|
380
|
-
self_attn.kv_b_proj.scales,
|
381
|
-
self_attn.kv_b_proj.qzeros,
|
382
|
-
0,
|
383
|
-
0,
|
384
|
-
0,
|
385
|
-
).T
|
386
|
-
else:
|
387
|
-
w = self_attn.kv_b_proj.weight
|
388
|
-
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
389
|
-
# This may affect the accuracy of fp8 model.
|
390
|
-
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
391
|
-
torch.float8_e4m3fn,
|
392
|
-
torch.float8_e4m3fnuz,
|
393
|
-
):
|
394
|
-
weight_block_size = self.quant_config.weight_block_size
|
395
|
-
if weight_block_size is not None:
|
396
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
397
|
-
if _is_hip:
|
398
|
-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
399
|
-
weight=w,
|
400
|
-
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
401
|
-
input_scale=None,
|
402
|
-
)
|
403
|
-
else:
|
404
|
-
weight = w
|
405
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
406
|
-
|
407
|
-
w, scale = block_quant_to_tensor_quant(
|
408
|
-
weight, weight_scale, weight_block_size
|
409
|
-
)
|
410
|
-
self_attn.w_scale = scale
|
411
|
-
if w.dtype == torch.int8:
|
412
|
-
if hasattr(self.quant_config, "weight_block_size"):
|
413
|
-
# block-wise int8 need it
|
414
|
-
weight_block_size = self.quant_config.weight_block_size
|
415
|
-
if weight_block_size is not None:
|
416
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
417
|
-
weight = w
|
418
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
419
|
-
w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
|
420
|
-
torch.bfloat16
|
421
|
-
)
|
422
|
-
else:
|
423
|
-
# channel-wise int8 need it
|
424
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
425
|
-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
426
|
-
torch.bfloat16
|
427
|
-
)
|
428
|
-
w_kc, w_vc = w.unflatten(
|
429
|
-
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
430
|
-
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
431
|
-
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
432
|
-
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
433
|
-
if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
|
434
|
-
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
435
|
-
if _is_hip:
|
436
|
-
self_attn.w_scale *= 2.0
|
180
|
+
super().load_weights(weights, is_nextn=True)
|
437
181
|
|
438
182
|
|
439
183
|
EntryClass = [DeepseekV3ForCausalLMNextN]
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -752,7 +752,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
752
752
|
q_nope_out = q_nope_out.transpose(0, 1)
|
753
753
|
|
754
754
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
755
|
-
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
755
|
+
k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
|
756
756
|
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
757
757
|
|
758
758
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
@@ -1391,6 +1391,9 @@ class DeepseekV2Model(nn.Module):
|
|
1391
1391
|
|
1392
1392
|
self.dp_size = get_attention_dp_size()
|
1393
1393
|
|
1394
|
+
def get_input_embeddings(self) -> torch.Tensor:
|
1395
|
+
return self.embed_tokens
|
1396
|
+
|
1394
1397
|
def forward(
|
1395
1398
|
self,
|
1396
1399
|
input_ids: torch.Tensor,
|
@@ -1502,11 +1505,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1502
1505
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1503
1506
|
)
|
1504
1507
|
|
1505
|
-
def post_load_weights(self):
|
1508
|
+
def post_load_weights(self, is_nextn=False):
|
1506
1509
|
|
1507
1510
|
# Perform post-processing after loading weights
|
1508
|
-
|
1509
|
-
|
1511
|
+
layer_ids = (
|
1512
|
+
range(self.config.num_hidden_layers)
|
1513
|
+
if not is_nextn
|
1514
|
+
else [self.config.num_hidden_layers]
|
1515
|
+
)
|
1516
|
+
for layer_id in layer_ids:
|
1517
|
+
self_attn = (
|
1518
|
+
self.model.layers[layer_id].self_attn
|
1519
|
+
if not is_nextn
|
1520
|
+
else self.model.decoder.self_attn
|
1521
|
+
)
|
1510
1522
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1511
1523
|
# AWQ compatible
|
1512
1524
|
if _is_cuda:
|
@@ -1612,7 +1624,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1612
1624
|
self_attn.w_vc = w_vc.contiguous()
|
1613
1625
|
self_attn.use_deep_gemm_bmm = True
|
1614
1626
|
|
1615
|
-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1627
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
1628
|
+
if is_nextn:
|
1629
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
1630
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
1631
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
1632
|
+
# compatible with old design
|
1633
|
+
nextn_layer_id = (
|
1634
|
+
0
|
1635
|
+
if self.config.num_hidden_layers == 1
|
1636
|
+
else self.config.num_hidden_layers
|
1637
|
+
)
|
1638
|
+
else:
|
1639
|
+
raise ValueError("num_nextn_predict_layers is not in the config")
|
1640
|
+
|
1616
1641
|
stacked_params_mapping = [
|
1617
1642
|
# (param_name, shard_name, shard_id)
|
1618
1643
|
("gate_up_proj", "gate_proj", 0),
|
@@ -1640,12 +1665,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1640
1665
|
"up_proj.weight_scale_inv",
|
1641
1666
|
]
|
1642
1667
|
names_to_remove = []
|
1643
|
-
|
1668
|
+
|
1669
|
+
moe_layers = (
|
1644
1670
|
range(
|
1645
1671
|
self.config.first_k_dense_replace,
|
1646
1672
|
self.config.num_hidden_layers,
|
1647
1673
|
self.config.moe_layer_freq,
|
1648
|
-
)
|
1674
|
+
)
|
1675
|
+
if not is_nextn
|
1676
|
+
else [nextn_layer_id]
|
1677
|
+
)
|
1678
|
+
|
1679
|
+
for moe_layer in tqdm(
|
1680
|
+
moe_layers,
|
1649
1681
|
desc=f"Cloning {self.n_share_experts_fusion} "
|
1650
1682
|
"replicas of the shared expert into MoE",
|
1651
1683
|
):
|
@@ -1686,18 +1718,46 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1686
1718
|
)
|
1687
1719
|
cached_a_proj = {} if fuse_qkv_a_proj else None
|
1688
1720
|
|
1721
|
+
if is_nextn:
|
1722
|
+
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
1723
|
+
nextn_spec_weight_names = [
|
1724
|
+
"shared_head.norm",
|
1725
|
+
"eh_proj",
|
1726
|
+
"enorm",
|
1727
|
+
"hnorm",
|
1728
|
+
]
|
1729
|
+
|
1689
1730
|
params_dict = dict(self.named_parameters())
|
1690
1731
|
for name, loaded_weight in weights:
|
1691
|
-
|
1692
|
-
|
1693
|
-
|
1694
|
-
|
1695
|
-
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1732
|
+
if not is_nextn:
|
1733
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
1734
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
1735
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
1736
|
+
name_list = name.split(".")
|
1737
|
+
if (
|
1738
|
+
len(name_list) >= 3
|
1739
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
1740
|
+
):
|
1741
|
+
continue
|
1742
|
+
else:
|
1743
|
+
if not name.startswith(nextn_layer_prefix):
|
1744
|
+
continue
|
1745
|
+
|
1746
|
+
# Use shared head and embed weights from target model
|
1747
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
1748
|
+
continue
|
1749
|
+
|
1750
|
+
is_decoder = True
|
1751
|
+
# For nextn specific weights
|
1752
|
+
for weight_name in nextn_spec_weight_names:
|
1753
|
+
if weight_name in name:
|
1754
|
+
name = name.replace(nextn_layer_prefix, "model")
|
1755
|
+
is_decoder = False
|
1756
|
+
break
|
1757
|
+
# For decoder layer weights
|
1758
|
+
if is_decoder:
|
1759
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
1760
|
+
|
1701
1761
|
if "rotary_emb.inv_freq" in name:
|
1702
1762
|
continue
|
1703
1763
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
@@ -1786,7 +1846,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1786
1846
|
)
|
1787
1847
|
weight_loader(param, loaded_weight)
|
1788
1848
|
|
1789
|
-
self.post_load_weights()
|
1849
|
+
self.post_load_weights(is_nextn=is_nextn)
|
1790
1850
|
|
1791
1851
|
def get_embed_and_head(self):
|
1792
1852
|
return self.model.embed_tokens.weight, self.lm_head.weight
|