sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
56
  from sglang.srt.model_loader.utils import set_default_torch_dtype
57
57
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
58
  from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
59
+ from sglang.srt.utils import add_prefix
59
60
 
60
61
  RawImageType = Union[Image.Image, torch.Tensor]
61
62
 
@@ -158,14 +159,14 @@ class Idefics2VisionMLP(nn.Module):
158
159
  config.intermediate_size,
159
160
  bias=True,
160
161
  quant_config=quant_config,
161
- prefix=f"{prefix}.fc1",
162
+ prefix=add_prefix("fc1", prefix),
162
163
  )
163
164
  self.fc2 = RowParallelLinear(
164
165
  config.intermediate_size,
165
166
  config.hidden_size,
166
167
  bias=True,
167
168
  quant_config=quant_config,
168
- prefix=f"{prefix}.fc2",
169
+ prefix=add_prefix("fc2", prefix),
169
170
  )
170
171
 
171
172
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -199,10 +200,14 @@ class Idefics2EncoderLayer(nn.Module):
199
200
  use_context_forward=False,
200
201
  use_full_precision_softmax=True,
201
202
  flatten_batch=False,
202
- prefix=f"{prefix}.self_attn",
203
+ prefix=add_prefix("self_attn", prefix),
203
204
  )
204
205
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
205
- self.mlp = Idefics2VisionMLP(config, quant_config=quant_config)
206
+ self.mlp = Idefics2VisionMLP(
207
+ config,
208
+ quant_config=quant_config,
209
+ prefix=add_prefix("mlp", prefix),
210
+ )
206
211
  self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
207
212
 
208
213
  def forward(
@@ -242,6 +247,7 @@ class Idefics2Encoder(nn.Module):
242
247
  self,
243
248
  config: PretrainedConfig,
244
249
  quant_config: Optional[QuantizationConfig] = None,
250
+ prefix: str = "",
245
251
  ) -> None:
246
252
  super().__init__()
247
253
 
@@ -251,8 +257,9 @@ class Idefics2Encoder(nn.Module):
251
257
  Idefics2EncoderLayer(
252
258
  config,
253
259
  quant_config=quant_config,
260
+ prefix=add_prefix(f"layers.{i}", prefix),
254
261
  )
255
- for _ in range(config.num_hidden_layers)
262
+ for i in range(config.num_hidden_layers)
256
263
  ]
257
264
  )
258
265
 
@@ -379,13 +386,18 @@ class Idefics2VisionTransformer(nn.Module):
379
386
  self,
380
387
  config: PretrainedConfig,
381
388
  quant_config: Optional[QuantizationConfig] = None,
389
+ prefix: str = "",
382
390
  ) -> None:
383
391
  super().__init__()
384
392
 
385
393
  embed_dim = config.hidden_size
386
394
  self.config = config
387
395
  self.embeddings = Idefics2VisionEmbeddings(config)
388
- self.encoder = Idefics2Encoder(config=config, quant_config=quant_config)
396
+ self.encoder = Idefics2Encoder(
397
+ config=config,
398
+ quant_config=quant_config,
399
+ prefix=add_prefix("encoder", prefix),
400
+ )
389
401
  self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
390
402
 
391
403
  def get_input_embeddings(self):
@@ -503,7 +515,7 @@ class BaseResampler(nn.Module):
503
515
  embed_dim,
504
516
  bias=False,
505
517
  quant_config=quant_config,
506
- prefix=f"{prefix}.kv_proj",
518
+ prefix=add_prefix("kv_proj", prefix),
507
519
  )
508
520
  else:
509
521
  # Maintain the same return value with ReplicatedLinear.forward
@@ -660,6 +672,7 @@ class MiniCPMVBaseModel(nn.Module):
660
672
  *,
661
673
  config: PretrainedConfig,
662
674
  quant_config: Optional[QuantizationConfig] = None,
675
+ prefix: str = "",
663
676
  ):
664
677
  super().__init__()
665
678
  # All MiniCPM-V models disable `tie_word_embeddings` but
@@ -669,8 +682,12 @@ class MiniCPMVBaseModel(nn.Module):
669
682
  self.config = config
670
683
 
671
684
  self.version = get_version_by_config(self.config)
672
- self.llm = self.init_llm(config=config, quant_config=quant_config)
673
- self.vpm = self.init_vision_module(config, quant_config)
685
+ self.llm = self.init_llm(
686
+ config=config, quant_config=quant_config, prefix=add_prefix("llm", prefix)
687
+ )
688
+ self.vpm = self.init_vision_module(
689
+ config, quant_config, add_prefix("vpm", prefix)
690
+ )
674
691
  self.vision_dim = (
675
692
  self.vpm.embed_dim
676
693
  if self.version == (2, 0)
@@ -679,7 +696,10 @@ class MiniCPMVBaseModel(nn.Module):
679
696
  self.embed_dim = self.config.hidden_size
680
697
 
681
698
  self.resampler = self.init_resampler(
682
- self.embed_dim, self.vision_dim, quant_config=quant_config
699
+ self.embed_dim,
700
+ self.vision_dim,
701
+ quant_config=quant_config,
702
+ prefix=add_prefix("resampler", prefix),
683
703
  )
684
704
 
685
705
  self.logits_processor = LogitsProcessor(config)
@@ -937,6 +957,7 @@ class MiniCPMVBaseModel(nn.Module):
937
957
  self,
938
958
  config: Qwen2Config,
939
959
  quant_config: Optional[QuantizationConfig] = None,
960
+ prefix: str = "",
940
961
  ) -> nn.Module:
941
962
  raise NotImplementedError
942
963
 
@@ -944,6 +965,7 @@ class MiniCPMVBaseModel(nn.Module):
944
965
  self,
945
966
  config: PretrainedConfig,
946
967
  quant_config: Optional[QuantizationConfig],
968
+ prefix: str = "",
947
969
  ) -> nn.Module:
948
970
  raise NotImplementedError
949
971
 
@@ -952,6 +974,7 @@ class MiniCPMVBaseModel(nn.Module):
952
974
  embed_dim: int,
953
975
  vision_dim: int,
954
976
  quant_config: Optional[QuantizationConfig] = None,
977
+ prefix: str = "",
955
978
  ) -> nn.Module:
956
979
  raise NotImplementedError
957
980
 
@@ -1011,24 +1034,27 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1011
1034
  self,
1012
1035
  config: PretrainedConfig,
1013
1036
  quant_config: Optional[QuantizationConfig] = None,
1037
+ prefix: str = "",
1014
1038
  ):
1015
- super().__init__(config=config, quant_config=quant_config)
1039
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
1016
1040
  assert self.version == (2, 6)
1017
1041
 
1018
1042
  def init_llm(
1019
1043
  self,
1020
1044
  config: Qwen2Config,
1021
1045
  quant_config: Optional[QuantizationConfig] = None,
1046
+ prefix: str = "",
1022
1047
  ) -> nn.Module:
1023
- return Qwen2ForCausalLM(config=config, quant_config=quant_config)
1048
+ return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
1024
1049
 
1025
1050
  def init_vision_module(
1026
1051
  self,
1027
1052
  config: PretrainedConfig,
1028
1053
  quant_config: Optional[QuantizationConfig],
1054
+ prefix: str = "",
1029
1055
  ) -> nn.Module:
1030
1056
  model = Idefics2VisionTransformer(
1031
- config=config.vision_config, quant_config=quant_config
1057
+ config=config.vision_config, quant_config=quant_config, prefix=prefix
1032
1058
  )
1033
1059
  if self.config.drop_vision_last_layer:
1034
1060
  model.encoder.layers = model.encoder.layers[:-1]
@@ -1042,6 +1068,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1042
1068
  embed_dim: int,
1043
1069
  vision_dim: int,
1044
1070
  quant_config: Optional[QuantizationConfig] = None,
1071
+ prefix: str = "",
1045
1072
  ) -> nn.Module:
1046
1073
  with set_default_torch_dtype(torch.float16):
1047
1074
  # The resampler in 2.6 remains consistent with the one in 2.5.
@@ -1051,6 +1078,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1051
1078
  num_heads=embed_dim // 128,
1052
1079
  kv_dim=vision_dim,
1053
1080
  quant_config=quant_config,
1081
+ prefix=prefix,
1054
1082
  )
1055
1083
 
1056
1084
  return resampler.to(device="cuda", dtype=torch.get_default_dtype())
@@ -1207,6 +1235,7 @@ class MiniCPMV:
1207
1235
  self,
1208
1236
  config: PretrainedConfig,
1209
1237
  quant_config: Optional[QuantizationConfig] = None,
1238
+ prefix: str = "",
1210
1239
  ) -> None:
1211
1240
  super().__init__()
1212
1241
 
@@ -1221,7 +1250,9 @@ class MiniCPMV:
1221
1250
  raise ValueError("Currently, MiniCPMV only supports versions 2.6")
1222
1251
 
1223
1252
  try:
1224
- minicpmv = instance_class(config=config, quant_config=quant_config)
1253
+ minicpmv = instance_class(
1254
+ config=config, quant_config=quant_config, prefix=prefix
1255
+ )
1225
1256
  self.minicpmv = minicpmv
1226
1257
  except Exception as e:
1227
1258
  print(f"Failed to instantiate MiniCPMV: {e}")
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
45
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
  from sglang.srt.model_loader.weight_utils import default_weight_loader
48
+ from sglang.srt.utils import add_prefix
48
49
 
49
50
 
50
51
  class MixtralMoE(nn.Module):
@@ -78,7 +79,7 @@ class MixtralMoE(nn.Module):
78
79
  bias=False,
79
80
  params_dtype=params_dtype,
80
81
  quant_config=None,
81
- prefix=f"{prefix}.gate",
82
+ prefix=add_prefix("gate", prefix),
82
83
  )
83
84
  MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
84
85
  self.experts = MoEImpl(
@@ -90,7 +91,7 @@ class MixtralMoE(nn.Module):
90
91
  renormalize=True,
91
92
  quant_config=quant_config,
92
93
  tp_size=tp_size,
93
- prefix=f"{prefix}.experts",
94
+ prefix=add_prefix("experts", prefix),
94
95
  )
95
96
 
96
97
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -146,14 +147,14 @@ class MixtralAttention(nn.Module):
146
147
  self.total_num_kv_heads,
147
148
  bias=False,
148
149
  quant_config=quant_config,
149
- prefix=f"{prefix}.qkv_proj",
150
+ prefix=add_prefix("qkv_proj", prefix),
150
151
  )
151
152
  self.o_proj = RowParallelLinear(
152
153
  self.total_num_heads * self.head_dim,
153
154
  hidden_size,
154
155
  bias=False,
155
156
  quant_config=quant_config,
156
- prefix=f"{prefix}.o_proj",
157
+ prefix=add_prefix("o_proj", prefix),
157
158
  )
158
159
  self.rotary_emb = get_rope(
159
160
  self.head_dim,
@@ -168,6 +169,7 @@ class MixtralAttention(nn.Module):
168
169
  self.scaling,
169
170
  num_kv_heads=self.num_kv_heads,
170
171
  layer_id=layer_id,
172
+ prefix=add_prefix("attn", prefix),
171
173
  )
172
174
 
173
175
  def forward(
@@ -204,7 +206,7 @@ class MixtralDecoderLayer(nn.Module):
204
206
  layer_id=layer_id,
205
207
  rope_theta=rope_theta,
206
208
  quant_config=quant_config,
207
- prefix=f"{prefix}.self_attn",
209
+ prefix=add_prefix("self_attn", prefix),
208
210
  )
209
211
  self.block_sparse_moe = MixtralMoE(
210
212
  num_experts=config.num_local_experts,
@@ -212,7 +214,7 @@ class MixtralDecoderLayer(nn.Module):
212
214
  hidden_size=config.hidden_size,
213
215
  intermediate_size=config.intermediate_size,
214
216
  quant_config=quant_config,
215
- prefix=f"{prefix}.block_sparse_moe",
217
+ prefix=add_prefix("block_sparse_moe", prefix),
216
218
  )
217
219
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
218
220
  self.post_attention_layernorm = RMSNorm(
@@ -258,11 +260,15 @@ class MixtralModel(nn.Module):
258
260
  self.embed_tokens = VocabParallelEmbedding(
259
261
  config.vocab_size,
260
262
  config.hidden_size,
263
+ prefix=add_prefix("embed_tokens", prefix),
261
264
  )
262
265
  self.layers = nn.ModuleList(
263
266
  [
264
267
  MixtralDecoderLayer(
265
- config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
268
+ config,
269
+ i,
270
+ quant_config=quant_config,
271
+ prefix=add_prefix(f"layers.{i}", prefix),
266
272
  )
267
273
  for i in range(config.num_hidden_layers)
268
274
  ]
@@ -296,12 +302,17 @@ class MixtralForCausalLM(nn.Module):
296
302
  self,
297
303
  config: MixtralConfig,
298
304
  quant_config: Optional[QuantizationConfig] = None,
305
+ prefix: str = "",
299
306
  ) -> None:
300
307
  super().__init__()
301
308
  self.config = config
302
309
  self.quant_config = quant_config
303
- self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
304
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
310
+ self.model = MixtralModel(
311
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
312
+ )
313
+ self.lm_head = ParallelLMHead(
314
+ config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
315
+ )
305
316
  self.logits_processor = LogitsProcessor(config)
306
317
 
307
318
  def forward(
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
45
  )
46
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
  from sglang.srt.model_loader.weight_utils import default_weight_loader
48
+ from sglang.srt.utils import add_prefix
48
49
 
49
50
 
50
51
  class MixtralMLP(nn.Module):
@@ -54,6 +55,7 @@ class MixtralMLP(nn.Module):
54
55
  hidden_size: int,
55
56
  intermediate_size: int,
56
57
  quant_config: Optional[QuantizationConfig] = None,
58
+ prefix: str = "",
57
59
  ) -> None:
58
60
  super().__init__()
59
61
  self.num_experts = num_experts
@@ -61,13 +63,25 @@ class MixtralMLP(nn.Module):
61
63
  self.hidden_dim = hidden_size
62
64
 
63
65
  self.w1 = ReplicatedLinear(
64
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
66
+ self.hidden_dim,
67
+ self.ffn_dim,
68
+ bias=False,
69
+ quant_config=quant_config,
70
+ prefix=add_prefix("w1", prefix),
65
71
  )
66
72
  self.w2 = ReplicatedLinear(
67
- self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
73
+ self.ffn_dim,
74
+ self.hidden_dim,
75
+ bias=False,
76
+ quant_config=quant_config,
77
+ prefix=add_prefix("w2", prefix),
68
78
  )
69
79
  self.w3 = ReplicatedLinear(
70
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
80
+ self.hidden_dim,
81
+ self.ffn_dim,
82
+ bias=False,
83
+ quant_config=quant_config,
84
+ prefix=add_prefix("w3", prefix),
71
85
  )
72
86
 
73
87
  # TODO: Use vllm's SiluAndMul
@@ -87,6 +101,7 @@ class MixtralMoE(nn.Module):
87
101
  self,
88
102
  config: MixtralConfig,
89
103
  quant_config: Optional[QuantizationConfig] = None,
104
+ prefix: str = "",
90
105
  ):
91
106
  super().__init__()
92
107
  self.config = config
@@ -114,6 +129,7 @@ class MixtralMoE(nn.Module):
114
129
  config.hidden_size,
115
130
  config.intermediate_size,
116
131
  quant_config=quant_config,
132
+ prefix=add_prefix(f"experts.{idx}", prefix),
117
133
  )
118
134
  if idx in self.expert_indicies
119
135
  else None
@@ -122,7 +138,11 @@ class MixtralMoE(nn.Module):
122
138
  ]
123
139
  )
124
140
  self.gate = ReplicatedLinear(
125
- config.hidden_size, self.num_total_experts, bias=False, quant_config=None
141
+ config.hidden_size,
142
+ self.num_total_experts,
143
+ bias=False,
144
+ quant_config=None,
145
+ prefix=add_prefix("gate", prefix),
126
146
  )
127
147
 
128
148
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -159,6 +179,7 @@ class MixtralAttention(nn.Module):
159
179
  max_position: int = 4096 * 32,
