sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
  from sglang.srt.models.llama import LlamaForCausalLM
43
43
  from sglang.srt.models.mistral import MistralForCausalLM
44
44
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
45
+ from sglang.srt.utils import add_prefix
45
46
 
46
47
 
47
48
  class LlavaBaseForCausalLM(nn.Module):
@@ -475,6 +476,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
475
476
  self,
476
477
  config: LlavaConfig,
477
478
  quant_config: Optional[QuantizationConfig] = None,
479
+ prefix: str = "",
478
480
  ) -> None:
479
481
  super().__init__()
480
482
 
@@ -484,7 +486,11 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
484
486
  self.config.text_config.hidden_size = config.hidden_size
485
487
 
486
488
  self.multi_modal_projector = LlavaMultiModalProjector(config)
487
- self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
489
+ self.language_model = LlamaForCausalLM(
490
+ config,
491
+ quant_config=quant_config,
492
+ prefix=add_prefix("language_model", prefix),
493
+ )
488
494
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
489
495
  self.language_model.model.image_newline = nn.Parameter(
490
496
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
@@ -496,6 +502,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
496
502
  self,
497
503
  config: LlavaConfig,
498
504
  quant_config: Optional[QuantizationConfig] = None,
505
+ prefix: str = "",
499
506
  ) -> None:
500
507
  super().__init__()
501
508
 
@@ -516,7 +523,11 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
516
523
  self.config.image_token_index = 151646
517
524
 
518
525
  self.multi_modal_projector = LlavaMultiModalProjector(config)
519
- self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
526
+ self.language_model = Qwen2ForCausalLM(
527
+ config,
528
+ quant_config=quant_config,
529
+ prefix=add_prefix("language_model", prefix),
530
+ )
520
531
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
521
532
  self.language_model.model.image_newline = nn.Parameter(
522
533
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
@@ -528,6 +539,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
528
539
  self,
529
540
  config: LlavaConfig,
530
541
  quant_config: Optional[QuantizationConfig] = None,
542
+ prefix: str = "",
531
543
  ) -> None:
532
544
  super().__init__()
533
545
 
@@ -548,7 +560,11 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
548
560
  self.config.image_token_index = 32000
549
561
 
550
562
  self.multi_modal_projector = LlavaMultiModalProjector(config)
551
- self.language_model = MistralForCausalLM(config, quant_config=quant_config)
563
+ self.language_model = MistralForCausalLM(
564
+ config,
565
+ quant_config=quant_config,
566
+ prefix=add_prefix("language_model", prefix),
567
+ )
552
568
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
553
569
  self.language_model.model.image_newline = nn.Parameter(
554
570
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
@@ -26,6 +26,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.llama import LlamaForCausalLM
29
+ from sglang.srt.utils import add_prefix
29
30
 
30
31
 
31
32
  class LlavaVidForCausalLM(nn.Module):
@@ -33,6 +34,7 @@ class LlavaVidForCausalLM(nn.Module):
33
34
  self,
34
35
  config: LlavaConfig,
35
36
  quant_config: Optional[QuantizationConfig] = None,
37
+ prefix: str = "",
36
38
  ) -> None:
37
39
  super().__init__()
38
40
  self.config = config
@@ -44,7 +46,11 @@ class LlavaVidForCausalLM(nn.Module):
44
46
  self.resampler = nn.AvgPool2d(
45
47
  kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
46
48
  )
47
- self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
49
+ self.language_model = LlamaForCausalLM(
50
+ config,
51
+ quant_config=quant_config,
52
+ prefix=add_prefix("language_model", prefix),
53
+ )
48
54
  self.num_frames = getattr(self.config, "num_frames", 16)
49
55
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
50
56
  self.language_model.model.image_newline = nn.Parameter(
@@ -110,6 +116,9 @@ class LlavaVidForCausalLM(nn.Module):
110
116
  if forward_batch.forward_mode.is_extend():
111
117
  bs = forward_batch.batch_size
112
118
 
119
+ # Clamp input ids. See llava.py for more details
120
+ input_ids = input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
121
+
113
122
  # Embed text inputs
114
123
  input_embeds = self.language_model.model.embed_tokens(input_ids)
115
124
 
@@ -37,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
37
37
  )
