sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
177
177
  )
178
178
 
179
179
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
180
- if hasattr(self.config, "num_nextn_predict_layers"):
181
- num_nextn_layers = self.config.num_nextn_predict_layers
182
- assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
183
- assert num_nextn_layers == self.config.num_hidden_layers
184
- else:
185
- raise ValueError("num_nextn_predict_layers is not in the config")
186
-
187
- stacked_params_mapping = [
188
- # (param_name, shard_name, shard_id)
189
- ("gate_up_proj", "gate_proj", 0),
190
- ("gate_up_proj", "up_proj", 1),
191
- ]
192
- if self.n_share_experts_fusion > 0:
193
- logger.info(
194
- f"Cloning {self.n_share_experts_fusion} "
195
- "replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
196
- )
197
- weights_list = list(weights)
198
- weights_dict = dict(weights_list)
199
- if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
200
- suffix_list = [
201
- "down_proj.weight",
202
- "down_proj.weight_scale",
203
- "gate_proj.weight",
204
- "gate_proj.weight_scale",
205
- "up_proj.weight",
206
- "up_proj.weight_scale",
207
- ]
208
- else:
209
- suffix_list = [
210
- "down_proj.weight",
211
- "down_proj.weight_scale_inv",
212
- "gate_proj.weight",
213
- "gate_proj.weight_scale_inv",
214
- "up_proj.weight",
215
- "up_proj.weight_scale_inv",
216
- ]
217
- names_to_remove = []
218
- for suffix in suffix_list:
219
- shared_expert_weight_name = (
220
- f"model.layers.0.mlp.shared_experts.{suffix}"
221
- )
222
- for num_repeat in range(self.n_share_experts_fusion):
223
- weights_list.append(
224
- (
225
- f"model.layers.0."
226
- f"mlp.experts."
227
- f"{self.config.n_routed_experts + num_repeat}"
228
- f".{suffix}",
229
- weights_dict[shared_expert_weight_name],
230
- )
231
- )
232
- names_to_remove += [shared_expert_weight_name]
233
- weights = [w for w in weights_list if w[0] not in names_to_remove]
234
-
235
- # Params for weights, fp8 weight scales, fp8 activation scales
236
- # (param_name, weight_name, expert_id, shard_id)
237
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
238
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
239
- ckpt_gate_proj_name="gate_proj",
240
- ckpt_down_proj_name="down_proj",
241
- ckpt_up_proj_name="up_proj",
242
- num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
243
- )
244
-
245
- # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
246
- fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
247
- self.config.q_lora_rank is not None
248
- )
249
- cached_a_proj = {} if fuse_qkv_a_proj else None
250
-
251
- nextn_layer_prefix = "model.layers.0"
252
- nextn_spec_weight_names = [
253
- "shared_head.norm",
254
- "eh_proj",
255
- "enorm",
256
- "hnorm",
257
- ]
258
-
259
- params_dict = dict(self.named_parameters())
260
- for name, loaded_weight in weights:
261
- if not name.startswith(nextn_layer_prefix):
262
- continue
263
-
264
- # Use shared head and embed weights from target model
265
- if "shared_head.head" in name or "embed_tokens" in name:
266
- continue
267
-
268
- is_decoder = True
269
- # For nextn specific weights
270
- for weight_name in nextn_spec_weight_names:
271
- if weight_name in name:
272
- name = name.replace(nextn_layer_prefix, "model")
273
- is_decoder = False
274
- break
275
- # For decoder layer weights
276
- if is_decoder:
277
- name = name.replace(nextn_layer_prefix, "model.decoder")
278
-
279
- if "rotary_emb.inv_freq" in name:
280
- continue
281
- for param_name, weight_name, shard_id in stacked_params_mapping:
282
- # Skip non-stacked layers and experts (experts handled below).
283
- if weight_name not in name:
284
- continue
285
- # We have mlp.experts[0].gate_proj in the checkpoint.
286
- # Since we handle the experts below in expert_params_mapping,
287
- # we need to skip here BEFORE we update the name, otherwise
288
- # name will be updated to mlp.experts[0].gate_up_proj, which
289
- # will then be updated below in expert_params_mapping
290
- # for mlp.experts[0].gate_gate_up_proj, which breaks load.
291
- if ("mlp.experts." in name) and name not in params_dict:
292
- continue
293
- name = name.replace(weight_name, param_name)
294
- # Skip loading extra bias for GPTQ models.
295
- if name.endswith(".bias") and name not in params_dict:
296
- continue
297
- param = params_dict[name]
298
- weight_loader = param.weight_loader
299
- weight_loader(param, loaded_weight, shard_id)
300
- break
301
- else:
302
- for mapping in expert_params_mapping:
303
- param_name, weight_name, expert_id, shard_id = mapping
304
- if weight_name not in name:
305
- continue
306
- name = name.replace(weight_name, param_name)
307
- param = params_dict[name]
308
- weight_loader = param.weight_loader
309
- weight_loader(
310
- param,
311
- loaded_weight,
312
- name,
313
- shard_id=shard_id,
314
- expert_id=expert_id,
315
- )
316
- break
317
- else:
318
- # Skip loading extra bias for GPTQ models.
319
- if name.endswith(".bias") and name not in params_dict:
320
- continue
321
-
322
- # Handle fused_qkv_a_proj
323
- if fuse_qkv_a_proj and (
324
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
325
- ):
326
- cached_a_proj[name] = loaded_weight
327
- q_a_proj_name = (
328
- name
329
- if "q_a_proj" in name
330
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
331
- )
332
- kv_a_proj_name = (
333
- name
334
- if "kv_a_proj_with_mqa" in name
335
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
336
- )
337
-
338
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
339
- if (
340
- q_a_proj_name in cached_a_proj
341
- and kv_a_proj_name in cached_a_proj
342
- ):
343
-
344
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
345
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
346
- fused_weight = torch.cat(
347
- [q_a_proj_weight, kv_a_proj_weight], dim=0
348
- )
349
-
350
- param_name = name.replace(
351
- "q_a_proj", "fused_qkv_a_proj_with_mqa"
352
- )
353
- param = params_dict[param_name]
354
-
355
- weight_loader = getattr(
356
- param, "weight_loader", default_weight_loader
357
- )
358
- weight_loader(param, fused_weight)
359
- cached_a_proj.pop(q_a_proj_name)
360
- cached_a_proj.pop(kv_a_proj_name)
361
- else:
362
- param = params_dict[name]
363
- weight_loader = getattr(
364
- param, "weight_loader", default_weight_loader
365
- )
366
- weight_loader(param, loaded_weight)
367
-
368
- self_attn = self.model.decoder.self_attn
369
- if hasattr(self_attn.kv_b_proj, "qweight"):
370
- # AWQ compatible
371
- if _is_cuda:
372
- w = awq_dequantize(
373
- self_attn.kv_b_proj.qweight,
374
- self_attn.kv_b_proj.scales,
375
- self_attn.kv_b_proj.qzeros,
376
- ).T
377
- else:
378
- w = awq_dequantize(
379
- self_attn.kv_b_proj.qweight,
380
- self_attn.kv_b_proj.scales,
381
- self_attn.kv_b_proj.qzeros,
382
- 0,
383
- 0,
384
- 0,
385
- ).T
386
- else:
387
- w = self_attn.kv_b_proj.weight
388
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
389
- # This may affect the accuracy of fp8 model.
390
- if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
391
- torch.float8_e4m3fn,
392
- torch.float8_e4m3fnuz,
393
- ):
394
- weight_block_size = self.quant_config.weight_block_size
395
- if weight_block_size is not None:
396
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
397
- if _is_hip:
398
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
399
- weight=w,
400
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
401
- input_scale=None,
402
- )
403
- else:
404
- weight = w
405
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
406
-
407
- w, scale = block_quant_to_tensor_quant(
408
- weight, weight_scale, weight_block_size
409
- )
410
- self_attn.w_scale = scale
411
- if w.dtype == torch.int8:
412
- if hasattr(self.quant_config, "weight_block_size"):
413
- # block-wise int8 need it
414
- weight_block_size = self.quant_config.weight_block_size
415
- if weight_block_size is not None:
416
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
417
- weight = w
418
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
419
- w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
420
- torch.bfloat16
421
- )
422
- else:
423
- # channel-wise int8 need it
424
- assert hasattr(self_attn.kv_b_proj, "weight_scale")
425
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
426
- torch.bfloat16
427
- )
428
- w_kc, w_vc = w.unflatten(
429
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
430
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
431
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
432
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
433
- if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
434
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
435
- if _is_hip:
436
- self_attn.w_scale *= 2.0
180
+ super().load_weights(weights, is_nextn=True)
437
181
 
438
182
 
439
183
  EntryClass = [DeepseekV3ForCausalLMNextN]
@@ -752,7 +752,7 @@ class DeepseekV2AttentionMLA(nn.Module):
752
752
  q_nope_out = q_nope_out.transpose(0, 1)
753
753
 
754
754
  k_nope = latent_cache[..., : self.kv_lora_rank]
755
- k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
755
+ k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
756
756
  k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
757
757
 
758
758
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
@@ -1391,6 +1391,9 @@ class DeepseekV2Model(nn.Module):
1391
1391
 
1392
1392
  self.dp_size = get_attention_dp_size()
1393
1393
 
1394
+ def get_input_embeddings(self) -> torch.Tensor:
1395
+ return self.embed_tokens
1396
+
1394
1397
  def forward(
1395
1398
  self,
1396
1399
  input_ids: torch.Tensor,
@@ -1502,11 +1505,20 @@ class DeepseekV2ForCausalLM(nn.Module):
1502
1505
  input_ids, hidden_states, self.lm_head, forward_batch
1503
1506
  )
1504
1507
 
1505
- def post_load_weights(self):
1508
+ def post_load_weights(self, is_nextn=False):
1506
1509
 
1507
1510
  # Perform post-processing after loading weights
1508
- for layer_id in range(self.config.num_hidden_layers):
1509
- self_attn = self.model.layers[layer_id].self_attn
1511
+ layer_ids = (
1512
+ range(self.config.num_hidden_layers)
1513
+ if not is_nextn
1514
+ else [self.config.num_hidden_layers]
1515
+ )
1516
+ for layer_id in layer_ids:
1517
+ self_attn = (
1518
+ self.model.layers[layer_id].self_attn
1519
+ if not is_nextn
1520
+ else self.model.decoder.self_attn
1521
+ )
1510
1522
  if hasattr(self_attn.kv_b_proj, "qweight"):
1511
1523
  # AWQ compatible
1512
1524
  if _is_cuda:
@@ -1612,7 +1624,20 @@ class DeepseekV2ForCausalLM(nn.Module):
1612
1624
  self_attn.w_vc = w_vc.contiguous()
1613
1625
  self_attn.use_deep_gemm_bmm = True
1614
1626
 
1615
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1627
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1628
+ if is_nextn:
1629
+ if hasattr(self.config, "num_nextn_predict_layers"):
1630
+ num_nextn_layers = self.config.num_nextn_predict_layers
1631
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
1632
+ # compatible with old design
1633
+ nextn_layer_id = (
1634
+ 0
1635
+ if self.config.num_hidden_layers == 1
1636
+ else self.config.num_hidden_layers
1637
+ )
1638
+ else:
1639
+ raise ValueError("num_nextn_predict_layers is not in the config")
1640
+
1616
1641
  stacked_params_mapping = [
1617
1642
  # (param_name, shard_name, shard_id)
1618
1643
  ("gate_up_proj", "gate_proj", 0),
@@ -1640,12 +1665,19 @@ class DeepseekV2ForCausalLM(nn.Module):
1640
1665
  "up_proj.weight_scale_inv",
1641
1666
  ]
1642
1667
  names_to_remove = []
1643
- for moe_layer in tqdm(
1668
+
1669
+ moe_layers = (
1644
1670
  range(
1645
1671
  self.config.first_k_dense_replace,
1646
1672
  self.config.num_hidden_layers,
1647
1673
  self.config.moe_layer_freq,
1648
- ),
1674
+ )
1675
+ if not is_nextn
1676
+ else [nextn_layer_id]
1677
+ )
1678
+
1679
+ for moe_layer in tqdm(
1680
+ moe_layers,
1649
1681
  desc=f"Cloning {self.n_share_experts_fusion} "
1650
1682
  "replicas of the shared expert into MoE",
1651
1683
  ):
@@ -1686,18 +1718,46 @@ class DeepseekV2ForCausalLM(nn.Module):
1686
1718
  )
