sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
+
|
4
|
+
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
|
5
|
+
#
|
6
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
7
|
+
# and OPT implementations in this library. It has been modified from its
|
8
|
+
# original forms to accommodate minor architectural differences compared
|
9
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
10
|
+
#
|
11
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
12
|
+
# you may not use this file except in compliance with the License.
|
13
|
+
# You may obtain a copy of the License at
|
14
|
+
#
|
15
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
16
|
+
#
|
17
|
+
# Unless required by applicable law or agreed to in writing, software
|
18
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
19
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
20
|
+
# See the License for the specific language governing permissions and
|
21
|
+
# limitations under the License.
|
22
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/starcoder2.py
|
23
|
+
""" PyTorch Starcoder2 model."""
|
24
|
+
from collections.abc import Iterable
|
25
|
+
from typing import Optional, Tuple
|
26
|
+
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
from transformers import Starcoder2Config
|
30
|
+
|
31
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
32
|
+
from sglang.srt.layers.activation import get_act_fn
|
33
|
+
from sglang.srt.layers.linear import (
|
34
|
+
ColumnParallelLinear,
|
35
|
+
QKVParallelLinear,
|
36
|
+
RowParallelLinear,
|
37
|
+
)
|
38
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
42
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
43
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
44
|
+
ParallelLMHead,
|
45
|
+
VocabParallelEmbedding,
|
46
|
+
)
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
49
|
+
from sglang.srt.utils import add_prefix, make_layers
|
50
|
+
|
51
|
+
|
52
|
+
class Starcoder2Attention(nn.Module):
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
config: Starcoder2Config,
|
57
|
+
quant_config: Optional[QuantizationConfig] = None,
|
58
|
+
prefix: str = "",
|
59
|
+
layer_id: int = 0,
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
self.config = config
|
63
|
+
|
64
|
+
self.hidden_size = config.hidden_size
|
65
|
+
tp_size = get_tensor_model_parallel_world_size()
|
66
|
+
self.total_num_heads = config.num_attention_heads
|
67
|
+
assert self.total_num_heads % tp_size == 0
|
68
|
+
self.num_heads = self.total_num_heads // tp_size
|
69
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
70
|
+
if self.total_num_kv_heads >= tp_size:
|
71
|
+
# Number of KV heads is greater than TP size, so we partition
|
72
|
+
# the KV heads across multiple tensor parallel GPUs.
|
73
|
+
assert self.total_num_kv_heads % tp_size == 0
|
74
|
+
else:
|
75
|
+
# Number of KV heads is less than TP size, so we replicate
|
76
|
+
# the KV heads across multiple tensor parallel GPUs.
|
77
|
+
assert tp_size % self.total_num_kv_heads == 0
|
78
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
79
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
80
|
+
self.q_size = self.num_heads * self.head_dim
|
81
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
82
|
+
self.scaling = self.head_dim**-0.5
|
83
|
+
self.rope_theta = config.rope_theta
|
84
|
+
self.max_position_embeddings = config.max_position_embeddings
|
85
|
+
self.use_bias = config.use_bias
|
86
|
+
|
87
|
+
self.qkv_proj = QKVParallelLinear(
|
88
|
+
self.hidden_size,
|
89
|
+
self.head_dim,
|
90
|
+
self.total_num_heads,
|
91
|
+
self.total_num_kv_heads,
|
92
|
+
bias=self.use_bias,
|
93
|
+
quant_config=quant_config,
|
94
|
+
prefix=f"{prefix}.qkv_proj",
|
95
|
+
)
|
96
|
+
self.o_proj = RowParallelLinear(
|
97
|
+
self.total_num_heads * self.head_dim,
|
98
|
+
self.hidden_size,
|
99
|
+
bias=self.use_bias,
|
100
|
+
quant_config=quant_config,
|
101
|
+
prefix=f"{prefix}.o_proj",
|
102
|
+
)
|
103
|
+
self.rotary_emb = get_rope(
|
104
|
+
self.head_dim,
|
105
|
+
rotary_dim=self.head_dim,
|
106
|
+
max_position=self.max_position_embeddings,
|
107
|
+
base=int(self.rope_theta),
|
108
|
+
is_neox_style=True,
|
109
|
+
)
|
110
|
+
self.attn = RadixAttention(
|
111
|
+
self.num_heads,
|
112
|
+
self.head_dim,
|
113
|
+
self.scaling,
|
114
|
+
num_kv_heads=self.num_kv_heads,
|
115
|
+
layer_id=layer_id,
|
116
|
+
quant_config=quant_config,
|
117
|
+
prefix=f"{prefix}.attn",
|
118
|
+
)
|
119
|
+
|
120
|
+
def forward(
|
121
|
+
self,
|
122
|
+
positions: torch.Tensor,
|
123
|
+
hidden_states: torch.Tensor,
|
124
|
+
forward_batch: ForwardBatch,
|
125
|
+
) -> torch.Tensor:
|
126
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
127
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
128
|
+
q, k = self.rotary_emb(positions, q, k)
|
129
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
130
|
+
output, _ = self.o_proj(attn_output)
|
131
|
+
return output
|
132
|
+
|
133
|
+
|
134
|
+
class Starcoder2MLP(nn.Module):
|
135
|
+
|
136
|
+
def __init__(
|
137
|
+
self,
|
138
|
+
config: Starcoder2Config,
|
139
|
+
quant_config: Optional[QuantizationConfig] = None,
|
140
|
+
prefix: str = "",
|
141
|
+
):
|
142
|
+
super().__init__()
|
143
|
+
self.c_fc = ColumnParallelLinear(
|
144
|
+
config.hidden_size,
|
145
|
+
config.intermediate_size,
|
146
|
+
bias=config.use_bias,
|
147
|
+
quant_config=quant_config,
|
148
|
+
prefix=f"{prefix}.c_fc",
|
149
|
+
)
|
150
|
+
self.c_proj = RowParallelLinear(
|
151
|
+
config.intermediate_size,
|
152
|
+
config.hidden_size,
|
153
|
+
bias=config.use_bias,
|
154
|
+
quant_config=quant_config,
|
155
|
+
prefix=f"{prefix}.c_proj",
|
156
|
+
)
|
157
|
+
self.act = get_act_fn(config.hidden_act)
|
158
|
+
|
159
|
+
def forward(
|
160
|
+
self,
|
161
|
+
hidden_states: torch.Tensor,
|
162
|
+
) -> torch.Tensor:
|
163
|
+
hidden_states, _ = self.c_fc(hidden_states)
|
164
|
+
hidden_states = self.act(hidden_states)
|
165
|
+
hidden_states, _ = self.c_proj(hidden_states)
|
166
|
+
return hidden_states
|
167
|
+
|
168
|
+
|
169
|
+
class Starcoder2DecoderLayer(nn.Module):
|
170
|
+
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
config: Starcoder2Config,
|
174
|
+
layer_id: int,
|
175
|
+
quant_config: Optional[QuantizationConfig] = None,
|
176
|
+
prefix: str = "",
|
177
|
+
):
|
178
|
+
super().__init__()
|
179
|
+
self.hidden_size = config.hidden_size
|
180
|
+
self.self_attn = Starcoder2Attention(
|
181
|
+
config=config,
|
182
|
+
layer_id=layer_id,
|
183
|
+
quant_config=quant_config,
|
184
|
+
prefix=f"{prefix}.self_attn",
|
185
|
+
)
|
186
|
+
self.mlp = Starcoder2MLP(
|
187
|
+
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
188
|
+
)
|
189
|
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
190
|
+
self.post_attention_layernorm = nn.LayerNorm(
|
191
|
+
config.hidden_size, eps=config.norm_epsilon
|
192
|
+
)
|
193
|
+
|
194
|
+
def forward(
|
195
|
+
self,
|
196
|
+
positions: torch.Tensor,
|
197
|
+
hidden_states: torch.Tensor,
|
198
|
+
forward_batch: ForwardBatch,
|
199
|
+
) -> torch.Tensor:
|
200
|
+
# Self Attention
|
201
|
+
residual = hidden_states
|
202
|
+
hidden_states = self.input_layernorm(hidden_states)
|
203
|
+
hidden_states = self.self_attn(
|
204
|
+
positions=positions,
|
205
|
+
hidden_states=hidden_states,
|
206
|
+
forward_batch=forward_batch,
|
207
|
+
)
|
208
|
+
hidden_states = residual + hidden_states
|
209
|
+
|
210
|
+
# Fully Connected
|
211
|
+
residual = hidden_states
|
212
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
213
|
+
hidden_states = self.mlp(hidden_states)
|
214
|
+
hidden_states = residual + hidden_states
|
215
|
+
|
216
|
+
return hidden_states
|
217
|
+
|
218
|
+
|
219
|
+
class Starcoder2Model(nn.Module):
|
220
|
+
|
221
|
+
def __init__(
|
222
|
+
self,
|
223
|
+
config: Starcoder2Config,
|
224
|
+
quant_config: Optional[QuantizationConfig] = None,
|
225
|
+
prefix: str = "",
|
226
|
+
):
|
227
|
+
super().__init__()
|
228
|
+
|
229
|
+
self.config = config
|
230
|
+
self.vocab_size = config.vocab_size
|
231
|
+
|
232
|
+
self.embed_tokens = VocabParallelEmbedding(
|
233
|
+
config.vocab_size,
|
234
|
+
config.hidden_size,
|
235
|
+
quant_config=quant_config,
|
236
|
+
prefix=f"{prefix}.embed_tokens",
|
237
|
+
)
|
238
|
+
|
239
|
+
pp_group = get_pp_group()
|
240
|
+
pp_size = pp_group.world_size
|
241
|
+
pp_rank = pp_group.rank
|
242
|
+
self.start_layer = pp_rank * config.num_hidden_layers // pp_size
|
243
|
+
self.end_layer = (pp_rank + 1) * config.num_hidden_layers // pp_size
|
244
|
+
|
245
|
+
self.layers = make_layers(
|
246
|
+
config.num_hidden_layers,
|
247
|
+
lambda idx, prefix: Starcoder2DecoderLayer(
|
248
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
249
|
+
),
|
250
|
+
prefix=f"{prefix}.layers",
|
251
|
+
)
|
252
|
+
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
253
|
+
|
254
|
+
def forward(
|
255
|
+
self,
|
256
|
+
input_ids: torch.Tensor,
|
257
|
+
positions: torch.Tensor,
|
258
|
+
forward_batch: ForwardBatch,
|
259
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
260
|
+
) -> torch.Tensor:
|
261
|
+
if inputs_embeds is None:
|
262
|
+
hidden_states = self.embed_tokens(input_ids)
|
263
|
+
else:
|
264
|
+
hidden_states = inputs_embeds
|
265
|
+
for i in range(self.start_layer, self.end_layer):
|
266
|
+
layer = self.layers[i]
|
267
|
+
hidden_states = layer(
|
268
|
+
positions,
|
269
|
+
hidden_states,
|
270
|
+
forward_batch,
|
271
|
+
)
|
272
|
+
hidden_states = self.norm(hidden_states)
|
273
|
+
return hidden_states
|
274
|
+
|
275
|
+
|
276
|
+
class Starcoder2ForCausalLM(nn.Module):
|
277
|
+
|
278
|
+
def __init__(
|
279
|
+
self,
|
280
|
+
config: Starcoder2Config,
|
281
|
+
quant_config: Optional[QuantizationConfig] = None,
|
282
|
+
prefix: str = "",
|
283
|
+
):
|
284
|
+
super().__init__()
|
285
|
+
self.config = config
|
286
|
+
self.model = Starcoder2Model(
|
287
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
288
|
+
)
|
289
|
+
self.vocab_size = config.vocab_size
|
290
|
+
self.unpadded_vocab_size = config.vocab_size
|
291
|
+
if config.tie_word_embeddings:
|
292
|
+
self.lm_head = self.model.embed_tokens
|
293
|
+
else:
|
294
|
+
self.unpadded_vocab_size = config.vocab_size
|
295
|
+
self.lm_head = ParallelLMHead(
|
296
|
+
self.unpadded_vocab_size,
|
297
|
+
config.hidden_size,
|
298
|
+
org_num_embeddings=config.vocab_size,
|
299
|
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
300
|
+
quant_config=quant_config,
|
301
|
+
prefix=f"{prefix}.lm_head",
|
302
|
+
)
|
303
|
+
self.logits_processor = LogitsProcessor(config=config)
|
304
|
+
|
305
|
+
def forward(
|
306
|
+
self,
|
307
|
+
input_ids: torch.Tensor,
|
308
|
+
positions: torch.Tensor,
|
309
|
+
forward_batch: ForwardBatch,
|
310
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
311
|
+
) -> torch.Tensor:
|
312
|
+
hidden_states = self.model(
|
313
|
+
input_ids=input_ids,
|
314
|
+
positions=positions,
|
315
|
+
forward_batch=forward_batch,
|
316
|
+
inputs_embeds=inputs_embeds,
|
317
|
+
)
|
318
|
+
return self.logits_processor(
|
319
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
320
|
+
)
|
321
|
+
|
322
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
323
|
+
stacked_params_mapping = [
|
324
|
+
# (param_name, shard_name, shard_id)
|
325
|
+
("qkv_proj", "q_proj", "q"),
|
326
|
+
("qkv_proj", "k_proj", "k"),
|
327
|
+
("qkv_proj", "v_proj", "v"),
|
328
|
+
]
|
329
|
+
params_dict = dict(self.named_parameters())
|
330
|
+
|
331
|
+
for name, loaded_weight in weights:
|
332
|
+
if "rotary_emb.inv_freqs" in name:
|
333
|
+
continue
|
334
|
+
|
335
|
+
is_stacked = False
|
336
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
337
|
+
if weight_name in name:
|
338
|
+
name = name.replace(weight_name, param_name)
|
339
|
+
param = params_dict[name]
|
340
|
+
weight_loader = getattr(
|
341
|
+
param, "weight_loader", default_weight_loader
|
342
|
+
)
|
343
|
+
weight_loader(param, loaded_weight, shard_id)
|
344
|
+
is_stacked = True
|
345
|
+
break
|
346
|
+
if is_stacked:
|
347
|
+
continue
|
348
|
+
|
349
|
+
param = params_dict.get(name)
|
350
|
+
if param is None:
|
351
|
+
continue
|
352
|
+
|
353
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
354
|
+
weight_loader(param, loaded_weight)
|
355
|
+
|
356
|
+
|
357
|
+
EntryClass = Starcoder2ForCausalLM
|
@@ -66,8 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
66
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
67
|
from sglang.srt.utils import add_prefix
|
68
68
|
|
69
|
-
tp_size =
|
70
|
-
tp_rank =
|
69
|
+
tp_size: Optional[int] = None
|
70
|
+
tp_rank: Optional[int] = None
|
71
71
|
|
72
72
|
|
73
73
|
def gate_up_proj_weight_loader(
|
@@ -341,6 +341,13 @@ class LlamaModel(nn.Module):
|
|
341
341
|
quant_config: Optional[QuantizationConfig] = None,
|
342
342
|
) -> None:
|
343
343
|
super().__init__()
|
344
|
+
|
345
|
+
global tp_size, tp_rank
|
346
|
+
if tp_size is None:
|
347
|
+
tp_size = get_tensor_model_parallel_world_size()
|
348
|
+
if tp_rank is None:
|
349
|
+
tp_rank = get_tensor_model_parallel_rank()
|
350
|
+
|
344
351
|
self.config = config
|
345
352
|
self.padding_idx = config.pad_token_id
|
346
353
|
self.vocab_size = config.vocab_size
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
18
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
|
+
from sglang.srt.utils import is_cuda
|
20
|
+
|
21
|
+
_is_cuda = is_cuda()
|
22
|
+
|
23
|
+
|
24
|
+
if _is_cuda:
|
25
|
+
from sgl_kernel import FusedSetKVBufferArg
|
26
|
+
|
27
|
+
|
28
|
+
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
29
|
+
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
30
|
+
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
31
|
+
|
32
|
+
|
33
|
+
def create_fused_set_kv_buffer_arg(
|
34
|
+
value: torch.Tensor,
|
35
|
+
layer: RadixAttention,
|
36
|
+
forward_batch: ForwardBatch,
|
37
|
+
):
|
38
|
+
layer_id = layer.layer_id
|
39
|
+
token_to_kv_pool = forward_batch.token_to_kv_pool
|
40
|
+
|
41
|
+
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
42
|
+
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
43
|
+
|
44
|
+
return FusedSetKVBufferArg(
|
45
|
+
value=value,
|
46
|
+
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
47
|
+
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
48
|
+
k_scale=layer.k_scale,
|
49
|
+
v_scale=layer.v_scale,
|
50
|
+
cache_loc=forward_batch.out_cache_loc,
|
51
|
+
)
|
@@ -234,19 +234,27 @@ class BaseMultimodalProcessor(ABC):
|
|
234
234
|
and isinstance(processor.image_processor, BaseImageProcessorFast)
|
235
235
|
and not self.server_args.disable_fast_image_processor
|
236
236
|
):
|
237
|
-
|
237
|
+
if not _is_npu:
|
238
|
+
kwargs["device"] = "cuda"
|
239
|
+
elif processor.__class__.__name__ not in {
|
240
|
+
"Qwen2_5_VLProcessor",
|
241
|
+
"Qwen3VLProcessor",
|
242
|
+
}:
|
243
|
+
# Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
|
244
|
+
kwargs["device"] = "npu"
|
238
245
|
result = processor.__call__(
|
239
246
|
text=[input_text],
|
240
247
|
padding=True,
|
241
248
|
return_tensors="pt",
|
242
249
|
**kwargs,
|
243
250
|
)
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
251
|
+
if not self.server_args.keep_mm_feature_on_device:
|
252
|
+
# move feature tensors to cpu
|
253
|
+
for feature_name in self.FEATURE_NAMES:
|
254
|
+
if feature_name in result and isinstance(
|
255
|
+
result[feature_name], torch.Tensor
|
256
|
+
):
|
257
|
+
result[feature_name] = result[feature_name].to("cpu")
|
250
258
|
|
251
259
|
return result
|
252
260
|
|
@@ -5,6 +5,7 @@ from typing import Dict, List, Union
|
|
5
5
|
|
6
6
|
from PIL import Image
|
7
7
|
|
8
|
+
from sglang.srt.models.dots_ocr import DotsOCRForCausalLM
|
8
9
|
from sglang.srt.models.dots_vlm import DotsVLMForCausalLM
|
9
10
|
from sglang.srt.multimodal.processors.base_processor import (
|
10
11
|
BaseMultimodalProcessor,
|
@@ -14,7 +15,7 @@ from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
|
|
14
15
|
|
15
16
|
|
16
17
|
class DotsVLMImageProcessor(BaseMultimodalProcessor):
|
17
|
-
models = [DotsVLMForCausalLM]
|
18
|
+
models = [DotsVLMForCausalLM, DotsOCRForCausalLM]
|
18
19
|
|
19
20
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
20
21
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
@@ -82,11 +83,9 @@ class DotsVLMImageProcessor(BaseMultimodalProcessor):
|
|
82
83
|
for image in base_output.images
|
83
84
|
]
|
84
85
|
base_output.images = await asyncio.gather(*resize_tasks)
|
85
|
-
|
86
86
|
combined_mm_item, input_ids, _ = self.process_and_combine_mm_data(
|
87
87
|
base_output, self.mm_tokens
|
88
88
|
)
|
89
|
-
|
90
89
|
if combined_mm_item is None:
|
91
90
|
return None
|
92
91
|
|
@@ -1,5 +1,7 @@
|
|
1
1
|
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
2
2
|
|
3
|
+
from functools import lru_cache
|
4
|
+
|
3
5
|
import numpy as np
|
4
6
|
import torch
|
5
7
|
import torchvision.transforms as T
|
@@ -19,6 +21,20 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
19
21
|
class InternVLImageProcessor(BaseMultimodalProcessor):
|
20
22
|
models = [InternVLChatModel, InternS1ForConditionalGeneration]
|
21
23
|
|
24
|
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
25
|
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
26
|
+
|
27
|
+
@staticmethod
|
28
|
+
@lru_cache(maxsize=1)
|
29
|
+
def _get_normalize_tensors(device="cuda", dtype=torch.float32):
|
30
|
+
mean = torch.tensor(
|
31
|
+
InternVLImageProcessor.IMAGENET_MEAN, device=device, dtype=dtype
|
32
|
+
).view(-1, 1, 1)
|
33
|
+
std = torch.tensor(
|
34
|
+
InternVLImageProcessor.IMAGENET_STD, device=device, dtype=dtype
|
35
|
+
).view(-1, 1, 1)
|
36
|
+
return mean, std
|
37
|
+
|
22
38
|
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
|
23
39
|
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
|
24
40
|
image_size = (
|
@@ -88,6 +104,8 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
88
104
|
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
89
105
|
)
|
90
106
|
|
107
|
+
mean, std = InternVLImageProcessor._get_normalize_tensors(device="cuda")
|
108
|
+
|
91
109
|
for frame_index in frame_indices:
|
92
110
|
# Load frame
|
93
111
|
frame = vr[frame_index]
|
@@ -97,10 +115,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
97
115
|
img_np = frame.asnumpy()
|
98
116
|
img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
|
99
117
|
|
100
|
-
# Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
|
101
|
-
mean = img.mean(dim=[1, 2], keepdim=True)
|
102
|
-
# Prevent division by zero; clamp to minimum value of 1e-6
|
103
|
-
std = img.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
|
104
118
|
img = (img - mean) / std
|
105
119
|
|
106
120
|
tiles = InternVLImageProcessor.dynamic_preprocess(
|
@@ -188,6 +202,8 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
188
202
|
num_patches_list = []
|
189
203
|
pixel_values = []
|
190
204
|
|
205
|
+
mean, std = InternVLImageProcessor._get_normalize_tensors(device="cuda")
|
206
|
+
|
191
207
|
# Process each input with allocated frames
|
192
208
|
for image_index, image in enumerate(base_output.images):
|
193
209
|
try:
|
@@ -201,10 +217,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
201
217
|
else:
|
202
218
|
tensor = image.cuda() # assume already tensor
|
203
219
|
|
204
|
-
# Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
|
205
|
-
mean = tensor.mean(dim=[1, 2], keepdim=True)
|
206
|
-
# Prevent division by zero; clamp to minimum value of 1e-6
|
207
|
-
std = tensor.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
|
208
220
|
tensor = (tensor - mean) / std
|
209
221
|
tiles = self.dynamic_preprocess(
|
210
222
|
tensor, image_size=448, max_num=12, use_thumbnail=True
|
@@ -12,6 +12,8 @@ from torchvision.transforms import InterpolationMode
|
|
12
12
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
13
13
|
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
14
14
|
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
15
|
+
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
16
|
+
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
15
17
|
from sglang.srt.multimodal.processors.base_processor import (
|
16
18
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
17
19
|
)
|
@@ -209,7 +211,12 @@ async def preprocess_video(
|
|
209
211
|
|
210
212
|
# Compatible with Qwen2VL and Qwen2_5VL
|
211
213
|
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
212
|
-
models = [
|
214
|
+
models = [
|
215
|
+
Qwen2VLForConditionalGeneration,
|
216
|
+
Qwen2_5_VLForConditionalGeneration,
|
217
|
+
Qwen3VLForConditionalGeneration,
|
218
|
+
Qwen3VLMoeForConditionalGeneration,
|
219
|
+
]
|
213
220
|
|
214
221
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
215
222
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from typing import List, Union
|
2
|
+
|
3
|
+
from sglang.srt.models.sarashina2_vision import Sarashina2VisionForCausalLM
|
4
|
+
from sglang.srt.multimodal.processors.base_processor import (
|
5
|
+
BaseMultimodalProcessor,
|
6
|
+
MultimodalSpecialTokens,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class Sarashina2VisionProcessor(BaseMultimodalProcessor):
|
11
|
+
models = [Sarashina2VisionForCausalLM]
|
12
|
+
|
13
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
14
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
15
|
+
|
16
|
+
# Sarashina2Vision specific tokens (default is <|file|>)
|
17
|
+
self.IMAGE_TOKEN = "<|file|>"
|
18
|
+
self.IM_TOKEN_ID = getattr(hf_config, "image_token_index", 14)
|
19
|
+
self.IM_START_ID = getattr(hf_config, "start_image_token_index", 102397)
|
20
|
+
self.IM_END_ID = getattr(hf_config, "end_image_token_index", 102398)
|
21
|
+
|
22
|
+
self.mm_tokens = MultimodalSpecialTokens(
|
23
|
+
image_token=self.IMAGE_TOKEN,
|
24
|
+
image_token_id=self.IM_TOKEN_ID,
|
25
|
+
).build(_processor)
|
26
|
+
|
27
|
+
# Patch the processor's image processor to handle parameter compatibility
|
28
|
+
if hasattr(_processor, "image_processor") and hasattr(
|
29
|
+
_processor.image_processor, "_preprocess"
|
30
|
+
):
|
31
|
+
original_preprocess = _processor.image_processor._preprocess
|
32
|
+
|
33
|
+
def patched_preprocess(*args, **kwargs):
|
34
|
+
# Filter kwargs to only include parameters that the custom _preprocess method accepts
|
35
|
+
# Based on Sarashina2VisionImageProcessor._preprocess signature
|
36
|
+
allowed_params = {
|
37
|
+
"do_resize",
|
38
|
+
"resample",
|
39
|
+
"do_rescale",
|
40
|
+
"rescale_factor",
|
41
|
+
"do_normalize",
|
42
|
+
"image_mean",
|
43
|
+
"image_std",
|
44
|
+
"do_convert_rgb",
|
45
|
+
"data_format",
|
46
|
+
"input_data_format",
|
47
|
+
}
|
48
|
+
filtered_kwargs = {
|
49
|
+
k: v for k, v in kwargs.items() if k in allowed_params
|
50
|
+
}
|
51
|
+
return original_preprocess(*args, **filtered_kwargs)
|
52
|
+
|
53
|
+
_processor.image_processor._preprocess = patched_preprocess
|
54
|
+
|
55
|
+
async def process_mm_data_async(
|
56
|
+
self,
|
57
|
+
image_data: List[Union[str, bytes]],
|
58
|
+
input_text,
|
59
|
+
request_obj,
|
60
|
+
*args,
|
61
|
+
**kwargs,
|
62
|
+
):
|
63
|
+
"""Process image data for Sarashina2Vision model using standard SGLang pattern."""
|
64
|
+
base_output = self.load_mm_data(
|
65
|
+
prompt=input_text,
|
66
|
+
image_data=image_data,
|
67
|
+
multimodal_tokens=self.mm_tokens,
|
68
|
+
)
|
69
|
+
|
70
|
+
mm_items, input_ids, ret = self.process_and_combine_mm_data(
|
71
|
+
base_output=base_output,
|
72
|
+
mm_tokens=self.mm_tokens,
|
73
|
+
)
|
74
|
+
|
75
|
+
return {
|
76
|
+
"mm_items": mm_items,
|
77
|
+
"input_ids": input_ids.tolist(),
|
78
|
+
"im_token_id": self.mm_tokens.image_token_id,
|
79
|
+
"im_start_id": self.IM_START_ID,
|
80
|
+
"im_end_id": self.IM_END_ID,
|
81
|
+
}
|
@@ -89,6 +89,12 @@ def detect_jinja_template_content_format(chat_template: str) -> str:
|
|
89
89
|
- If template has loops like {%- for content in message['content'] -%} → 'openai'
|
90
90
|
- Otherwise → 'string'
|
91
91
|
"""
|
92
|
+
# Shortcut for multimodal templates
|
93
|
+
if any(
|
94
|
+
keyword in chat_template for keyword in ["image", "audio", "video", "vision"]
|
95
|
+
):
|
96
|
+
return "openai"
|
97
|
+
|
92
98
|
jinja_ast = _try_extract_ast(chat_template)
|
93
99
|
if jinja_ast is None:
|
94
100
|
return "string"
|