38
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
39
  from sglang.srt.model_loader.weight_utils import default_weight_loader
40
+ from sglang.srt.utils import add_prefix
40
41
 
41
42
 
42
43
  class MiniCPMMLP(nn.Module):
@@ -46,6 +47,7 @@ class MiniCPMMLP(nn.Module):
46
47
  intermediate_size: int,
47
48
  hidden_act: str,
48
49
  quant_config: Optional[QuantizationConfig] = None,
50
+ prefix: str = "",
49
51
  ) -> None:
50
52
  super().__init__()
51
53
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -53,12 +55,14 @@ class MiniCPMMLP(nn.Module):
53
55
  [intermediate_size] * 2,
54
56
  bias=False,
55
57
  quant_config=quant_config,
58
+ prefix=add_prefix("gate_up_proj", prefix),
56
59
  )
57
60
  self.down_proj = RowParallelLinear(
58
61
  intermediate_size,
59
62
  hidden_size,
60
63
  bias=False,
61
64
  quant_config=quant_config,
65
+ prefix=add_prefix("down_proj", prefix),
62
66
  )
63
67
  if hidden_act != "silu":
64
68
  raise ValueError(
@@ -85,6 +89,7 @@ class MiniCPMAttention(nn.Module):
85
89
  rope_scaling: Optional[Dict[str, Any]] = None,
86
90
  max_position_embeddings: int = 8192,
87
91
  quant_config: Optional[QuantizationConfig] = None,
92
+ prefix: str = "",
88
93
  ) -> None:
89
94
  super().__init__()
90
95
  self.hidden_size = hidden_size
@@ -116,12 +121,14 @@ class MiniCPMAttention(nn.Module):
116
121
  self.total_num_kv_heads,
117
122
  bias=False,
118
123
  quant_config=quant_config,
124
+ prefix=add_prefix("qkv_proj", prefix),
119
125
  )
120
126
  self.o_proj = RowParallelLinear(
121
127
  self.total_num_heads * self.head_dim,
122
128
  hidden_size,
123
129
  bias=False,
124
130
  quant_config=quant_config,
131
+ prefix=add_prefix("o_proj", prefix),
125
132
  )
126
133
 
127
134
  self.rotary_emb = get_rope(
@@ -139,6 +146,7 @@ class MiniCPMAttention(nn.Module):
139
146
  self.scaling,
140
147
  num_kv_heads=self.num_kv_heads,
141
148
  layer_id=layer_id,
149
+ prefix=add_prefix("attn", prefix),
142
150
  )
143
151
 
144
152
  def forward(
@@ -164,6 +172,7 @@ class MiniCPMDecoderLayer(nn.Module):
164
172
  config,
165
173
  layer_id: int = 0,
166
174
  quant_config: Optional[QuantizationConfig] = None,
175
+ prefix: str = "",
167
176
  ) -> None:
168
177
  super().__init__()
169
178
  self.config = config
@@ -180,12 +189,14 @@ class MiniCPMDecoderLayer(nn.Module):
180
189
  rope_scaling=rope_scaling,
181
190
  max_position_embeddings=max_position_embeddings,
182
191
  quant_config=quant_config,
192
+ prefix=add_prefix("self_attn", prefix),
183
193
  )
184
194
  self.mlp = MiniCPMMLP(
185
195
  hidden_size=self.hidden_size,
186
196
  intermediate_size=config.intermediate_size,
187
197
  hidden_act=config.hidden_act,
188
198
  quant_config=quant_config,
199
+ prefix=add_prefix("mlp", prefix),
189
200
  )
190
201
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
191
202
  self.post_attention_layernorm = RMSNorm(
@@ -227,6 +238,7 @@ class MiniCPMModel(nn.Module):
227
238
  self,
228
239
  config,
229
240
  quant_config: Optional[QuantizationConfig] = None,
241
+ prefix: str = "",
230
242
  ) -> None:
231
243
  super().__init__()
232
244
  self.config = config
@@ -236,10 +248,16 @@ class MiniCPMModel(nn.Module):
236
248
  self.vocab_size,
237
249
  config.hidden_size,
238
250
  org_num_embeddings=config.vocab_size,
251
+ prefix=add_prefix("embed_tokens", prefix),
239
252
  )
