sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -37,8 +37,11 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
|
37
37
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
38
38
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
|
39
39
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
|
40
|
-
from sglang.srt.layers.quantization.modelopt_quant import
|
|
41
|
-
|
|
40
|
+
from sglang.srt.layers.quantization.modelopt_quant import (
|
|
41
|
+
ModelOptFp4Config,
|
|
42
|
+
ModelOptFp8Config,
|
|
43
|
+
)
|
|
44
|
+
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
|
|
42
45
|
from sglang.utils import is_in_ci
|
|
43
46
|
|
|
44
47
|
logger = logging.getLogger(__name__)
|
|
@@ -110,6 +113,9 @@ def convert_bin_to_safetensor_file(
|
|
|
110
113
|
|
|
111
114
|
dirname = os.path.dirname(sf_filename)
|
|
112
115
|
os.makedirs(dirname, exist_ok=True)
|
|
116
|
+
|
|
117
|
+
from safetensors.torch import save_file
|
|
118
|
+
|
|
113
119
|
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
|
114
120
|
|
|
115
121
|
# check file size
|
|
@@ -132,11 +138,26 @@ def convert_bin_to_safetensor_file(
|
|
|
132
138
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
|
133
139
|
|
|
134
140
|
|
|
141
|
+
def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
|
|
142
|
+
for prefix, new_prefix in prefix_mapping.items():
|
|
143
|
+
if key.startswith(prefix):
|
|
144
|
+
key = key.replace(prefix, new_prefix, 1)
|
|
145
|
+
return key
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str:
|
|
149
|
+
for substr, new_substr in substring_mapping.items():
|
|
150
|
+
if substr in key:
|
|
151
|
+
key = key.replace(substr, new_substr)
|
|
152
|
+
return key
|
|
153
|
+
|
|
154
|
+
|
|
135
155
|
# TODO(woosuk): Move this to other place.
|
|
136
156
|
def get_quant_config(
|
|
137
157
|
model_config: ModelConfig,
|
|
138
158
|
load_config: LoadConfig,
|
|
139
159
|
packed_modules_mapping: Dict[str, List[str]],
|
|
160
|
+
remap_prefix: Dict[str, str] | None = None,
|
|
140
161
|
) -> QuantizationConfig:
|
|
141
162
|
quant_cls = get_quantization_config(model_config.quantization)
|
|
142
163
|
|
|
@@ -206,35 +227,33 @@ def get_quant_config(
|
|
|
206
227
|
quant_config_file = quant_config_files[0]
|
|
207
228
|
with open(quant_config_file) as f:
|
|
208
229
|
config = json.load(f)
|
|
230
|
+
if remap_prefix is not None:
|
|
231
|
+
exclude_modules = [
|
|
232
|
+
replace_prefix(key, remap_prefix)
|
|
233
|
+
for key in config["quantization"]["exclude_modules"]
|
|
234
|
+
]
|
|
235
|
+
config["quantization"]["exclude_modules"] = exclude_modules
|
|
236
|
+
config["packed_modules_mapping"] = packed_modules_mapping
|
|
209
237
|
|
|
210
238
|
if model_config.quantization == "bitsandbytes":
|
|
211
239
|
config["adapter_name_or_path"] = model_name_or_path
|
|
212
|
-
elif model_config.quantization
|
|
213
|
-
|
|
240
|
+
elif model_config.quantization.startswith("modelopt") and (
|
|
241
|
+
config["producer"]["name"].startswith("modelopt")
|
|
242
|
+
):
|
|
243
|
+
quant_algo = config["quantization"]["quant_algo"]
|
|
244
|
+
if quant_algo is None:
|
|
214
245
|
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
|
215
|
-
if
|
|
216
|
-
|
|
217
|
-
model_config.
|
|
218
|
-
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
return ModelOptFp4Config.from_config(config)
|
|
227
|
-
else:
|
|
228
|
-
return quant_cls.from_config(config)
|
|
229
|
-
else:
|
|
230
|
-
raise ValueError(
|
|
231
|
-
f"Unsupported quantization config"
|
|
232
|
-
f" found for {model_config.quantization} in {f}."
|
|
233
|
-
)
|
|
234
|
-
elif model_config.quantization == "w8a8_int8":
|
|
235
|
-
config["packed_modules_mapping"] = packed_modules_mapping
|
|
236
|
-
|
|
237
|
-
return quant_cls.from_config(config)
|
|
246
|
+
if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3":
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
|
249
|
+
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
|
250
|
+
)
|
|
251
|
+
return None
|
|
252
|
+
elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8":
|
|
253
|
+
return ModelOptFp8Config.from_config(config)
|
|
254
|
+
elif "FP4" in quant_algo:
|
|
255
|
+
return ModelOptFp4Config.from_config(config)
|
|
256
|
+
return quant_cls.from_config(config)
|
|
238
257
|
|
|
239
258
|
|
|
240
259
|
def find_local_hf_snapshot_dir(
|
|
@@ -426,7 +445,7 @@ def download_weights_from_hf(
|
|
|
426
445
|
allow_patterns = [pattern]
|
|
427
446
|
break
|
|
428
447
|
|
|
429
|
-
logger
|
|
448
|
+
log_info_on_rank0(logger, f"Using model weights format {allow_patterns}")
|
|
430
449
|
# Use file lock to prevent multiple processes from
|
|
431
450
|
# downloading the same model weights at the same time.
|
|
432
451
|
with get_lock(model_name_or_path, cache_dir):
|
sglang/srt/models/apertus.py
CHANGED
|
@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
|
46
46
|
ParallelLMHead,
|
|
47
47
|
VocabParallelEmbedding,
|
|
48
48
|
)
|
|
49
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
51
50
|
from sglang.srt.model_loader.weight_utils import (
|
|
52
51
|
default_weight_loader,
|
|
53
52
|
kv_cache_scales_loader,
|
|
54
53
|
maybe_remap_kv_scale_name,
|
|
55
54
|
)
|
|
55
|
+
from sglang.srt.server_args import get_global_server_args
|
|
56
56
|
from sglang.srt.utils import add_prefix, make_layers
|
|
57
|
-
from sglang.utils import get_exception_traceback
|
|
58
57
|
|
|
59
58
|
logger = logging.getLogger(__name__)
|
|
60
59
|
|
|
@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module):
|
|
|
447
446
|
config.hidden_size,
|
|
448
447
|
quant_config=quant_config,
|
|
449
448
|
prefix=add_prefix("lm_head", prefix),
|
|
450
|
-
use_attn_tp_group=
|
|
449
|
+
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
|
451
450
|
)
|
|
452
451
|
self.logits_processor = LogitsProcessor(config)
|
|
453
452
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
sglang/srt/models/arcee.py
CHANGED
|
@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
|
42
42
|
ParallelLMHead,
|
|
43
43
|
VocabParallelEmbedding,
|
|
44
44
|
)
|
|
45
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
46
45
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
47
46
|
from sglang.srt.model_loader.weight_utils import (
|
|
48
47
|
default_weight_loader,
|
|
49
48
|
kv_cache_scales_loader,
|
|
50
49
|
maybe_remap_kv_scale_name,
|
|
51
50
|
)
|
|
51
|
+
from sglang.srt.server_args import get_global_server_args
|
|
52
52
|
from sglang.srt.utils import add_prefix, make_layers
|
|
53
53
|
|
|
54
54
|
logger = logging.getLogger(__name__)
|
|
@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
|
|
|
407
407
|
config.hidden_size,
|
|
408
408
|
quant_config=quant_config,
|
|
409
409
|
prefix=add_prefix("lm_head", prefix),
|
|
410
|
-
use_attn_tp_group=
|
|
410
|
+
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
|
411
411
|
)
|
|
412
412
|
self.logits_processor = LogitsProcessor(config)
|
|
413
413
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
sglang/srt/models/bailing_moe.py
CHANGED
|
@@ -17,9 +17,9 @@
|
|
|
17
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
18
|
# See the License for the specific language governing permissions and
|
|
19
19
|
# limitations under the License.
|
|
20
|
-
"""
|
|
20
|
+
"""SGLang BailingMoE model."""
|
|
21
21
|
import logging
|
|
22
|
-
from typing import
|
|
22
|
+
from typing import Iterable, Optional, Tuple, Union
|
|
23
23
|
|
|
24
24
|
import torch
|
|
25
25
|
import torch.nn.functional as F
|
|
@@ -54,12 +54,11 @@ from sglang.srt.layers.linear import (
|
|
|
54
54
|
RowParallelLinear,
|
|
55
55
|
)
|
|
56
56
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
57
|
-
from sglang.srt.layers.moe import get_moe_a2a_backend
|
|
57
|
+
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
|
58
58
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
|
59
59
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
60
60
|
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
|
61
61
|
from sglang.srt.layers.moe.topk import TopK
|
|
62
|
-
from sglang.srt.layers.moe.utils import DeepEPMode
|
|
63
62
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
64
63
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
65
64
|
from sglang.srt.layers.rotary_embedding import get_rope
|
|
@@ -68,7 +67,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
|
68
67
|
ParallelLMHead,
|
|
69
68
|
VocabParallelEmbedding,
|
|
70
69
|
)
|
|
71
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
72
70
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
|
73
71
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
74
72
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
@@ -76,6 +74,7 @@ from sglang.srt.models.utils import (
|
|
|
76
74
|
create_fused_set_kv_buffer_arg,
|
|
77
75
|
enable_fused_set_kv_buffer,
|
|
78
76
|
)
|
|
77
|
+
from sglang.srt.server_args import get_global_server_args
|
|
79
78
|
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
|
|
80
79
|
|
|
81
80
|
LoraConfig = None
|
|
@@ -204,8 +203,8 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|
|
204
203
|
else:
|
|
205
204
|
self.router_dtype = torch.bfloat16
|
|
206
205
|
|
|
207
|
-
# TODO
|
|
208
|
-
assert
|
|
206
|
+
# TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now
|
|
207
|
+
assert get_global_server_args().ep_num_redundant_experts == 0
|
|
209
208
|
# check group topk
|
|
210
209
|
self.num_expert_group = getattr(config, "n_group", 0)
|
|
211
210
|
self.topk_group = getattr(config, "topk_group", 0)
|
|
@@ -220,7 +219,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|
|
220
219
|
self.use_grouped_topk = False
|
|
221
220
|
|
|
222
221
|
self.num_experts = (
|
|
223
|
-
config.num_experts +
|
|
222
|
+
config.num_experts + get_global_server_args().ep_num_redundant_experts
|
|
224
223
|
)
|
|
225
224
|
|
|
226
225
|
self.gate = BailingMoEGate(
|
|
@@ -293,7 +292,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|
|
293
292
|
num_local_experts=config.num_experts // self.tp_size,
|
|
294
293
|
hidden_size=config.hidden_size,
|
|
295
294
|
params_dtype=config.torch_dtype,
|
|
296
|
-
deepep_mode=
|
|
295
|
+
deepep_mode=get_deepep_mode(),
|
|
297
296
|
async_finish=True, # TODO
|
|
298
297
|
return_recv_hook=True,
|
|
299
298
|
)
|
|
@@ -381,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|
|
381
380
|
if self.num_shared_experts > 0:
|
|
382
381
|
shared_output = self.shared_experts(hidden_states)
|
|
383
382
|
|
|
384
|
-
|
|
383
|
+
topk_output = self.topk(
|
|
385
384
|
hidden_states,
|
|
386
385
|
router_logits,
|
|
387
386
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
|
@@ -390,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|
|
390
389
|
),
|
|
391
390
|
)
|
|
392
391
|
else:
|
|
393
|
-
|
|
394
|
-
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
395
|
-
)
|
|
396
|
-
topk_weights = torch.empty(
|
|
397
|
-
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
if self.ep_size > 1:
|
|
401
|
-
(
|
|
402
|
-
hidden_states,
|
|
403
|
-
topk_idx,
|
|
404
|
-
topk_weights,
|
|
405
|
-
reorder_topk_ids,
|
|
406
|
-
num_recv_tokens_per_expert,
|
|
407
|
-
seg_indptr,
|
|
408
|
-
masked_m,
|
|
409
|
-
expected_m,
|
|
410
|
-
) = self.deepep_dispatcher.dispatch(
|
|
411
|
-
hidden_states,
|
|
412
|
-
topk_idx,
|
|
413
|
-
topk_weights,
|
|
414
|
-
forward_batch=forward_batch,
|
|
415
|
-
)
|
|
392
|
+
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
|
416
393
|
|
|
417
394
|
final_hidden_states = self.experts(
|
|
418
395
|
hidden_states=hidden_states,
|
|
419
|
-
|
|
420
|
-
topk_weights=topk_weights,
|
|
421
|
-
reorder_topk_ids=reorder_topk_ids,
|
|
422
|
-
seg_indptr=seg_indptr,
|
|
423
|
-
masked_m=masked_m,
|
|
424
|
-
expected_m=expected_m,
|
|
425
|
-
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
|
426
|
-
forward_batch=forward_batch,
|
|
396
|
+
topk_output=topk_output,
|
|
427
397
|
)
|
|
428
|
-
if self.ep_size > 1:
|
|
429
|
-
final_hidden_states = self.deepep_dispatcher.combine(
|
|
430
|
-
final_hidden_states,
|
|
431
|
-
topk_idx,
|
|
432
|
-
topk_weights,
|
|
433
|
-
forward_batch=forward_batch,
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
final_hidden_states *= self.routed_scaling_factor
|
|
437
398
|
|
|
438
399
|
if shared_output is not None:
|
|
439
|
-
final_hidden_states
|
|
400
|
+
final_hidden_states += shared_output
|
|
440
401
|
return final_hidden_states
|
|
441
402
|
|
|
442
403
|
|
|
@@ -824,7 +785,7 @@ class BailingMoEForCausalLM(nn.Module):
|
|
|
824
785
|
config.hidden_size,
|
|
825
786
|
quant_config=quant_config,
|
|
826
787
|
prefix=add_prefix("lm_head", prefix),
|
|
827
|
-
use_attn_tp_group=
|
|
788
|
+
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
|
828
789
|
)
|
|
829
790
|
self.logits_processor = LogitsProcessor(config)
|
|
830
791
|
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
18
|
# See the License for the specific language governing permissions and
|
|
19
19
|
# limitations under the License.
|
|
20
|
-
"""
|
|
20
|
+
"""SGLang BailingMoENextN model."""
|
|
21
21
|
import logging
|
|
22
22
|
from typing import Iterable, Optional, Tuple
|
|
23
23
|
|
|
@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
|
29
29
|
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
|
30
30
|
from sglang.srt.layers.layernorm import RMSNorm
|
|
31
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
32
|
-
from sglang.srt.layers.moe.topk import select_experts
|
|
33
32
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
34
33
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
35
34
|
ParallelLMHead,
|
|
36
35
|
VocabParallelEmbedding,
|
|
37
36
|
)
|
|
38
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
39
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
40
38
|
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
|
|
39
|
+
from sglang.srt.server_args import get_global_server_args
|
|
41
40
|
from sglang.srt.utils import add_prefix
|
|
42
41
|
|
|
43
42
|
LoraConfig = None
|
|
@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
|
|
|
145
144
|
config.hidden_size,
|
|
146
145
|
quant_config=quant_config,
|
|
147
146
|
prefix=add_prefix("model.shared_head.head", prefix),
|
|
148
|
-
use_attn_tp_group=
|
|
147
|
+
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
|
149
148
|
)
|
|
150
149
|
self.logits_processor = LogitsProcessor(config)
|
|
151
150
|
|
sglang/srt/models/bert.py
CHANGED
|
@@ -25,14 +25,19 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
|
|
|
25
25
|
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
|
26
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
|
27
27
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
28
|
+
from sglang.srt.layers.quantization import Fp8Config
|
|
28
29
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
29
30
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
30
31
|
ParallelLMHead,
|
|
31
32
|
VocabParallelEmbedding,
|
|
32
33
|
)
|
|
33
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
34
34
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
35
|
-
from sglang.srt.models.deepseek_v2 import
|
|
35
|
+
from sglang.srt.models.deepseek_v2 import (
|
|
36
|
+
DeepseekV2DecoderLayer,
|
|
37
|
+
DeepseekV3ForCausalLM,
|
|
38
|
+
enable_nextn_moe_bf16_cast_to_fp8,
|
|
39
|
+
)
|
|
40
|
+
from sglang.srt.server_args import get_global_server_args
|
|
36
41
|
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
|
|
37
42
|
|
|
38
43
|
logger = logging.getLogger(__name__)
|
|
@@ -49,6 +54,16 @@ class DeepseekModelNextN(nn.Module):
|
|
|
49
54
|
prefix: str = "",
|
|
50
55
|
) -> None:
|
|
51
56
|
super().__init__()
|
|
57
|
+
|
|
58
|
+
if enable_nextn_moe_bf16_cast_to_fp8(quant_config):
|
|
59
|
+
# refer to real DeepSeek V3 quant config
|
|
60
|
+
moe_quant_config = Fp8Config(
|
|
61
|
+
is_checkpoint_fp8_serialized=True,
|
|
62
|
+
weight_block_size=[128, 128],
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
moe_quant_config = None
|
|
66
|
+
|
|
52
67
|
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
|
53
68
|
logger.warning(
|
|
54
69
|
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
|
|
@@ -74,6 +89,7 @@ class DeepseekModelNextN(nn.Module):
|
|
|
74
89
|
config,
|
|
75
90
|
0,
|
|
76
91
|
quant_config=quant_config,
|
|
92
|
+
moe_quant_config=moe_quant_config,
|
|
77
93
|
is_nextn=True,
|
|
78
94
|
prefix=add_prefix("decoder", prefix),
|
|
79
95
|
alt_stream=self.alt_stream,
|
|
@@ -152,7 +168,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
|
152
168
|
config.hidden_size,
|
|
153
169
|
quant_config=quant_config,
|
|
154
170
|
prefix=add_prefix("model.shared_head.head", prefix),
|
|
155
|
-
use_attn_tp_group=
|
|
171
|
+
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
|
156
172
|
)
|
|
157
173
|
self.logits_processor = LogitsProcessor(config)
|
|
158
174
|
|