sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,471 @@
|
|
1
|
+
# Copyright 2025 Qwen Team
|
2
|
+
# Copyright 2025 SGLang Team
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
16
|
+
import logging
|
17
|
+
from functools import lru_cache, partial
|
18
|
+
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import torch
|
22
|
+
import torch.nn as nn
|
23
|
+
import torch.nn.functional as F
|
24
|
+
from einops import rearrange
|
25
|
+
from transformers import BatchFeature
|
26
|
+
from transformers.activations import ACT2FN
|
27
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
28
|
+
Qwen2_5_VisionRotaryEmbedding,
|
29
|
+
)
|
30
|
+
|
31
|
+
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
|
32
|
+
from sglang.srt.distributed import (
|
33
|
+
get_moe_expert_parallel_world_size,
|
34
|
+
get_pp_group,
|
35
|
+
get_tensor_model_parallel_rank,
|
36
|
+
)
|
37
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
39
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
40
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
|
+
from sglang.srt.layers.utils import get_layer_id
|
42
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
43
|
+
from sglang.srt.managers.mm_utils import (
|
44
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
45
|
+
general_mm_embed_routine,
|
46
|
+
)
|
47
|
+
from sglang.srt.managers.schedule_batch import (
|
48
|
+
MultimodalDataItem,
|
49
|
+
MultimodalInputs,
|
50
|
+
global_server_args_dict,
|
51
|
+
)
|
52
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
53
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
54
|
+
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
55
|
+
from sglang.srt.models.qwen3_vl import (
|
56
|
+
Qwen3_VisionTransformer,
|
57
|
+
Qwen3VLForConditionalGeneration,
|
58
|
+
)
|
59
|
+
from sglang.srt.utils import add_prefix
|
60
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
61
|
+
|
62
|
+
logger = logging.getLogger(__name__)
|
63
|
+
|
64
|
+
cached_get_processor = lru_cache(get_processor)
|
65
|
+
|
66
|
+
|
67
|
+
class Qwen3MoeLLMModel(Qwen3MoeModel):
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
*,
|
71
|
+
config: Qwen3VLMoeConfig,
|
72
|
+
quant_config: Optional[QuantizationConfig] = None,
|
73
|
+
prefix: str = "",
|
74
|
+
):
|
75
|
+
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
76
|
+
|
77
|
+
self.hidden_size = config.hidden_size
|
78
|
+
|
79
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
80
|
+
return self.embed_tokens
|
81
|
+
|
82
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
83
|
+
# in qwen-vl, last dim is the same
|
84
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
85
|
+
self.visual.dtype
|
86
|
+
)
|
87
|
+
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
88
|
+
assert pixel_values.dim() == 2, pixel_values.dim()
|
89
|
+
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
90
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
91
|
+
return image_embeds
|
92
|
+
|
93
|
+
def forward(
|
94
|
+
self,
|
95
|
+
input_ids: torch.Tensor,
|
96
|
+
positions: torch.Tensor,
|
97
|
+
forward_batch: ForwardBatch,
|
98
|
+
input_embeds: torch.Tensor = None,
|
99
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
100
|
+
input_deepstack_embeds: Optional[torch.Tensor] = None,
|
101
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
102
|
+
if self.pp_group.is_first_rank:
|
103
|
+
if input_embeds is None:
|
104
|
+
hidden_states = self.embed_tokens(input_ids)
|
105
|
+
else:
|
106
|
+
hidden_states = input_embeds
|
107
|
+
residual = None
|
108
|
+
else:
|
109
|
+
assert pp_proxy_tensors is not None
|
110
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
111
|
+
residual = pp_proxy_tensors["residual"]
|
112
|
+
|
113
|
+
aux_hidden_states = []
|
114
|
+
for layer_idx, layer in enumerate(
|
115
|
+
self.layers[self.start_layer : self.end_layer]
|
116
|
+
):
|
117
|
+
layer_idx = layer_idx + self.start_layer
|
118
|
+
if layer_idx in self.layers_to_capture:
|
119
|
+
aux_hidden_states.append(
|
120
|
+
hidden_states + residual if residual is not None else hidden_states
|
121
|
+
)
|
122
|
+
|
123
|
+
hidden_states, residual = layer(
|
124
|
+
positions,
|
125
|
+
hidden_states,
|
126
|
+
forward_batch,
|
127
|
+
residual,
|
128
|
+
)
|
129
|
+
|
130
|
+
# process deepstack
|
131
|
+
if input_deepstack_embeds is not None and layer_idx in range(3):
|
132
|
+
sep = self.hidden_size * layer_idx
|
133
|
+
hidden_states = (
|
134
|
+
hidden_states
|
135
|
+
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
136
|
+
)
|
137
|
+
|
138
|
+
if not self.pp_group.is_last_rank:
|
139
|
+
return PPProxyTensors(
|
140
|
+
{
|
141
|
+
"hidden_states": hidden_states,
|
142
|
+
"residual": residual,
|
143
|
+
}
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
if hidden_states.shape[0] != 0:
|
147
|
+
if residual is None:
|
148
|
+
hidden_states = self.norm(hidden_states)
|
149
|
+
else:
|
150
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
151
|
+
|
152
|
+
if len(aux_hidden_states) == 0:
|
153
|
+
return hidden_states
|
154
|
+
|
155
|
+
return hidden_states, aux_hidden_states
|
156
|
+
|
157
|
+
|
158
|
+
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
159
|
+
def __init__(
|
160
|
+
self,
|
161
|
+
*,
|
162
|
+
config: Qwen3VLMoeConfig,
|
163
|
+
quant_config: Optional[QuantizationConfig] = None,
|
164
|
+
prefix: str = "",
|
165
|
+
):
|
166
|
+
super(Qwen3VLForConditionalGeneration, self).__init__()
|
167
|
+
self.config = config
|
168
|
+
|
169
|
+
self.visual = Qwen3_VisionTransformer(
|
170
|
+
config.vision_config,
|
171
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
172
|
+
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
173
|
+
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
174
|
+
quant_config=quant_config,
|
175
|
+
prefix=add_prefix("visual", prefix),
|
176
|
+
)
|
177
|
+
|
178
|
+
self.model = Qwen3MoeLLMModel(
|
179
|
+
config=config,
|
180
|
+
quant_config=quant_config,
|
181
|
+
prefix=add_prefix("model", prefix),
|
182
|
+
)
|
183
|
+
|
184
|
+
if config.tie_word_embeddings:
|
185
|
+
self.lm_head = self.model.embed_tokens
|
186
|
+
else:
|
187
|
+
self.lm_head = ParallelLMHead(
|
188
|
+
config.vocab_size,
|
189
|
+
config.hidden_size,
|
190
|
+
quant_config=quant_config,
|
191
|
+
prefix=add_prefix("lm_head", prefix),
|
192
|
+
)
|
193
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
194
|
+
|
195
|
+
self.logits_processor = LogitsProcessor(config)
|
196
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
197
|
+
|
198
|
+
# deepstack
|
199
|
+
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
200
|
+
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
201
|
+
|
202
|
+
@property
|
203
|
+
def use_deepstack(self) -> bool:
|
204
|
+
return hasattr(self, "deepstack_visual_indexes")
|
205
|
+
|
206
|
+
def forward(
|
207
|
+
self,
|
208
|
+
input_ids: torch.Tensor,
|
209
|
+
positions: torch.Tensor,
|
210
|
+
forward_batch: ForwardBatch,
|
211
|
+
get_embedding: bool = False,
|
212
|
+
):
|
213
|
+
"""Run forward pass for Qwen3-VL.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
input_ids: Flattened (concatenated) input_ids corresponding to a
|
217
|
+
batch.
|
218
|
+
positions: Flattened (concatenated) position ids corresponding to a
|
219
|
+
batch.
|
220
|
+
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
221
|
+
opensource models), the shape will be `(3, seq_len)`,
|
222
|
+
otherwise it will be `(seq_len,).
|
223
|
+
(Use input_metadata.mrope_positions to replace it)
|
224
|
+
"""
|
225
|
+
if self.is_mrope_enabled:
|
226
|
+
positions = forward_batch.mrope_positions
|
227
|
+
|
228
|
+
if not (
|
229
|
+
forward_batch.forward_mode.is_decode()
|
230
|
+
or not forward_batch.contains_image_inputs()
|
231
|
+
):
|
232
|
+
if self.is_mrope_enabled:
|
233
|
+
assert positions.ndim == 2 and positions.size(0) == 3, (
|
234
|
+
"multimodal section rotary embedding requires "
|
235
|
+
f"(3, seq_len) positions, but got {positions.size()}"
|
236
|
+
)
|
237
|
+
|
238
|
+
hidden_states = general_mm_embed_routine(
|
239
|
+
input_ids=input_ids,
|
240
|
+
forward_batch=forward_batch,
|
241
|
+
language_model=self.model,
|
242
|
+
multimodal_model=self,
|
243
|
+
positions=positions,
|
244
|
+
use_deepstack=self.use_deepstack,
|
245
|
+
)
|
246
|
+
|
247
|
+
if not get_embedding:
|
248
|
+
return self.logits_processor(
|
249
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
250
|
+
)
|
251
|
+
else:
|
252
|
+
return self.pooler(hidden_states, forward_batch)
|
253
|
+
|
254
|
+
def load_fused_expert_weights(
|
255
|
+
self,
|
256
|
+
name: str,
|
257
|
+
params_dict: dict,
|
258
|
+
loaded_weight: torch.Tensor,
|
259
|
+
shard_id: str,
|
260
|
+
num_experts: int,
|
261
|
+
):
|
262
|
+
param = params_dict[name]
|
263
|
+
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
264
|
+
weight_loader = param.weight_loader
|
265
|
+
ep_rank = get_tensor_model_parallel_rank()
|
266
|
+
ep_size = get_moe_expert_parallel_world_size()
|
267
|
+
if ep_size == 1:
|
268
|
+
for expert_id in range(num_experts):
|
269
|
+
curr_expert_weight = loaded_weight[expert_id]
|
270
|
+
weight_loader(
|
271
|
+
param,
|
272
|
+
curr_expert_weight,
|
273
|
+
name,
|
274
|
+
shard_id,
|
275
|
+
expert_id,
|
276
|
+
)
|
277
|
+
else:
|
278
|
+
experts_per_ep = num_experts // ep_size
|
279
|
+
start_expert = ep_rank * experts_per_ep
|
280
|
+
end_expert = (
|
281
|
+
(ep_rank + 1) * experts_per_ep
|
282
|
+
if ep_rank != ep_size - 1
|
283
|
+
else num_experts
|
284
|
+
)
|
285
|
+
|
286
|
+
for idx, expert_id in enumerate(range(start_expert, end_expert)):
|
287
|
+
curr_expert_weight = loaded_weight[expert_id]
|
288
|
+
weight_loader(
|
289
|
+
param,
|
290
|
+
curr_expert_weight,
|
291
|
+
name,
|
292
|
+
shard_id,
|
293
|
+
idx,
|
294
|
+
)
|
295
|
+
return True
|
296
|
+
|
297
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
298
|
+
stacked_params_mapping = [
|
299
|
+
# (param_name, shard_name, shard_id)
|
300
|
+
(".qkv_proj", ".q_proj", "q"),
|
301
|
+
(".qkv_proj", ".k_proj", "k"),
|
302
|
+
(".qkv_proj", ".v_proj", "v"),
|
303
|
+
("gate_up_proj", "up_proj", 1),
|
304
|
+
("gate_up_proj", "gate_proj", 0),
|
305
|
+
]
|
306
|
+
|
307
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
308
|
+
ckpt_gate_proj_name="gate_proj",
|
309
|
+
ckpt_down_proj_name="down_proj",
|
310
|
+
ckpt_up_proj_name="up_proj",
|
311
|
+
num_experts=self.config.num_experts,
|
312
|
+
)
|
313
|
+
|
314
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
315
|
+
ignore_suffixes = (
|
316
|
+
".bias",
|
317
|
+
"_bias",
|
318
|
+
".k_scale",
|
319
|
+
"_k_scale",
|
320
|
+
".v_scale",
|
321
|
+
"_v_scale",
|
322
|
+
".weight_scale",
|
323
|
+
"_weight_scale",
|
324
|
+
".input_scale",
|
325
|
+
"_input_scale",
|
326
|
+
)
|
327
|
+
|
328
|
+
is_fused_expert = False
|
329
|
+
fused_expert_params_mapping = [
|
330
|
+
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
331
|
+
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
332
|
+
]
|
333
|
+
|
334
|
+
num_experts = self.config.num_experts
|
335
|
+
|
336
|
+
# Cache params_dict to avoid repeated expensive traversal of model parameters
|
337
|
+
if not hasattr(self, "_cached_params_dict"):
|
338
|
+
self._cached_params_dict = dict(self.named_parameters())
|
339
|
+
params_dict = self._cached_params_dict
|
340
|
+
for name, loaded_weight in weights:
|
341
|
+
if "language_model" in name:
|
342
|
+
name = name.replace(r"model.language_model.", r"model.")
|
343
|
+
|
344
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
345
|
+
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
346
|
+
is_fused_expert = True
|
347
|
+
expert_params_mapping = fused_expert_params_mapping
|
348
|
+
|
349
|
+
# Skip non-stacked layers and experts (experts handled below).
|
350
|
+
if weight_name not in name:
|
351
|
+
continue
|
352
|
+
if "visual" in name:
|
353
|
+
continue
|
354
|
+
|
355
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
356
|
+
# Since we handle the experts below in expert_params_mapping,
|
357
|
+
# we need to skip here BEFORE we update the name, otherwise
|
358
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
359
|
+
# will then be updated below in expert_params_mapping
|
360
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
361
|
+
if "mlp.experts" in name:
|
362
|
+
continue
|
363
|
+
name = name.replace(weight_name, param_name)
|
364
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
365
|
+
if name.endswith(ignore_suffixes) and name not in params_dict:
|
366
|
+
continue
|
367
|
+
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
|
368
|
+
# if is_pp_missing_parameter(name, self):
|
369
|
+
# continue
|
370
|
+
|
371
|
+
if name not in params_dict:
|
372
|
+
continue
|
373
|
+
|
374
|
+
param = params_dict[name]
|
375
|
+
weight_loader = param.weight_loader
|
376
|
+
weight_loader(param, loaded_weight, shard_id)
|
377
|
+
break
|
378
|
+
else:
|
379
|
+
# Track if this is an expert weight to enable early skipping
|
380
|
+
is_expert_weight = False
|
381
|
+
|
382
|
+
for mapping in expert_params_mapping:
|
383
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
384
|
+
if weight_name not in name:
|
385
|
+
continue
|
386
|
+
if "visual" in name:
|
387
|
+
continue
|
388
|
+
# Anyway, this is an expert weight and should not be
|
389
|
+
# attempted to load as other weights later
|
390
|
+
is_expert_weight = True
|
391
|
+
name_mapped = name.replace(weight_name, param_name)
|
392
|
+
if is_fused_expert:
|
393
|
+
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
394
|
+
if "experts.gate_up_proj" in name:
|
395
|
+
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
396
|
+
self.load_fused_expert_weights(
|
397
|
+
name_mapped,
|
398
|
+
params_dict,
|
399
|
+
loaded_weight[0],
|
400
|
+
"w1",
|
401
|
+
num_experts,
|
402
|
+
)
|
403
|
+
self.load_fused_expert_weights(
|
404
|
+
name_mapped,
|
405
|
+
params_dict,
|
406
|
+
loaded_weight[1],
|
407
|
+
"w3",
|
408
|
+
num_experts,
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
self.load_fused_expert_weights(
|
412
|
+
name_mapped,
|
413
|
+
params_dict,
|
414
|
+
loaded_weight,
|
415
|
+
shard_id,
|
416
|
+
num_experts,
|
417
|
+
)
|
418
|
+
else:
|
419
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
420
|
+
if (
|
421
|
+
name_mapped.endswith(ignore_suffixes)
|
422
|
+
and name_mapped not in params_dict
|
423
|
+
):
|
424
|
+
continue
|
425
|
+
param = params_dict[name_mapped]
|
426
|
+
# We should ask the weight loader to return success or
|
427
|
+
# not here since otherwise we may skip experts with
|
428
|
+
# # other available replicas.
|
429
|
+
weight_loader = param.weight_loader
|
430
|
+
weight_loader(
|
431
|
+
param,
|
432
|
+
loaded_weight,
|
433
|
+
name_mapped,
|
434
|
+
shard_id=shard_id,
|
435
|
+
expert_id=expert_id,
|
436
|
+
)
|
437
|
+
name = name_mapped
|
438
|
+
break
|
439
|
+
else:
|
440
|
+
if is_expert_weight:
|
441
|
+
# This is an expert weight but not mapped to this rank, skip all remaining processing
|
442
|
+
continue
|
443
|
+
if "visual" in name:
|
444
|
+
# adapt to VisionAttention
|
445
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
446
|
+
name = name.replace(r"model.visual.", r"visual.")
|
447
|
+
|
448
|
+
# Skip loading extra parameters for GPTQ/modelopt models.
|
449
|
+
if name.endswith(ignore_suffixes) and name not in params_dict:
|
450
|
+
continue
|
451
|
+
|
452
|
+
if name in params_dict.keys():
|
453
|
+
param = params_dict[name]
|
454
|
+
weight_loader = getattr(
|
455
|
+
param, "weight_loader", default_weight_loader
|
456
|
+
)
|
457
|
+
weight_loader(param, loaded_weight)
|
458
|
+
else:
|
459
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
460
|
+
|
461
|
+
# TODO mimic deepseek
|
462
|
+
# Lazy initialization of expert weights cache to avoid slowing down load_weights
|
463
|
+
# if not hasattr(self, "routed_experts_weights_of_layer"):
|
464
|
+
# self.routed_experts_weights_of_layer = {
|
465
|
+
# layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
466
|
+
# for layer_id in range(self.start_layer, self.end_layer)
|
467
|
+
# if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
|
468
|
+
# }
|
469
|
+
|
470
|
+
|
471
|
+
EntryClass = Qwen3VLMoeForConditionalGeneration
|
sglang/srt/models/registry.py
CHANGED
@@ -17,6 +17,18 @@ class _ModelRegistry:
|
|
17
17
|
# Keyed by model_arch
|
18
18
|
models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
|
19
19
|
|
20
|
+
def register(self, package_name: str, overwrite: bool = False):
|
21
|
+
new_models = import_model_classes(package_name)
|
22
|
+
if overwrite:
|
23
|
+
self.models.update(new_models)
|
24
|
+
else:
|
25
|
+
for arch, cls in new_models.items():
|
26
|
+
if arch in self.models:
|
27
|
+
raise ValueError(
|
28
|
+
f"Model architecture {arch} already registered. Set overwrite=True to replace."
|
29
|
+
)
|
30
|
+
self.models[arch] = cls
|
31
|
+
|
20
32
|
def get_supported_archs(self) -> AbstractSet[str]:
|
21
33
|
return self.models.keys()
|
22
34
|
|
@@ -74,9 +86,8 @@ class _ModelRegistry:
|
|
74
86
|
|
75
87
|
|
76
88
|
@lru_cache()
|
77
|
-
def import_model_classes():
|
89
|
+
def import_model_classes(package_name: str):
|
78
90
|
model_arch_name_to_cls = {}
|
79
|
-
package_name = "sglang.srt.models"
|
80
91
|
package = importlib.import_module(package_name)
|
81
92
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
82
93
|
if not ispkg:
|
@@ -104,4 +115,5 @@ def import_model_classes():
|
|
104
115
|
return model_arch_name_to_cls
|
105
116
|
|
106
117
|
|
107
|
-
ModelRegistry = _ModelRegistry(
|
118
|
+
ModelRegistry = _ModelRegistry()
|
119
|
+
ModelRegistry.register("sglang.srt.models")
|