160
180
  rope_theta: float = 10000,
161
181
  quant_config: Optional[QuantizationConfig] = None,
182
+ prefix: str = "",
162
183
  ) -> None:
163
184
  super().__init__()
164
185
  self.hidden_size = hidden_size
@@ -189,12 +210,14 @@ class MixtralAttention(nn.Module):
189
210
  self.total_num_kv_heads,
190
211
  bias=False,
191
212
  quant_config=quant_config,
213
+ prefix=add_prefix("qkv_proj", prefix),
192
214
  )
193
215
  self.o_proj = RowParallelLinear(
194
216
  self.total_num_heads * self.head_dim,
195
217
  hidden_size,
196
218
  bias=False,
197
219
  quant_config=quant_config,
220
+ prefix=add_prefix("o_proj", prefix),
198
221
  )
199
222
  self.rotary_emb = get_rope(
200
223
  self.head_dim,
@@ -209,6 +232,7 @@ class MixtralAttention(nn.Module):
209
232
  self.scaling,
210
233
  num_kv_heads=self.num_kv_heads,
211
234
  layer_id=layer_id,
235
+ prefix=add_prefix("attn", prefix),
212
236
  )
213
237
 
214
238
  def forward(
@@ -231,6 +255,7 @@ class MixtralDecoderLayer(nn.Module):
231
255
  config: MixtralConfig,
232
256
  layer_id: int = 0,
233
257
  quant_config: Optional[QuantizationConfig] = None,
258
+ prefix: str = "",
234
259
  ) -> None:
235
260
  super().__init__()
236
261
  self.hidden_size = config.hidden_size
@@ -244,8 +269,13 @@ class MixtralDecoderLayer(nn.Module):
244
269
  layer_id=layer_id,
245
270
  rope_theta=rope_theta,
246
271
  quant_config=quant_config,
272
+ prefix=add_prefix("self_attn", prefix),
273
+ )
274
+ self.block_sparse_moe = MixtralMoE(
275
+ config=config,
276
+ quant_config=quant_config,
277
+ prefix=add_prefix("block_sparse_moe", prefix),
247
278
  )
248
- self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
249
279
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
280
  self.post_attention_layernorm = RMSNorm(
251
281
  config.hidden_size, eps=config.rms_norm_eps
@@ -281,6 +311,7 @@ class MixtralModel(nn.Module):
281
311
  self,
282
312
  config: MixtralConfig,
283
313
  quant_config: Optional[QuantizationConfig] = None,
314
+ prefix: str = "",
284
315
  ) -> None:
285
316
  super().__init__()
286
317
  self.padding_idx = config.pad_token_id
@@ -289,10 +320,16 @@ class MixtralModel(nn.Module):
289
320
  self.embed_tokens = VocabParallelEmbedding(
290
321
  config.vocab_size,
291
322
  config.hidden_size,
323
+ prefix=add_prefix("embed_tokens", prefix),
292
324
  )
293
325
  self.layers = nn.ModuleList(
294
326
  [
295
- MixtralDecoderLayer(config, i, quant_config=quant_config)
327
+ MixtralDecoderLayer(
328
+ config,
329
+ i,
330
+ quant_config=quant_config,
331
+ prefix=add_prefix(f"layers.{i}", prefix),
332
+ )
296
333
  for i in range(config.num_hidden_layers)
297
334
  ]
298
335
  )
@@ -324,12 +361,17 @@ class QuantMixtralForCausalLM(nn.Module):
324
361
  self,
325
362
  config: MixtralConfig,
326
363
  quant_config: Optional[QuantizationConfig] = None,
364
+ prefix: str = "",
327
365
  ) -> None:
328
366
  super().__init__()
329
367
  self.config = config
330
368
  self.quant_config = quant_config
331
- self.model = MixtralModel(config, quant_config=quant_config)
332
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
369
+ self.model = MixtralModel(
370
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
371
+ )
372
+ self.lm_head = ParallelLMHead(
373
+ config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
374
+ )
333
375
  self.logits_processor = LogitsProcessor(config)
334
376
 
335
377
  @torch.no_grad()