sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
# Copyright 2025 Qwen Team
|
|
2
|
+
# Copyright 2025 SGLang Team
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
|
16
|
+
import logging
|
|
17
|
+
from functools import lru_cache
|
|
18
|
+
from typing import Iterable, Optional, Tuple, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
|
|
23
|
+
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
|
|
24
|
+
from sglang.srt.distributed import (
|
|
25
|
+
get_moe_expert_parallel_world_size,
|
|
26
|
+
get_tensor_model_parallel_rank,
|
|
27
|
+
)
|
|
28
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
29
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
30
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
31
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
32
|
+
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
|
33
|
+
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
|
34
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
cached_get_processor = lru_cache(get_processor)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
config: Qwen3VLMoeTextConfig,
|
|
46
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
47
|
+
prefix: str = "",
|
|
48
|
+
):
|
|
49
|
+
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
|
50
|
+
self.hidden_size = config.hidden_size
|
|
51
|
+
|
|
52
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
|
53
|
+
return self.embed_tokens
|
|
54
|
+
|
|
55
|
+
def forward(
|
|
56
|
+
self,
|
|
57
|
+
input_ids: torch.Tensor,
|
|
58
|
+
positions: torch.Tensor,
|
|
59
|
+
forward_batch: ForwardBatch,
|
|
60
|
+
input_embeds: torch.Tensor = None,
|
|
61
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
62
|
+
input_deepstack_embeds: Optional[torch.Tensor] = None,
|
|
63
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
|
64
|
+
if self.pp_group.is_first_rank:
|
|
65
|
+
if input_embeds is None:
|
|
66
|
+
hidden_states = self.embed_tokens(input_ids)
|
|
67
|
+
else:
|
|
68
|
+
hidden_states = input_embeds
|
|
69
|
+
residual = None
|
|
70
|
+
else:
|
|
71
|
+
assert pp_proxy_tensors is not None
|
|
72
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
|
73
|
+
residual = pp_proxy_tensors["residual"]
|
|
74
|
+
|
|
75
|
+
aux_hidden_states = []
|
|
76
|
+
for layer_idx, layer in enumerate(
|
|
77
|
+
self.layers[self.start_layer : self.end_layer]
|
|
78
|
+
):
|
|
79
|
+
layer_idx += self.start_layer
|
|
80
|
+
if layer_idx in self.layers_to_capture:
|
|
81
|
+
aux_hidden_states.append(
|
|
82
|
+
hidden_states + residual if residual is not None else hidden_states
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
hidden_states, residual = layer(
|
|
86
|
+
positions,
|
|
87
|
+
hidden_states,
|
|
88
|
+
forward_batch,
|
|
89
|
+
residual,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# process deepstack
|
|
93
|
+
if input_deepstack_embeds is not None and layer_idx < 3:
|
|
94
|
+
sep = self.hidden_size * layer_idx
|
|
95
|
+
hidden_states.add_(
|
|
96
|
+
input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if not self.pp_group.is_last_rank:
|
|
100
|
+
return PPProxyTensors(
|
|
101
|
+
{
|
|
102
|
+
"hidden_states": hidden_states,
|
|
103
|
+
"residual": residual,
|
|
104
|
+
}
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
if hidden_states.shape[0] != 0:
|
|
108
|
+
if residual is None:
|
|
109
|
+
hidden_states = self.norm(hidden_states)
|
|
110
|
+
else:
|
|
111
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
|
112
|
+
|
|
113
|
+
if len(aux_hidden_states) == 0:
|
|
114
|
+
return hidden_states
|
|
115
|
+
|
|
116
|
+
return hidden_states, aux_hidden_states
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def load_fused_expert_weights(
|
|
120
|
+
name: str,
|
|
121
|
+
params_dict: dict,
|
|
122
|
+
loaded_weight: torch.Tensor,
|
|
123
|
+
shard_id: str,
|
|
124
|
+
num_experts: int,
|
|
125
|
+
):
|
|
126
|
+
param = params_dict[name]
|
|
127
|
+
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
|
128
|
+
weight_loader = param.weight_loader
|
|
129
|
+
ep_rank = get_tensor_model_parallel_rank()
|
|
130
|
+
ep_size = get_moe_expert_parallel_world_size()
|
|
131
|
+
if ep_size == 1:
|
|
132
|
+
for expert_id in range(num_experts):
|
|
133
|
+
curr_expert_weight = loaded_weight[expert_id]
|
|
134
|
+
weight_loader(
|
|
135
|
+
param,
|
|
136
|
+
curr_expert_weight,
|
|
137
|
+
name,
|
|
138
|
+
shard_id,
|
|
139
|
+
expert_id,
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
experts_per_ep = num_experts // ep_size
|
|
143
|
+
start_expert = ep_rank * experts_per_ep
|
|
144
|
+
end_expert = (
|
|
145
|
+
(ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
for idx, expert_id in enumerate(range(start_expert, end_expert)):
|
|
149
|
+
curr_expert_weight = loaded_weight[expert_id]
|
|
150
|
+
weight_loader(
|
|
151
|
+
param,
|
|
152
|
+
curr_expert_weight,
|
|
153
|
+
name,
|
|
154
|
+
shard_id,
|
|
155
|
+
idx,
|
|
156
|
+
)
|
|
157
|
+
return True
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
config: Qwen3VLMoeConfig,
|
|
164
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
165
|
+
prefix: str = "",
|
|
166
|
+
language_model_cls=Qwen3MoeLLMModel,
|
|
167
|
+
):
|
|
168
|
+
super().__init__(config, quant_config, prefix, language_model_cls)
|
|
169
|
+
|
|
170
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
171
|
+
stacked_params_mapping = [
|
|
172
|
+
# (param_name, shard_name, shard_id)
|
|
173
|
+
(".qkv_proj", ".q_proj", "q"),
|
|
174
|
+
(".qkv_proj", ".k_proj", "k"),
|
|
175
|
+
(".qkv_proj", ".v_proj", "v"),
|
|
176
|
+
("gate_up_proj", "up_proj", 1),
|
|
177
|
+
("gate_up_proj", "gate_proj", 0),
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
181
|
+
ckpt_gate_proj_name="gate_proj",
|
|
182
|
+
ckpt_down_proj_name="down_proj",
|
|
183
|
+
ckpt_up_proj_name="up_proj",
|
|
184
|
+
num_experts=self.config.num_experts,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
|
188
|
+
ignore_suffixes = (
|
|
189
|
+
".bias",
|
|
190
|
+
"_bias",
|
|
191
|
+
".k_scale",
|
|
192
|
+
"_k_scale",
|
|
193
|
+
".v_scale",
|
|
194
|
+
"_v_scale",
|
|
195
|
+
".weight_scale",
|
|
196
|
+
"_weight_scale",
|
|
197
|
+
".input_scale",
|
|
198
|
+
"_input_scale",
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
is_fused_expert = False
|
|
202
|
+
fused_expert_params_mapping = [
|
|
203
|
+
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
|
204
|
+
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
num_experts = self.config.num_experts
|
|
208
|
+
|
|
209
|
+
# Cache params_dict to avoid repeated expensive traversal of model parameters
|
|
210
|
+
if not hasattr(self, "_cached_params_dict"):
|
|
211
|
+
self._cached_params_dict = dict(self.named_parameters())
|
|
212
|
+
params_dict = self._cached_params_dict
|
|
213
|
+
for name, loaded_weight in weights:
|
|
214
|
+
name = name.replace(r"model.language_model.", r"model.")
|
|
215
|
+
|
|
216
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
217
|
+
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
|
218
|
+
is_fused_expert = True
|
|
219
|
+
expert_params_mapping = fused_expert_params_mapping
|
|
220
|
+
|
|
221
|
+
# Skip non-stacked layers and experts (experts handled below).
|
|
222
|
+
if weight_name not in name:
|
|
223
|
+
continue
|
|
224
|
+
if "visual" in name:
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
228
|
+
# Since we handle the experts below in expert_params_mapping,
|
|
229
|
+
# we need to skip here BEFORE we update the name, otherwise
|
|
230
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
231
|
+
# will then be updated below in expert_params_mapping
|
|
232
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
233
|
+
if "mlp.experts" in name:
|
|
234
|
+
continue
|
|
235
|
+
name = name.replace(weight_name, param_name)
|
|
236
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
|
237
|
+
if name.endswith(ignore_suffixes) and name not in params_dict:
|
|
238
|
+
continue
|
|
239
|
+
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
|
|
240
|
+
# if is_pp_missing_parameter(name, self):
|
|
241
|
+
# continue
|
|
242
|
+
|
|
243
|
+
if name not in params_dict:
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
param = params_dict[name]
|
|
247
|
+
weight_loader = param.weight_loader
|
|
248
|
+
weight_loader(param, loaded_weight, shard_id)
|
|
249
|
+
break
|
|
250
|
+
else:
|
|
251
|
+
# Track if this is an expert weight to enable early skipping
|
|
252
|
+
is_expert_weight = False
|
|
253
|
+
|
|
254
|
+
for mapping in expert_params_mapping:
|
|
255
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
|
256
|
+
if weight_name not in name:
|
|
257
|
+
continue
|
|
258
|
+
if "visual" in name:
|
|
259
|
+
continue
|
|
260
|
+
# Anyway, this is an expert weight and should not be
|
|
261
|
+
# attempted to load as other weights later
|
|
262
|
+
is_expert_weight = True
|
|
263
|
+
name_mapped = name.replace(weight_name, param_name)
|
|
264
|
+
if is_fused_expert:
|
|
265
|
+
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
|
266
|
+
if "experts.gate_up_proj" in name:
|
|
267
|
+
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
|
268
|
+
load_fused_expert_weights(
|
|
269
|
+
name_mapped,
|
|
270
|
+
params_dict,
|
|
271
|
+
loaded_weight[0],
|
|
272
|
+
"w1",
|
|
273
|
+
num_experts,
|
|
274
|
+
)
|
|
275
|
+
load_fused_expert_weights(
|
|
276
|
+
name_mapped,
|
|
277
|
+
params_dict,
|
|
278
|
+
loaded_weight[1],
|
|
279
|
+
"w3",
|
|
280
|
+
num_experts,
|
|
281
|
+
)
|
|
282
|
+
else:
|
|
283
|
+
load_fused_expert_weights(
|
|
284
|
+
name_mapped,
|
|
285
|
+
params_dict,
|
|
286
|
+
loaded_weight,
|
|
287
|
+
shard_id,
|
|
288
|
+
num_experts,
|
|
289
|
+
)
|
|
290
|
+
else:
|
|
291
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
|
292
|
+
if (
|
|
293
|
+
name_mapped.endswith(ignore_suffixes)
|
|
294
|
+
and name_mapped not in params_dict
|
|
295
|
+
):
|
|
296
|
+
continue
|
|
297
|
+
param = params_dict[name_mapped]
|
|
298
|
+
# We should ask the weight loader to return success or
|
|
299
|
+
# not here since otherwise we may skip experts with
|
|
300
|
+
# # other available replicas.
|
|
301
|
+
weight_loader = param.weight_loader
|
|
302
|
+
weight_loader(
|
|
303
|
+
param,
|
|
304
|
+
loaded_weight,
|
|
305
|
+
name_mapped,
|
|
306
|
+
shard_id=shard_id,
|
|
307
|
+
expert_id=expert_id,
|
|
308
|
+
)
|
|
309
|
+
name = name_mapped
|
|
310
|
+
break
|
|
311
|
+
else:
|
|
312
|
+
if is_expert_weight:
|
|
313
|
+
# This is an expert weight but not mapped to this rank, skip all remaining processing
|
|
314
|
+
continue
|
|
315
|
+
if "visual" in name:
|
|
316
|
+
# adapt to VisionAttention
|
|
317
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
|
318
|
+
name = name.replace(r"model.visual.", r"visual.")
|
|
319
|
+
|
|
320
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
|
321
|
+
if name.endswith(ignore_suffixes) and name not in params_dict:
|
|
322
|
+
continue
|
|
323
|
+
|
|
324
|
+
if name in params_dict.keys():
|
|
325
|
+
param = params_dict[name]
|
|
326
|
+
weight_loader = getattr(
|
|
327
|
+
param, "weight_loader", default_weight_loader
|
|
328
|
+
)
|
|
329
|
+
weight_loader(param, loaded_weight)
|
|
330
|
+
else:
|
|
331
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
|
332
|
+
|
|
333
|
+
# TODO mimic deepseek
|
|
334
|
+
# Lazy initialization of expert weights cache to avoid slowing down load_weights
|
|
335
|
+
# if not hasattr(self, "routed_experts_weights_of_layer"):
|
|
336
|
+
# self.routed_experts_weights_of_layer = {
|
|
337
|
+
# layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
|
338
|
+
# for layer_id in range(self.start_layer, self.end_layer)
|
|
339
|
+
# if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
|
|
340
|
+
# }
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
EntryClass = Qwen3VLMoeForConditionalGeneration
|
sglang/srt/models/registry.py
CHANGED
|
@@ -17,6 +17,18 @@ class _ModelRegistry:
|
|
|
17
17
|
# Keyed by model_arch
|
|
18
18
|
models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
|
|
19
19
|
|
|
20
|
+
def register(self, package_name: str, overwrite: bool = False):
|
|
21
|
+
new_models = import_model_classes(package_name)
|
|
22
|
+
if overwrite:
|
|
23
|
+
self.models.update(new_models)
|
|
24
|
+
else:
|
|
25
|
+
for arch, cls in new_models.items():
|
|
26
|
+
if arch in self.models:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Model architecture {arch} already registered. Set overwrite=True to replace."
|
|
29
|
+
)
|
|
30
|
+
self.models[arch] = cls
|
|
31
|
+
|
|
20
32
|
def get_supported_archs(self) -> AbstractSet[str]:
|
|
21
33
|
return self.models.keys()
|
|
22
34
|
|
|
@@ -74,9 +86,8 @@ class _ModelRegistry:
|
|
|
74
86
|
|
|
75
87
|
|
|
76
88
|
@lru_cache()
|
|
77
|
-
def import_model_classes():
|
|
89
|
+
def import_model_classes(package_name: str):
|
|
78
90
|
model_arch_name_to_cls = {}
|
|
79
|
-
package_name = "sglang.srt.models"
|
|
80
91
|
package = importlib.import_module(package_name)
|
|
81
92
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
|
82
93
|
if not ispkg:
|
|
@@ -104,4 +115,5 @@ def import_model_classes():
|
|
|
104
115
|
return model_arch_name_to_cls
|
|
105
116
|
|
|
106
117
|
|
|
107
|
-
ModelRegistry = _ModelRegistry(
|
|
118
|
+
ModelRegistry = _ModelRegistry()
|
|
119
|
+
ModelRegistry.register("sglang.srt.models")
|
sglang/srt/models/roberta.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import os
|
|
4
4
|
from typing import Iterable, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import torch
|
|
@@ -8,10 +8,12 @@ from torch import nn
|
|
|
8
8
|
|
|
9
9
|
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
|
10
10
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
11
|
+
from sglang.srt.layers.sparse_pooler import SparsePooler
|
|
11
12
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
12
13
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
13
14
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
14
15
|
from sglang.srt.models.bert import BertEncoder
|
|
16
|
+
from sglang.srt.utils.hf_transformers_utils import download_from_hf
|
|
15
17
|
|
|
16
18
|
RobertaConfig = None
|
|
17
19
|
|
|
@@ -206,12 +208,29 @@ class XLMRobertaModel(nn.Module):
|
|
|
206
208
|
config: RobertaConfig,
|
|
207
209
|
quant_config: Optional[QuantizationConfig] = None,
|
|
208
210
|
prefix: str = "",
|
|
211
|
+
sparse_head: Optional[str] = None,
|
|
212
|
+
model_path: Optional[str] = None,
|
|
209
213
|
):
|
|
210
214
|
super().__init__()
|
|
211
215
|
self.roberta = XLMRobertaBaseModel(
|
|
212
216
|
config=config, quant_config=quant_config, prefix=prefix
|
|
213
217
|
)
|
|
214
|
-
|
|
218
|
+
if sparse_head is not None:
|
|
219
|
+
self._is_sparse = True
|
|
220
|
+
self._model_path = model_path
|
|
221
|
+
self._sparse_head = sparse_head
|
|
222
|
+
self.pooler = SparsePooler(config=config)
|
|
223
|
+
# Zero out special tokens
|
|
224
|
+
self._special_tokens = [
|
|
225
|
+
config.bos_token_id,
|
|
226
|
+
config.eos_token_id,
|
|
227
|
+
config.pad_token_id,
|
|
228
|
+
# self.config.unk_token_id # not available in the XLMRobertaConfig
|
|
229
|
+
]
|
|
230
|
+
self._special_tokens = [t for t in self._special_tokens if t is not None]
|
|
231
|
+
else:
|
|
232
|
+
self._is_sparse = False
|
|
233
|
+
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
|
215
234
|
|
|
216
235
|
def forward(
|
|
217
236
|
self,
|
|
@@ -224,11 +243,44 @@ class XLMRobertaModel(nn.Module):
|
|
|
224
243
|
hidden_states = self.roberta(
|
|
225
244
|
input_ids, positions, forward_batch, input_embeds, get_embedding
|
|
226
245
|
)
|
|
227
|
-
|
|
246
|
+
embeddings = self.pooler(hidden_states, forward_batch)
|
|
247
|
+
|
|
248
|
+
if self._is_sparse:
|
|
249
|
+
for token_id in self._special_tokens:
|
|
250
|
+
embeddings.embeddings[:, token_id] = 0.0
|
|
251
|
+
embeddings.embeddings = embeddings.embeddings.to_sparse()
|
|
252
|
+
|
|
253
|
+
return embeddings
|
|
228
254
|
|
|
229
255
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
230
256
|
self.roberta.load_weights(weights)
|
|
231
257
|
|
|
258
|
+
if self._is_sparse:
|
|
259
|
+
sparse_dict = XLMRobertaModel._load_sparse_linear(
|
|
260
|
+
self._model_path, self._sparse_head
|
|
261
|
+
)
|
|
262
|
+
self.pooler.load_weights(sparse_dict)
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def _load_sparse_linear(model_path_or_dir: str, sparse_head: str) -> dict:
|
|
266
|
+
"""
|
|
267
|
+
Load sparse_head from local dir or HF Hub.
|
|
268
|
+
Returns a state_dict suitable for nn.Linear.load_state_dict().
|
|
269
|
+
"""
|
|
270
|
+
if os.path.isdir(model_path_or_dir):
|
|
271
|
+
path = os.path.join(model_path_or_dir, sparse_head)
|
|
272
|
+
if not os.path.exists(path):
|
|
273
|
+
raise FileNotFoundError(
|
|
274
|
+
f"'{sparse_head}' not found in {model_path_or_dir}"
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
# remote → use SGLang HF utility
|
|
278
|
+
local_dir = download_from_hf(model_path_or_dir, allow_patterns=sparse_head)
|
|
279
|
+
path = os.path.join(local_dir, sparse_head)
|
|
280
|
+
|
|
281
|
+
state_dict = torch.load(path)
|
|
282
|
+
return state_dict
|
|
283
|
+
|
|
232
284
|
|
|
233
285
|
class XLMRobertaForSequenceClassification(nn.Module):
|
|
234
286
|
def __init__(
|