240
253
  self.layers = nn.ModuleList(
241
254
  [
242
- MiniCPMDecoderLayer(config, i, quant_config=quant_config)
255
+ MiniCPMDecoderLayer(
256
+ config,
257
+ i,
258
+ quant_config=quant_config,
259
+ prefix=add_prefix(f"layers.{i}", prefix),
260
+ )
243
261
  for i in range(config.num_hidden_layers)
244
262
  ]
245
263
  )
@@ -275,19 +293,23 @@ class MiniCPMForCausalLM(nn.Module):
275
293
  self,
276
294
  config,
277
295
  quant_config: Optional[QuantizationConfig] = None,
296
+ prefix: str = "",
278
297
  ) -> None:
279
298
  super().__init__()
280
299
  self.config = config
281
300
 
282
301
  self.num_experts = getattr(self.config, "num_experts", 0)
283
302
  self.quant_config = quant_config
284
- self.model = MiniCPMModel(config, quant_config=quant_config)
303
+ self.model = MiniCPMModel(
304
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
305
+ )
285
306
  # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
286
307
  if not self.config.tie_word_embeddings:
287
308
  self.lm_head = ParallelLMHead(
288
309
  config.vocab_size,
289
310
  config.hidden_size,
290
311
  org_num_embeddings=config.vocab_size,
312
+ prefix=add_prefix("lm_head", prefix),
291
313
  )
292
314
 
293
315
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
@@ -339,6 +361,8 @@ class MiniCPMForCausalLM(nn.Module):
339
361
  # Models trained using ColossalAI may include these tensors in
340
362
  # the checkpoint. Skip them.
341
363
  continue
364
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
365
+ continue
342
366
 
343
367
  for param_name, weight_name, shard_id in stacked_params_mapping:
344
368
  if weight_name not in name:
@@ -40,7 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  from sglang.srt.managers.schedule_batch import global_server_args_dict
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
- from sglang.srt.utils import is_cuda_available
43
+ from sglang.srt.utils import add_prefix, is_cuda_available
44
44
 
45
45
  if is_cuda_available():
46
46
  from sgl_kernel import bmm_fp8
@@ -53,6 +53,7 @@ class MiniCPM3MLP(nn.Module):
53
53
  intermediate_size: int,
54
54
  hidden_act: str,
55
55
  quant_config: Optional[QuantizationConfig] = None,
56
+ prefix: str = "",
56
57
  ) -> None:
57
58
  super().__init__()
58
59
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -60,12 +61,14 @@ class MiniCPM3MLP(nn.Module):
60
61
  [intermediate_size] * 2,
61
62
  bias=False,
62
63
  quant_config=quant_config,
64
+ prefix=add_prefix("gate_up_proj", prefix),
63
65
  )
64
66
  self.down_proj = RowParallelLinear(
65
67
  intermediate_size,
66
68
  hidden_size,
67
69
  bias=False,
68
70
  quant_config=quant_config,
71
+ prefix=add_prefix("down_proj", prefix),
69
72
  )
70
73
  if hidden_act != "silu":
71
74
  raise ValueError(
@@ -107,6 +110,7 @@ class MiniCPM3Attention(nn.Module):
107
110
  max_position_embeddings: int = 8192,
108
111
  quant_config: Optional[QuantizationConfig] = None,
109
112
  layer_id=None,
113
+ prefix: str = "",
110
114
  ) -> None:
111
115
  super().__init__()
112
116
  self.layer_id = layer_id
@@ -131,6 +135,7 @@ class MiniCPM3Attention(nn.Module):
131
135
  self.q_lora_rank,
132
136
  bias=False,
133
137
  quant_config=quant_config,
138
+ prefix=add_prefix("q_a_proj", prefix),
134
139
  )
135
140
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
136
141
  self.q_b_proj = ColumnParallelLinear(
@@ -138,6 +143,7 @@ class MiniCPM3Attention(nn.Module):
138
143
  self.num_heads * self.qk_head_dim,
139
144
  bias=False,
140
145
  quant_config=quant_config,
146
+ prefix=add_prefix("q_b_proj", prefix),
141
147
  )
142
148
  else:
