sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,787 @@
|
|
1
|
+
# Copyright 2025 Qwen Team
|
2
|
+
# Copyright 2025 SGLang Team
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
16
|
+
import logging
|
17
|
+
from functools import lru_cache, partial
|
18
|
+
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import torch
|
22
|
+
import torch.nn as nn
|
23
|
+
import torch.nn.functional as F
|
24
|
+
from einops import rearrange
|
25
|
+
from transformers.activations import ACT2FN
|
26
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
27
|
+
Qwen2_5_VisionRotaryEmbedding,
|
28
|
+
)
|
29
|
+
|
30
|
+
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
|
31
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
32
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
33
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
34
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
35
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
37
|
+
from sglang.srt.managers.mm_utils import (
|
38
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
39
|
+
general_mm_embed_routine,
|
40
|
+
)
|
41
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
43
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
44
|
+
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
45
|
+
from sglang.srt.models.qwen3 import Qwen3Model
|
46
|
+
from sglang.srt.utils import add_prefix
|
47
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
48
|
+
|
49
|
+
logger = logging.getLogger(__name__)
|
50
|
+
|
51
|
+
# === Vision Encoder === #
|
52
|
+
|
53
|
+
|
54
|
+
class Qwen3_VisionMLP(nn.Module):
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
in_features: int,
|
59
|
+
hidden_features: int,
|
60
|
+
bias: bool = True,
|
61
|
+
hidden_act="silu",
|
62
|
+
quant_config: Optional[QuantizationConfig] = None,
|
63
|
+
prefix: str = "",
|
64
|
+
):
|
65
|
+
super().__init__()
|
66
|
+
self.linear_fc1 = ColumnParallelLinear(
|
67
|
+
in_features,
|
68
|
+
hidden_features,
|
69
|
+
bias=bias,
|
70
|
+
quant_config=quant_config,
|
71
|
+
prefix=add_prefix("linear_fc1", prefix),
|
72
|
+
)
|
73
|
+
self.linear_fc2 = RowParallelLinear(
|
74
|
+
hidden_features,
|
75
|
+
in_features,
|
76
|
+
bias=bias,
|
77
|
+
quant_config=quant_config,
|
78
|
+
prefix=add_prefix("linear_fc2", prefix),
|
79
|
+
)
|
80
|
+
self.act = ACT2FN[hidden_act]
|
81
|
+
|
82
|
+
def forward(self, x: torch.Tensor):
|
83
|
+
x_fc1, _ = self.linear_fc1(x)
|
84
|
+
mlp_output, _ = self.linear_fc2(self.act(x_fc1))
|
85
|
+
return mlp_output
|
86
|
+
|
87
|
+
|
88
|
+
class Qwen3VLVisionPatchEmbed(nn.Module):
|
89
|
+
def __init__(self, config) -> None:
|
90
|
+
super().__init__()
|
91
|
+
self.patch_size = config.patch_size
|
92
|
+
self.temporal_patch_size = config.temporal_patch_size
|
93
|
+
self.in_channels = config.in_channels
|
94
|
+
self.embed_dim = config.hidden_size
|
95
|
+
|
96
|
+
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
97
|
+
self.proj = nn.Conv3d(
|
98
|
+
self.in_channels,
|
99
|
+
self.embed_dim,
|
100
|
+
kernel_size=kernel_size,
|
101
|
+
stride=kernel_size,
|
102
|
+
bias=True,
|
103
|
+
)
|
104
|
+
|
105
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
106
|
+
target_dtype = self.proj.weight.dtype
|
107
|
+
hidden_states = hidden_states.view(
|
108
|
+
-1,
|
109
|
+
self.in_channels,
|
110
|
+
self.temporal_patch_size,
|
111
|
+
self.patch_size,
|
112
|
+
self.patch_size,
|
113
|
+
)
|
114
|
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
|
115
|
+
-1, self.embed_dim
|
116
|
+
)
|
117
|
+
return hidden_states
|
118
|
+
|
119
|
+
|
120
|
+
class Qwen3_VisionBlock(nn.Module):
|
121
|
+
|
122
|
+
def __init__(
|
123
|
+
self,
|
124
|
+
dim: int,
|
125
|
+
num_heads: int,
|
126
|
+
intermediate_dim: int,
|
127
|
+
hidden_act="silu",
|
128
|
+
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
129
|
+
attn_implementation: Optional[str] = "sdpa",
|
130
|
+
quant_config: Optional[QuantizationConfig] = None,
|
131
|
+
prefix: str = "",
|
132
|
+
) -> None:
|
133
|
+
super().__init__()
|
134
|
+
if norm_layer is None:
|
135
|
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
136
|
+
self.norm1 = norm_layer(dim)
|
137
|
+
self.norm2 = norm_layer(dim)
|
138
|
+
|
139
|
+
if attn_implementation == "sdpa":
|
140
|
+
softmax_in_single_precision = False
|
141
|
+
qkv_backend = "sdpa"
|
142
|
+
flatten_batch = True
|
143
|
+
elif attn_implementation == "flash_attention_2":
|
144
|
+
softmax_in_single_precision = False
|
145
|
+
qkv_backend = "triton_attn"
|
146
|
+
flatten_batch = True
|
147
|
+
elif attn_implementation == "eager":
|
148
|
+
softmax_in_single_precision = True
|
149
|
+
qkv_backend = "sdpa"
|
150
|
+
flatten_batch = True
|
151
|
+
elif attn_implementation == "flash_attention_3":
|
152
|
+
softmax_in_single_precision = False
|
153
|
+
qkv_backend = "fa3"
|
154
|
+
flatten_batch = True
|
155
|
+
|
156
|
+
self.attn = VisionAttention(
|
157
|
+
embed_dim=dim,
|
158
|
+
num_heads=num_heads,
|
159
|
+
projection_size=dim,
|
160
|
+
use_qkv_parallel=True,
|
161
|
+
rotary_embed="normal",
|
162
|
+
proj_bias=True,
|
163
|
+
qkv_backend=qkv_backend,
|
164
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
165
|
+
flatten_batch=flatten_batch,
|
166
|
+
quant_config=quant_config,
|
167
|
+
prefix=add_prefix("attn", prefix),
|
168
|
+
)
|
169
|
+
self.mlp = Qwen3_VisionMLP(
|
170
|
+
dim,
|
171
|
+
intermediate_dim,
|
172
|
+
hidden_act=hidden_act,
|
173
|
+
bias=True,
|
174
|
+
quant_config=quant_config,
|
175
|
+
prefix=f"{prefix}.mlp",
|
176
|
+
)
|
177
|
+
|
178
|
+
def forward(
|
179
|
+
self,
|
180
|
+
x: torch.Tensor,
|
181
|
+
cu_seqlens: torch.Tensor,
|
182
|
+
position_embeddings: torch.Tensor,
|
183
|
+
) -> torch.Tensor:
|
184
|
+
hidden_states = self.norm1(x)
|
185
|
+
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
186
|
+
attn = self.attn(
|
187
|
+
hidden_states,
|
188
|
+
cu_seqlens=cu_seqlens,
|
189
|
+
position_embeddings=position_embeddings,
|
190
|
+
)
|
191
|
+
attn = rearrange(attn, "b s ... -> s b ...")
|
192
|
+
x = x + attn
|
193
|
+
norm2 = self.norm2(x)
|
194
|
+
mlp = self.mlp(norm2)
|
195
|
+
x = x + mlp
|
196
|
+
return x
|
197
|
+
|
198
|
+
|
199
|
+
class Qwen3_VisionPatchMerger(nn.Module):
|
200
|
+
|
201
|
+
def __init__(
|
202
|
+
self,
|
203
|
+
dim: int,
|
204
|
+
context_dim: int,
|
205
|
+
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
206
|
+
spatial_merge_size: int = 2,
|
207
|
+
use_postshuffle_norm: bool = False,
|
208
|
+
quant_config: Optional[QuantizationConfig] = None,
|
209
|
+
prefix: str = "",
|
210
|
+
) -> None:
|
211
|
+
super().__init__()
|
212
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
213
|
+
|
214
|
+
self.use_postshuffle_norm = use_postshuffle_norm
|
215
|
+
|
216
|
+
if norm_layer is None:
|
217
|
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
218
|
+
self.norm = norm_layer(
|
219
|
+
self.hidden_size if use_postshuffle_norm else context_dim
|
220
|
+
)
|
221
|
+
self.linear_fc1 = ColumnParallelLinear(
|
222
|
+
self.hidden_size,
|
223
|
+
self.hidden_size,
|
224
|
+
bias=True,
|
225
|
+
quant_config=quant_config,
|
226
|
+
prefix=add_prefix("linear_fc1", prefix),
|
227
|
+
)
|
228
|
+
self.act_fn = nn.GELU()
|
229
|
+
self.linear_fc2 = RowParallelLinear(
|
230
|
+
self.hidden_size,
|
231
|
+
dim,
|
232
|
+
bias=True,
|
233
|
+
quant_config=quant_config,
|
234
|
+
prefix=add_prefix("linear_fc2", prefix),
|
235
|
+
)
|
236
|
+
|
237
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
238
|
+
if self.use_postshuffle_norm:
|
239
|
+
x = self.norm(x.view(-1, self.hidden_size))
|
240
|
+
else:
|
241
|
+
x = self.norm(x).view(-1, self.hidden_size)
|
242
|
+
|
243
|
+
x_parallel, _ = self.linear_fc1(x)
|
244
|
+
x_parallel = self.act_fn(x_parallel)
|
245
|
+
out, _ = self.linear_fc2(x_parallel)
|
246
|
+
return out
|
247
|
+
|
248
|
+
|
249
|
+
class Qwen3_VisionTransformer(nn.Module):
|
250
|
+
|
251
|
+
def __init__(
|
252
|
+
self,
|
253
|
+
vision_config: Qwen3VLVisionConfig,
|
254
|
+
norm_eps: float = 1e-6,
|
255
|
+
quant_config: Optional[QuantizationConfig] = None,
|
256
|
+
prefix: str = "",
|
257
|
+
) -> None:
|
258
|
+
super().__init__()
|
259
|
+
self.hidden_size = vision_config.hidden_size
|
260
|
+
self.num_heads = vision_config.num_heads
|
261
|
+
self.num_position_embeddings = vision_config.num_position_embeddings
|
262
|
+
self.patch_size = vision_config.patch_size
|
263
|
+
self.spatial_merge_size = vision_config.spatial_merge_size
|
264
|
+
self.spatial_merge_unit = self.spatial_merge_size**2
|
265
|
+
self.temporal_patch_size = vision_config.temporal_patch_size
|
266
|
+
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
267
|
+
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
|
268
|
+
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
|
269
|
+
|
270
|
+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
271
|
+
head_dim = self.hidden_size // self.num_heads
|
272
|
+
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
273
|
+
|
274
|
+
self.blocks = nn.ModuleList(
|
275
|
+
[
|
276
|
+
Qwen3_VisionBlock(
|
277
|
+
dim=self.hidden_size,
|
278
|
+
num_heads=self.num_heads,
|
279
|
+
intermediate_dim=vision_config.intermediate_size,
|
280
|
+
hidden_act=vision_config.hidden_act,
|
281
|
+
norm_layer=norm_layer,
|
282
|
+
attn_implementation="flash_attention_3",
|
283
|
+
quant_config=quant_config,
|
284
|
+
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
|
285
|
+
)
|
286
|
+
for layer_idx in range(vision_config.depth)
|
287
|
+
]
|
288
|
+
)
|
289
|
+
self.merger = Qwen3_VisionPatchMerger(
|
290
|
+
dim=vision_config.out_hidden_size,
|
291
|
+
context_dim=self.hidden_size,
|
292
|
+
norm_layer=norm_layer,
|
293
|
+
spatial_merge_size=self.spatial_merge_size,
|
294
|
+
quant_config=quant_config,
|
295
|
+
prefix=add_prefix("merger", prefix),
|
296
|
+
)
|
297
|
+
|
298
|
+
self.deepstack_merger_list = nn.ModuleList(
|
299
|
+
[
|
300
|
+
Qwen3_VisionPatchMerger(
|
301
|
+
dim=vision_config.out_hidden_size,
|
302
|
+
context_dim=self.hidden_size,
|
303
|
+
spatial_merge_size=self.spatial_merge_size,
|
304
|
+
use_postshuffle_norm=True,
|
305
|
+
norm_layer=norm_layer,
|
306
|
+
quant_config=quant_config,
|
307
|
+
prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix),
|
308
|
+
)
|
309
|
+
for layer_idx in range(len(self.deepstack_visual_indexes))
|
310
|
+
]
|
311
|
+
)
|
312
|
+
|
313
|
+
@property
|
314
|
+
def dtype(self) -> torch.dtype:
|
315
|
+
return self.patch_embed.proj.weight.dtype
|
316
|
+
|
317
|
+
@property
|
318
|
+
def device(self) -> torch.device:
|
319
|
+
return self.patch_embed.proj.weight.device
|
320
|
+
|
321
|
+
def rot_pos_emb(self, grid_thw):
|
322
|
+
pos_ids = []
|
323
|
+
for t, h, w in grid_thw:
|
324
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
325
|
+
hpos_ids = hpos_ids.reshape(
|
326
|
+
h // self.spatial_merge_size,
|
327
|
+
self.spatial_merge_size,
|
328
|
+
w // self.spatial_merge_size,
|
329
|
+
self.spatial_merge_size,
|
330
|
+
)
|
331
|
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
332
|
+
hpos_ids = hpos_ids.flatten()
|
333
|
+
|
334
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
335
|
+
wpos_ids = wpos_ids.reshape(
|
336
|
+
h // self.spatial_merge_size,
|
337
|
+
self.spatial_merge_size,
|
338
|
+
w // self.spatial_merge_size,
|
339
|
+
self.spatial_merge_size,
|
340
|
+
)
|
341
|
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
342
|
+
wpos_ids = wpos_ids.flatten()
|
343
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
344
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
345
|
+
max_grid_size = grid_thw[:, 1:].max()
|
346
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
347
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
348
|
+
return rotary_pos_emb
|
349
|
+
|
350
|
+
def fast_pos_embed_interpolate(self, grid_thw):
|
351
|
+
num_grid_per_side = int(self.num_position_embeddings**0.5)
|
352
|
+
|
353
|
+
idx_list = [[] for _ in range(4)]
|
354
|
+
weight_list = [[] for _ in range(4)]
|
355
|
+
|
356
|
+
# TODO: use torch instand of np
|
357
|
+
for t, h, w in grid_thw:
|
358
|
+
h_idxs = np.linspace(0, num_grid_per_side - 1, h)
|
359
|
+
w_idxs = np.linspace(0, num_grid_per_side - 1, w)
|
360
|
+
|
361
|
+
h_idxs_floor = h_idxs.astype(int)
|
362
|
+
w_idxs_floor = w_idxs.astype(int)
|
363
|
+
h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
|
364
|
+
w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
|
365
|
+
|
366
|
+
dh = h_idxs - h_idxs_floor
|
367
|
+
dw = w_idxs - w_idxs_floor
|
368
|
+
|
369
|
+
idx_list[0].extend(
|
370
|
+
((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None])
|
371
|
+
.flatten()
|
372
|
+
.tolist()
|
373
|
+
* t
|
374
|
+
)
|
375
|
+
idx_list[1].extend(
|
376
|
+
((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None])
|
377
|
+
.flatten()
|
378
|
+
.tolist()
|
379
|
+
* t
|
380
|
+
)
|
381
|
+
idx_list[2].extend(
|
382
|
+
((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None])
|
383
|
+
.flatten()
|
384
|
+
.tolist()
|
385
|
+
* t
|
386
|
+
)
|
387
|
+
idx_list[3].extend(
|
388
|
+
((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None])
|
389
|
+
.flatten()
|
390
|
+
.tolist()
|
391
|
+
* t
|
392
|
+
)
|
393
|
+
|
394
|
+
weight_list[0].extend(
|
395
|
+
((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t
|
396
|
+
)
|
397
|
+
weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
|
398
|
+
weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
|
399
|
+
weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t)
|
400
|
+
|
401
|
+
device = self.pos_embed.weight.device
|
402
|
+
dtype = self.pos_embed.weight.dtype
|
403
|
+
|
404
|
+
p0 = (
|
405
|
+
self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device))
|
406
|
+
* torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None]
|
407
|
+
)
|
408
|
+
p1 = (
|
409
|
+
self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device))
|
410
|
+
* torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None]
|
411
|
+
)
|
412
|
+
p2 = (
|
413
|
+
self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device))
|
414
|
+
* torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None]
|
415
|
+
)
|
416
|
+
p3 = (
|
417
|
+
self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device))
|
418
|
+
* torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None]
|
419
|
+
)
|
420
|
+
|
421
|
+
patch_pos_embeds = p0 + p1 + p2 + p3
|
422
|
+
patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw])
|
423
|
+
patch_pos_embeds_permute = []
|
424
|
+
m_size = self.spatial_merge_size
|
425
|
+
for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
|
426
|
+
pos_embed = (
|
427
|
+
pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1)
|
428
|
+
.permute(0, 1, 3, 2, 4, 5)
|
429
|
+
.flatten(0, 4)
|
430
|
+
)
|
431
|
+
patch_pos_embeds_permute.append(pos_embed)
|
432
|
+
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
|
433
|
+
return patch_pos_embeds
|
434
|
+
|
435
|
+
def forward(
|
436
|
+
self,
|
437
|
+
x: torch.Tensor,
|
438
|
+
grid_thw: torch.Tensor,
|
439
|
+
) -> torch.Tensor:
|
440
|
+
x = x.to(device=self.device, dtype=self.dtype)
|
441
|
+
x = self.patch_embed(x)
|
442
|
+
|
443
|
+
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
444
|
+
x = x + pos_embeds
|
445
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
446
|
+
|
447
|
+
seq_len, _ = x.size()
|
448
|
+
rotary_pos_emb = rotary_pos_emb.to(x.device)
|
449
|
+
|
450
|
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
451
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
452
|
+
position_embeddings = (emb.cos(), emb.sin())
|
453
|
+
|
454
|
+
# compute cu_seqlens
|
455
|
+
cu_seqlens = torch.cat(
|
456
|
+
[
|
457
|
+
torch.tensor([0], device=grid_thw.device),
|
458
|
+
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
459
|
+
]
|
460
|
+
)
|
461
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
462
|
+
|
463
|
+
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
464
|
+
x = x.unsqueeze(1)
|
465
|
+
|
466
|
+
deepstack_feature_lists = []
|
467
|
+
num_deepstack_captured = 0
|
468
|
+
for layer_num, blk in enumerate(self.blocks):
|
469
|
+
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
470
|
+
if layer_num in self.deepstack_visual_indexes:
|
471
|
+
deepstack_feature = self.deepstack_merger_list[num_deepstack_captured](
|
472
|
+
x
|
473
|
+
)
|
474
|
+
deepstack_feature_lists.append(deepstack_feature)
|
475
|
+
num_deepstack_captured += 1
|
476
|
+
x = self.merger(x)
|
477
|
+
hidden_states = torch.cat(
|
478
|
+
[x] + deepstack_feature_lists, dim=1
|
479
|
+
) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
|
480
|
+
return hidden_states
|
481
|
+
|
482
|
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
483
|
+
stacked_params_mapping = [
|
484
|
+
# (param_name, shard_name, shard_id)
|
485
|
+
("attn.qkv.", "attn.q.", "q"),
|
486
|
+
("attn.qkv.", "attn.k.", "k"),
|
487
|
+
("attn.qkv.", "attn.v.", "v"),
|
488
|
+
]
|
489
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
490
|
+
loaded_params: set[str] = set()
|
491
|
+
|
492
|
+
for name, loaded_weight in weights:
|
493
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
494
|
+
if weight_name not in name:
|
495
|
+
continue
|
496
|
+
name = name.replace(weight_name, param_name)
|
497
|
+
|
498
|
+
param = params_dict[name]
|
499
|
+
weight_loader = param.weight_loader
|
500
|
+
weight_loader(param, loaded_weight, shard_id)
|
501
|
+
break
|
502
|
+
else:
|
503
|
+
param = params_dict[name]
|
504
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
505
|
+
weight_loader(param, loaded_weight)
|
506
|
+
loaded_params.add(name)
|
507
|
+
return loaded_params
|
508
|
+
|
509
|
+
|
510
|
+
cached_get_processor = lru_cache(get_processor)
|
511
|
+
|
512
|
+
|
513
|
+
class Qwen3LLMModel(Qwen3Model):
|
514
|
+
|
515
|
+
def __init__(
|
516
|
+
self,
|
517
|
+
*,
|
518
|
+
config: Qwen3VLConfig,
|
519
|
+
quant_config: Optional[QuantizationConfig] = None,
|
520
|
+
prefix: str = "",
|
521
|
+
):
|
522
|
+
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
523
|
+
if not self.pp_group.is_first_rank:
|
524
|
+
assert self.start_layer >= len(
|
525
|
+
config.vision_config.deepstack_visual_indexes
|
526
|
+
), "start_layer should be greater than or equal to len(deepstack_visual_indexes)"
|
527
|
+
|
528
|
+
self.hidden_size = config.hidden_size
|
529
|
+
self.deepstack_embed_to_decoder_layer = range(
|
530
|
+
len(config.vision_config.deepstack_visual_indexes)
|
531
|
+
)
|
532
|
+
|
533
|
+
def forward(
|
534
|
+
self,
|
535
|
+
input_ids: torch.Tensor,
|
536
|
+
positions: torch.Tensor,
|
537
|
+
forward_batch: ForwardBatch,
|
538
|
+
input_embeds: torch.Tensor = None,
|
539
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
540
|
+
input_deepstack_embeds: Optional[torch.Tensor] = None,
|
541
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
542
|
+
|
543
|
+
if self.pp_group.is_first_rank:
|
544
|
+
if input_embeds is None:
|
545
|
+
hidden_states = self.embed_tokens(input_ids)
|
546
|
+
else:
|
547
|
+
hidden_states = input_embeds
|
548
|
+
residual = None
|
549
|
+
else:
|
550
|
+
assert pp_proxy_tensors is not None
|
551
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
552
|
+
residual = pp_proxy_tensors["residual"]
|
553
|
+
|
554
|
+
aux_hidden_states = []
|
555
|
+
for layer_idx, layer in enumerate(
|
556
|
+
self.layers[self.start_layer : self.end_layer]
|
557
|
+
):
|
558
|
+
layer_idx = layer_idx + self.start_layer
|
559
|
+
if layer_idx in self.layers_to_capture:
|
560
|
+
aux_hidden_states.append(
|
561
|
+
hidden_states + residual if residual is not None else hidden_states
|
562
|
+
)
|
563
|
+
|
564
|
+
hidden_states, residual = layer(
|
565
|
+
positions,
|
566
|
+
hidden_states,
|
567
|
+
forward_batch,
|
568
|
+
residual,
|
569
|
+
)
|
570
|
+
|
571
|
+
# process deepstack
|
572
|
+
if (
|
573
|
+
input_deepstack_embeds is not None
|
574
|
+
and layer_idx in self.deepstack_embed_to_decoder_layer
|
575
|
+
):
|
576
|
+
sep = self.hidden_size * layer_idx
|
577
|
+
hidden_states = (
|
578
|
+
hidden_states
|
579
|
+
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
580
|
+
)
|
581
|
+
|
582
|
+
if not self.pp_group.is_last_rank:
|
583
|
+
return PPProxyTensors(
|
584
|
+
{
|
585
|
+
"hidden_states": hidden_states,
|
586
|
+
"residual": residual,
|
587
|
+
}
|
588
|
+
)
|
589
|
+
else:
|
590
|
+
if hidden_states.shape[0] != 0:
|
591
|
+
if residual is None:
|
592
|
+
hidden_states = self.norm(hidden_states)
|
593
|
+
else:
|
594
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
595
|
+
|
596
|
+
if len(aux_hidden_states) == 0:
|
597
|
+
return hidden_states
|
598
|
+
|
599
|
+
return hidden_states, aux_hidden_states
|
600
|
+
|
601
|
+
|
602
|
+
class Qwen3VLForConditionalGeneration(nn.Module):
|
603
|
+
def __init__(
|
604
|
+
self,
|
605
|
+
config: Qwen3VLConfig,
|
606
|
+
quant_config: Optional[QuantizationConfig] = None,
|
607
|
+
prefix: str = "",
|
608
|
+
) -> None:
|
609
|
+
super().__init__()
|
610
|
+
|
611
|
+
self.config = config
|
612
|
+
self.visual = Qwen3_VisionTransformer(
|
613
|
+
config.vision_config,
|
614
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
615
|
+
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
616
|
+
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
617
|
+
quant_config=quant_config,
|
618
|
+
prefix=add_prefix("visual", prefix),
|
619
|
+
)
|
620
|
+
|
621
|
+
self.model = Qwen3LLMModel(
|
622
|
+
config=config,
|
623
|
+
quant_config=quant_config,
|
624
|
+
prefix=add_prefix("model", prefix),
|
625
|
+
)
|
626
|
+
|
627
|
+
if config.tie_word_embeddings:
|
628
|
+
self.lm_head = self.model.embed_tokens
|
629
|
+
else:
|
630
|
+
self.lm_head = ParallelLMHead(
|
631
|
+
config.vocab_size,
|
632
|
+
config.hidden_size,
|
633
|
+
quant_config=quant_config,
|
634
|
+
prefix=add_prefix("lm_head", prefix),
|
635
|
+
)
|
636
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
637
|
+
|
638
|
+
self.logits_processor = LogitsProcessor(config)
|
639
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
640
|
+
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
|
641
|
+
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
|
642
|
+
|
643
|
+
# deepstack
|
644
|
+
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
645
|
+
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
646
|
+
|
647
|
+
@property
|
648
|
+
def use_deepstack(self) -> bool:
|
649
|
+
return hasattr(self, "deepstack_visual_indexes")
|
650
|
+
|
651
|
+
def separate_deepstack_embeds(self, embedding):
|
652
|
+
assert (
|
653
|
+
embedding.shape[-1] % (1 + self.num_deepstack_embeddings) == 0
|
654
|
+
), f"hidden_state of {embedding.shape} should be divisible by ({1 + self.num_deepstack_embeddings})"
|
655
|
+
|
656
|
+
separate_index = self.config.hidden_size
|
657
|
+
input_embeds = embedding[:, :separate_index]
|
658
|
+
input_deepstack_embeds = embedding[:, separate_index:]
|
659
|
+
return input_embeds, input_deepstack_embeds
|
660
|
+
|
661
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
662
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
663
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
664
|
+
|
665
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
666
|
+
# in qwen-vl, last dim is the same
|
667
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
668
|
+
self.visual.dtype
|
669
|
+
)
|
670
|
+
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
671
|
+
assert pixel_values.dim() == 2, pixel_values.dim()
|
672
|
+
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
673
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
674
|
+
return image_embeds
|
675
|
+
|
676
|
+
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
677
|
+
# in qwen-vl, last dim is the same
|
678
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
679
|
+
self.visual.dtype
|
680
|
+
)
|
681
|
+
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
682
|
+
assert pixel_values.dim() == 2, pixel_values.dim()
|
683
|
+
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
684
|
+
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
|
685
|
+
return video_embeds
|
686
|
+
|
687
|
+
def get_input_embeddings(self):
|
688
|
+
return self.model.embed_tokens
|
689
|
+
|
690
|
+
def forward(
|
691
|
+
self,
|
692
|
+
input_ids: torch.Tensor,
|
693
|
+
positions: torch.Tensor,
|
694
|
+
forward_batch: ForwardBatch,
|
695
|
+
get_embedding: bool = False,
|
696
|
+
):
|
697
|
+
"""Run forward pass for Qwen3-VL.
|
698
|
+
|
699
|
+
Args:
|
700
|
+
input_ids: Flattened (concatenated) input_ids corresponding to a
|
701
|
+
batch.
|
702
|
+
positions: Flattened (concatenated) position ids corresponding to a
|
703
|
+
batch.
|
704
|
+
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
705
|
+
opensource models), the shape will be `(3, seq_len)`,
|
706
|
+
otherwise it will be `(seq_len,).
|
707
|
+
(Use input_metadata.mrope_positions to replace it)
|
708
|
+
"""
|
709
|
+
if self.is_mrope_enabled:
|
710
|
+
positions = forward_batch.mrope_positions
|
711
|
+
|
712
|
+
if not (
|
713
|
+
forward_batch.forward_mode.is_decode()
|
714
|
+
or not forward_batch.contains_image_inputs()
|
715
|
+
):
|
716
|
+
if self.is_mrope_enabled:
|
717
|
+
assert positions.ndim == 2 and positions.size(0) == 3, (
|
718
|
+
"multimodal section rotary embedding requires "
|
719
|
+
f"(3, seq_len) positions, but got {positions.size()}"
|
720
|
+
)
|
721
|
+
|
722
|
+
hidden_states = general_mm_embed_routine(
|
723
|
+
input_ids=input_ids,
|
724
|
+
forward_batch=forward_batch,
|
725
|
+
language_model=self.model,
|
726
|
+
multimodal_model=self,
|
727
|
+
positions=positions,
|
728
|
+
use_deepstack=self.use_deepstack,
|
729
|
+
)
|
730
|
+
|
731
|
+
if not get_embedding:
|
732
|
+
return self.logits_processor(
|
733
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
734
|
+
)
|
735
|
+
else:
|
736
|
+
return self.pooler(hidden_states, forward_batch)
|
737
|
+
|
738
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
739
|
+
stacked_params_mapping = [
|
740
|
+
# (param_name, shard_name, shard_id)
|
741
|
+
(".qkv_proj", ".q_proj", "q"),
|
742
|
+
(".qkv_proj", ".k_proj", "k"),
|
743
|
+
(".qkv_proj", ".v_proj", "v"),
|
744
|
+
("gate_up_proj", "up_proj", 1),
|
745
|
+
("gate_up_proj", "gate_proj", 0),
|
746
|
+
]
|
747
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
748
|
+
for name, loaded_weight in weights:
|
749
|
+
if "rotary_emb.inv_freq" in name:
|
750
|
+
continue
|
751
|
+
if "language_model" in name:
|
752
|
+
name = name.replace(r"model.language_model.", r"model.")
|
753
|
+
|
754
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
755
|
+
if weight_name not in name:
|
756
|
+
continue
|
757
|
+
if "visual" in name:
|
758
|
+
continue
|
759
|
+
name = name.replace(weight_name, param_name)
|
760
|
+
|
761
|
+
# Skip loading extra bias for GPTQ models.
|
762
|
+
if name.endswith(".bias") and name not in params_dict:
|
763
|
+
continue
|
764
|
+
param = params_dict[name]
|
765
|
+
weight_loader = param.weight_loader
|
766
|
+
weight_loader(param, loaded_weight, shard_id)
|
767
|
+
break
|
768
|
+
else:
|
769
|
+
if "visual" in name:
|
770
|
+
# adapt to VisionAttention
|
771
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
772
|
+
name = name.replace(r"model.visual.", r"visual.")
|
773
|
+
|
774
|
+
try:
|
775
|
+
# Skip loading extra bias for GPTQ models.
|
776
|
+
if name.endswith(".bias") and name not in params_dict:
|
777
|
+
continue
|
778
|
+
param = params_dict[name]
|
779
|
+
except KeyError:
|
780
|
+
print(params_dict.keys())
|
781
|
+
raise
|
782
|
+
|
783
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
784
|
+
weight_loader(param, loaded_weight)
|
785
|
+
|
786
|
+
|
787
|
+
EntryClass = Qwen3VLForConditionalGeneration
|