sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,661 @@
|
|
|
1
|
+
# Copyright 2025 Qwen Team
|
|
2
|
+
# Copyright 2025 SGLang Team
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
|
16
|
+
import math
|
|
17
|
+
from typing import Iterable, List, Optional, Tuple
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
import torch.nn.functional as F
|
|
23
|
+
from transformers import PreTrainedModel
|
|
24
|
+
from transformers.activations import ACT2FN
|
|
25
|
+
from transformers.modeling_outputs import BaseModelOutput
|
|
26
|
+
|
|
27
|
+
from sglang.srt.configs.qwen3_omni import (
|
|
28
|
+
Qwen3OmniMoeAudioEncoderConfig,
|
|
29
|
+
Qwen3OmniMoeThinkerConfig,
|
|
30
|
+
Qwen3OmniMoeVisionEncoderConfig,
|
|
31
|
+
)
|
|
32
|
+
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig
|
|
33
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
|
34
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
|
35
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
|
36
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
38
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
|
39
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
40
|
+
from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel
|
|
41
|
+
from sglang.srt.models.qwen3_vl_moe import (
|
|
42
|
+
Qwen3MoeLLMModel,
|
|
43
|
+
Qwen3VLMoeForConditionalGeneration,
|
|
44
|
+
load_fused_expert_weights,
|
|
45
|
+
)
|
|
46
|
+
from sglang.srt.utils import add_prefix, logger
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
config: Qwen3OmniMoeAudioEncoderConfig,
|
|
53
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
54
|
+
prefix: str = "",
|
|
55
|
+
):
|
|
56
|
+
super().__init__()
|
|
57
|
+
embed_dim = config.d_model
|
|
58
|
+
self.embed_dim = config.d_model
|
|
59
|
+
self.self_attn = VisionAttention(
|
|
60
|
+
embed_dim=embed_dim,
|
|
61
|
+
num_heads=config.encoder_attention_heads,
|
|
62
|
+
projection_size=embed_dim,
|
|
63
|
+
use_qkv_parallel=True,
|
|
64
|
+
rotary_embed="normal",
|
|
65
|
+
proj_bias=True,
|
|
66
|
+
qkv_backend="fa3",
|
|
67
|
+
softmax_in_single_precision=False,
|
|
68
|
+
flatten_batch=True,
|
|
69
|
+
quant_config=quant_config,
|
|
70
|
+
prefix=add_prefix("attn", prefix),
|
|
71
|
+
)
|
|
72
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
73
|
+
self.dropout = config.dropout
|
|
74
|
+
self.activation_fn = ACT2FN[config.activation_function]
|
|
75
|
+
self.activation_dropout = config.activation_dropout
|
|
76
|
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
77
|
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
78
|
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
79
|
+
|
|
80
|
+
def forward(
|
|
81
|
+
self,
|
|
82
|
+
hidden_states: torch.Tensor,
|
|
83
|
+
cu_seqlens: torch.Tensor,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> torch.Tensor:
|
|
86
|
+
"""
|
|
87
|
+
Args:
|
|
88
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
89
|
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
|
90
|
+
`(encoder_attention_heads,)`.
|
|
91
|
+
output_attentions (`bool`, *optional*):
|
|
92
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
93
|
+
returned tensors for more detail.
|
|
94
|
+
"""
|
|
95
|
+
residual = hidden_states
|
|
96
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
97
|
+
hidden_states = self.self_attn(
|
|
98
|
+
x=hidden_states,
|
|
99
|
+
cu_seqlens=cu_seqlens,
|
|
100
|
+
)
|
|
101
|
+
hidden_states = residual + hidden_states
|
|
102
|
+
residual = hidden_states
|
|
103
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
104
|
+
hidden_states = self.fc1(hidden_states)
|
|
105
|
+
hidden_states = self.activation_fn(hidden_states)
|
|
106
|
+
hidden_states = self.fc2(hidden_states)
|
|
107
|
+
hidden_states = residual + hidden_states
|
|
108
|
+
|
|
109
|
+
if hidden_states.dtype == torch.float16:
|
|
110
|
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
111
|
+
hidden_states = torch.clamp(
|
|
112
|
+
hidden_states, min=-clamp_value, max=clamp_value
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
outputs = (hidden_states,)
|
|
116
|
+
|
|
117
|
+
return outputs
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class SinusoidsPositionEmbedding(nn.Module):
|
|
121
|
+
def __init__(self, length, channels, max_timescale=10000):
|
|
122
|
+
super().__init__()
|
|
123
|
+
if channels % 2 != 0:
|
|
124
|
+
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
|
125
|
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
126
|
+
inv_timescales = torch.exp(
|
|
127
|
+
-log_timescale_increment * torch.arange(channels // 2).float()
|
|
128
|
+
)
|
|
129
|
+
scaled_time = (
|
|
130
|
+
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
|
131
|
+
)
|
|
132
|
+
self.register_buffer(
|
|
133
|
+
"positional_embedding",
|
|
134
|
+
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
|
|
135
|
+
persistent=False,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def forward(self, seqlen: int):
|
|
139
|
+
return self.positional_embedding[:seqlen, :]
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _get_feat_extract_output_lengths(input_lengths):
|
|
143
|
+
"""
|
|
144
|
+
Computes the output length of the convolutional layers and the output length of the audio encoder
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
input_lengths_leave = input_lengths % 100
|
|
148
|
+
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
|
149
|
+
output_lengths = (
|
|
150
|
+
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
|
151
|
+
)
|
|
152
|
+
return output_lengths
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class Qwen3OmniMoeAudioEncoder(PreTrainedModel):
|
|
156
|
+
config: Qwen3OmniMoeAudioEncoderConfig
|
|
157
|
+
|
|
158
|
+
def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
|
|
159
|
+
super().__init__(config)
|
|
160
|
+
self.dropout = config.dropout
|
|
161
|
+
|
|
162
|
+
embed_dim = config.d_model
|
|
163
|
+
self.num_mel_bins = config.num_mel_bins
|
|
164
|
+
self.max_source_positions = config.max_source_positions
|
|
165
|
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
|
166
|
+
self.n_window = config.n_window
|
|
167
|
+
self.positional_embedding = SinusoidsPositionEmbedding(
|
|
168
|
+
self.max_source_positions, embed_dim
|
|
169
|
+
)
|
|
170
|
+
self.layers = nn.ModuleList(
|
|
171
|
+
[
|
|
172
|
+
Qwen3OmniMoeAudioEncoderLayer(config)
|
|
173
|
+
for _ in range(config.encoder_layers)
|
|
174
|
+
]
|
|
175
|
+
)
|
|
176
|
+
self.ln_post = nn.LayerNorm(config.d_model)
|
|
177
|
+
self.gradient_checkpointing = False
|
|
178
|
+
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
|
|
179
|
+
self.conv2d2 = nn.Conv2d(
|
|
180
|
+
config.downsample_hidden_size,
|
|
181
|
+
config.downsample_hidden_size,
|
|
182
|
+
3,
|
|
183
|
+
2,
|
|
184
|
+
padding=1,
|
|
185
|
+
)
|
|
186
|
+
self.conv2d3 = nn.Conv2d(
|
|
187
|
+
config.downsample_hidden_size,
|
|
188
|
+
config.downsample_hidden_size,
|
|
189
|
+
3,
|
|
190
|
+
2,
|
|
191
|
+
padding=1,
|
|
192
|
+
)
|
|
193
|
+
self.conv_out = nn.Linear(
|
|
194
|
+
config.downsample_hidden_size
|
|
195
|
+
* ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
|
|
196
|
+
config.d_model,
|
|
197
|
+
bias=False,
|
|
198
|
+
)
|
|
199
|
+
self.proj1 = nn.Linear(config.d_model, config.d_model)
|
|
200
|
+
self.act = ACT2FN[config.activation_function]
|
|
201
|
+
self.proj2 = nn.Linear(config.d_model, config.output_dim)
|
|
202
|
+
self.n_window_infer = self.config.n_window_infer
|
|
203
|
+
self.conv_chunksize = self.config.conv_chunksize
|
|
204
|
+
|
|
205
|
+
def _freeze_parameters(self):
|
|
206
|
+
for param in self.parameters():
|
|
207
|
+
param.requires_grad = False
|
|
208
|
+
self._requires_grad = False
|
|
209
|
+
|
|
210
|
+
def get_input_embeddings(self) -> nn.Module:
|
|
211
|
+
return self.conv1
|
|
212
|
+
|
|
213
|
+
def set_input_embeddings(self, value: nn.Module):
|
|
214
|
+
self.conv1 = value
|
|
215
|
+
|
|
216
|
+
def forward(
|
|
217
|
+
self,
|
|
218
|
+
input_features,
|
|
219
|
+
feature_lens=None,
|
|
220
|
+
aftercnn_lens=None,
|
|
221
|
+
):
|
|
222
|
+
r"""
|
|
223
|
+
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
|
224
|
+
mel length
|
|
225
|
+
aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
|
226
|
+
mel length after cnn
|
|
227
|
+
"""
|
|
228
|
+
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
|
|
229
|
+
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
|
|
230
|
+
|
|
231
|
+
chunk_lengths = torch.tensor(
|
|
232
|
+
[self.n_window * 2] * chunk_num.sum(),
|
|
233
|
+
dtype=torch.long,
|
|
234
|
+
device=feature_lens.device,
|
|
235
|
+
)
|
|
236
|
+
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
|
|
237
|
+
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
|
|
238
|
+
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
|
|
239
|
+
|
|
240
|
+
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
|
|
241
|
+
padded_feature = nn.utils.rnn.pad_sequence(
|
|
242
|
+
chunk_list, batch_first=True
|
|
243
|
+
).transpose(1, 2)
|
|
244
|
+
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
|
|
245
|
+
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
|
|
246
|
+
[
|
|
247
|
+
torch.ones(length, dtype=torch.bool, device=padded_feature.device)
|
|
248
|
+
for length in feature_lens_after_cnn
|
|
249
|
+
],
|
|
250
|
+
batch_first=True,
|
|
251
|
+
)
|
|
252
|
+
padded_feature = padded_feature.unsqueeze(1)
|
|
253
|
+
# Split to chunk to avoid OOM during convolution
|
|
254
|
+
padded_embeds = []
|
|
255
|
+
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
|
|
256
|
+
padded_embed = F.gelu(self.conv2d1(chunk))
|
|
257
|
+
padded_embed = F.gelu(self.conv2d2(padded_embed))
|
|
258
|
+
padded_embed = F.gelu(self.conv2d3(padded_embed))
|
|
259
|
+
padded_embeds.append(padded_embed)
|
|
260
|
+
padded_embed = torch.cat(padded_embeds, dim=0)
|
|
261
|
+
b, c, f, t = padded_embed.size()
|
|
262
|
+
padded_embed = self.conv_out(
|
|
263
|
+
padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
positional_embedding = (
|
|
267
|
+
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
|
|
268
|
+
.unsqueeze(0)
|
|
269
|
+
.to(padded_embed.dtype)
|
|
270
|
+
)
|
|
271
|
+
padded_embed = padded_embed + positional_embedding
|
|
272
|
+
hidden_states = padded_embed[padded_mask_after_cnn]
|
|
273
|
+
cu_chunk_lens = [0]
|
|
274
|
+
window_aftercnn = padded_mask_after_cnn.shape[-1] * (
|
|
275
|
+
self.n_window_infer // (self.n_window * 2)
|
|
276
|
+
)
|
|
277
|
+
for cnn_len in aftercnn_lens:
|
|
278
|
+
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
|
|
279
|
+
remainder = cnn_len % window_aftercnn
|
|
280
|
+
if remainder != 0:
|
|
281
|
+
cu_chunk_lens += [remainder]
|
|
282
|
+
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
|
|
283
|
+
-1, dtype=torch.int32
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
for encoder_layer in self.layers:
|
|
287
|
+
layer_outputs = encoder_layer(
|
|
288
|
+
hidden_states,
|
|
289
|
+
cu_seqlens,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
hidden_states = layer_outputs[0]
|
|
293
|
+
|
|
294
|
+
hidden_states = self.ln_post(hidden_states)
|
|
295
|
+
hidden_states = self.proj1(hidden_states)
|
|
296
|
+
hidden_states = self.act(hidden_states)
|
|
297
|
+
hidden_states = self.proj2(hidden_states)
|
|
298
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
299
|
+
|
|
300
|
+
# Ignore copy
|
|
301
|
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
|
302
|
+
"""
|
|
303
|
+
Computes the output length of the convolutional layers and the output length of the audio encoder
|
|
304
|
+
"""
|
|
305
|
+
input_lengths = (input_lengths - 1) // 2 + 1
|
|
306
|
+
output_lengths = (input_lengths - 2) // 2 + 1
|
|
307
|
+
return input_lengths, output_lengths
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class Qwen3OmniMoeVisionPatchMerger(nn.Module):
|
|
311
|
+
|
|
312
|
+
def __init__(
|
|
313
|
+
self,
|
|
314
|
+
dim: int,
|
|
315
|
+
context_dim: int,
|
|
316
|
+
spatial_merge_size: int = 2,
|
|
317
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
318
|
+
prefix: str = "",
|
|
319
|
+
use_postshuffle_norm=False,
|
|
320
|
+
) -> None:
|
|
321
|
+
super().__init__()
|
|
322
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
|
323
|
+
self.use_postshuffle_norm = use_postshuffle_norm
|
|
324
|
+
self.ln_q = RMSNorm(
|
|
325
|
+
self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6
|
|
326
|
+
)
|
|
327
|
+
self.mlp = nn.ModuleList(
|
|
328
|
+
[
|
|
329
|
+
ColumnParallelLinear(
|
|
330
|
+
self.hidden_size,
|
|
331
|
+
self.hidden_size,
|
|
332
|
+
bias=True,
|
|
333
|
+
quant_config=quant_config,
|
|
334
|
+
prefix=add_prefix("mlp.0", prefix),
|
|
335
|
+
),
|
|
336
|
+
nn.GELU(),
|
|
337
|
+
RowParallelLinear(
|
|
338
|
+
self.hidden_size,
|
|
339
|
+
dim,
|
|
340
|
+
bias=True,
|
|
341
|
+
quant_config=quant_config,
|
|
342
|
+
prefix=add_prefix("mlp.2", prefix),
|
|
343
|
+
),
|
|
344
|
+
]
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
348
|
+
x = (
|
|
349
|
+
x.view(-1, self.hidden_size)
|
|
350
|
+
if self.use_postshuffle_norm
|
|
351
|
+
else x.view(-1, x.shape[-1])
|
|
352
|
+
)
|
|
353
|
+
hidden = self.ln_q(x).view(-1, self.hidden_size)
|
|
354
|
+
for layer in self.mlp:
|
|
355
|
+
if isinstance(hidden, tuple):
|
|
356
|
+
hidden = hidden[0]
|
|
357
|
+
hidden = layer(hidden)
|
|
358
|
+
|
|
359
|
+
if isinstance(hidden, tuple):
|
|
360
|
+
hidden = hidden[0]
|
|
361
|
+
|
|
362
|
+
return hidden
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
class Qwen3OmniMoeVisionEncoder(Qwen3VLMoeVisionModel):
|
|
366
|
+
config: Qwen3OmniMoeVisionEncoderConfig
|
|
367
|
+
|
|
368
|
+
def __init__(
|
|
369
|
+
self,
|
|
370
|
+
config: Qwen3OmniMoeVisionEncoderConfig,
|
|
371
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
372
|
+
prefix: str = None,
|
|
373
|
+
**kwargs,
|
|
374
|
+
):
|
|
375
|
+
super().__init__(
|
|
376
|
+
vision_config=config,
|
|
377
|
+
quant_config=quant_config,
|
|
378
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
self.merger = Qwen3OmniMoeVisionPatchMerger(
|
|
382
|
+
dim=config.out_hidden_size,
|
|
383
|
+
context_dim=config.hidden_size,
|
|
384
|
+
spatial_merge_size=config.spatial_merge_size,
|
|
385
|
+
quant_config=quant_config,
|
|
386
|
+
use_postshuffle_norm=False,
|
|
387
|
+
prefix=add_prefix("merger", prefix),
|
|
388
|
+
)
|
|
389
|
+
self.merger_list = nn.ModuleList(
|
|
390
|
+
[
|
|
391
|
+
Qwen3OmniMoeVisionPatchMerger(
|
|
392
|
+
dim=config.out_hidden_size,
|
|
393
|
+
context_dim=config.hidden_size,
|
|
394
|
+
spatial_merge_size=config.spatial_merge_size,
|
|
395
|
+
use_postshuffle_norm=True,
|
|
396
|
+
quant_config=quant_config,
|
|
397
|
+
prefix=add_prefix("merger_list", prefix),
|
|
398
|
+
)
|
|
399
|
+
for _ in range(len(config.deepstack_visual_indexes))
|
|
400
|
+
]
|
|
401
|
+
)
|
|
402
|
+
del self.deepstack_merger_list
|
|
403
|
+
|
|
404
|
+
@property
|
|
405
|
+
def deepstack_merger_list(self):
|
|
406
|
+
return self.merger_list
|
|
407
|
+
|
|
408
|
+
@property
|
|
409
|
+
def dtype(self) -> torch.dtype:
|
|
410
|
+
return self.patch_embed.proj.weight.dtype
|
|
411
|
+
|
|
412
|
+
@property
|
|
413
|
+
def device(self) -> torch.device:
|
|
414
|
+
return self.patch_embed.proj.weight.device
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
|
|
418
|
+
config: Qwen3OmniMoeThinkerConfig
|
|
419
|
+
|
|
420
|
+
def __init__(
|
|
421
|
+
self,
|
|
422
|
+
config: Qwen3OmniMoeThinkerConfig,
|
|
423
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
424
|
+
prefix: str = "",
|
|
425
|
+
):
|
|
426
|
+
super().__init__(
|
|
427
|
+
config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel
|
|
428
|
+
)
|
|
429
|
+
self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config)
|
|
430
|
+
self.visual = Qwen3OmniMoeVisionEncoder(
|
|
431
|
+
config.vision_config,
|
|
432
|
+
quant_config=quant_config,
|
|
433
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
434
|
+
prefix=add_prefix("visual", prefix),
|
|
435
|
+
)
|
|
436
|
+
self.pad_token_id = (
|
|
437
|
+
self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def get_audio_feature(self, items: List[MultimodalDataItem]):
|
|
441
|
+
feature_attention_mask = torch.cat(
|
|
442
|
+
[item.feature_attention_mask for item in items], dim=0
|
|
443
|
+
).type(torch.long)
|
|
444
|
+
input_features = (
|
|
445
|
+
torch.cat([item.feature for item in items])
|
|
446
|
+
.type(self.audio_tower.dtype)
|
|
447
|
+
.to(next(self.audio_tower.parameters()).device)
|
|
448
|
+
)
|
|
449
|
+
if feature_attention_mask is not None:
|
|
450
|
+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
|
451
|
+
input_features = input_features.permute(0, 2, 1)[
|
|
452
|
+
feature_attention_mask.bool()
|
|
453
|
+
].permute(1, 0)
|
|
454
|
+
else:
|
|
455
|
+
audio_feature_lengths = None
|
|
456
|
+
|
|
457
|
+
feature_lens = (
|
|
458
|
+
audio_feature_lengths
|
|
459
|
+
if audio_feature_lengths is not None
|
|
460
|
+
else feature_attention_mask.sum(-1)
|
|
461
|
+
)
|
|
462
|
+
audio_outputs = self.audio_tower(
|
|
463
|
+
input_features,
|
|
464
|
+
feature_lens=feature_lens,
|
|
465
|
+
)
|
|
466
|
+
audio_features = audio_outputs.last_hidden_state
|
|
467
|
+
|
|
468
|
+
return audio_features
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class Qwen3OmniMoeForConditionalGeneration(PreTrainedModel):
|
|
472
|
+
def __init__(
|
|
473
|
+
self,
|
|
474
|
+
config: Qwen3VLMoeConfig,
|
|
475
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
476
|
+
prefix: str = "",
|
|
477
|
+
):
|
|
478
|
+
super().__init__(config)
|
|
479
|
+
self.config = config
|
|
480
|
+
|
|
481
|
+
self.thinker = Qwen3OmniMoeThinkerForConditionalGeneration(
|
|
482
|
+
config.thinker_config, quant_config=quant_config, prefix=prefix
|
|
483
|
+
)
|
|
484
|
+
self.enable_talker = False
|
|
485
|
+
self.pad_input_ids = self.thinker.pad_input_ids
|
|
486
|
+
self.forward = self.thinker.forward
|
|
487
|
+
|
|
488
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
489
|
+
stacked_params_mapping = [
|
|
490
|
+
# (param_name, shard_name, shard_id)
|
|
491
|
+
(".qkv_proj", ".q_proj", "q"),
|
|
492
|
+
(".qkv_proj", ".k_proj", "k"),
|
|
493
|
+
(".qkv_proj", ".v_proj", "v"),
|
|
494
|
+
("gate_up_proj", "up_proj", 1),
|
|
495
|
+
("gate_up_proj", "gate_proj", 0),
|
|
496
|
+
]
|
|
497
|
+
|
|
498
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
499
|
+
ckpt_gate_proj_name="gate_proj",
|
|
500
|
+
ckpt_down_proj_name="down_proj",
|
|
501
|
+
ckpt_up_proj_name="up_proj",
|
|
502
|
+
num_experts=self.config.num_experts,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
|
506
|
+
ignore_suffixes = (
|
|
507
|
+
".bias",
|
|
508
|
+
"_bias",
|
|
509
|
+
".k_scale",
|
|
510
|
+
"_k_scale",
|
|
511
|
+
".v_scale",
|
|
512
|
+
"_v_scale",
|
|
513
|
+
".weight_scale",
|
|
514
|
+
"_weight_scale",
|
|
515
|
+
".input_scale",
|
|
516
|
+
"_input_scale",
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
is_fused_expert = False
|
|
520
|
+
fused_expert_params_mapping = [
|
|
521
|
+
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
|
522
|
+
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
|
523
|
+
]
|
|
524
|
+
|
|
525
|
+
num_experts = self.config.num_experts
|
|
526
|
+
|
|
527
|
+
# Cache params_dict to avoid repeated expensive traversal of model parameters
|
|
528
|
+
if not hasattr(self, "_cached_params_dict"):
|
|
529
|
+
self._cached_params_dict = dict(self.named_parameters())
|
|
530
|
+
params_dict = self._cached_params_dict
|
|
531
|
+
|
|
532
|
+
for name, loaded_weight in weights:
|
|
533
|
+
name = name.replace(r"model.language_model.", r"model.")
|
|
534
|
+
|
|
535
|
+
if ("talker" in name or "code2wav" in name) and not self.enable_talker:
|
|
536
|
+
continue
|
|
537
|
+
|
|
538
|
+
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
|
|
539
|
+
|
|
540
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
541
|
+
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
|
542
|
+
is_fused_expert = True
|
|
543
|
+
expert_params_mapping = fused_expert_params_mapping
|
|
544
|
+
|
|
545
|
+
# Skip non-stacked layers and experts (experts handled below).
|
|
546
|
+
if weight_name not in name:
|
|
547
|
+
continue
|
|
548
|
+
if "visual" in name:
|
|
549
|
+
continue
|
|
550
|
+
|
|
551
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
552
|
+
# Since we handle the experts below in expert_params_mapping,
|
|
553
|
+
# we need to skip here BEFORE we update the name, otherwise
|
|
554
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
555
|
+
# will then be updated below in expert_params_mapping
|
|
556
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
557
|
+
if "mlp.experts" in name:
|
|
558
|
+
continue
|
|
559
|
+
name = name.replace(weight_name, param_name)
|
|
560
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
|
561
|
+
if name.endswith(ignore_suffixes) and name not in params_dict:
|
|
562
|
+
continue
|
|
563
|
+
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
|
|
564
|
+
# if is_pp_missing_parameter(name, self):
|
|
565
|
+
# continue
|
|
566
|
+
|
|
567
|
+
if name not in params_dict:
|
|
568
|
+
continue
|
|
569
|
+
|
|
570
|
+
param = params_dict[name]
|
|
571
|
+
weight_loader = param.weight_loader
|
|
572
|
+
weight_loader(param, loaded_weight, shard_id)
|
|
573
|
+
break
|
|
574
|
+
else:
|
|
575
|
+
# Track if this is an expert weight to enable early skipping
|
|
576
|
+
is_expert_weight = False
|
|
577
|
+
|
|
578
|
+
for mapping in expert_params_mapping:
|
|
579
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
|
580
|
+
if weight_name not in name:
|
|
581
|
+
continue
|
|
582
|
+
if "visual" in name or "audio_tower" in name:
|
|
583
|
+
continue
|
|
584
|
+
# Anyway, this is an expert weight and should not be
|
|
585
|
+
# attempted to load as other weights later
|
|
586
|
+
is_expert_weight = True
|
|
587
|
+
name_mapped = name.replace(weight_name, param_name)
|
|
588
|
+
if is_fused_expert:
|
|
589
|
+
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
|
590
|
+
if "experts.gate_up_proj" in name:
|
|
591
|
+
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
|
592
|
+
load_fused_expert_weights(
|
|
593
|
+
name_mapped,
|
|
594
|
+
params_dict,
|
|
595
|
+
loaded_weight[0],
|
|
596
|
+
"w1",
|
|
597
|
+
num_experts,
|
|
598
|
+
)
|
|
599
|
+
load_fused_expert_weights(
|
|
600
|
+
name_mapped,
|
|
601
|
+
params_dict,
|
|
602
|
+
loaded_weight[1],
|
|
603
|
+
"w3",
|
|
604
|
+
num_experts,
|
|
605
|
+
)
|
|
606
|
+
else:
|
|
607
|
+
load_fused_expert_weights(
|
|
608
|
+
name_mapped,
|
|
609
|
+
params_dict,
|
|
610
|
+
loaded_weight,
|
|
611
|
+
shard_id,
|
|
612
|
+
num_experts,
|
|
613
|
+
)
|
|
614
|
+
else:
|
|
615
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
|
616
|
+
if (
|
|
617
|
+
name_mapped.endswith(ignore_suffixes)
|
|
618
|
+
and name_mapped not in params_dict
|
|
619
|
+
):
|
|
620
|
+
continue
|
|
621
|
+
param = params_dict[name_mapped]
|
|
622
|
+
# We should ask the weight loader to return success or
|
|
623
|
+
# not here since otherwise we may skip experts with
|
|
624
|
+
# # other available replicas.
|
|
625
|
+
weight_loader = param.weight_loader
|
|
626
|
+
weight_loader(
|
|
627
|
+
param,
|
|
628
|
+
loaded_weight,
|
|
629
|
+
name_mapped,
|
|
630
|
+
shard_id=shard_id,
|
|
631
|
+
expert_id=expert_id,
|
|
632
|
+
)
|
|
633
|
+
name = name_mapped
|
|
634
|
+
break
|
|
635
|
+
else:
|
|
636
|
+
if is_expert_weight:
|
|
637
|
+
# This is an expert weight but not mapped to this rank, skip all remaining processing
|
|
638
|
+
continue
|
|
639
|
+
if "visual" in name or "audio_tower" in name:
|
|
640
|
+
# adapt to VisionAttention
|
|
641
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
|
642
|
+
name = name.replace(r"model.visual.", r"visual.")
|
|
643
|
+
name = name.replace(r"attn.out_proj.", r"attn.proj.")
|
|
644
|
+
|
|
645
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
|
646
|
+
if name.endswith(ignore_suffixes) and name not in params_dict:
|
|
647
|
+
continue
|
|
648
|
+
|
|
649
|
+
if name in params_dict.keys():
|
|
650
|
+
param = params_dict[name]
|
|
651
|
+
weight_loader = getattr(
|
|
652
|
+
param, "weight_loader", default_weight_loader
|
|
653
|
+
)
|
|
654
|
+
weight_loader(param, loaded_weight)
|
|
655
|
+
else:
|
|
656
|
+
logger.warning(
|
|
657
|
+
f"Loaded weight with {name=} not found in params_dict"
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
EntryClass = Qwen3OmniMoeForConditionalGeneration
|