143
149
  self.q_proj = ColumnParallelLinear(
@@ -145,6 +151,7 @@ class MiniCPM3Attention(nn.Module):
145
151
  self.num_heads * self.qk_head_dim,
146
152
  bias=False,
147
153
  quant_config=quant_config,
154
+ prefix=add_prefix("q_proj", prefix),
148
155
  )
149
156
 
150
157
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -152,6 +159,7 @@ class MiniCPM3Attention(nn.Module):
152
159
  self.kv_lora_rank + self.qk_rope_head_dim,
153
160
  bias=False,
154
161
  quant_config=quant_config,
162
+ prefix=add_prefix("kv_a_proj_with_mqa", prefix),
155
163
  )
156
164
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
157
165
  self.kv_b_proj = ColumnParallelLinear(
@@ -159,6 +167,7 @@ class MiniCPM3Attention(nn.Module):
159
167
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
160
168
  bias=False,
161
169
  quant_config=quant_config,
170
+ prefix=add_prefix("kv_b_proj", prefix),
162
171
  )
163
172
  # O projection.
164
173
  self.o_proj = RowParallelLinear(
@@ -166,6 +175,7 @@ class MiniCPM3Attention(nn.Module):
166
175
  self.hidden_size,
167
176
  bias=False,
168
177
  quant_config=quant_config,
178
+ prefix=add_prefix("o_proj", prefix),
169
179
  )
170
180
  self.rotary_emb = get_rope(
171
181
  qk_rope_head_dim,
@@ -182,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
182
192
  self.scaling,
183
193
  num_kv_heads=self.num_local_heads,
184
194
  layer_id=layer_id,
195
+ prefix=add_prefix("attn", prefix),
185
196
  )
186
197
 
187
198
  def forward(
@@ -250,6 +261,7 @@ class MiniCPM3AttentionMLA(nn.Module):
250
261
  max_position_embeddings: int = 8192,
251
262
  quant_config: Optional[QuantizationConfig] = None,
252
263
  layer_id=None,
264
+ prefix: str = "",
253
265
  ) -> None:
254
266
  super().__init__()
255
267
  self.layer_id = layer_id
@@ -274,6 +286,7 @@ class MiniCPM3AttentionMLA(nn.Module):
274
286
  self.q_lora_rank,
275
287
  bias=False,
276
288
  quant_config=quant_config,
289
+ prefix=add_prefix("q_a_proj", prefix),
277
290
  )
278
291
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
279
292
  self.q_b_proj = ColumnParallelLinear(
@@ -281,6 +294,7 @@ class MiniCPM3AttentionMLA(nn.Module):
281
294
  self.num_heads * self.qk_head_dim,
282
295
  bias=False,
283
296
  quant_config=quant_config,
297
+ prefix=add_prefix("q_b_proj", prefix),
284
298
  )
285
299
  else:
286
300
  self.q_proj = ColumnParallelLinear(
@@ -288,6 +302,7 @@ class MiniCPM3AttentionMLA(nn.Module):
288
302
  self.num_heads * self.qk_head_dim,
289
303
  bias=False,
290
304
  quant_config=quant_config,
305
+ prefix=add_prefix("q_proj", prefix),
291
306
  )
292
307
 
293
308
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -295,6 +310,7 @@ class MiniCPM3AttentionMLA(nn.Module):
295
310
  self.kv_lora_rank + self.qk_rope_head_dim,
296
311
  bias=False,
297
312
  quant_config=quant_config,
313
+ prefix=add_prefix("kv_a_proj_with_mqa", prefix),
298
314
  )
299
315
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
300
316
  self.kv_b_proj = ColumnParallelLinear(
@@ -302,6 +318,7 @@ class MiniCPM3AttentionMLA(nn.Module):
302
318
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
303
319
  bias=False,
304
320
  quant_config=quant_config,
321
+ prefix=add_prefix("kv_b_proj", prefix),
305
322
  )
306
323
  # O projection.
307
324
  self.o_proj = RowParallelLinear(
@@ -309,6 +326,7 @@ class MiniCPM3AttentionMLA(nn.Module):
309
326
  self.hidden_size,
310
327
  bias=False,
311
328
  quant_config=quant_config,
329
+ prefix=add_prefix("o_proj", prefix),
312
330
  )