1687
1719
  cached_a_proj = {} if fuse_qkv_a_proj else None
1688
1720
 
1721
+ if is_nextn:
1722
+ nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
1723
+ nextn_spec_weight_names = [
1724
+ "shared_head.norm",
1725
+ "eh_proj",
1726
+ "enorm",
1727
+ "hnorm",
1728
+ ]
1729
+
1689
1730
  params_dict = dict(self.named_parameters())
1690
1731
  for name, loaded_weight in weights:
1691
- # TODO(HandH1998): Modify it when nextn is supported.
1692
- if hasattr(self.config, "num_nextn_predict_layers"):
1693
- num_nextn_layers = self.config.num_nextn_predict_layers
1694
- if num_nextn_layers > 0 and name.startswith("model.layers"):
1695
- name_list = name.split(".")
1696
- if (
1697
- len(name_list) >= 3
1698
- and int(name_list[2]) >= self.config.num_hidden_layers
1699
- ):
1700
- continue
1732
+ if not is_nextn:
1733
+ if hasattr(self.config, "num_nextn_predict_layers"):
1734
+ num_nextn_layers = self.config.num_nextn_predict_layers
1735
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
1736
+ name_list = name.split(".")
1737
+ if (
1738
+ len(name_list) >= 3
1739
+ and int(name_list[2]) >= self.config.num_hidden_layers
1740
+ ):
1741
+ continue
1742
+ else:
1743
+ if not name.startswith(nextn_layer_prefix):
1744
+ continue
1745
+
1746
+ # Use shared head and embed weights from target model
1747
+ if "shared_head.head" in name or "embed_tokens" in name:
1748
+ continue
1749
+
1750
+ is_decoder = True
1751
+ # For nextn specific weights
1752
+ for weight_name in nextn_spec_weight_names:
1753
+ if weight_name in name:
1754
+ name = name.replace(nextn_layer_prefix, "model")
1755
+ is_decoder = False
1756
+ break
1757
+ # For decoder layer weights
1758
+ if is_decoder:
1759
+ name = name.replace(nextn_layer_prefix, "model.decoder")
1760
+
1701
1761
  if "rotary_emb.inv_freq" in name:
1702
1762
  continue
1703
1763
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -1786,7 +1846,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1786
1846
  )
1787
1847
  weight_loader(param, loaded_weight)
1788
1848
 
1789
- self.post_load_weights()
1849
+ self.post_load_weights(is_nextn=is_nextn)
1790
1850
 
1791
1851
  def get_embed_and_head(self):
1792
1852
  return self.model.embed_tokens.weight, self.lm_head.weight