sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
# ruff: noqa: SIM117
|
|
6
6
|
import collections
|
|
7
|
-
import concurrent
|
|
8
7
|
import dataclasses
|
|
9
8
|
import fnmatch
|
|
10
9
|
import glob
|
|
@@ -12,13 +11,11 @@ import json
|
|
|
12
11
|
import logging
|
|
13
12
|
import math
|
|
14
13
|
import os
|
|
15
|
-
import re
|
|
16
14
|
import socket
|
|
17
15
|
import threading
|
|
18
16
|
import time
|
|
19
17
|
from abc import ABC, abstractmethod
|
|
20
|
-
from
|
|
21
|
-
from contextlib import contextmanager
|
|
18
|
+
from contextlib import contextmanager, suppress
|
|
22
19
|
from typing import (
|
|
23
20
|
TYPE_CHECKING,
|
|
24
21
|
Any,
|
|
@@ -30,17 +27,28 @@ from typing import (
|
|
|
30
27
|
Tuple,
|
|
31
28
|
cast,
|
|
32
29
|
)
|
|
33
|
-
from urllib.parse import urlparse
|
|
34
30
|
|
|
35
31
|
import huggingface_hub
|
|
36
32
|
import numpy as np
|
|
37
|
-
import requests
|
|
38
|
-
import safetensors.torch
|
|
39
33
|
import torch
|
|
34
|
+
|
|
35
|
+
from sglang.srt.server_args import get_global_server_args
|
|
36
|
+
|
|
37
|
+
# Try to import accelerate (optional dependency)
|
|
38
|
+
try:
|
|
39
|
+
from accelerate import infer_auto_device_map, init_empty_weights
|
|
40
|
+
from accelerate.utils import get_max_memory
|
|
41
|
+
|
|
42
|
+
HAS_ACCELERATE = True
|
|
43
|
+
except ImportError:
|
|
44
|
+
HAS_ACCELERATE = False
|
|
45
|
+
infer_auto_device_map = None
|
|
46
|
+
init_empty_weights = None
|
|
47
|
+
get_max_memory = None
|
|
48
|
+
|
|
40
49
|
from huggingface_hub import HfApi, hf_hub_download
|
|
41
50
|
from torch import nn
|
|
42
|
-
from
|
|
43
|
-
from transformers import AutoModelForCausalLM
|
|
51
|
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
|
44
52
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|
45
53
|
|
|
46
54
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
|
@@ -54,6 +62,8 @@ from sglang.srt.distributed import (
|
|
|
54
62
|
get_tensor_model_parallel_rank,
|
|
55
63
|
get_tensor_model_parallel_world_size,
|
|
56
64
|
)
|
|
65
|
+
from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
|
|
66
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
57
67
|
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
|
58
68
|
trigger_transferring_weights_request,
|
|
59
69
|
)
|
|
@@ -62,9 +72,13 @@ from sglang.srt.model_loader.utils import (
|
|
|
62
72
|
post_load_weights,
|
|
63
73
|
set_default_torch_dtype,
|
|
64
74
|
)
|
|
75
|
+
|
|
76
|
+
# Constants for memory management
|
|
77
|
+
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
|
|
78
|
+
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
|
|
79
|
+
)
|
|
80
|
+
from sglang.srt.environ import envs
|
|
65
81
|
from sglang.srt.model_loader.weight_utils import (
|
|
66
|
-
_BAR_FORMAT,
|
|
67
|
-
default_weight_loader,
|
|
68
82
|
download_safetensors_index_file_from_hf,
|
|
69
83
|
download_weights_from_hf,
|
|
70
84
|
filter_duplicate_safetensors_files,
|
|
@@ -85,6 +99,7 @@ from sglang.srt.utils import (
|
|
|
85
99
|
get_device_capability,
|
|
86
100
|
is_npu,
|
|
87
101
|
is_pin_memory_available,
|
|
102
|
+
rank0_log,
|
|
88
103
|
set_weight_attrs,
|
|
89
104
|
)
|
|
90
105
|
|
|
@@ -94,6 +109,8 @@ if TYPE_CHECKING:
|
|
|
94
109
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
95
110
|
|
|
96
111
|
_is_npu = is_npu()
|
|
112
|
+
# ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
|
|
113
|
+
# which contains the complete mapping of quantization config choices
|
|
97
114
|
|
|
98
115
|
|
|
99
116
|
@contextmanager
|
|
@@ -163,11 +180,12 @@ def _get_quantization_config(
|
|
|
163
180
|
model_config: ModelConfig,
|
|
164
181
|
load_config: LoadConfig,
|
|
165
182
|
packed_modules_mapping: Dict[str, List[str]],
|
|
183
|
+
remap_prefix: Dict[str, str] | None = None,
|
|
166
184
|
) -> Optional[QuantizationConfig]:
|
|
167
185
|
"""Get the quantization config."""
|
|
168
186
|
if model_config.quantization is not None:
|
|
169
187
|
quant_config = get_quant_config(
|
|
170
|
-
model_config, load_config, packed_modules_mapping
|
|
188
|
+
model_config, load_config, packed_modules_mapping, remap_prefix
|
|
171
189
|
)
|
|
172
190
|
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
|
173
191
|
if quant_config is None:
|
|
@@ -203,6 +221,7 @@ def _initialize_model(
|
|
|
203
221
|
"""Initialize a model with the given configurations."""
|
|
204
222
|
model_class, _ = get_model_architecture(model_config)
|
|
205
223
|
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
|
224
|
+
remap_prefix = getattr(model_class, "remap_prefix", None)
|
|
206
225
|
if _is_npu:
|
|
207
226
|
packed_modules_mapping.update(
|
|
208
227
|
{
|
|
@@ -226,13 +245,22 @@ def _initialize_model(
|
|
|
226
245
|
)
|
|
227
246
|
|
|
228
247
|
quant_config = _get_quantization_config(
|
|
229
|
-
model_config, load_config, packed_modules_mapping
|
|
230
|
-
)
|
|
231
|
-
return model_class(
|
|
232
|
-
config=model_config.hf_config,
|
|
233
|
-
quant_config=quant_config,
|
|
248
|
+
model_config, load_config, packed_modules_mapping, remap_prefix
|
|
234
249
|
)
|
|
235
250
|
|
|
251
|
+
# Build kwargs conditionally
|
|
252
|
+
kwargs = {
|
|
253
|
+
"config": model_config.hf_config,
|
|
254
|
+
"quant_config": quant_config,
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
# Only add sparse head kwargs if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
|
|
258
|
+
if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set():
|
|
259
|
+
kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.value
|
|
260
|
+
kwargs["model_path"] = model_config.model_path
|
|
261
|
+
|
|
262
|
+
return model_class(**kwargs)
|
|
263
|
+
|
|
236
264
|
|
|
237
265
|
class BaseModelLoader(ABC):
|
|
238
266
|
"""Base class for model loaders."""
|
|
@@ -424,10 +452,8 @@ class DefaultModelLoader(BaseModelLoader):
|
|
|
424
452
|
hf_weights_files,
|
|
425
453
|
)
|
|
426
454
|
elif use_safetensors:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
weight_loader_disable_mmap = global_server_args_dict.get(
|
|
430
|
-
"weight_loader_disable_mmap"
|
|
455
|
+
weight_loader_disable_mmap = (
|
|
456
|
+
get_global_server_args().weight_loader_disable_mmap
|
|
431
457
|
)
|
|
432
458
|
|
|
433
459
|
if extra_config.get("enable_multithread_load"):
|
|
@@ -477,12 +503,87 @@ class DefaultModelLoader(BaseModelLoader):
|
|
|
477
503
|
model_config.model_path, model_config.revision, fall_back_to_pt=True
|
|
478
504
|
)
|
|
479
505
|
|
|
506
|
+
def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
|
|
507
|
+
"""Load and prepare the base model for ModelOpt quantization.
|
|
508
|
+
|
|
509
|
+
This method handles the common model loading logic shared between
|
|
510
|
+
DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
|
|
511
|
+
"""
|
|
512
|
+
if not HAS_ACCELERATE:
|
|
513
|
+
raise ImportError(
|
|
514
|
+
"accelerate is required for ModelOpt quantization. "
|
|
515
|
+
"Please install it with: pip install accelerate"
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
hf_config = AutoConfig.from_pretrained(
|
|
519
|
+
model_config.model_path, trust_remote_code=True
|
|
520
|
+
)
|
|
521
|
+
with init_empty_weights():
|
|
522
|
+
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
|
|
523
|
+
model = AutoModelForCausalLM.from_config(
|
|
524
|
+
hf_config, torch_dtype=torch_dtype, trust_remote_code=True
|
|
525
|
+
)
|
|
526
|
+
max_memory = get_max_memory()
|
|
527
|
+
inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
|
|
528
|
+
|
|
529
|
+
on_cpu = "cpu" in inferred_device_map.values()
|
|
530
|
+
model_kwargs = {"torch_dtype": "auto"}
|
|
531
|
+
device_map = "auto"
|
|
532
|
+
|
|
533
|
+
if on_cpu:
|
|
534
|
+
for device in max_memory.keys():
|
|
535
|
+
if isinstance(device, int):
|
|
536
|
+
max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
|
|
537
|
+
|
|
538
|
+
logger.warning(
|
|
539
|
+
"Model does not fit to the GPU mem. "
|
|
540
|
+
f"We apply the following memory limit for calibration: \n{max_memory}\n"
|
|
541
|
+
f"If you hit GPU OOM issue, please adjust the memory fraction "
|
|
542
|
+
f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
|
|
543
|
+
"reduce the calibration `batch_size` manually."
|
|
544
|
+
)
|
|
545
|
+
model_kwargs["max_memory"] = max_memory
|
|
546
|
+
|
|
547
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
548
|
+
model_config.model_path,
|
|
549
|
+
device_map=device_map,
|
|
550
|
+
**model_kwargs,
|
|
551
|
+
trust_remote_code=True,
|
|
552
|
+
)
|
|
553
|
+
# Handle both legacy modelopt_quant and unified quantization flags
|
|
554
|
+
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
|
555
|
+
# Legacy approach
|
|
556
|
+
quant_choice_str = model_config.modelopt_quant
|
|
557
|
+
rank0_log(f"ModelOpt quantization requested (legacy): {quant_choice_str}")
|
|
558
|
+
else:
|
|
559
|
+
# Unified approach - extract quantization type
|
|
560
|
+
quant_choice_str = model_config._get_modelopt_quant_type()
|
|
561
|
+
rank0_log(
|
|
562
|
+
f"ModelOpt quantization requested (unified): {model_config.quantization} -> {quant_choice_str}"
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
if not isinstance(quant_choice_str, str):
|
|
566
|
+
raise TypeError(
|
|
567
|
+
f"Quantization type must be a string (e.g., 'fp8'), "
|
|
568
|
+
f"got {type(quant_choice_str)}"
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
return model
|
|
572
|
+
|
|
480
573
|
def load_model(
|
|
481
574
|
self,
|
|
482
575
|
*,
|
|
483
576
|
model_config: ModelConfig,
|
|
484
577
|
device_config: DeviceConfig,
|
|
485
578
|
) -> nn.Module:
|
|
579
|
+
|
|
580
|
+
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
|
581
|
+
# Load base model using shared method
|
|
582
|
+
model = self._load_modelopt_base_model(model_config)
|
|
583
|
+
# Note: DefaultModelLoader doesn't do additional quantization processing
|
|
584
|
+
# For full ModelOpt quantization, use ModelOptModelLoader
|
|
585
|
+
return model.eval()
|
|
586
|
+
|
|
486
587
|
target_device = torch.device(device_config.device)
|
|
487
588
|
with set_default_torch_dtype(model_config.dtype):
|
|
488
589
|
with target_device:
|
|
@@ -491,9 +592,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|
|
491
592
|
self.load_config,
|
|
492
593
|
)
|
|
493
594
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
595
|
+
self.load_weights_and_postprocess(
|
|
596
|
+
model, self._get_all_weights(model_config, model), target_device
|
|
597
|
+
)
|
|
497
598
|
|
|
498
599
|
return model.eval()
|
|
499
600
|
|
|
@@ -511,6 +612,8 @@ class DefaultModelLoader(BaseModelLoader):
|
|
|
511
612
|
# parameters onto device for processing and back off after.
|
|
512
613
|
with device_loading_context(module, target_device):
|
|
513
614
|
quant_method.process_weights_after_loading(module)
|
|
615
|
+
if _is_npu:
|
|
616
|
+
torch.npu.empty_cache()
|
|
514
617
|
|
|
515
618
|
|
|
516
619
|
class LayeredModelLoader(DefaultModelLoader):
|
|
@@ -529,9 +632,9 @@ class LayeredModelLoader(DefaultModelLoader):
|
|
|
529
632
|
device_config: DeviceConfig,
|
|
530
633
|
) -> nn.Module:
|
|
531
634
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
|
532
|
-
from sglang.srt.
|
|
635
|
+
from sglang.srt.server_args import get_global_server_args
|
|
533
636
|
|
|
534
|
-
torchao_config =
|
|
637
|
+
torchao_config = get_global_server_args().torchao_config
|
|
535
638
|
target_device = torch.device(device_config.device)
|
|
536
639
|
|
|
537
640
|
with set_default_torch_dtype(model_config.dtype):
|
|
@@ -1668,9 +1771,303 @@ def load_model_with_cpu_quantization(
|
|
|
1668
1771
|
return model.eval()
|
|
1669
1772
|
|
|
1670
1773
|
|
|
1671
|
-
|
|
1774
|
+
class ModelOptModelLoader(DefaultModelLoader):
|
|
1775
|
+
"""
|
|
1776
|
+
Model loader that applies NVIDIA Model Optimizer quantization
|
|
1777
|
+
"""
|
|
1778
|
+
|
|
1779
|
+
def __init__(self, load_config: LoadConfig):
|
|
1780
|
+
super().__init__(load_config)
|
|
1781
|
+
# Any ModelOpt specific initialization if needed
|
|
1782
|
+
|
|
1783
|
+
def _setup_modelopt_quantization(
|
|
1784
|
+
self,
|
|
1785
|
+
model,
|
|
1786
|
+
tokenizer,
|
|
1787
|
+
quant_cfg,
|
|
1788
|
+
quantized_ckpt_restore_path: str | None = None,
|
|
1789
|
+
quantized_ckpt_save_path: str | None = None,
|
|
1790
|
+
export_path: str | None = None,
|
|
1791
|
+
) -> None:
|
|
1792
|
+
"""
|
|
1793
|
+
Set up ModelOpt quantization for the given model.
|
|
1794
|
+
|
|
1795
|
+
Args:
|
|
1796
|
+
model: The model to quantize
|
|
1797
|
+
tokenizer: The tokenizer associated with the model
|
|
1798
|
+
quant_cfg: The quantization configuration
|
|
1799
|
+
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
|
|
1800
|
+
quantized_ckpt_save_path: Path to save quantized checkpoint to
|
|
1801
|
+
export_path: Path to export the quantized model in HuggingFace format
|
|
1802
|
+
|
|
1803
|
+
Raises:
|
|
1804
|
+
ImportError: If ModelOpt is not available
|
|
1805
|
+
Exception: If quantization setup fails
|
|
1806
|
+
"""
|
|
1807
|
+
try:
|
|
1808
|
+
import modelopt.torch.opt as mto
|
|
1809
|
+
import modelopt.torch.quantization as mtq
|
|
1810
|
+
from modelopt.torch.quantization.utils import is_quantized
|
|
1811
|
+
except ImportError as e:
|
|
1812
|
+
raise ImportError(
|
|
1813
|
+
"ModelOpt is not available. Please install modelopt."
|
|
1814
|
+
) from e
|
|
1815
|
+
|
|
1816
|
+
if is_quantized(model):
|
|
1817
|
+
rank0_log("Model is already quantized, skipping quantization setup.")
|
|
1818
|
+
return
|
|
1819
|
+
# Restore from checkpoint if provided
|
|
1820
|
+
if quantized_ckpt_restore_path:
|
|
1821
|
+
try:
|
|
1822
|
+
mto.restore(model, quantized_ckpt_restore_path)
|
|
1823
|
+
rank0_log(
|
|
1824
|
+
f"Restored quantized model from {quantized_ckpt_restore_path}"
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
# Export model if path provided (even when restoring from checkpoint)
|
|
1828
|
+
self._maybe_export_modelopt(model, export_path)
|
|
1829
|
+
return
|
|
1830
|
+
except Exception as e:
|
|
1831
|
+
logger.warning(
|
|
1832
|
+
f"Failed to restore from {quantized_ckpt_restore_path}: {e}"
|
|
1833
|
+
)
|
|
1834
|
+
rank0_log("Proceeding with calibration-based quantization...")
|
|
1835
|
+
|
|
1836
|
+
# Set up calibration-based quantization
|
|
1837
|
+
try:
|
|
1838
|
+
# Left padding tends to work better for batched generation with decoder-only LMs
|
|
1839
|
+
with suppress(Exception):
|
|
1840
|
+
tokenizer.padding_side = "left"
|
|
1841
|
+
|
|
1842
|
+
from modelopt.torch.utils.dataset_utils import (
|
|
1843
|
+
create_forward_loop,
|
|
1844
|
+
get_dataset_dataloader,
|
|
1845
|
+
)
|
|
1846
|
+
|
|
1847
|
+
# Create calibration dataloader
|
|
1848
|
+
calib_dataloader = get_dataset_dataloader(
|
|
1849
|
+
dataset_name="cnn_dailymail", # TODO: Consider making this configurable
|
|
1850
|
+
tokenizer=tokenizer,
|
|
1851
|
+
batch_size=36, # TODO: Consider making this configurable
|
|
1852
|
+
num_samples=512, # TODO: Consider making this configurable
|
|
1853
|
+
device=model.device,
|
|
1854
|
+
include_labels=False,
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1857
|
+
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
|
|
1858
|
+
|
|
1859
|
+
# Apply quantization
|
|
1860
|
+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
|
1861
|
+
|
|
1862
|
+
if get_tensor_model_parallel_rank() == 0:
|
|
1863
|
+
mtq.print_quant_summary(model)
|
|
1864
|
+
|
|
1865
|
+
# Save checkpoint if path provided
|
|
1866
|
+
if quantized_ckpt_save_path:
|
|
1867
|
+
try:
|
|
1868
|
+
mto.save(model, quantized_ckpt_save_path)
|
|
1869
|
+
rank0_log(f"Quantized model saved to {quantized_ckpt_save_path}")
|
|
1870
|
+
except Exception as e:
|
|
1871
|
+
logger.warning(
|
|
1872
|
+
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
|
|
1873
|
+
)
|
|
1874
|
+
|
|
1875
|
+
# Export model if path provided
|
|
1876
|
+
self._maybe_export_modelopt(model, export_path)
|
|
1877
|
+
|
|
1878
|
+
except Exception as e:
|
|
1879
|
+
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
|
|
1880
|
+
|
|
1881
|
+
def _maybe_export_modelopt(self, model, export_path: str | None) -> None:
|
|
1882
|
+
"""Export model to HuggingFace format if export_path is provided."""
|
|
1883
|
+
if export_path:
|
|
1884
|
+
try:
|
|
1885
|
+
# Get the original model path from the model config
|
|
1886
|
+
original_model_path = getattr(self, "_original_model_path", None)
|
|
1887
|
+
self._export_modelopt_checkpoint(
|
|
1888
|
+
model, export_path, original_model_path
|
|
1889
|
+
)
|
|
1890
|
+
rank0_log(
|
|
1891
|
+
f"Quantized model exported to HuggingFace format at {export_path}"
|
|
1892
|
+
)
|
|
1893
|
+
except Exception as e:
|
|
1894
|
+
rank0_log(
|
|
1895
|
+
f"Warning: Failed to export quantized model to {export_path}: {e}"
|
|
1896
|
+
)
|
|
1897
|
+
|
|
1898
|
+
def _export_modelopt_checkpoint(
|
|
1899
|
+
self,
|
|
1900
|
+
model,
|
|
1901
|
+
export_path: str,
|
|
1902
|
+
model_path: str = None,
|
|
1903
|
+
trust_remote_code: bool = True,
|
|
1904
|
+
) -> None:
|
|
1905
|
+
"""
|
|
1906
|
+
Export the quantized model to HuggingFace format using ModelOpt export API.
|
|
1907
|
+
|
|
1908
|
+
Args:
|
|
1909
|
+
model: The quantized model to export
|
|
1910
|
+
export_path: Directory path to export the model to
|
|
1911
|
+
model_path: Path to the original model (for tokenizer export)
|
|
1912
|
+
trust_remote_code: Whether to trust remote code for tokenizer loading
|
|
1913
|
+
|
|
1914
|
+
Raises:
|
|
1915
|
+
ImportError: If ModelOpt export functionality is not available
|
|
1916
|
+
Exception: If export fails
|
|
1917
|
+
"""
|
|
1918
|
+
try:
|
|
1919
|
+
from modelopt.torch.export import export_hf_checkpoint
|
|
1920
|
+
from transformers import AutoTokenizer
|
|
1921
|
+
except ImportError as e:
|
|
1922
|
+
raise ImportError(
|
|
1923
|
+
"ModelOpt export functionality is not available. "
|
|
1924
|
+
"Please ensure you have the latest version of modelopt installed."
|
|
1925
|
+
) from e
|
|
1926
|
+
|
|
1927
|
+
# Create export directory if it doesn't exist
|
|
1928
|
+
os.makedirs(export_path, exist_ok=True)
|
|
1929
|
+
|
|
1930
|
+
# Export the quantized model
|
|
1931
|
+
export_hf_checkpoint(model, export_dir=export_path)
|
|
1932
|
+
|
|
1933
|
+
# Export the tokenizer if model_path is provided
|
|
1934
|
+
if model_path:
|
|
1935
|
+
try:
|
|
1936
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
1937
|
+
model_path, trust_remote_code=trust_remote_code
|
|
1938
|
+
)
|
|
1939
|
+
tokenizer.save_pretrained(export_path)
|
|
1940
|
+
rank0_log(f"Tokenizer exported to {export_path}")
|
|
1941
|
+
except Exception as e:
|
|
1942
|
+
rank0_log(f"Warning: Failed to export tokenizer: {e}")
|
|
1943
|
+
|
|
1944
|
+
def load_model(
|
|
1945
|
+
self,
|
|
1946
|
+
*,
|
|
1947
|
+
model_config: ModelConfig,
|
|
1948
|
+
device_config: DeviceConfig,
|
|
1949
|
+
) -> nn.Module:
|
|
1950
|
+
|
|
1951
|
+
logger.info("ModelOptModelLoader: Loading base model...")
|
|
1952
|
+
|
|
1953
|
+
# Store the original model path for tokenizer export
|
|
1954
|
+
self._original_model_path = model_config.model_path
|
|
1955
|
+
|
|
1956
|
+
# Check if model is already quantized
|
|
1957
|
+
if model_config._is_already_quantized():
|
|
1958
|
+
logger.info("Model is already quantized, loading directly...")
|
|
1959
|
+
# Use default loading for pre-quantized models
|
|
1960
|
+
return super().load_model(
|
|
1961
|
+
model_config=model_config, device_config=device_config
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
# TODO: Quantize-and-serve mode has been disabled at the ModelConfig level
|
|
1965
|
+
# All quantization now uses the standard workflow (quantize + export/save)
|
|
1966
|
+
logger.info("Standard quantization mode: Will quantize and export/save")
|
|
1967
|
+
return self._standard_quantization_workflow(model_config, device_config)
|
|
1968
|
+
|
|
1969
|
+
def _standard_quantization_workflow(
|
|
1970
|
+
self, model_config: ModelConfig, device_config: DeviceConfig
|
|
1971
|
+
) -> nn.Module:
|
|
1972
|
+
"""Standard quantization workflow: quantize, save checkpoint, export, then return model."""
|
|
1973
|
+
# Use shared method from parent class to load base model for quantization
|
|
1974
|
+
model = self._load_modelopt_base_model(model_config)
|
|
1975
|
+
|
|
1976
|
+
# Import ModelOpt modules
|
|
1977
|
+
try:
|
|
1978
|
+
import modelopt.torch.quantization as mtq
|
|
1979
|
+
except ImportError:
|
|
1980
|
+
logger.error(
|
|
1981
|
+
"NVIDIA Model Optimizer (modelopt) library not found. "
|
|
1982
|
+
"Please install it to use ModelOpt quantization."
|
|
1983
|
+
)
|
|
1984
|
+
raise
|
|
1985
|
+
|
|
1986
|
+
# Handle both old modelopt_quant and new unified quantization flags
|
|
1987
|
+
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
|
1988
|
+
# Legacy modelopt_quant flag
|
|
1989
|
+
quant_choice_str = model_config.modelopt_quant
|
|
1990
|
+
else:
|
|
1991
|
+
# Unified quantization flag - extract the type (fp8/fp4)
|
|
1992
|
+
quant_choice_str = model_config._get_modelopt_quant_type()
|
|
1993
|
+
|
|
1994
|
+
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
|
1995
|
+
if not quant_cfg_name:
|
|
1996
|
+
raise ValueError(
|
|
1997
|
+
f"Invalid quantization choice: '{quant_choice_str}'. "
|
|
1998
|
+
f"Available choices: {list(QUANT_CFG_CHOICES.keys())}"
|
|
1999
|
+
)
|
|
2000
|
+
|
|
2001
|
+
try:
|
|
2002
|
+
# getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
|
|
2003
|
+
quant_cfg = getattr(mtq, quant_cfg_name)
|
|
2004
|
+
except AttributeError:
|
|
2005
|
+
raise AttributeError(
|
|
2006
|
+
f"ModelOpt quantization config '{quant_cfg_name}' not found. "
|
|
2007
|
+
"Please verify the ModelOpt library installation."
|
|
2008
|
+
)
|
|
2009
|
+
|
|
2010
|
+
logger.info(
|
|
2011
|
+
f"Quantizing model with ModelOpt using config: mtq.{quant_cfg_name}"
|
|
2012
|
+
)
|
|
2013
|
+
|
|
2014
|
+
# Get ModelOpt configuration from LoadConfig
|
|
2015
|
+
modelopt_config = self.load_config.modelopt_config
|
|
2016
|
+
quantized_ckpt_restore_path = (
|
|
2017
|
+
modelopt_config.checkpoint_restore_path if modelopt_config else None
|
|
2018
|
+
)
|
|
2019
|
+
quantized_ckpt_save_path = (
|
|
2020
|
+
modelopt_config.checkpoint_save_path if modelopt_config else None
|
|
2021
|
+
)
|
|
2022
|
+
export_path = modelopt_config.export_path if modelopt_config else None
|
|
2023
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
2024
|
+
model_config.model_path, use_fast=True
|
|
2025
|
+
)
|
|
2026
|
+
|
|
2027
|
+
try:
|
|
2028
|
+
self._setup_modelopt_quantization(
|
|
2029
|
+
model,
|
|
2030
|
+
tokenizer,
|
|
2031
|
+
quant_cfg,
|
|
2032
|
+
quantized_ckpt_restore_path=quantized_ckpt_restore_path,
|
|
2033
|
+
quantized_ckpt_save_path=quantized_ckpt_save_path,
|
|
2034
|
+
export_path=export_path,
|
|
2035
|
+
)
|
|
2036
|
+
except Exception as e:
|
|
2037
|
+
logger.warning(f"ModelOpt quantization failed: {e}")
|
|
2038
|
+
rank0_log("Proceeding without quantization...")
|
|
2039
|
+
|
|
2040
|
+
return model.eval()
|
|
2041
|
+
|
|
2042
|
+
|
|
2043
|
+
def get_model_loader(
|
|
2044
|
+
load_config: LoadConfig, model_config: Optional[ModelConfig] = None
|
|
2045
|
+
) -> BaseModelLoader:
|
|
1672
2046
|
"""Get a model loader based on the load format."""
|
|
1673
2047
|
|
|
2048
|
+
if model_config and (
|
|
2049
|
+
(hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant)
|
|
2050
|
+
or model_config.quantization in ["modelopt_fp8", "modelopt_fp4", "modelopt"]
|
|
2051
|
+
):
|
|
2052
|
+
logger.info("Using ModelOptModelLoader due to ModelOpt quantization config.")
|
|
2053
|
+
return ModelOptModelLoader(load_config)
|
|
2054
|
+
|
|
2055
|
+
# Use ModelOptModelLoader for unified quantization flags
|
|
2056
|
+
if (
|
|
2057
|
+
model_config
|
|
2058
|
+
and hasattr(model_config, "quantization")
|
|
2059
|
+
and model_config.quantization in ["modelopt_fp8", "modelopt_fp4"]
|
|
2060
|
+
):
|
|
2061
|
+
if model_config._is_already_quantized():
|
|
2062
|
+
logger.info(
|
|
2063
|
+
f"Using ModelOptModelLoader for pre-quantized model: {model_config.quantization}"
|
|
2064
|
+
)
|
|
2065
|
+
else:
|
|
2066
|
+
logger.info(
|
|
2067
|
+
f"Using ModelOptModelLoader for quantization: {model_config.quantization}"
|
|
2068
|
+
)
|
|
2069
|
+
return ModelOptModelLoader(load_config)
|
|
2070
|
+
|
|
1674
2071
|
if isinstance(load_config.load_format, type):
|
|
1675
2072
|
return load_config.load_format(load_config)
|
|
1676
2073
|
|
sglang/srt/model_loader/utils.py
CHANGED
|
@@ -99,7 +99,6 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
|
|
|
99
99
|
|
|
100
100
|
if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS:
|
|
101
101
|
architectures = resolve_transformers_arch(model_config, architectures)
|
|
102
|
-
|
|
103
102
|
return ModelRegistry.resolve_model_cls(architectures)
|
|
104
103
|
|
|
105
104
|
|