313
331
  self.rotary_emb = get_rope(
314
332
  qk_rope_head_dim,
@@ -325,6 +343,7 @@ class MiniCPM3AttentionMLA(nn.Module):
325
343
  num_kv_heads=1,
326
344
  layer_id=layer_id,
327
345
  v_head_dim=self.kv_lora_rank,
346
+ prefix=add_prefix("attn", prefix),
328
347
  )
329
348
 
330
349
  self.w_kc = None
@@ -405,6 +424,7 @@ class MiniCPM3DecoderLayer(nn.Module):
405
424
  config: PretrainedConfig,
406
425
  layer_id: int,
407
426
  quant_config: Optional[QuantizationConfig] = None,
427
+ prefix: str = "",
408
428
  ) -> None:
409
429
  super().__init__()
410
430
  self.config = config
@@ -429,6 +449,7 @@ class MiniCPM3DecoderLayer(nn.Module):
429
449
  max_position_embeddings=max_position_embeddings,
430
450
  quant_config=quant_config,
431
451
  layer_id=layer_id,
452
+ prefix=add_prefix("self_attn", prefix),
432
453
  )
433
454
  else:
434
455
  self.self_attn = MiniCPM3Attention(
@@ -447,12 +468,14 @@ class MiniCPM3DecoderLayer(nn.Module):
447
468
  max_position_embeddings=max_position_embeddings,
448
469
  quant_config=quant_config,
449
470
  layer_id=layer_id,
471
+ prefix=add_prefix("self_attn", prefix),
450
472
  )
451
473
  self.mlp = MiniCPM3MLP(
452
474
  hidden_size=self.hidden_size,
453
475
  intermediate_size=config.intermediate_size,
454
476
  hidden_act=config.hidden_act,
455
477
  quant_config=quant_config,
478
+ prefix=add_prefix("mlp", prefix),
456
479
  )
457
480
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
458
481
  self.post_attention_layernorm = RMSNorm(
@@ -494,6 +517,7 @@ class MiniCPM3Model(nn.Module):
494
517
  self,
495
518
  config: PretrainedConfig,
496
519
  quant_config: Optional[QuantizationConfig] = None,
520
+ prefix: str = "",
497
521
  ) -> None:
498
522
  super().__init__()
499
523
  self.config = config
@@ -503,10 +527,16 @@ class MiniCPM3Model(nn.Module):
503
527
  self.vocab_size,
504
528
  config.hidden_size,
505
529
  org_num_embeddings=config.vocab_size,
530
+ prefix=add_prefix("embed_tokens", prefix),
506
531
  )
507
532
  self.layers = nn.ModuleList(
508
533
  [
509
- MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
534
+ MiniCPM3DecoderLayer(
535
+ config,
536
+ i,
537
+ quant_config=quant_config,
538
+ prefix=add_prefix(f"layers.{i}", prefix),
539
+ )
510
540
  for i in range(config.num_hidden_layers)
511
541
  ]
512
542
  )
@@ -542,19 +572,23 @@ class MiniCPM3ForCausalLM(nn.Module):
542
572
  self,
543
573
  config: PretrainedConfig,
544
574
  quant_config: Optional[QuantizationConfig] = None,
575
+ prefix: str = "",
545
576
  ) -> None:
546
577
  super().__init__()
547
578
  self.config = config
548
579
 
549
580
  self.num_experts = getattr(self.config, "num_experts", 0)
550
581
  self.quant_config = quant_config
551
- self.model = MiniCPM3Model(config, quant_config=quant_config)
582
+ self.model = MiniCPM3Model(
583
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
584
+ )
552
585
  # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
553
586
  if not self.config.tie_word_embeddings:
554
587
  self.lm_head = ParallelLMHead(
555
588
  config.vocab_size,
556
589
  config.hidden_size,
557
590
  org_num_embeddings=config.vocab_size,
591
+ prefix=add_prefix("lm_head", prefix),
558
592
  )
559
593
 
560
594
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
@@ -603,6 +637,8 @@ class MiniCPM3ForCausalLM(nn.Module):
603
637
  # Models trained using ColossalAI may include these tensors in
604
638
  # the checkpoint. Skip them.
605
639
  continue
640
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
641
+ continue
606
642
 
607
643
  for param_name, weight_name, shard_id in stacked_params_mapping:
608
644
  if weight_name not in name: