sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
46
  )
47
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
49
+ from sglang.srt.utils import add_prefix
49
50
 
50
51
 
51
52
  class DeepseekMLP(nn.Module):
@@ -57,10 +58,15 @@ class DeepseekMLP(nn.Module):
57
58
  hidden_act: str,
58
59
  quant_config: Optional[QuantizationConfig] = None,
59
60
  reduce_results: bool = True,
61
+ prefix: str = "",
60
62
  ) -> None:
61
63
  super().__init__()
62
64
  self.gate_up_proj = MergedColumnParallelLinear(
63
- hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
65
+ hidden_size,
66
+ [intermediate_size] * 2,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ prefix=add_prefix("gate_up_proj", prefix),
64
70
  )
65
71
  self.down_proj = RowParallelLinear(
66
72
  intermediate_size,
@@ -68,6 +74,7 @@ class DeepseekMLP(nn.Module):
68
74
  bias=False,
69
75
  quant_config=quant_config,
70
76
  reduce_results=reduce_results,
77
+ prefix=add_prefix("down_proj", prefix),
71
78
  )
72
79
  if hidden_act != "silu":
73
80
  raise ValueError(
@@ -89,6 +96,7 @@ class DeepseekMoE(nn.Module):
89
96
  self,
90
97
  config: PretrainedConfig,
91
98
  quant_config: Optional[QuantizationConfig] = None,
99
+ prefix: str = "",
92
100
  ):
93
101
  super().__init__()
94
102
  self.config = config
@@ -110,6 +118,7 @@ class DeepseekMoE(nn.Module):
110
118
  hidden_act=config.hidden_act,
111
119
  quant_config=quant_config,
112
120
  reduce_results=False,
121
+ prefix=add_prefix(f"{idx}.experts", prefix),
113
122
  )
114
123
  for idx in range(self.n_routed_experts)
115
124
  ]
@@ -117,7 +126,11 @@ class DeepseekMoE(nn.Module):
117
126
  self.pack_params()
118
127
 
119
128
  self.gate = ReplicatedLinear(
120
- config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
129
+ config.hidden_size,
130
+ self.n_routed_experts,
131
+ bias=False,
132
+ quant_config=None,
133
+ prefix=add_prefix("gate", prefix),
121
134
  )
122
135
 
123
136
  if config.n_shared_experts is not None:
@@ -128,6 +141,7 @@ class DeepseekMoE(nn.Module):
128
141
  hidden_act=config.hidden_act,
129
142
  quant_config=quant_config,
130
143
  reduce_results=False,
144
+ prefix=add_prefix("shared_experts", prefix),
131
145
  )
132
146
 
133
147
  def pack_params(self):
@@ -185,6 +199,7 @@ class DeepseekAttention(nn.Module):
185
199
  rope_scaling: Optional[Dict[str, Any]] = None,
186
200
  max_position_embeddings: int = 8192,
187
201
  quant_config: Optional[QuantizationConfig] = None,
202
+ prefix: str = "",
188
203
  ) -> None:
189
204
  super().__init__()
190
205
  self.hidden_size = hidden_size
@@ -216,6 +231,7 @@ class DeepseekAttention(nn.Module):
216
231
  self.total_num_kv_heads,
217
232
  bias=False,
218
233
  quant_config=quant_config,
234
+ prefix=add_prefix("qkv_proj", prefix),
219
235
  )
220
236
 
221
237
  self.o_proj = RowParallelLinear(
@@ -223,6 +239,7 @@ class DeepseekAttention(nn.Module):
223
239
  hidden_size,
224
240
  bias=False,
225
241
  quant_config=quant_config,
242
+ prefix=add_prefix("o_proj", prefix),
226
243
  )
227
244
 
228
245
  self.rotary_emb = get_rope(
@@ -238,6 +255,7 @@ class DeepseekAttention(nn.Module):
238
255
  self.scaling,
239
256
  num_kv_heads=self.num_kv_heads,
240
257
  layer_id=layer_id,
258
+ prefix=add_prefix("attn", prefix),
241
259
  )
242
260
 
243
261
  def forward(
@@ -261,6 +279,7 @@ class DeepseekDecoderLayer(nn.Module):
261
279
  config: PretrainedConfig,
262
280
  layer_id: int,
263
281
  quant_config: Optional[QuantizationConfig] = None,
282
+ prefix: str = "",
264
283
  ) -> None:
265
284
  super().__init__()
266
285
  self.hidden_size = config.hidden_size
@@ -276,19 +295,25 @@ class DeepseekDecoderLayer(nn.Module):
276
295
  rope_scaling=rope_scaling,
277
296
  max_position_embeddings=max_position_embeddings,
278
297
  quant_config=quant_config,
298
+ prefix=add_prefix("self_attn", prefix),
279
299
  )
280
300
  if (
281
301
  config.n_routed_experts is not None
282
302
  and layer_id >= config.first_k_dense_replace
283
303
  and layer_id % config.moe_layer_freq == 0
284
304
  ):
285
- self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
305
+ self.mlp = DeepseekMoE(
306
+ config=config,
307
+ quant_config=quant_config,
308
+ prefix=add_prefix("mlp", prefix),
309
+ )
286
310
  else:
287
311
  self.mlp = DeepseekMLP(
288
312
  hidden_size=config.hidden_size,
289
313
  intermediate_size=config.intermediate_size,
290
314
  hidden_act=config.hidden_act,
291
315
  quant_config=quant_config,
316
+ prefix=add_prefix("mlp", prefix),
292
317
  )
293
318
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
294
319
  self.post_attention_layernorm = RMSNorm(
@@ -328,6 +353,7 @@ class DeepseekModel(nn.Module):
328
353
  self,
329
354
  config: PretrainedConfig,
330
355
  quant_config: Optional[QuantizationConfig] = None,
356
+ prefix: str = "",
331
357
  ) -> None:
332
358
  super().__init__()
333
359
  self.padding_idx = config.pad_token_id
@@ -339,7 +365,12 @@ class DeepseekModel(nn.Module):
339
365
  )
340
366
  self.layers = nn.ModuleList(
341
367
  [
342
- DeepseekDecoderLayer(config, layer_id, quant_config=quant_config)
368
+ DeepseekDecoderLayer(
369
+ config,
370
+ layer_id,
371
+ quant_config=quant_config,
372
+ prefix=add_prefix(f"layers.{layer_id}", prefix),
373
+ )
343
374
  for layer_id in range(config.num_hidden_layers)
344
375
  ]
345
376
  )
@@ -368,13 +399,19 @@ class DeepseekForCausalLM(nn.Module):
368
399
  self,
369
400
  config: PretrainedConfig,
370
401
  quant_config: Optional[QuantizationConfig] = None,
402
+ prefix: str = "",
371
403
  ) -> None:
372
404
  super().__init__()
373
405
  self.config = config
374
406
  self.quant_config = quant_config
375
- self.model = DeepseekModel(config, quant_config)
407
+ self.model = DeepseekModel(
408
+ config, quant_config, prefix=add_prefix("model", prefix)
409
+ )
376
410
  self.lm_head = ParallelLMHead(
377
- config.vocab_size, config.hidden_size, quant_config=quant_config
411
+ config.vocab_size,
412
+ config.hidden_size,
413
+ quant_config=quant_config,
414
+ prefix=add_prefix("lm_head", prefix),
378
415
  )
379
416
  self.logits_processor = LogitsProcessor(config)
380
417
 
@@ -38,7 +38,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
38
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
39
  from sglang.srt.model_loader.weight_utils import default_weight_loader
40
40
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
41
- from sglang.srt.utils import is_hip
41
+ from sglang.srt.utils import add_prefix, is_hip
42
42
 
43
43
  is_hip_ = is_hip()
44
44
 
@@ -48,6 +48,7 @@ class DeepseekModelNextN(nn.Module):
48
48
  self,
49
49
  config: PretrainedConfig,
50
50
  quant_config: Optional[QuantizationConfig] = None,
51
+ prefix: str = "",
51
52
  ) -> None:
52
53
  super().__init__()
53
54
  self.vocab_size = config.vocab_size
@@ -56,6 +57,7 @@ class DeepseekModelNextN(nn.Module):
56
57
  config.vocab_size,
57
58
  config.hidden_size,
58
59
  enable_tp=not global_server_args_dict["enable_dp_attention"],
60
+ prefix=add_prefix("embed_tokens", prefix),
59
61
  )
60
62
 
61
63
  self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -64,7 +66,11 @@ class DeepseekModelNextN(nn.Module):
64
66
  self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
65
67
 
66
68
  self.decoder = DeepseekV2DecoderLayer(
67
- config, 0, quant_config=quant_config, is_nextn=True
69
+ config,
70
+ 0,
71
+ quant_config=quant_config,
72
+ is_nextn=True,
73
+ prefix=add_prefix("decoder", prefix),
68
74
  )
69
75
 
70
76
  self.shared_head = nn.Module()
@@ -108,25 +114,30 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
108
114
  self,
109
115
  config: PretrainedConfig,
110
116
  quant_config: Optional[QuantizationConfig] = None,
117
+ prefix: str = "",
111
118
  ) -> None:
112
119
  nn.Module.__init__(self)
113
120
  self.config = config
114
121
  self.quant_config = quant_config
115
122
 
116
- self.model = DeepseekModelNextN(config, quant_config)
123
+ self.model = DeepseekModelNextN(
124
+ config, quant_config, prefix=add_prefix("model", prefix)
125
+ )
117
126
 
118
127
  if global_server_args_dict["enable_dp_attention"]:
119
- self.model.shared_head.head = ReplicatedLinear(
128
+ self.lm_head = ReplicatedLinear(
120
129
  config.hidden_size,
121
130
  config.vocab_size,
122
131
  bias=False,
132
+ prefix=add_prefix("model.shared_head.head", prefix),
123
133
  )
124
134
  self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
125
135
  else:
126
- self.model.shared_head.head = ParallelLMHead(
136
+ self.lm_head = ParallelLMHead(
127
137
  config.vocab_size,
128
138
  config.hidden_size,
129
139
  quant_config=quant_config,
140
+ prefix=add_prefix("model.shared_head.head", prefix),
130
141
  )
131
142
  self.logits_processor = LogitsProcessor(config)
132
143
 
@@ -139,7 +150,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
139
150
  ) -> torch.Tensor:
140
151
  hidden_states = self.model(input_ids, positions, forward_batch)
141
152
  return self.logits_processor(
142
- input_ids, hidden_states, self.model.shared_head.head, forward_batch
153
+ input_ids, hidden_states, self.lm_head, forward_batch
143
154
  )
144
155
 
145
156
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -168,10 +179,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
168
179
 
169
180
  nextn_layer_prefix = "model.layers.0"
170
181
  nextn_spec_weight_names = [
171
- "shared_head.head",
172
182
  "shared_head.norm",
173
183
  "eh_proj",
174
- "embed_tokens",
175
184
  "enorm",
176
185
  "hnorm",
177
186
  ]
@@ -180,17 +189,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
180
189
  for name, loaded_weight in weights:
181
190
  if not name.startswith(nextn_layer_prefix):
182
191
  continue
183
- else:
184
- is_decoder = True
185
- # For nextn specific weights
186
- for weight_name in nextn_spec_weight_names:
187
- if weight_name in name:
188
- name = name.replace(nextn_layer_prefix, "model")
189
- is_decoder = False
190
- break
191
- # For decoder layer weights
192
- if is_decoder:
193
- name = name.replace(nextn_layer_prefix, "model.decoder")
192
+
193
+ # Use shared head and embed weights from target model
194
+ if "shared_head.head" in name or "embed_tokens" in name:
195
+ continue
196
+
197
+ is_decoder = True
198
+ # For nextn specific weights
199
+ for weight_name in nextn_spec_weight_names:
200
+ if weight_name in name:
201
+ name = name.replace(nextn_layer_prefix, "model")
202
+ is_decoder = False
203
+ break
204
+ # For decoder layer weights
205
+ if is_decoder:
206
+ name = name.replace(nextn_layer_prefix, "model.decoder")
194
207
 
195
208
  if "rotary_emb.inv_freq" in name:
196
209
  continue