sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
sglang/srt/models/internvl.py
CHANGED
@@ -1,16 +1,3 @@
|
|
1
|
-
# Copyright 2023-2024 SGLang Team
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
# ==========================582====================================================
|
14
1
|
from typing import Iterable, List, Optional, Set, Tuple, Union
|
15
2
|
|
16
3
|
import torch
|
@@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
23
10
|
from transformers.activations import ACT2FN
|
24
11
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
25
12
|
|
13
|
+
from sglang.srt.distributed import parallel_state
|
26
14
|
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
15
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
27
16
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
17
|
from sglang.srt.managers.mm_utils import (
|
29
18
|
MultiModalityDataPaddingPatternTokenPairs,
|
@@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
39
28
|
from sglang.srt.models.deepseek_janus_pro import DropPath
|
40
29
|
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
41
30
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
31
|
+
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
42
32
|
from sglang.utils import logger
|
43
33
|
|
44
34
|
|
@@ -53,7 +43,6 @@ class InternAttention(nn.Module):
|
|
53
43
|
self.embed_dim = config.hidden_size
|
54
44
|
self.num_heads = config.num_attention_heads
|
55
45
|
self.head_dim = self.embed_dim // self.num_heads
|
56
|
-
|
57
46
|
self.scale = self.head_dim**-0.5
|
58
47
|
|
59
48
|
self.attn = VisionAttention(
|
@@ -64,18 +53,16 @@ class InternAttention(nn.Module):
|
|
64
53
|
use_qkv_parallel=True,
|
65
54
|
quant_config=quant_config,
|
66
55
|
dropout=getattr(config, "dropout", 0.0),
|
67
|
-
|
56
|
+
qkv_bias=getattr(config, "qkv_bias", False)
|
57
|
+
or getattr(config, "attention_bias", False),
|
58
|
+
num_dummy_heads=getattr(config, "num_dummy_heads", 0),
|
59
|
+
qk_normalization=getattr(config, "qk_normalization", False)
|
60
|
+
or getattr(config, "use_qk_norm", False),
|
68
61
|
flatten_batch=False,
|
69
62
|
)
|
70
63
|
|
71
64
|
self.proj_drop = nn.Dropout(config.dropout)
|
72
65
|
|
73
|
-
self.qk_normalization = config.qk_normalization
|
74
|
-
|
75
|
-
if self.qk_normalization:
|
76
|
-
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
77
|
-
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
78
|
-
|
79
66
|
def forward(
|
80
67
|
self,
|
81
68
|
hidden_states: torch.Tensor,
|
@@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module):
|
|
91
78
|
super().__init__()
|
92
79
|
self.config = config
|
93
80
|
self.embed_dim = config.hidden_size
|
94
|
-
self.image_size =
|
95
|
-
|
81
|
+
self.image_size = (
|
82
|
+
config.image_size
|
83
|
+
if isinstance(config.image_size, int)
|
84
|
+
else config.image_size[0]
|
85
|
+
)
|
86
|
+
self.patch_size = (
|
87
|
+
config.patch_size
|
88
|
+
if isinstance(config.patch_size, int)
|
89
|
+
else config.patch_size[0]
|
90
|
+
)
|
96
91
|
|
97
92
|
self.class_embedding = nn.Parameter(
|
98
93
|
torch.randn(1, 1, self.embed_dim),
|
@@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module):
|
|
199
194
|
self.embed_dim = config.hidden_size
|
200
195
|
self.intermediate_size = config.intermediate_size
|
201
196
|
self.norm_type = config.norm_type
|
202
|
-
self.attn = InternAttention(config)
|
197
|
+
self.attn = InternAttention(config=config, quant_config=quant_config)
|
203
198
|
self.mlp = InternMLP(config)
|
204
199
|
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
205
200
|
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
@@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module):
|
|
417
412
|
super().__init__()
|
418
413
|
self.config = config
|
419
414
|
self.quant_config = quant_config
|
420
|
-
|
415
|
+
self._update_vision_config()
|
421
416
|
image_size = config.force_image_size or config.vision_config.image_size
|
422
417
|
patch_size = config.vision_config.patch_size
|
423
418
|
self.patch_size = patch_size
|
@@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module):
|
|
446
441
|
self.language_model = InternLM2ForCausalLM(
|
447
442
|
config=config.llm_config, quant_config=quant_config
|
448
443
|
)
|
444
|
+
elif config.llm_config.architectures[0] == "Qwen3MoeForCausalLM":
|
445
|
+
self.language_model = Qwen3MoeForCausalLM(
|
446
|
+
config=config.llm_config, quant_config=quant_config
|
447
|
+
)
|
449
448
|
else:
|
450
449
|
raise NotImplementedError(
|
451
450
|
f"{config.llm_config.architectures[0]} is not implemented."
|
@@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module):
|
|
463
462
|
nn.Linear(llm_hidden_size, llm_hidden_size),
|
464
463
|
)
|
465
464
|
|
465
|
+
def _update_vision_config(self):
|
466
|
+
"""update vision config to support tp"""
|
467
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
468
|
+
num_heads = self.config.vision_config.num_attention_heads
|
469
|
+
head_dim = self.config.vision_config.hidden_size // num_heads
|
470
|
+
num_dummy_heads = 0
|
471
|
+
|
472
|
+
if num_heads % world_size != 0:
|
473
|
+
num_dummy_heads = (
|
474
|
+
(num_heads + world_size) // world_size
|
475
|
+
) * world_size - num_heads
|
476
|
+
|
477
|
+
setattr(self.config.vision_config, "head_dim", head_dim)
|
478
|
+
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
479
|
+
|
466
480
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
467
481
|
n, w, h, c = x.size()
|
468
482
|
# N, W, H, C --> N, W, H * scale, C // scale
|
@@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module):
|
|
545
559
|
|
546
560
|
return helper.pad_input_tokens(input_ids, mm_inputs)
|
547
561
|
|
562
|
+
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
563
|
+
"""pad attn qkv weights for dummy heads"""
|
564
|
+
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
565
|
+
if num_dummy_heads == 0:
|
566
|
+
return loaded_weight
|
567
|
+
head_dim = self.config.vision_config.head_dim
|
568
|
+
|
569
|
+
if "attn.qkv_proj" in name:
|
570
|
+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
571
|
+
if name.endswith(".weight"):
|
572
|
+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
573
|
+
elif name.endswith(".bias"):
|
574
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
575
|
+
else:
|
576
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
577
|
+
pad_func = lambda x: torch.cat(
|
578
|
+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
579
|
+
).flatten(0, 1)
|
580
|
+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
581
|
+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
582
|
+
if "attn.proj.weight" in name:
|
583
|
+
padded_weight = loaded_weight.new_zeros(
|
584
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
585
|
+
)
|
586
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
587
|
+
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
588
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
589
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
590
|
+
return loaded_weight
|
591
|
+
|
548
592
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
593
|
+
expert_params_mapping = []
|
549
594
|
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
|
550
595
|
stacked_params_mapping = [
|
551
596
|
# (param_name, shard_name, shard_id)
|
@@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module):
|
|
561
606
|
("gate_up_proj", "gate_proj", 0),
|
562
607
|
("gate_up_proj", "up_proj", 1),
|
563
608
|
]
|
609
|
+
elif "Qwen3MoeForCausalLM" in self.config.llm_config.architectures:
|
610
|
+
stacked_params_mapping = [
|
611
|
+
# (param_name, shard_name, shard_id)
|
612
|
+
("qkv_proj", "q_proj", "q"),
|
613
|
+
("qkv_proj", "k_proj", "k"),
|
614
|
+
("qkv_proj", "v_proj", "v"),
|
615
|
+
("gate_up_proj", "gate_proj", 0),
|
616
|
+
("gate_up_proj", "up_proj", 1),
|
617
|
+
]
|
618
|
+
|
619
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
620
|
+
ckpt_gate_proj_name="gate_proj",
|
621
|
+
ckpt_down_proj_name="down_proj",
|
622
|
+
ckpt_up_proj_name="up_proj",
|
623
|
+
num_experts=self.config.num_experts,
|
624
|
+
)
|
625
|
+
|
564
626
|
params_dict = dict(self.named_parameters())
|
565
627
|
loaded_params: Set[str] = set()
|
566
628
|
|
567
629
|
for name, loaded_weight in weights:
|
568
630
|
if "rotary_emb.inv_freq" in name:
|
569
631
|
continue
|
632
|
+
|
570
633
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
571
634
|
if weight_name not in name:
|
572
635
|
continue
|
636
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
637
|
+
# Since we handle the experts below in expert_params_mapping,
|
638
|
+
# we need to skip here BEFORE we update the name, otherwise
|
639
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
640
|
+
# will then be updated below in expert_params_mapping
|
641
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
642
|
+
if "mlp.experts" in name:
|
643
|
+
continue
|
573
644
|
name = name.replace(weight_name, param_name)
|
574
645
|
# Skip loading extra bias for GPTQ models.
|
575
646
|
if name.endswith(".bias") and name not in params_dict:
|
@@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module):
|
|
584
655
|
name = name.replace(r"attn.", r"attn.attn.")
|
585
656
|
name = name.replace(r"qkv.", r"qkv_proj.")
|
586
657
|
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
kv_groups = config.num_attention_heads // config.num_key_value_heads
|
594
|
-
head_dim = config.hidden_size // config.num_attention_heads
|
595
|
-
loaded_weight = loaded_weight.view(
|
596
|
-
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
597
|
-
)
|
598
|
-
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
|
599
|
-
wq = wq.reshape(-1, wq.shape[-1])
|
600
|
-
wk = wk.reshape(-1, wk.shape[-1])
|
601
|
-
wv = wv.reshape(-1, wv.shape[-1])
|
658
|
+
for mapping in expert_params_mapping:
|
659
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
660
|
+
if weight_name not in name:
|
661
|
+
continue
|
662
|
+
name = name.replace(weight_name, param_name)
|
663
|
+
param = params_dict[name]
|
602
664
|
weight_loader = param.weight_loader
|
603
|
-
weight_loader(
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
665
|
+
weight_loader(
|
666
|
+
param,
|
667
|
+
loaded_weight,
|
668
|
+
name,
|
669
|
+
shard_id=shard_id,
|
670
|
+
expert_id=expert_id,
|
609
671
|
)
|
610
|
-
|
672
|
+
break
|
673
|
+
else:
|
674
|
+
# Skip loading extra bias for GPTQ models.
|
675
|
+
if name.endswith(".bias") and name not in params_dict:
|
676
|
+
continue
|
677
|
+
param = params_dict[name]
|
678
|
+
if "wqkv" in name:
|
679
|
+
config = self.config
|
680
|
+
kv_groups = (
|
681
|
+
config.num_attention_heads // config.num_key_value_heads
|
682
|
+
)
|
683
|
+
head_dim = config.hidden_size // config.num_attention_heads
|
684
|
+
loaded_weight = loaded_weight.view(
|
685
|
+
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
686
|
+
)
|
687
|
+
wq, wk, wv = torch.split(
|
688
|
+
loaded_weight, [kv_groups, 1, 1], dim=1
|
689
|
+
)
|
690
|
+
wq = wq.reshape(-1, wq.shape[-1])
|
691
|
+
wk = wk.reshape(-1, wk.shape[-1])
|
692
|
+
wv = wv.reshape(-1, wv.shape[-1])
|
693
|
+
weight_loader = param.weight_loader
|
694
|
+
weight_loader(param, wq, "q")
|
695
|
+
weight_loader(param, wk, "k")
|
696
|
+
weight_loader(param, wv, "v")
|
697
|
+
else:
|
698
|
+
weight_loader = getattr(
|
699
|
+
param, "weight_loader", default_weight_loader
|
700
|
+
)
|
701
|
+
if "vision_model" in name:
|
702
|
+
loaded_weight = self._pad_vit_attn_dummy_heads(
|
703
|
+
name, loaded_weight
|
704
|
+
)
|
705
|
+
weight_loader(param, loaded_weight)
|
706
|
+
|
611
707
|
loaded_params.add(name)
|
612
708
|
unloaded_params = params_dict.keys() - loaded_params
|
613
709
|
if unloaded_params:
|
sglang/srt/models/llava.py
CHANGED
@@ -656,11 +656,15 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
|
656
656
|
self, auto_model_type: Type[AutoModel]
|
657
657
|
) -> Dict[str, str]:
|
658
658
|
mapping = {}
|
659
|
-
for config_cls
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
659
|
+
for config_cls in auto_model_type._model_mapping.keys():
|
660
|
+
archs = auto_model_type._model_mapping.get(config_cls, None)
|
661
|
+
if archs is not None:
|
662
|
+
if isinstance(archs, tuple):
|
663
|
+
mapping[config_cls.__name__] = tuple(
|
664
|
+
arch.__name__ for arch in archs
|
665
|
+
)
|
666
|
+
else:
|
667
|
+
mapping[config_cls.__name__] = archs.__name__
|
664
668
|
return mapping
|
665
669
|
|
666
670
|
def __init__(
|
sglang/srt/models/minicpmo.py
CHANGED
@@ -1134,7 +1134,10 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
|
1134
1134
|
"""
|
1135
1135
|
residual = hidden_states
|
1136
1136
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
1137
|
-
|
1137
|
+
# TODO (lifuhuang): confirmed with Mick that the logic for past_key_values is copied from minicpmo official code,
|
1138
|
+
# currently we are not using past_key_values at all. We need to redesign the caching logic when we support streaming
|
1139
|
+
# in the future.
|
1140
|
+
hidden_states, attn_weights = self.self_attn(
|
1138
1141
|
hidden_states=hidden_states,
|
1139
1142
|
attention_mask=attention_mask,
|
1140
1143
|
layer_head_mask=layer_head_mask,
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
147
147
|
# Additional args for FusedMoE
|
148
148
|
**(
|
149
149
|
dict(
|
150
|
-
|
150
|
+
enable_flashinfer_cutlass_moe=True,
|
151
151
|
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
152
152
|
)
|
153
|
-
if global_server_args_dict["
|
153
|
+
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
154
154
|
else {}
|
155
155
|
),
|
156
156
|
)
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
120
120
|
# Additional args for FusedMoE
|
121
121
|
**(
|
122
122
|
dict(
|
123
|
-
|
123
|
+
enable_flashinfer_cutlass_moe=True,
|
124
124
|
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
125
125
|
)
|
126
|
-
if global_server_args_dict["
|
126
|
+
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
127
127
|
else {}
|
128
128
|
),
|
129
129
|
)
|
@@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
144
144
|
)
|
145
145
|
self.top_k = config.num_experts_per_tok
|
146
146
|
|
147
|
-
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
148
|
-
group=parallel_state.get_tp_group().device_group,
|
149
|
-
router_topk=self.top_k,
|
150
|
-
permute_fusion=True,
|
151
|
-
num_experts=self.num_experts,
|
152
|
-
num_local_experts=config.num_experts // self.tp_size,
|
153
|
-
hidden_size=config.hidden_size,
|
154
|
-
params_dtype=config.torch_dtype,
|
155
|
-
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
156
|
-
async_finish=True, # TODO
|
157
|
-
return_recv_hook=True,
|
158
|
-
)
|
159
|
-
|
160
147
|
def forward(
|
161
148
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
162
149
|
) -> torch.Tensor:
|
@@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
207
194
|
topk_weights = torch.empty(
|
208
195
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
209
196
|
)
|
210
|
-
if self.ep_size > 1:
|
211
|
-
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
212
|
-
(
|
213
|
-
hidden_states,
|
214
|
-
topk_idx,
|
215
|
-
topk_weights,
|
216
|
-
reorder_topk_ids,
|
217
|
-
num_recv_tokens_per_expert,
|
218
|
-
seg_indptr,
|
219
|
-
masked_m,
|
220
|
-
expected_m,
|
221
|
-
) = self.deepep_dispatcher.dispatch(
|
222
|
-
hidden_states=hidden_states,
|
223
|
-
topk_idx=topk_idx,
|
224
|
-
topk_weights=topk_weights,
|
225
|
-
forward_batch=forward_batch,
|
226
|
-
)
|
227
197
|
final_hidden_states = self.experts(
|
228
198
|
hidden_states=hidden_states,
|
229
199
|
topk_idx=topk_idx,
|
230
200
|
topk_weights=topk_weights,
|
231
|
-
reorder_topk_ids=reorder_topk_ids,
|
232
|
-
seg_indptr=seg_indptr,
|
233
|
-
masked_m=masked_m,
|
234
|
-
expected_m=expected_m,
|
235
|
-
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
236
201
|
forward_batch=forward_batch,
|
237
202
|
)
|
238
|
-
if self.ep_size > 1:
|
239
|
-
final_hidden_states = self.deepep_dispatcher.combine(
|
240
|
-
hidden_states=final_hidden_states,
|
241
|
-
topk_idx=topk_idx,
|
242
|
-
topk_weights=topk_weights,
|
243
|
-
forward_batch=forward_batch,
|
244
|
-
)
|
245
203
|
return final_hidden_states
|
246
204
|
|
247
205
|
def op_gate(self, state):
|
@@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
278
236
|
|
279
237
|
def op_dispatch_a(self, state):
|
280
238
|
if self.ep_size > 1:
|
281
|
-
|
282
|
-
self.deepep_dispatcher.dispatch_a(
|
239
|
+
self.experts.deepep_dispatcher.dispatch_a(
|
283
240
|
hidden_states=state.pop("hidden_states_mlp_input"),
|
284
241
|
topk_idx=state.pop("topk_idx_local"),
|
285
242
|
topk_weights=state.pop("topk_weights_local"),
|
@@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
292
249
|
with get_global_expert_distribution_recorder().with_current_layer(
|
293
250
|
self.layer_id
|
294
251
|
):
|
295
|
-
(
|
296
|
-
state.hidden_states_experts_input,
|
297
|
-
state.topk_idx_dispatched,
|
298
|
-
state.topk_weights_dispatched,
|
299
|
-
state.reorder_topk_ids,
|
300
|
-
state.num_recv_tokens_per_expert,
|
301
|
-
state.seg_indptr,
|
302
|
-
state.masked_m,
|
303
|
-
state.expected_m,
|
304
|
-
) = self.deepep_dispatcher.dispatch_b(
|
252
|
+
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
305
253
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
306
254
|
)
|
307
255
|
|
308
256
|
def op_experts(self, state):
|
309
|
-
state.hidden_states_experts_output = self.experts(
|
310
|
-
|
311
|
-
topk_idx=state.topk_idx_dispatched,
|
312
|
-
topk_weights=state.topk_weights_dispatched,
|
313
|
-
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
314
|
-
seg_indptr=state.pop("seg_indptr"),
|
315
|
-
masked_m=state.pop("masked_m"),
|
316
|
-
expected_m=state.pop("expected_m"),
|
317
|
-
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
318
|
-
forward_batch=state.forward_batch,
|
257
|
+
state.hidden_states_experts_output = self.experts.moe_impl(
|
258
|
+
dispatch_output=state.dispatch_output,
|
319
259
|
)
|
320
260
|
|
321
261
|
def op_combine_a(self, state):
|
322
262
|
if self.ep_size > 1:
|
323
|
-
self.deepep_dispatcher.combine_a(
|
263
|
+
self.experts.deepep_dispatcher.combine_a(
|
324
264
|
hidden_states=state.pop("hidden_states_experts_output"),
|
325
|
-
topk_idx=state.
|
326
|
-
topk_weights=state.
|
265
|
+
topk_idx=state.dispatch_output.topk_idx,
|
266
|
+
topk_weights=state.dispatch_output.topk_weights,
|
327
267
|
forward_batch=state.forward_batch,
|
328
268
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
329
269
|
)
|
270
|
+
state.pop("dispatch_output")
|
330
271
|
|
331
272
|
def op_combine_b(self, state):
|
332
273
|
if self.ep_size > 1:
|
333
|
-
state.hidden_states_after_combine =
|
334
|
-
|
274
|
+
state.hidden_states_after_combine = (
|
275
|
+
self.experts.deepep_dispatcher.combine_b(
|
276
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
277
|
+
)
|
335
278
|
)
|
336
279
|
|
337
280
|
def op_output(self, state):
|
@@ -707,6 +650,9 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
707
650
|
self.logits_processor = LogitsProcessor(config)
|
708
651
|
self.capture_aux_hidden_states = False
|
709
652
|
|
653
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
654
|
+
return self.model.embed_tokens
|
655
|
+
|
710
656
|
@torch.no_grad()
|
711
657
|
def forward(
|
712
658
|
self,
|
@@ -12,6 +12,7 @@ import torch
|
|
12
12
|
from PIL import Image
|
13
13
|
from transformers import BaseImageProcessorFast
|
14
14
|
|
15
|
+
from sglang.srt.managers.mm_utils import TransportProxyTensor
|
15
16
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
16
17
|
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
17
18
|
|
@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
|
|
142
143
|
class BaseMultimodalProcessor(ABC):
|
143
144
|
models = []
|
144
145
|
|
145
|
-
def __init__(
|
146
|
+
def __init__(
|
147
|
+
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
|
148
|
+
):
|
146
149
|
self.hf_config = hf_config
|
147
150
|
self._processor = _processor
|
148
151
|
self.arch = hf_config.architectures[0]
|
149
152
|
self.server_args = server_args
|
153
|
+
self.transport_mode = transport_mode
|
150
154
|
|
151
155
|
# FIXME: not accurate, model and image specific
|
152
156
|
self.NUM_TOKEN_PER_FRAME = 330
|
@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
|
|
217
221
|
return_tensors="pt",
|
218
222
|
**kwargs,
|
219
223
|
)
|
220
|
-
if "pixel_values" in result and isinstance(
|
221
|
-
result["pixel_values"], torch.Tensor
|
222
|
-
):
|
223
|
-
result["pixel_values"] = result["pixel_values"].to("cpu")
|
224
224
|
return result
|
225
225
|
|
226
226
|
@abstractmethod
|
@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
|
|
500
500
|
) -> List[MultimodalDataItem]:
|
501
501
|
"""Create mm_items directly from processor output."""
|
502
502
|
items: dict[Modality, MultimodalDataItem] = {}
|
503
|
-
|
504
503
|
for attr_name, value in data_dict.items():
|
505
504
|
if attr_name == "input_ids":
|
506
505
|
continue
|
@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
|
|
624
623
|
mm_token_id=mm_token_id,
|
625
624
|
)
|
626
625
|
|
626
|
+
# post-process
|
627
|
+
for item in all_collected_items:
|
628
|
+
# replace the feature tensor with a proxy
|
629
|
+
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
|
630
|
+
item.feature = TransportProxyTensor(
|
631
|
+
transport_mode=self.transport_mode, data=item.feature
|
632
|
+
)
|
633
|
+
elif (
|
634
|
+
isinstance(item.precomputed_embeddings, torch.Tensor)
|
635
|
+
and item.precomputed_embeddings.is_cuda
|
636
|
+
):
|
637
|
+
item.precomputed_embeddings = TransportProxyTensor(
|
638
|
+
transport_mode=self.transport_mode, data=item.precomputed_embeddings
|
639
|
+
)
|
640
|
+
|
627
641
|
return all_collected_items, input_ids, ret
|
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
10
10
|
class ClipImageProcessor(BaseMultimodalProcessor):
|
11
11
|
models = [CLIPModel]
|
12
12
|
|
13
|
-
def __init__(self, hf_config, server_args, _processor):
|
14
|
-
super().__init__(hf_config, server_args, _processor)
|
13
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
14
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
15
15
|
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
16
16
|
_processor
|
17
17
|
)
|
@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
31
31
|
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
32
32
|
models = [DeepseekVL2ForCausalLM]
|
33
33
|
|
34
|
-
def __init__(self, hf_config, server_args, _processor):
|
35
|
-
super().__init__(hf_config, server_args, _processor)
|
34
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
35
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
36
36
|
self.mm_tokens = MultimodalSpecialTokens(
|
37
37
|
image_token="<image>", image_token_id=self._processor.image_token_id
|
38
38
|
).build(_processor)
|
@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
|
14
14
|
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
15
15
|
models = [Gemma3ForConditionalGeneration]
|
16
16
|
|
17
|
-
def __init__(self, hf_config, server_args, _processor):
|
18
|
-
super().__init__(hf_config, server_args, _processor)
|
17
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
18
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
19
19
|
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
20
20
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
21
21
|
self.mm_tokens = MultimodalSpecialTokens(
|
@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
|
27
27
|
|
28
28
|
models = [Gemma3nForConditionalGeneration]
|
29
29
|
|
30
|
-
def __init__(self, hf_config, server_args, _processor):
|
31
|
-
super().__init__(hf_config, server_args, _processor)
|
30
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
31
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
32
32
|
|
33
33
|
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
34
34
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
@@ -6,6 +6,7 @@ from decord import VideoReader, cpu
|
|
6
6
|
from PIL import Image
|
7
7
|
|
8
8
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
9
|
+
from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
|
9
10
|
from sglang.srt.models.internvl import InternVLChatModel
|
10
11
|
from sglang.srt.multimodal.processors.base_processor import (
|
11
12
|
BaseMultimodalProcessor,
|
@@ -14,12 +15,19 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
14
15
|
|
15
16
|
|
16
17
|
class InternVLImageProcessor(BaseMultimodalProcessor):
|
17
|
-
models = [InternVLChatModel]
|
18
|
+
models = [InternVLChatModel, InternS1ForConditionalGeneration]
|
18
19
|
|
19
|
-
def __init__(self, hf_config, server_args, _image_processor):
|
20
|
-
super().__init__(hf_config, server_args, _image_processor)
|
21
|
-
image_size =
|
20
|
+
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
|
21
|
+
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
|
22
|
+
image_size = (
|
23
|
+
getattr(hf_config, "force_image_size", None)
|
24
|
+
or hf_config.vision_config.image_size
|
25
|
+
)
|
22
26
|
patch_size = hf_config.vision_config.patch_size
|
27
|
+
if isinstance(image_size, list):
|
28
|
+
image_size = image_size[0]
|
29
|
+
if isinstance(patch_size, list):
|
30
|
+
patch_size = patch_size[0]
|
23
31
|
|
24
32
|
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
25
33
|
self.IMG_START_TOKEN = "<img>"
|
@@ -27,8 +35,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
27
35
|
self.num_image_token = int(
|
28
36
|
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
|
29
37
|
)
|
38
|
+
if hasattr(self._processor, "tokenizer"):
|
39
|
+
tokenizer = self._processor.tokenizer
|
40
|
+
else:
|
41
|
+
tokenizer = self._processor
|
42
|
+
self.tokenizer = tokenizer
|
30
43
|
|
31
|
-
tokenizer = self._processor
|
32
44
|
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
33
45
|
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
34
46
|
self.mm_tokens = MultimodalSpecialTokens(
|
@@ -195,7 +207,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
195
207
|
try:
|
196
208
|
# TODO: video input
|
197
209
|
raw_image = process_image_internvl(image)
|
198
|
-
pixel_value = [raw_image.to(torch.bfloat16)
|
210
|
+
pixel_value = [raw_image.to(torch.bfloat16)]
|
199
211
|
pixel_values += pixel_value
|
200
212
|
num_patches = raw_image.shape[0]
|
201
213
|
num_patches_list += [num_patches]
|
@@ -214,8 +226,9 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
214
226
|
)
|
215
227
|
input_text = input_text.replace("<image>", image_tokens, 1)
|
216
228
|
|
217
|
-
|
218
|
-
|
229
|
+
input_ids = self.tokenizer(input_text, return_tensors="pt")[
|
230
|
+
"input_ids"
|
231
|
+
].flatten()
|
219
232
|
image_offsets = self.get_mm_items_offset(
|
220
233
|
input_ids=input_ids,
|
221
234
|
mm_token_id=self.mm_tokens.image_token_id,
|