sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/llava.py
CHANGED
@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
42
42
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
43
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
44
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
45
|
+
from sglang.srt.utils import add_prefix
|
45
46
|
|
46
47
|
|
47
48
|
class LlavaBaseForCausalLM(nn.Module):
|
@@ -475,6 +476,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
|
475
476
|
self,
|
476
477
|
config: LlavaConfig,
|
477
478
|
quant_config: Optional[QuantizationConfig] = None,
|
479
|
+
prefix: str = "",
|
478
480
|
) -> None:
|
479
481
|
super().__init__()
|
480
482
|
|
@@ -484,7 +486,11 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
|
484
486
|
self.config.text_config.hidden_size = config.hidden_size
|
485
487
|
|
486
488
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
487
|
-
self.language_model = LlamaForCausalLM(
|
489
|
+
self.language_model = LlamaForCausalLM(
|
490
|
+
config,
|
491
|
+
quant_config=quant_config,
|
492
|
+
prefix=add_prefix("language_model", prefix),
|
493
|
+
)
|
488
494
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
489
495
|
self.language_model.model.image_newline = nn.Parameter(
|
490
496
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
@@ -496,6 +502,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
|
496
502
|
self,
|
497
503
|
config: LlavaConfig,
|
498
504
|
quant_config: Optional[QuantizationConfig] = None,
|
505
|
+
prefix: str = "",
|
499
506
|
) -> None:
|
500
507
|
super().__init__()
|
501
508
|
|
@@ -516,7 +523,11 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
|
516
523
|
self.config.image_token_index = 151646
|
517
524
|
|
518
525
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
519
|
-
self.language_model = Qwen2ForCausalLM(
|
526
|
+
self.language_model = Qwen2ForCausalLM(
|
527
|
+
config,
|
528
|
+
quant_config=quant_config,
|
529
|
+
prefix=add_prefix("language_model", prefix),
|
530
|
+
)
|
520
531
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
521
532
|
self.language_model.model.image_newline = nn.Parameter(
|
522
533
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
@@ -528,6 +539,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
528
539
|
self,
|
529
540
|
config: LlavaConfig,
|
530
541
|
quant_config: Optional[QuantizationConfig] = None,
|
542
|
+
prefix: str = "",
|
531
543
|
) -> None:
|
532
544
|
super().__init__()
|
533
545
|
|
@@ -548,7 +560,11 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
548
560
|
self.config.image_token_index = 32000
|
549
561
|
|
550
562
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
551
|
-
self.language_model = MistralForCausalLM(
|
563
|
+
self.language_model = MistralForCausalLM(
|
564
|
+
config,
|
565
|
+
quant_config=quant_config,
|
566
|
+
prefix=add_prefix("language_model", prefix),
|
567
|
+
)
|
552
568
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
553
569
|
self.language_model.model.image_newline = nn.Parameter(
|
554
570
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
sglang/srt/models/llavavid.py
CHANGED
@@ -26,6 +26,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
|
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
27
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.llama import LlamaForCausalLM
|
29
|
+
from sglang.srt.utils import add_prefix
|
29
30
|
|
30
31
|
|
31
32
|
class LlavaVidForCausalLM(nn.Module):
|
@@ -33,6 +34,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
33
34
|
self,
|
34
35
|
config: LlavaConfig,
|
35
36
|
quant_config: Optional[QuantizationConfig] = None,
|
37
|
+
prefix: str = "",
|
36
38
|
) -> None:
|
37
39
|
super().__init__()
|
38
40
|
self.config = config
|
@@ -44,7 +46,11 @@ class LlavaVidForCausalLM(nn.Module):
|
|
44
46
|
self.resampler = nn.AvgPool2d(
|
45
47
|
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
|
46
48
|
)
|
47
|
-
self.language_model = LlamaForCausalLM(
|
49
|
+
self.language_model = LlamaForCausalLM(
|
50
|
+
config,
|
51
|
+
quant_config=quant_config,
|
52
|
+
prefix=add_prefix("language_model", prefix),
|
53
|
+
)
|
48
54
|
self.num_frames = getattr(self.config, "num_frames", 16)
|
49
55
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
50
56
|
self.language_model.model.image_newline = nn.Parameter(
|
@@ -110,6 +116,9 @@ class LlavaVidForCausalLM(nn.Module):
|
|
110
116
|
if forward_batch.forward_mode.is_extend():
|
111
117
|
bs = forward_batch.batch_size
|
112
118
|
|
119
|
+
# Clamp input ids. See llava.py for more details
|
120
|
+
input_ids = input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
121
|
+
|
113
122
|
# Embed text inputs
|
114
123
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
115
124
|
|
sglang/srt/models/minicpm.py
CHANGED
@@ -37,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
37
37
|
)
|
38
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
39
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
40
|
+
from sglang.srt.utils import add_prefix
|
40
41
|
|
41
42
|
|
42
43
|
class MiniCPMMLP(nn.Module):
|
@@ -46,6 +47,7 @@ class MiniCPMMLP(nn.Module):
|
|
46
47
|
intermediate_size: int,
|
47
48
|
hidden_act: str,
|
48
49
|
quant_config: Optional[QuantizationConfig] = None,
|
50
|
+
prefix: str = "",
|
49
51
|
) -> None:
|
50
52
|
super().__init__()
|
51
53
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -53,12 +55,14 @@ class MiniCPMMLP(nn.Module):
|
|
53
55
|
[intermediate_size] * 2,
|
54
56
|
bias=False,
|
55
57
|
quant_config=quant_config,
|
58
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
56
59
|
)
|
57
60
|
self.down_proj = RowParallelLinear(
|
58
61
|
intermediate_size,
|
59
62
|
hidden_size,
|
60
63
|
bias=False,
|
61
64
|
quant_config=quant_config,
|
65
|
+
prefix=add_prefix("down_proj", prefix),
|
62
66
|
)
|
63
67
|
if hidden_act != "silu":
|
64
68
|
raise ValueError(
|
@@ -85,6 +89,7 @@ class MiniCPMAttention(nn.Module):
|
|
85
89
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
86
90
|
max_position_embeddings: int = 8192,
|
87
91
|
quant_config: Optional[QuantizationConfig] = None,
|
92
|
+
prefix: str = "",
|
88
93
|
) -> None:
|
89
94
|
super().__init__()
|
90
95
|
self.hidden_size = hidden_size
|
@@ -116,12 +121,14 @@ class MiniCPMAttention(nn.Module):
|
|
116
121
|
self.total_num_kv_heads,
|
117
122
|
bias=False,
|
118
123
|
quant_config=quant_config,
|
124
|
+
prefix=add_prefix("qkv_proj", prefix),
|
119
125
|
)
|
120
126
|
self.o_proj = RowParallelLinear(
|
121
127
|
self.total_num_heads * self.head_dim,
|
122
128
|
hidden_size,
|
123
129
|
bias=False,
|
124
130
|
quant_config=quant_config,
|
131
|
+
prefix=add_prefix("o_proj", prefix),
|
125
132
|
)
|
126
133
|
|
127
134
|
self.rotary_emb = get_rope(
|
@@ -139,6 +146,7 @@ class MiniCPMAttention(nn.Module):
|
|
139
146
|
self.scaling,
|
140
147
|
num_kv_heads=self.num_kv_heads,
|
141
148
|
layer_id=layer_id,
|
149
|
+
prefix=add_prefix("attn", prefix),
|
142
150
|
)
|
143
151
|
|
144
152
|
def forward(
|
@@ -164,6 +172,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|
164
172
|
config,
|
165
173
|
layer_id: int = 0,
|
166
174
|
quant_config: Optional[QuantizationConfig] = None,
|
175
|
+
prefix: str = "",
|
167
176
|
) -> None:
|
168
177
|
super().__init__()
|
169
178
|
self.config = config
|
@@ -180,12 +189,14 @@ class MiniCPMDecoderLayer(nn.Module):
|
|
180
189
|
rope_scaling=rope_scaling,
|
181
190
|
max_position_embeddings=max_position_embeddings,
|
182
191
|
quant_config=quant_config,
|
192
|
+
prefix=add_prefix("self_attn", prefix),
|
183
193
|
)
|
184
194
|
self.mlp = MiniCPMMLP(
|
185
195
|
hidden_size=self.hidden_size,
|
186
196
|
intermediate_size=config.intermediate_size,
|
187
197
|
hidden_act=config.hidden_act,
|
188
198
|
quant_config=quant_config,
|
199
|
+
prefix=add_prefix("mlp", prefix),
|
189
200
|
)
|
190
201
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
191
202
|
self.post_attention_layernorm = RMSNorm(
|
@@ -227,6 +238,7 @@ class MiniCPMModel(nn.Module):
|
|
227
238
|
self,
|
228
239
|
config,
|
229
240
|
quant_config: Optional[QuantizationConfig] = None,
|
241
|
+
prefix: str = "",
|
230
242
|
) -> None:
|
231
243
|
super().__init__()
|
232
244
|
self.config = config
|
@@ -236,10 +248,16 @@ class MiniCPMModel(nn.Module):
|
|
236
248
|
self.vocab_size,
|
237
249
|
config.hidden_size,
|
238
250
|
org_num_embeddings=config.vocab_size,
|
251
|
+
prefix=add_prefix("embed_tokens", prefix),
|
239
252
|
)
|
240
253
|
self.layers = nn.ModuleList(
|
241
254
|
[
|
242
|
-
MiniCPMDecoderLayer(
|
255
|
+
MiniCPMDecoderLayer(
|
256
|
+
config,
|
257
|
+
i,
|
258
|
+
quant_config=quant_config,
|
259
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
260
|
+
)
|
243
261
|
for i in range(config.num_hidden_layers)
|
244
262
|
]
|
245
263
|
)
|
@@ -275,19 +293,23 @@ class MiniCPMForCausalLM(nn.Module):
|
|
275
293
|
self,
|
276
294
|
config,
|
277
295
|
quant_config: Optional[QuantizationConfig] = None,
|
296
|
+
prefix: str = "",
|
278
297
|
) -> None:
|
279
298
|
super().__init__()
|
280
299
|
self.config = config
|
281
300
|
|
282
301
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
283
302
|
self.quant_config = quant_config
|
284
|
-
self.model = MiniCPMModel(
|
303
|
+
self.model = MiniCPMModel(
|
304
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
305
|
+
)
|
285
306
|
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
286
307
|
if not self.config.tie_word_embeddings:
|
287
308
|
self.lm_head = ParallelLMHead(
|
288
309
|
config.vocab_size,
|
289
310
|
config.hidden_size,
|
290
311
|
org_num_embeddings=config.vocab_size,
|
312
|
+
prefix=add_prefix("lm_head", prefix),
|
291
313
|
)
|
292
314
|
|
293
315
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
@@ -339,6 +361,8 @@ class MiniCPMForCausalLM(nn.Module):
|
|
339
361
|
# Models trained using ColossalAI may include these tensors in
|
340
362
|
# the checkpoint. Skip them.
|
341
363
|
continue
|
364
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
365
|
+
continue
|
342
366
|
|
343
367
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
344
368
|
if weight_name not in name:
|
sglang/srt/models/minicpm3.py
CHANGED
@@ -40,7 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
40
40
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
41
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
42
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
|
-
from sglang.srt.utils import is_cuda_available
|
43
|
+
from sglang.srt.utils import add_prefix, is_cuda_available
|
44
44
|
|
45
45
|
if is_cuda_available():
|
46
46
|
from sgl_kernel import bmm_fp8
|
@@ -53,6 +53,7 @@ class MiniCPM3MLP(nn.Module):
|
|
53
53
|
intermediate_size: int,
|
54
54
|
hidden_act: str,
|
55
55
|
quant_config: Optional[QuantizationConfig] = None,
|
56
|
+
prefix: str = "",
|
56
57
|
) -> None:
|
57
58
|
super().__init__()
|
58
59
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -60,12 +61,14 @@ class MiniCPM3MLP(nn.Module):
|
|
60
61
|
[intermediate_size] * 2,
|
61
62
|
bias=False,
|
62
63
|
quant_config=quant_config,
|
64
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
63
65
|
)
|
64
66
|
self.down_proj = RowParallelLinear(
|
65
67
|
intermediate_size,
|
66
68
|
hidden_size,
|
67
69
|
bias=False,
|
68
70
|
quant_config=quant_config,
|
71
|
+
prefix=add_prefix("down_proj", prefix),
|
69
72
|
)
|
70
73
|
if hidden_act != "silu":
|
71
74
|
raise ValueError(
|
@@ -107,6 +110,7 @@ class MiniCPM3Attention(nn.Module):
|
|
107
110
|
max_position_embeddings: int = 8192,
|
108
111
|
quant_config: Optional[QuantizationConfig] = None,
|
109
112
|
layer_id=None,
|
113
|
+
prefix: str = "",
|
110
114
|
) -> None:
|
111
115
|
super().__init__()
|
112
116
|
self.layer_id = layer_id
|
@@ -131,6 +135,7 @@ class MiniCPM3Attention(nn.Module):
|
|
131
135
|
self.q_lora_rank,
|
132
136
|
bias=False,
|
133
137
|
quant_config=quant_config,
|
138
|
+
prefix=add_prefix("q_a_proj", prefix),
|
134
139
|
)
|
135
140
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
136
141
|
self.q_b_proj = ColumnParallelLinear(
|
@@ -138,6 +143,7 @@ class MiniCPM3Attention(nn.Module):
|
|
138
143
|
self.num_heads * self.qk_head_dim,
|
139
144
|
bias=False,
|
140
145
|
quant_config=quant_config,
|
146
|
+
prefix=add_prefix("q_b_proj", prefix),
|
141
147
|
)
|
142
148
|
else:
|
143
149
|
self.q_proj = ColumnParallelLinear(
|
@@ -145,6 +151,7 @@ class MiniCPM3Attention(nn.Module):
|
|
145
151
|
self.num_heads * self.qk_head_dim,
|
146
152
|
bias=False,
|
147
153
|
quant_config=quant_config,
|
154
|
+
prefix=add_prefix("q_proj", prefix),
|
148
155
|
)
|
149
156
|
|
150
157
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -152,6 +159,7 @@ class MiniCPM3Attention(nn.Module):
|
|
152
159
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
153
160
|
bias=False,
|
154
161
|
quant_config=quant_config,
|
162
|
+
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
155
163
|
)
|
156
164
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
157
165
|
self.kv_b_proj = ColumnParallelLinear(
|
@@ -159,6 +167,7 @@ class MiniCPM3Attention(nn.Module):
|
|
159
167
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
160
168
|
bias=False,
|
161
169
|
quant_config=quant_config,
|
170
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
162
171
|
)
|
163
172
|
# O projection.
|
164
173
|
self.o_proj = RowParallelLinear(
|
@@ -166,6 +175,7 @@ class MiniCPM3Attention(nn.Module):
|
|
166
175
|
self.hidden_size,
|
167
176
|
bias=False,
|
168
177
|
quant_config=quant_config,
|
178
|
+
prefix=add_prefix("o_proj", prefix),
|
169
179
|
)
|
170
180
|
self.rotary_emb = get_rope(
|
171
181
|
qk_rope_head_dim,
|
@@ -182,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
|
|
182
192
|
self.scaling,
|
183
193
|
num_kv_heads=self.num_local_heads,
|
184
194
|
layer_id=layer_id,
|
195
|
+
prefix=add_prefix("attn", prefix),
|
185
196
|
)
|
186
197
|
|
187
198
|
def forward(
|
@@ -250,6 +261,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
250
261
|
max_position_embeddings: int = 8192,
|
251
262
|
quant_config: Optional[QuantizationConfig] = None,
|
252
263
|
layer_id=None,
|
264
|
+
prefix: str = "",
|
253
265
|
) -> None:
|
254
266
|
super().__init__()
|
255
267
|
self.layer_id = layer_id
|
@@ -274,6 +286,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
274
286
|
self.q_lora_rank,
|
275
287
|
bias=False,
|
276
288
|
quant_config=quant_config,
|
289
|
+
prefix=add_prefix("q_a_proj", prefix),
|
277
290
|
)
|
278
291
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
279
292
|
self.q_b_proj = ColumnParallelLinear(
|
@@ -281,6 +294,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
281
294
|
self.num_heads * self.qk_head_dim,
|
282
295
|
bias=False,
|
283
296
|
quant_config=quant_config,
|
297
|
+
prefix=add_prefix("q_b_proj", prefix),
|
284
298
|
)
|
285
299
|
else:
|
286
300
|
self.q_proj = ColumnParallelLinear(
|
@@ -288,6 +302,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
288
302
|
self.num_heads * self.qk_head_dim,
|
289
303
|
bias=False,
|
290
304
|
quant_config=quant_config,
|
305
|
+
prefix=add_prefix("q_proj", prefix),
|
291
306
|
)
|
292
307
|
|
293
308
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -295,6 +310,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
295
310
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
296
311
|
bias=False,
|
297
312
|
quant_config=quant_config,
|
313
|
+
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
298
314
|
)
|
299
315
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
300
316
|
self.kv_b_proj = ColumnParallelLinear(
|
@@ -302,6 +318,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
302
318
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
303
319
|
bias=False,
|
304
320
|
quant_config=quant_config,
|
321
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
305
322
|
)
|
306
323
|
# O projection.
|
307
324
|
self.o_proj = RowParallelLinear(
|
@@ -309,6 +326,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
309
326
|
self.hidden_size,
|
310
327
|
bias=False,
|
311
328
|
quant_config=quant_config,
|
329
|
+
prefix=add_prefix("o_proj", prefix),
|
312
330
|
)
|
313
331
|
self.rotary_emb = get_rope(
|
314
332
|
qk_rope_head_dim,
|
@@ -325,6 +343,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
325
343
|
num_kv_heads=1,
|
326
344
|
layer_id=layer_id,
|
327
345
|
v_head_dim=self.kv_lora_rank,
|
346
|
+
prefix=add_prefix("attn", prefix),
|
328
347
|
)
|
329
348
|
|
330
349
|
self.w_kc = None
|
@@ -405,6 +424,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
405
424
|
config: PretrainedConfig,
|
406
425
|
layer_id: int,
|
407
426
|
quant_config: Optional[QuantizationConfig] = None,
|
427
|
+
prefix: str = "",
|
408
428
|
) -> None:
|
409
429
|
super().__init__()
|
410
430
|
self.config = config
|
@@ -429,6 +449,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
429
449
|
max_position_embeddings=max_position_embeddings,
|
430
450
|
quant_config=quant_config,
|
431
451
|
layer_id=layer_id,
|
452
|
+
prefix=add_prefix("self_attn", prefix),
|
432
453
|
)
|
433
454
|
else:
|
434
455
|
self.self_attn = MiniCPM3Attention(
|
@@ -447,12 +468,14 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
447
468
|
max_position_embeddings=max_position_embeddings,
|
448
469
|
quant_config=quant_config,
|
449
470
|
layer_id=layer_id,
|
471
|
+
prefix=add_prefix("self_attn", prefix),
|
450
472
|
)
|
451
473
|
self.mlp = MiniCPM3MLP(
|
452
474
|
hidden_size=self.hidden_size,
|
453
475
|
intermediate_size=config.intermediate_size,
|
454
476
|
hidden_act=config.hidden_act,
|
455
477
|
quant_config=quant_config,
|
478
|
+
prefix=add_prefix("mlp", prefix),
|
456
479
|
)
|
457
480
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
458
481
|
self.post_attention_layernorm = RMSNorm(
|
@@ -494,6 +517,7 @@ class MiniCPM3Model(nn.Module):
|
|
494
517
|
self,
|
495
518
|
config: PretrainedConfig,
|
496
519
|
quant_config: Optional[QuantizationConfig] = None,
|
520
|
+
prefix: str = "",
|
497
521
|
) -> None:
|
498
522
|
super().__init__()
|
499
523
|
self.config = config
|
@@ -503,10 +527,16 @@ class MiniCPM3Model(nn.Module):
|
|
503
527
|
self.vocab_size,
|
504
528
|
config.hidden_size,
|
505
529
|
org_num_embeddings=config.vocab_size,
|
530
|
+
prefix=add_prefix("embed_tokens", prefix),
|
506
531
|
)
|
507
532
|
self.layers = nn.ModuleList(
|
508
533
|
[
|
509
|
-
MiniCPM3DecoderLayer(
|
534
|
+
MiniCPM3DecoderLayer(
|
535
|
+
config,
|
536
|
+
i,
|
537
|
+
quant_config=quant_config,
|
538
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
539
|
+
)
|
510
540
|
for i in range(config.num_hidden_layers)
|
511
541
|
]
|
512
542
|
)
|
@@ -542,19 +572,23 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
542
572
|
self,
|
543
573
|
config: PretrainedConfig,
|
544
574
|
quant_config: Optional[QuantizationConfig] = None,
|
575
|
+
prefix: str = "",
|
545
576
|
) -> None:
|
546
577
|
super().__init__()
|
547
578
|
self.config = config
|
548
579
|
|
549
580
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
550
581
|
self.quant_config = quant_config
|
551
|
-
self.model = MiniCPM3Model(
|
582
|
+
self.model = MiniCPM3Model(
|
583
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
584
|
+
)
|
552
585
|
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
553
586
|
if not self.config.tie_word_embeddings:
|
554
587
|
self.lm_head = ParallelLMHead(
|
555
588
|
config.vocab_size,
|
556
589
|
config.hidden_size,
|
557
590
|
org_num_embeddings=config.vocab_size,
|
591
|
+
prefix=add_prefix("lm_head", prefix),
|
558
592
|
)
|
559
593
|
|
560
594
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
@@ -603,6 +637,8 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
603
637
|
# Models trained using ColossalAI may include these tensors in
|
604
638
|
# the checkpoint. Skip them.
|
605
639
|
continue
|
640
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
641
|
+
continue
|
606
642
|
|
607
643
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
608
644
|
if weight_name not in name:
|