sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen3_vl.py
CHANGED
|
@@ -15,12 +15,11 @@
|
|
|
15
15
|
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
|
16
16
|
import logging
|
|
17
17
|
from functools import lru_cache, partial
|
|
18
|
-
from typing import Callable, Iterable, List,
|
|
18
|
+
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
import torch
|
|
22
22
|
import torch.nn as nn
|
|
23
|
-
import torch.nn.functional as F
|
|
24
23
|
from einops import rearrange
|
|
25
24
|
from transformers.activations import ACT2FN
|
|
26
25
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
@@ -38,16 +37,20 @@ from sglang.srt.managers.mm_utils import (
|
|
|
38
37
|
MultiModalityDataPaddingPatternMultimodalTokens,
|
|
39
38
|
general_mm_embed_routine,
|
|
40
39
|
)
|
|
41
|
-
from sglang.srt.managers.schedule_batch import
|
|
40
|
+
from sglang.srt.managers.schedule_batch import (
|
|
41
|
+
Modality,
|
|
42
|
+
MultimodalDataItem,
|
|
43
|
+
MultimodalInputs,
|
|
44
|
+
)
|
|
42
45
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
43
46
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
44
|
-
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
|
45
47
|
from sglang.srt.models.qwen3 import Qwen3Model
|
|
46
48
|
from sglang.srt.utils import add_prefix
|
|
47
49
|
from sglang.srt.utils.hf_transformers_utils import get_processor
|
|
48
50
|
|
|
49
51
|
logger = logging.getLogger(__name__)
|
|
50
52
|
|
|
53
|
+
|
|
51
54
|
# === Vision Encoder === #
|
|
52
55
|
|
|
53
56
|
|
|
@@ -189,14 +192,14 @@ class Qwen3_VisionBlock(nn.Module):
|
|
|
189
192
|
position_embeddings=position_embeddings,
|
|
190
193
|
)
|
|
191
194
|
attn = rearrange(attn, "b s ... -> s b ...")
|
|
192
|
-
x
|
|
195
|
+
x += attn
|
|
193
196
|
norm2 = self.norm2(x)
|
|
194
197
|
mlp = self.mlp(norm2)
|
|
195
|
-
x
|
|
198
|
+
x += mlp
|
|
196
199
|
return x
|
|
197
200
|
|
|
198
201
|
|
|
199
|
-
class
|
|
202
|
+
class Qwen3VLMoeVisionPatchMerger(nn.Module):
|
|
200
203
|
|
|
201
204
|
def __init__(
|
|
202
205
|
self,
|
|
@@ -246,7 +249,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
|
|
|
246
249
|
return out
|
|
247
250
|
|
|
248
251
|
|
|
249
|
-
class
|
|
252
|
+
class Qwen3VLMoeVisionModel(nn.Module):
|
|
250
253
|
|
|
251
254
|
def __init__(
|
|
252
255
|
self,
|
|
@@ -263,10 +266,10 @@ class Qwen3_VisionTransformer(nn.Module):
|
|
|
263
266
|
self.spatial_merge_size = vision_config.spatial_merge_size
|
|
264
267
|
self.spatial_merge_unit = self.spatial_merge_size**2
|
|
265
268
|
self.temporal_patch_size = vision_config.temporal_patch_size
|
|
269
|
+
# layer indexes of which layer's output should be deep-stacked
|
|
266
270
|
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
|
267
271
|
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
|
|
268
272
|
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
|
|
269
|
-
|
|
270
273
|
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
|
271
274
|
head_dim = self.hidden_size // self.num_heads
|
|
272
275
|
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
|
@@ -286,7 +289,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|
|
286
289
|
for layer_idx in range(vision_config.depth)
|
|
287
290
|
]
|
|
288
291
|
)
|
|
289
|
-
self.merger =
|
|
292
|
+
self.merger = Qwen3VLMoeVisionPatchMerger(
|
|
290
293
|
dim=vision_config.out_hidden_size,
|
|
291
294
|
context_dim=self.hidden_size,
|
|
292
295
|
norm_layer=norm_layer,
|
|
@@ -297,7 +300,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|
|
297
300
|
|
|
298
301
|
self.deepstack_merger_list = nn.ModuleList(
|
|
299
302
|
[
|
|
300
|
-
|
|
303
|
+
Qwen3VLMoeVisionPatchMerger(
|
|
301
304
|
dim=vision_config.out_hidden_size,
|
|
302
305
|
context_dim=self.hidden_size,
|
|
303
306
|
spatial_merge_size=self.spatial_merge_size,
|
|
@@ -441,7 +444,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|
|
441
444
|
x = self.patch_embed(x)
|
|
442
445
|
|
|
443
446
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
|
444
|
-
x
|
|
447
|
+
x += pos_embeds
|
|
445
448
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
446
449
|
|
|
447
450
|
seq_len, _ = x.size()
|
|
@@ -452,15 +455,16 @@ class Qwen3_VisionTransformer(nn.Module):
|
|
|
452
455
|
position_embeddings = (emb.cos(), emb.sin())
|
|
453
456
|
|
|
454
457
|
# compute cu_seqlens
|
|
458
|
+
cu_seqlens = torch.repeat_interleave(
|
|
459
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
|
460
|
+
).cumsum(dim=0)
|
|
455
461
|
cu_seqlens = torch.cat(
|
|
456
462
|
[
|
|
457
|
-
torch.
|
|
458
|
-
(
|
|
463
|
+
torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device),
|
|
464
|
+
cu_seqlens.to(torch.int32),
|
|
459
465
|
]
|
|
460
466
|
)
|
|
461
|
-
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
|
462
467
|
|
|
463
|
-
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
|
464
468
|
x = x.unsqueeze(1)
|
|
465
469
|
|
|
466
470
|
deepstack_feature_lists = []
|
|
@@ -574,10 +578,7 @@ class Qwen3LLMModel(Qwen3Model):
|
|
|
574
578
|
and layer_idx in self.deepstack_embed_to_decoder_layer
|
|
575
579
|
):
|
|
576
580
|
sep = self.hidden_size * layer_idx
|
|
577
|
-
hidden_states
|
|
578
|
-
hidden_states
|
|
579
|
-
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
|
580
|
-
)
|
|
581
|
+
hidden_states += input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
|
581
582
|
|
|
582
583
|
if not self.pp_group.is_last_rank:
|
|
583
584
|
return PPProxyTensors(
|
|
@@ -605,37 +606,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
|
|
|
605
606
|
config: Qwen3VLConfig,
|
|
606
607
|
quant_config: Optional[QuantizationConfig] = None,
|
|
607
608
|
prefix: str = "",
|
|
609
|
+
language_model_cls=Qwen3LLMModel,
|
|
608
610
|
) -> None:
|
|
609
611
|
super().__init__()
|
|
610
612
|
|
|
611
|
-
self.
|
|
612
|
-
self.visual = Qwen3_VisionTransformer(
|
|
613
|
+
self.visual = Qwen3VLMoeVisionModel(
|
|
613
614
|
config.vision_config,
|
|
614
|
-
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
615
615
|
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
|
616
616
|
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
|
617
617
|
quant_config=quant_config,
|
|
618
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
618
619
|
prefix=add_prefix("visual", prefix),
|
|
619
620
|
)
|
|
620
621
|
|
|
621
|
-
|
|
622
|
-
|
|
622
|
+
# TODO: make it more elegant
|
|
623
|
+
if language_model_cls is Qwen3LLMModel:
|
|
624
|
+
self.config: Qwen3VLConfig = config # for qwen3-vl
|
|
625
|
+
else:
|
|
626
|
+
self.config = config.text_config # for qwen3-omni
|
|
627
|
+
|
|
628
|
+
self.model = language_model_cls(
|
|
629
|
+
config=self.config,
|
|
623
630
|
quant_config=quant_config,
|
|
624
631
|
prefix=add_prefix("model", prefix),
|
|
625
632
|
)
|
|
626
633
|
|
|
627
|
-
if config.tie_word_embeddings:
|
|
634
|
+
if self.config.tie_word_embeddings:
|
|
628
635
|
self.lm_head = self.model.embed_tokens
|
|
629
636
|
else:
|
|
630
637
|
self.lm_head = ParallelLMHead(
|
|
631
|
-
config.vocab_size,
|
|
632
|
-
config.hidden_size,
|
|
638
|
+
self.config.vocab_size,
|
|
639
|
+
self.config.hidden_size,
|
|
633
640
|
quant_config=quant_config,
|
|
634
641
|
prefix=add_prefix("lm_head", prefix),
|
|
635
642
|
)
|
|
636
643
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
|
637
644
|
|
|
638
|
-
self.logits_processor = LogitsProcessor(config)
|
|
645
|
+
self.logits_processor = LogitsProcessor(self.config)
|
|
639
646
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
|
640
647
|
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
|
|
641
648
|
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
|
|
@@ -643,10 +650,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
|
|
|
643
650
|
# deepstack
|
|
644
651
|
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
|
645
652
|
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
|
646
|
-
|
|
647
|
-
@property
|
|
648
|
-
def use_deepstack(self) -> bool:
|
|
649
|
-
return hasattr(self, "deepstack_visual_indexes")
|
|
653
|
+
self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
|
|
650
654
|
|
|
651
655
|
def separate_deepstack_embeds(self, embedding):
|
|
652
656
|
assert (
|
|
@@ -14,49 +14,23 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
|
16
16
|
import logging
|
|
17
|
-
from functools import lru_cache
|
|
18
|
-
from typing import
|
|
17
|
+
from functools import lru_cache
|
|
18
|
+
from typing import Iterable, Optional, Tuple, Union
|
|
19
19
|
|
|
20
|
-
import numpy as np
|
|
21
20
|
import torch
|
|
22
21
|
import torch.nn as nn
|
|
23
|
-
import torch.nn.functional as F
|
|
24
|
-
from einops import rearrange
|
|
25
|
-
from transformers import BatchFeature
|
|
26
|
-
from transformers.activations import ACT2FN
|
|
27
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
28
|
-
Qwen2_5_VisionRotaryEmbedding,
|
|
29
|
-
)
|
|
30
22
|
|
|
31
|
-
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig,
|
|
23
|
+
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
|
|
32
24
|
from sglang.srt.distributed import (
|
|
33
25
|
get_moe_expert_parallel_world_size,
|
|
34
|
-
get_pp_group,
|
|
35
26
|
get_tensor_model_parallel_rank,
|
|
36
27
|
)
|
|
37
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
38
28
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
39
|
-
from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
40
29
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
41
|
-
from sglang.srt.layers.utils import get_layer_id
|
|
42
|
-
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
43
|
-
from sglang.srt.managers.mm_utils import (
|
|
44
|
-
MultiModalityDataPaddingPatternMultimodalTokens,
|
|
45
|
-
general_mm_embed_routine,
|
|
46
|
-
)
|
|
47
|
-
from sglang.srt.managers.schedule_batch import (
|
|
48
|
-
MultimodalDataItem,
|
|
49
|
-
MultimodalInputs,
|
|
50
|
-
global_server_args_dict,
|
|
51
|
-
)
|
|
52
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
53
31
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
54
|
-
from sglang.srt.models.qwen3_moe import
|
|
55
|
-
from sglang.srt.models.qwen3_vl import
|
|
56
|
-
Qwen3_VisionTransformer,
|
|
57
|
-
Qwen3VLForConditionalGeneration,
|
|
58
|
-
)
|
|
59
|
-
from sglang.srt.utils import add_prefix
|
|
32
|
+
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
|
33
|
+
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
|
60
34
|
from sglang.srt.utils.hf_transformers_utils import get_processor
|
|
61
35
|
|
|
62
36
|
logger = logging.getLogger(__name__)
|
|
@@ -68,28 +42,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|
|
68
42
|
def __init__(
|
|
69
43
|
self,
|
|
70
44
|
*,
|
|
71
|
-
config:
|
|
45
|
+
config: Qwen3VLMoeTextConfig,
|
|
72
46
|
quant_config: Optional[QuantizationConfig] = None,
|
|
73
47
|
prefix: str = "",
|
|
74
48
|
):
|
|
75
49
|
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
|
76
|
-
|
|
77
50
|
self.hidden_size = config.hidden_size
|
|
78
51
|
|
|
79
52
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
80
53
|
return self.embed_tokens
|
|
81
54
|
|
|
82
|
-
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
|
83
|
-
# in qwen-vl, last dim is the same
|
|
84
|
-
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
|
85
|
-
self.visual.dtype
|
|
86
|
-
)
|
|
87
|
-
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
|
88
|
-
assert pixel_values.dim() == 2, pixel_values.dim()
|
|
89
|
-
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
|
90
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
91
|
-
return image_embeds
|
|
92
|
-
|
|
93
55
|
def forward(
|
|
94
56
|
self,
|
|
95
57
|
input_ids: torch.Tensor,
|
|
@@ -114,7 +76,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|
|
114
76
|
for layer_idx, layer in enumerate(
|
|
115
77
|
self.layers[self.start_layer : self.end_layer]
|
|
116
78
|
):
|
|
117
|
-
layer_idx
|
|
79
|
+
layer_idx += self.start_layer
|
|
118
80
|
if layer_idx in self.layers_to_capture:
|
|
119
81
|
aux_hidden_states.append(
|
|
120
82
|
hidden_states + residual if residual is not None else hidden_states
|
|
@@ -128,11 +90,10 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|
|
128
90
|
)
|
|
129
91
|
|
|
130
92
|
# process deepstack
|
|
131
|
-
if input_deepstack_embeds is not None and layer_idx
|
|
93
|
+
if input_deepstack_embeds is not None and layer_idx < 3:
|
|
132
94
|
sep = self.hidden_size * layer_idx
|
|
133
|
-
hidden_states
|
|
134
|
-
|
|
135
|
-
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
|
95
|
+
hidden_states.add_(
|
|
96
|
+
input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
|
136
97
|
)
|
|
137
98
|
|
|
138
99
|
if not self.pp_group.is_last_rank:
|
|
@@ -155,144 +116,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|
|
155
116
|
return hidden_states, aux_hidden_states
|
|
156
117
|
|
|
157
118
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
config=config,
|
|
180
|
-
quant_config=quant_config,
|
|
181
|
-
prefix=add_prefix("model", prefix),
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
if config.tie_word_embeddings:
|
|
185
|
-
self.lm_head = self.model.embed_tokens
|
|
186
|
-
else:
|
|
187
|
-
self.lm_head = ParallelLMHead(
|
|
188
|
-
config.vocab_size,
|
|
189
|
-
config.hidden_size,
|
|
190
|
-
quant_config=quant_config,
|
|
191
|
-
prefix=add_prefix("lm_head", prefix),
|
|
119
|
+
def load_fused_expert_weights(
|
|
120
|
+
name: str,
|
|
121
|
+
params_dict: dict,
|
|
122
|
+
loaded_weight: torch.Tensor,
|
|
123
|
+
shard_id: str,
|
|
124
|
+
num_experts: int,
|
|
125
|
+
):
|
|
126
|
+
param = params_dict[name]
|
|
127
|
+
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
|
128
|
+
weight_loader = param.weight_loader
|
|
129
|
+
ep_rank = get_tensor_model_parallel_rank()
|
|
130
|
+
ep_size = get_moe_expert_parallel_world_size()
|
|
131
|
+
if ep_size == 1:
|
|
132
|
+
for expert_id in range(num_experts):
|
|
133
|
+
curr_expert_weight = loaded_weight[expert_id]
|
|
134
|
+
weight_loader(
|
|
135
|
+
param,
|
|
136
|
+
curr_expert_weight,
|
|
137
|
+
name,
|
|
138
|
+
shard_id,
|
|
139
|
+
expert_id,
|
|
192
140
|
)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
# deepstack
|
|
199
|
-
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
|
200
|
-
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
|
201
|
-
|
|
202
|
-
@property
|
|
203
|
-
def use_deepstack(self) -> bool:
|
|
204
|
-
return hasattr(self, "deepstack_visual_indexes")
|
|
205
|
-
|
|
206
|
-
def forward(
|
|
207
|
-
self,
|
|
208
|
-
input_ids: torch.Tensor,
|
|
209
|
-
positions: torch.Tensor,
|
|
210
|
-
forward_batch: ForwardBatch,
|
|
211
|
-
get_embedding: bool = False,
|
|
212
|
-
):
|
|
213
|
-
"""Run forward pass for Qwen3-VL.
|
|
214
|
-
|
|
215
|
-
Args:
|
|
216
|
-
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
217
|
-
batch.
|
|
218
|
-
positions: Flattened (concatenated) position ids corresponding to a
|
|
219
|
-
batch.
|
|
220
|
-
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
|
221
|
-
opensource models), the shape will be `(3, seq_len)`,
|
|
222
|
-
otherwise it will be `(seq_len,).
|
|
223
|
-
(Use input_metadata.mrope_positions to replace it)
|
|
224
|
-
"""
|
|
225
|
-
if self.is_mrope_enabled:
|
|
226
|
-
positions = forward_batch.mrope_positions
|
|
227
|
-
|
|
228
|
-
if not (
|
|
229
|
-
forward_batch.forward_mode.is_decode()
|
|
230
|
-
or not forward_batch.contains_image_inputs()
|
|
231
|
-
):
|
|
232
|
-
if self.is_mrope_enabled:
|
|
233
|
-
assert positions.ndim == 2 and positions.size(0) == 3, (
|
|
234
|
-
"multimodal section rotary embedding requires "
|
|
235
|
-
f"(3, seq_len) positions, but got {positions.size()}"
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
hidden_states = general_mm_embed_routine(
|
|
239
|
-
input_ids=input_ids,
|
|
240
|
-
forward_batch=forward_batch,
|
|
241
|
-
language_model=self.model,
|
|
242
|
-
multimodal_model=self,
|
|
243
|
-
positions=positions,
|
|
244
|
-
use_deepstack=self.use_deepstack,
|
|
141
|
+
else:
|
|
142
|
+
experts_per_ep = num_experts // ep_size
|
|
143
|
+
start_expert = ep_rank * experts_per_ep
|
|
144
|
+
end_expert = (
|
|
145
|
+
(ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
|
|
245
146
|
)
|
|
246
147
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
148
|
+
for idx, expert_id in enumerate(range(start_expert, end_expert)):
|
|
149
|
+
curr_expert_weight = loaded_weight[expert_id]
|
|
150
|
+
weight_loader(
|
|
151
|
+
param,
|
|
152
|
+
curr_expert_weight,
|
|
153
|
+
name,
|
|
154
|
+
shard_id,
|
|
155
|
+
idx,
|
|
250
156
|
)
|
|
251
|
-
|
|
252
|
-
return self.pooler(hidden_states, forward_batch)
|
|
157
|
+
return True
|
|
253
158
|
|
|
254
|
-
|
|
159
|
+
|
|
160
|
+
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|
161
|
+
def __init__(
|
|
255
162
|
self,
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
num_experts: int,
|
|
163
|
+
config: Qwen3VLMoeConfig,
|
|
164
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
165
|
+
prefix: str = "",
|
|
166
|
+
language_model_cls=Qwen3MoeLLMModel,
|
|
261
167
|
):
|
|
262
|
-
|
|
263
|
-
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
|
264
|
-
weight_loader = param.weight_loader
|
|
265
|
-
ep_rank = get_tensor_model_parallel_rank()
|
|
266
|
-
ep_size = get_moe_expert_parallel_world_size()
|
|
267
|
-
if ep_size == 1:
|
|
268
|
-
for expert_id in range(num_experts):
|
|
269
|
-
curr_expert_weight = loaded_weight[expert_id]
|
|
270
|
-
weight_loader(
|
|
271
|
-
param,
|
|
272
|
-
curr_expert_weight,
|
|
273
|
-
name,
|
|
274
|
-
shard_id,
|
|
275
|
-
expert_id,
|
|
276
|
-
)
|
|
277
|
-
else:
|
|
278
|
-
experts_per_ep = num_experts // ep_size
|
|
279
|
-
start_expert = ep_rank * experts_per_ep
|
|
280
|
-
end_expert = (
|
|
281
|
-
(ep_rank + 1) * experts_per_ep
|
|
282
|
-
if ep_rank != ep_size - 1
|
|
283
|
-
else num_experts
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
for idx, expert_id in enumerate(range(start_expert, end_expert)):
|
|
287
|
-
curr_expert_weight = loaded_weight[expert_id]
|
|
288
|
-
weight_loader(
|
|
289
|
-
param,
|
|
290
|
-
curr_expert_weight,
|
|
291
|
-
name,
|
|
292
|
-
shard_id,
|
|
293
|
-
idx,
|
|
294
|
-
)
|
|
295
|
-
return True
|
|
168
|
+
super().__init__(config, quant_config, prefix, language_model_cls)
|
|
296
169
|
|
|
297
170
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
298
171
|
stacked_params_mapping = [
|
|
@@ -338,8 +211,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|
|
338
211
|
self._cached_params_dict = dict(self.named_parameters())
|
|
339
212
|
params_dict = self._cached_params_dict
|
|
340
213
|
for name, loaded_weight in weights:
|
|
341
|
-
|
|
342
|
-
name = name.replace(r"model.language_model.", r"model.")
|
|
214
|
+
name = name.replace(r"model.language_model.", r"model.")
|
|
343
215
|
|
|
344
216
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
345
217
|
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
|
@@ -393,14 +265,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|
|
393
265
|
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
|
394
266
|
if "experts.gate_up_proj" in name:
|
|
395
267
|
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
|
396
|
-
|
|
268
|
+
load_fused_expert_weights(
|
|
397
269
|
name_mapped,
|
|
398
270
|
params_dict,
|
|
399
271
|
loaded_weight[0],
|
|
400
272
|
"w1",
|
|
401
273
|
num_experts,
|
|
402
274
|
)
|
|
403
|
-
|
|
275
|
+
load_fused_expert_weights(
|
|
404
276
|
name_mapped,
|
|
405
277
|
params_dict,
|
|
406
278
|
loaded_weight[1],
|
|
@@ -408,7 +280,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|
|
408
280
|
num_experts,
|
|
409
281
|
)
|
|
410
282
|
else:
|
|
411
|
-
|
|
283
|
+
load_fused_expert_weights(
|
|
412
284
|
name_mapped,
|
|
413
285
|
params_dict,
|
|
414
286
|
loaded_weight,
|
sglang/srt/models/roberta.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import os
|
|
4
4
|
from typing import Iterable, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import torch
|
|
@@ -8,10 +8,12 @@ from torch import nn
|
|
|
8
8
|
|
|
9
9
|
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
|
10
10
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
11
|
+
from sglang.srt.layers.sparse_pooler import SparsePooler
|
|
11
12
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
12
13
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
13
14
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
14
15
|
from sglang.srt.models.bert import BertEncoder
|
|
16
|
+
from sglang.srt.utils.hf_transformers_utils import download_from_hf
|
|
15
17
|
|
|
16
18
|
RobertaConfig = None
|
|
17
19
|
|
|
@@ -206,12 +208,29 @@ class XLMRobertaModel(nn.Module):
|
|
|
206
208
|
config: RobertaConfig,
|
|
207
209
|
quant_config: Optional[QuantizationConfig] = None,
|
|
208
210
|
prefix: str = "",
|
|
211
|
+
sparse_head: Optional[str] = None,
|
|
212
|
+
model_path: Optional[str] = None,
|
|
209
213
|
):
|
|
210
214
|
super().__init__()
|
|
211
215
|
self.roberta = XLMRobertaBaseModel(
|
|
212
216
|
config=config, quant_config=quant_config, prefix=prefix
|
|
213
217
|
)
|
|
214
|
-
|
|
218
|
+
if sparse_head is not None:
|
|
219
|
+
self._is_sparse = True
|
|
220
|
+
self._model_path = model_path
|
|
221
|
+
self._sparse_head = sparse_head
|
|
222
|
+
self.pooler = SparsePooler(config=config)
|
|
223
|
+
# Zero out special tokens
|
|
224
|
+
self._special_tokens = [
|
|
225
|
+
config.bos_token_id,
|
|
226
|
+
config.eos_token_id,
|
|
227
|
+
config.pad_token_id,
|
|
228
|
+
# self.config.unk_token_id # not available in the XLMRobertaConfig
|
|
229
|
+
]
|
|
230
|
+
self._special_tokens = [t for t in self._special_tokens if t is not None]
|
|
231
|
+
else:
|
|
232
|
+
self._is_sparse = False
|
|
233
|
+
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
|
215
234
|
|
|
216
235
|
def forward(
|
|
217
236
|
self,
|
|
@@ -224,11 +243,44 @@ class XLMRobertaModel(nn.Module):
|
|
|
224
243
|
hidden_states = self.roberta(
|
|
225
244
|
input_ids, positions, forward_batch, input_embeds, get_embedding
|
|
226
245
|
)
|
|
227
|
-
|
|
246
|
+
embeddings = self.pooler(hidden_states, forward_batch)
|
|
247
|
+
|
|
248
|
+
if self._is_sparse:
|
|
249
|
+
for token_id in self._special_tokens:
|
|
250
|
+
embeddings.embeddings[:, token_id] = 0.0
|
|
251
|
+
embeddings.embeddings = embeddings.embeddings.to_sparse()
|
|
252
|
+
|
|
253
|
+
return embeddings
|
|
228
254
|
|
|
229
255
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
230
256
|
self.roberta.load_weights(weights)
|
|
231
257
|
|
|
258
|
+
if self._is_sparse:
|
|
259
|
+
sparse_dict = XLMRobertaModel._load_sparse_linear(
|
|
260
|
+
self._model_path, self._sparse_head
|
|
261
|
+
)
|
|
262
|
+
self.pooler.load_weights(sparse_dict)
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def _load_sparse_linear(model_path_or_dir: str, sparse_head: str) -> dict:
|
|
266
|
+
"""
|
|
267
|
+
Load sparse_head from local dir or HF Hub.
|
|
268
|
+
Returns a state_dict suitable for nn.Linear.load_state_dict().
|
|
269
|
+
"""
|
|
270
|
+
if os.path.isdir(model_path_or_dir):
|
|
271
|
+
path = os.path.join(model_path_or_dir, sparse_head)
|
|
272
|
+
if not os.path.exists(path):
|
|
273
|
+
raise FileNotFoundError(
|
|
274
|
+
f"'{sparse_head}' not found in {model_path_or_dir}"
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
# remote → use SGLang HF utility
|
|
278
|
+
local_dir = download_from_hf(model_path_or_dir, allow_patterns=sparse_head)
|
|
279
|
+
path = os.path.join(local_dir, sparse_head)
|
|
280
|
+
|
|
281
|
+
state_dict = torch.load(path)
|
|
282
|
+
return state_dict
|
|
283
|
+
|
|
232
284
|
|
|
233
285
|
class XLMRobertaForSequenceClassification(nn.Module):
|
|
234
286
|
def __init__(
|
sglang/srt/models/step3_vl.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import math
|
|
3
|
-
from collections.abc import Iterable
|
|
4
3
|
from math import sqrt
|
|
5
|
-
from typing import Any, Dict, Iterable, List,
|
|
4
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
6
5
|
|
|
7
6
|
import torch
|
|
8
7
|
from torch import nn
|
|
@@ -57,7 +56,6 @@ from sglang.srt.managers.schedule_batch import (
|
|
|
57
56
|
Modality,
|
|
58
57
|
MultimodalDataItem,
|
|
59
58
|
MultimodalInputs,
|
|
60
|
-
global_server_args_dict,
|
|
61
59
|
)
|
|
62
60
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
63
61
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
@@ -300,7 +298,7 @@ class Step3TextDecoderLayer(nn.Module):
|
|
|
300
298
|
# self.n_shared_experts = 1
|
|
301
299
|
# self.num_fused_shared_experts = (
|
|
302
300
|
# 0
|
|
303
|
-
# if
|
|
301
|
+
# if global_server_args.disable_shared_experts_fusion
|
|
304
302
|
# else self.n_shared_experts
|
|
305
303
|
# )
|
|
306
304
|
self.num_fused_shared_experts = 0
|
|
@@ -774,7 +772,7 @@ class Step3VLForConditionalGeneration(nn.Module):
|
|
|
774
772
|
# self.n_shared_experts = 1
|
|
775
773
|
# self.num_fused_shared_experts = (
|
|
776
774
|
# 0
|
|
777
|
-
# if
|
|
775
|
+
# if global_server_args.disable_shared_experts_fusion
|
|
778
776
|
# else self.n_shared_experts
|
|
779
777
|
# )
|
|
780
778
|
self.num_fused_shared_experts = 0
|