sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from typing import Iterable, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch import nn
|
19
|
+
from transformers import Qwen2Config # Qwen3 uses Qwen2Config
|
20
|
+
|
21
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
|
25
|
+
from sglang.srt.utils import add_prefix
|
26
|
+
|
27
|
+
|
28
|
+
class Qwen3ForSequenceClassification(nn.Module):
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
config: Qwen2Config,
|
32
|
+
quant_config: Optional[QuantizationConfig] = None,
|
33
|
+
prefix: str = "",
|
34
|
+
) -> None:
|
35
|
+
super().__init__()
|
36
|
+
self.config = config
|
37
|
+
self.quant_config = quant_config
|
38
|
+
self.model = Qwen3Model(
|
39
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
40
|
+
)
|
41
|
+
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
42
|
+
# Use normalize=True for qwen3 embedding based on official implementation
|
43
|
+
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
|
44
|
+
# Official code: output = F.normalize(output, p=2, dim=1)
|
45
|
+
normalize = True
|
46
|
+
|
47
|
+
# We don't want to normalize the embedding if we have a classification head
|
48
|
+
if config.id2label is not None or config.label2id is not None:
|
49
|
+
normalize = False
|
50
|
+
|
51
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)
|
52
|
+
|
53
|
+
self.eos_token_id = config.eos_token_id
|
54
|
+
|
55
|
+
@torch.no_grad()
|
56
|
+
def forward(
|
57
|
+
self,
|
58
|
+
input_ids: torch.Tensor,
|
59
|
+
positions: torch.Tensor,
|
60
|
+
forward_batch: ForwardBatch,
|
61
|
+
input_embeds: Optional[torch.Tensor] = None,
|
62
|
+
get_embedding: bool = True,
|
63
|
+
) -> EmbeddingPoolerOutput:
|
64
|
+
assert (
|
65
|
+
get_embedding
|
66
|
+
), "Qwen3ForSequenceClassification is only used for embedding"
|
67
|
+
|
68
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
69
|
+
logits = self.score(hidden_states)
|
70
|
+
pooled_logits = self.pooler(logits, forward_batch).embeddings
|
71
|
+
|
72
|
+
return EmbeddingPoolerOutput(pooled_logits)
|
73
|
+
|
74
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
75
|
+
# Filter out lm_head weights of Qwen3ForCausalLM
|
76
|
+
filtered_weights = [
|
77
|
+
(name, w) for name, w in weights if not name.startswith("lm_head")
|
78
|
+
]
|
79
|
+
return Qwen3ForCausalLM.load_weights(self, filtered_weights)
|
80
|
+
|
81
|
+
|
82
|
+
EntryClass = [
|
83
|
+
Qwen3ForSequenceClassification,
|
84
|
+
]
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
|
|
28
28
|
get_pp_group,
|
29
29
|
get_tensor_model_parallel_rank,
|
30
30
|
get_tensor_model_parallel_world_size,
|
31
|
-
parallel_state,
|
32
|
-
split_tensor_along_last_dim,
|
33
|
-
tensor_model_parallel_all_gather,
|
34
31
|
tensor_model_parallel_all_reduce,
|
35
32
|
)
|
36
33
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
37
34
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
38
35
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
39
|
-
from sglang.srt.layers.activation import SiluAndMul
|
40
36
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
41
|
-
from sglang.srt.layers.dp_attention import
|
42
|
-
get_attention_tp_rank,
|
43
|
-
get_attention_tp_size,
|
44
|
-
get_local_attention_dp_size,
|
45
|
-
)
|
37
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
46
38
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
39
|
from sglang.srt.layers.linear import (
|
48
|
-
MergedColumnParallelLinear,
|
49
40
|
QKVParallelLinear,
|
50
41
|
ReplicatedLinear,
|
51
42
|
RowParallelLinear,
|
52
43
|
)
|
53
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
|
+
from sglang.srt.layers.moe import get_moe_a2a_backend
|
54
46
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
47
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
55
48
|
from sglang.srt.layers.moe.topk import TopK
|
56
49
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
57
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
58
51
|
from sglang.srt.layers.rotary_embedding import get_rope
|
59
52
|
from sglang.srt.layers.utils import get_layer_id
|
60
|
-
from sglang.srt.layers.vocab_parallel_embedding import
|
61
|
-
ParallelLMHead,
|
62
|
-
VocabParallelEmbedding,
|
63
|
-
)
|
53
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
64
54
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
65
55
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
66
|
-
from sglang.srt.model_executor.forward_batch_info import
|
67
|
-
ForwardBatch,
|
68
|
-
ForwardMode,
|
69
|
-
PPProxyTensors,
|
70
|
-
)
|
56
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
71
57
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
72
58
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
73
59
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
74
|
-
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
75
60
|
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
|
76
61
|
|
77
62
|
Qwen3MoeConfig = None
|
@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
112
97
|
intermediate_size=config.moe_intermediate_size,
|
113
98
|
quant_config=quant_config,
|
114
99
|
prefix=add_prefix("experts", prefix),
|
115
|
-
**(
|
116
|
-
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
117
|
-
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
118
|
-
else {}
|
119
|
-
),
|
120
|
-
# Additional args for FusedMoE
|
121
|
-
**(
|
122
|
-
dict(
|
123
|
-
enable_flashinfer_cutlass_moe=True,
|
124
|
-
)
|
125
|
-
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
126
|
-
else {}
|
127
|
-
),
|
128
100
|
)
|
129
101
|
|
130
102
|
self.gate = ReplicatedLinear(
|
@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
135
107
|
prefix=add_prefix("gate", prefix),
|
136
108
|
)
|
137
109
|
|
138
|
-
if
|
110
|
+
if get_moe_a2a_backend().is_deepep():
|
139
111
|
# TODO: we will support tp < ep in the future
|
140
112
|
self.ep_size = get_moe_expert_parallel_world_size()
|
141
113
|
self.num_experts = (
|
@@ -144,11 +116,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
144
116
|
self.top_k = config.num_experts_per_tok
|
145
117
|
|
146
118
|
def forward(
|
147
|
-
self,
|
119
|
+
self,
|
120
|
+
hidden_states: torch.Tensor,
|
121
|
+
forward_batch: Optional[ForwardBatch] = None,
|
122
|
+
use_reduce_scatter: bool = False,
|
148
123
|
) -> torch.Tensor:
|
149
124
|
|
150
|
-
if not
|
151
|
-
return self.forward_normal(hidden_states)
|
125
|
+
if not get_moe_a2a_backend().is_deepep():
|
126
|
+
return self.forward_normal(hidden_states, use_reduce_scatter)
|
152
127
|
else:
|
153
128
|
return self.forward_deepep(hidden_states, forward_batch)
|
154
129
|
|
@@ -159,7 +134,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
159
134
|
if name not in ["correction_bias"]
|
160
135
|
]
|
161
136
|
|
162
|
-
def forward_normal(
|
137
|
+
def forward_normal(
|
138
|
+
self,
|
139
|
+
hidden_states: torch.Tensor,
|
140
|
+
use_reduce_scatter: bool = False,
|
141
|
+
) -> torch.Tensor:
|
163
142
|
num_tokens, hidden_dim = hidden_states.shape
|
164
143
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
165
144
|
|
@@ -167,7 +146,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
167
146
|
router_logits, _ = self.gate(hidden_states)
|
168
147
|
topk_output = self.topk(hidden_states, router_logits)
|
169
148
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
170
|
-
if self.tp_size > 1:
|
149
|
+
if self.tp_size > 1 and not use_reduce_scatter:
|
171
150
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
172
151
|
|
173
152
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -484,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
484
463
|
|
485
464
|
self.attn_tp_size = get_attention_tp_size()
|
486
465
|
self.attn_tp_rank = get_attention_tp_rank()
|
487
|
-
self.local_dp_size = get_local_attention_dp_size()
|
488
466
|
|
489
467
|
# Qwen3MoE all layers are sparse and have no nextn now
|
490
468
|
self.is_layer_sparse = True
|
@@ -521,6 +499,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
521
499
|
layer_scatter_modes=self.layer_scatter_modes,
|
522
500
|
input_layernorm=self.input_layernorm,
|
523
501
|
post_attention_layernorm=self.post_attention_layernorm,
|
502
|
+
allow_reduce_scatter=True,
|
524
503
|
)
|
525
504
|
|
526
505
|
def forward(
|
@@ -546,7 +525,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
546
525
|
hidden_states, residual, forward_batch
|
547
526
|
)
|
548
527
|
|
549
|
-
|
528
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
529
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
530
|
+
forward_batch
|
531
|
+
)
|
532
|
+
|
533
|
+
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
550
534
|
|
551
535
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
552
536
|
hidden_states, residual, forward_batch
|
@@ -765,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
765
749
|
("gate_up_proj", "up_proj", 1),
|
766
750
|
]
|
767
751
|
|
768
|
-
expert_params_mapping =
|
752
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
769
753
|
ckpt_gate_proj_name="gate_proj",
|
770
754
|
ckpt_down_proj_name="down_proj",
|
771
755
|
ckpt_up_proj_name="up_proj",
|
sglang/srt/models/step3_vl.py
CHANGED
@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
|
25
25
|
from sglang.srt.layers.activation import SiluAndMul
|
26
26
|
from sglang.srt.layers.attention.vision import VisionAttention
|
27
27
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
28
|
-
from sglang.srt.layers.dp_attention import
|
28
|
+
from sglang.srt.layers.dp_attention import (
|
29
|
+
get_attention_tp_rank,
|
30
|
+
get_attention_tp_size,
|
31
|
+
is_dp_attention_enabled,
|
32
|
+
)
|
29
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
34
|
from sglang.srt.layers.linear import (
|
31
35
|
ColumnParallelLinear,
|
@@ -34,6 +38,7 @@ from sglang.srt.layers.linear import (
|
|
34
38
|
RowParallelLinear,
|
35
39
|
)
|
36
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
|
+
from sglang.srt.layers.moe import get_moe_a2a_backend
|
37
42
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
38
43
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
39
44
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -146,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
|
|
146
151
|
prefix=add_prefix("gate", prefix),
|
147
152
|
)
|
148
153
|
|
149
|
-
if
|
154
|
+
if get_moe_a2a_backend().is_deepep():
|
150
155
|
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
|
151
156
|
|
152
157
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -437,7 +442,7 @@ class Step3TextModel(nn.Module):
|
|
437
442
|
self.embed_tokens = VocabParallelEmbedding(
|
438
443
|
config.vocab_size,
|
439
444
|
config.hidden_size,
|
440
|
-
enable_tp=not
|
445
|
+
enable_tp=not is_dp_attention_enabled(),
|
441
446
|
prefix=add_prefix("embed_tokens", prefix),
|
442
447
|
)
|
443
448
|
|
sglang/srt/models/xverse_moe.py
CHANGED
@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
|
|
33
33
|
RowParallelLinear,
|
34
34
|
)
|
35
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
|
-
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
36
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
37
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
38
|
+
from sglang.srt.layers.moe.topk import TopK
|
37
39
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
41
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
|
|
121
123
|
]
|
122
124
|
)
|
123
125
|
self.pack_params()
|
126
|
+
self.moe_runner_config = MoeRunnerConfig(inplace=True)
|
124
127
|
|
125
128
|
self.router = ReplicatedLinear(
|
126
129
|
config.hidden_size,
|
@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
|
|
129
132
|
quant_config=None,
|
130
133
|
prefix=add_prefix("router", prefix),
|
131
134
|
)
|
135
|
+
self.topk = TopK(
|
136
|
+
top_k=self.top_k,
|
137
|
+
renormalize=getattr(self.config, "norm_topk_prob", False),
|
138
|
+
)
|
132
139
|
|
133
140
|
if config.num_shared_experts is not None:
|
134
141
|
intermediate_size = config.intermediate_size * config.num_shared_experts
|
@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
|
|
167
174
|
shared_output = self.shared_experts(hidden_states)
|
168
175
|
# router_logits: (num_tokens, n_experts)
|
169
176
|
router_logits, _ = self.router(hidden_states)
|
177
|
+
topk_output = self.topk(hidden_states, router_logits)
|
170
178
|
final_hidden_states = fused_moe(
|
171
179
|
hidden_states,
|
172
180
|
self.w1,
|
173
181
|
self.w2,
|
174
|
-
|
175
|
-
self.
|
176
|
-
renormalize=getattr(self.config, "norm_topk_prob", False),
|
177
|
-
inplace=True,
|
182
|
+
topk_output,
|
183
|
+
self.moe_runner_config,
|
178
184
|
)
|
179
185
|
|
180
186
|
if self.config.num_shared_experts is not None:
|
@@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC):
|
|
217
217
|
if videos:
|
218
218
|
kwargs["videos"] = videos
|
219
219
|
if audios:
|
220
|
-
if self.
|
221
|
-
"
|
222
|
-
"
|
220
|
+
if self._processor.__class__.__name__ in {
|
221
|
+
"Gemma3nProcessor",
|
222
|
+
"Qwen2AudioProcessor",
|
223
223
|
}:
|
224
224
|
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
225
225
|
kwargs["audio"] = audios
|
@@ -44,7 +44,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
44
44
|
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
45
45
|
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
46
46
|
self.mm_tokens = MultimodalSpecialTokens(
|
47
|
-
image_token="<
|
47
|
+
image_token="<IMG_CONTEXT>",
|
48
48
|
image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
|
49
49
|
).build(_image_processor)
|
50
50
|
|
@@ -218,13 +218,18 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
218
218
|
|
219
219
|
pixel_values = torch.cat(pixel_values, dim=0)
|
220
220
|
|
221
|
+
original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
|
222
|
+
input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
|
223
|
+
|
221
224
|
for idx, num_patches in enumerate(num_patches_list):
|
222
225
|
image_tokens = (
|
223
226
|
self.IMG_START_TOKEN
|
224
227
|
+ self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
225
228
|
+ self.IMG_END_TOKEN
|
226
229
|
)
|
227
|
-
input_text = input_text.replace(
|
230
|
+
input_text = input_text.replace(original_placeholder, image_tokens, 1)
|
231
|
+
|
232
|
+
input_text = input_text.replace(original_placeholder, self.IMG_CONTEXT_TOKEN)
|
228
233
|
|
229
234
|
input_ids = self.tokenizer(input_text, return_tensors="pt")[
|
230
235
|
"input_ids"
|
@@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM
|
|
18
18
|
from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
|
19
19
|
from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image
|
20
20
|
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
21
|
-
from sglang.srt.utils import load_image, logger
|
21
|
+
from sglang.srt.utils import ImageData, load_image, logger
|
22
22
|
from sglang.utils import get_exception_traceback
|
23
23
|
|
24
24
|
|
@@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
35
35
|
|
36
36
|
@staticmethod
|
37
37
|
def _process_single_image_task(
|
38
|
-
image_data: Union[str, bytes],
|
38
|
+
image_data: Union[str, bytes, ImageData],
|
39
39
|
image_aspect_ratio: Optional[str] = None,
|
40
40
|
image_grid_pinpoints: Optional[str] = None,
|
41
41
|
processor=None,
|
@@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
44
44
|
image_processor = processor.image_processor
|
45
45
|
|
46
46
|
try:
|
47
|
-
|
47
|
+
url = image_data.url if isinstance(image_data, ImageData) else image_data
|
48
|
+
image, image_size = load_image(url)
|
48
49
|
if image_size is not None:
|
49
50
|
# It is a video with multiple images
|
50
|
-
image_hash = hash(
|
51
|
+
image_hash = hash(url)
|
51
52
|
pixel_values = image_processor(image)["pixel_values"]
|
52
53
|
for _ in range(len(pixel_values)):
|
53
54
|
pixel_values[_] = pixel_values[_].astype(np.float16)
|
@@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
55
56
|
return pixel_values, image_hash, image_size
|
56
57
|
else:
|
57
58
|
# It is an image
|
58
|
-
image_hash = hash(
|
59
|
+
image_hash = hash(url)
|
59
60
|
if image_aspect_ratio == "pad":
|
60
61
|
image = expand2square(
|
61
62
|
image,
|
@@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
82
83
|
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
83
84
|
|
84
85
|
async def _process_single_image(
|
85
|
-
self,
|
86
|
+
self,
|
87
|
+
image_data: Union[bytes, str, ImageData],
|
88
|
+
aspect_ratio: str,
|
89
|
+
grid_pinpoints: str,
|
86
90
|
):
|
87
91
|
if self.cpu_executor is not None:
|
88
92
|
loop = asyncio.get_event_loop()
|
@@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
104
108
|
|
105
109
|
async def process_mm_data_async(
|
106
110
|
self,
|
107
|
-
image_data: List[Union[str, bytes]],
|
111
|
+
image_data: List[Union[str, bytes, ImageData]],
|
108
112
|
input_text,
|
109
113
|
request_obj,
|
110
114
|
*args,
|