sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +84 -22
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +25 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +37 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +68 -14
- sglang/srt/models/deepseek_v2.py +62 -28
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +5 -2
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +57 -6
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
sglang/srt/models/internvl.py
CHANGED
@@ -1,16 +1,3 @@
|
|
1
|
-
# Copyright 2023-2024 SGLang Team
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
# ==========================582====================================================
|
14
1
|
from typing import Iterable, List, Optional, Set, Tuple, Union
|
15
2
|
|
16
3
|
import torch
|
@@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
23
10
|
from transformers.activations import ACT2FN
|
24
11
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
25
12
|
|
13
|
+
from sglang.srt.distributed import parallel_state
|
26
14
|
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
15
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
27
16
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
17
|
from sglang.srt.managers.mm_utils import (
|
29
18
|
MultiModalityDataPaddingPatternTokenPairs,
|
@@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
39
28
|
from sglang.srt.models.deepseek_janus_pro import DropPath
|
40
29
|
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
41
30
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
31
|
+
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
42
32
|
from sglang.utils import logger
|
43
33
|
|
44
34
|
|
@@ -53,7 +43,6 @@ class InternAttention(nn.Module):
|
|
53
43
|
self.embed_dim = config.hidden_size
|
54
44
|
self.num_heads = config.num_attention_heads
|
55
45
|
self.head_dim = self.embed_dim // self.num_heads
|
56
|
-
|
57
46
|
self.scale = self.head_dim**-0.5
|
58
47
|
|
59
48
|
self.attn = VisionAttention(
|
@@ -64,18 +53,16 @@ class InternAttention(nn.Module):
|
|
64
53
|
use_qkv_parallel=True,
|
65
54
|
quant_config=quant_config,
|
66
55
|
dropout=getattr(config, "dropout", 0.0),
|
67
|
-
|
56
|
+
qkv_bias=getattr(config, "qkv_bias", False)
|
57
|
+
or getattr(config, "attention_bias", False),
|
58
|
+
num_dummy_heads=getattr(config, "num_dummy_heads", 0),
|
59
|
+
qk_normalization=getattr(config, "qk_normalization", False)
|
60
|
+
or getattr(config, "use_qk_norm", False),
|
68
61
|
flatten_batch=False,
|
69
62
|
)
|
70
63
|
|
71
64
|
self.proj_drop = nn.Dropout(config.dropout)
|
72
65
|
|
73
|
-
self.qk_normalization = config.qk_normalization
|
74
|
-
|
75
|
-
if self.qk_normalization:
|
76
|
-
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
77
|
-
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
78
|
-
|
79
66
|
def forward(
|
80
67
|
self,
|
81
68
|
hidden_states: torch.Tensor,
|
@@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module):
|
|
91
78
|
super().__init__()
|
92
79
|
self.config = config
|
93
80
|
self.embed_dim = config.hidden_size
|
94
|
-
self.image_size =
|
95
|
-
|
81
|
+
self.image_size = (
|
82
|
+
config.image_size
|
83
|
+
if isinstance(config.image_size, int)
|
84
|
+
else config.image_size[0]
|
85
|
+
)
|
86
|
+
self.patch_size = (
|
87
|
+
config.patch_size
|
88
|
+
if isinstance(config.patch_size, int)
|
89
|
+
else config.patch_size[0]
|
90
|
+
)
|
96
91
|
|
97
92
|
self.class_embedding = nn.Parameter(
|
98
93
|
torch.randn(1, 1, self.embed_dim),
|
@@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module):
|
|
199
194
|
self.embed_dim = config.hidden_size
|
200
195
|
self.intermediate_size = config.intermediate_size
|
201
196
|
self.norm_type = config.norm_type
|
202
|
-
self.attn = InternAttention(config)
|
197
|
+
self.attn = InternAttention(config=config, quant_config=quant_config)
|
203
198
|
self.mlp = InternMLP(config)
|
204
199
|
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
205
200
|
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
@@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module):
|
|
417
412
|
super().__init__()
|
418
413
|
self.config = config
|
419
414
|
self.quant_config = quant_config
|
420
|
-
|
415
|
+
self._update_vision_config()
|
421
416
|
image_size = config.force_image_size or config.vision_config.image_size
|
422
417
|
patch_size = config.vision_config.patch_size
|
423
418
|
self.patch_size = patch_size
|
@@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module):
|
|
446
441
|
self.language_model = InternLM2ForCausalLM(
|
447
442
|
config=config.llm_config, quant_config=quant_config
|
448
443
|
)
|
444
|
+
elif config.llm_config.architectures[0] == "Qwen3MoeForCausalLM":
|
445
|
+
self.language_model = Qwen3MoeForCausalLM(
|
446
|
+
config=config.llm_config, quant_config=quant_config
|
447
|
+
)
|
449
448
|
else:
|
450
449
|
raise NotImplementedError(
|
451
450
|
f"{config.llm_config.architectures[0]} is not implemented."
|
@@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module):
|
|
463
462
|
nn.Linear(llm_hidden_size, llm_hidden_size),
|
464
463
|
)
|
465
464
|
|
465
|
+
def _update_vision_config(self):
|
466
|
+
"""update vision config to support tp"""
|
467
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
468
|
+
num_heads = self.config.vision_config.num_attention_heads
|
469
|
+
head_dim = self.config.vision_config.hidden_size // num_heads
|
470
|
+
num_dummy_heads = 0
|
471
|
+
|
472
|
+
if num_heads % world_size != 0:
|
473
|
+
num_dummy_heads = (
|
474
|
+
(num_heads + world_size) // world_size
|
475
|
+
) * world_size - num_heads
|
476
|
+
|
477
|
+
setattr(self.config.vision_config, "head_dim", head_dim)
|
478
|
+
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
479
|
+
|
466
480
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
467
481
|
n, w, h, c = x.size()
|
468
482
|
# N, W, H, C --> N, W, H * scale, C // scale
|
@@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module):
|
|
545
559
|
|
546
560
|
return helper.pad_input_tokens(input_ids, mm_inputs)
|
547
561
|
|
562
|
+
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
563
|
+
"""pad attn qkv weights for dummy heads"""
|
564
|
+
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
565
|
+
if num_dummy_heads == 0:
|
566
|
+
return loaded_weight
|
567
|
+
head_dim = self.config.vision_config.head_dim
|
568
|
+
|
569
|
+
if "attn.qkv_proj" in name:
|
570
|
+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
571
|
+
if name.endswith(".weight"):
|
572
|
+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
573
|
+
elif name.endswith(".bias"):
|
574
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
575
|
+
else:
|
576
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
577
|
+
pad_func = lambda x: torch.cat(
|
578
|
+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
579
|
+
).flatten(0, 1)
|
580
|
+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
581
|
+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
582
|
+
if "attn.proj.weight" in name:
|
583
|
+
padded_weight = loaded_weight.new_zeros(
|
584
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
585
|
+
)
|
586
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
587
|
+
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
588
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
589
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
590
|
+
return loaded_weight
|
591
|
+
|
548
592
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
593
|
+
expert_params_mapping = []
|
549
594
|
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
|
550
595
|
stacked_params_mapping = [
|
551
596
|
# (param_name, shard_name, shard_id)
|
@@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module):
|
|
561
606
|
("gate_up_proj", "gate_proj", 0),
|
562
607
|
("gate_up_proj", "up_proj", 1),
|
563
608
|
]
|
609
|
+
elif "Qwen3MoeForCausalLM" in self.config.llm_config.architectures:
|
610
|
+
stacked_params_mapping = [
|
611
|
+
# (param_name, shard_name, shard_id)
|
612
|
+
("qkv_proj", "q_proj", "q"),
|
613
|
+
("qkv_proj", "k_proj", "k"),
|
614
|
+
("qkv_proj", "v_proj", "v"),
|
615
|
+
("gate_up_proj", "gate_proj", 0),
|
616
|
+
("gate_up_proj", "up_proj", 1),
|
617
|
+
]
|
618
|
+
|
619
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
620
|
+
ckpt_gate_proj_name="gate_proj",
|
621
|
+
ckpt_down_proj_name="down_proj",
|
622
|
+
ckpt_up_proj_name="up_proj",
|
623
|
+
num_experts=self.config.num_experts,
|
624
|
+
)
|
625
|
+
|
564
626
|
params_dict = dict(self.named_parameters())
|
565
627
|
loaded_params: Set[str] = set()
|
566
628
|
|
567
629
|
for name, loaded_weight in weights:
|
568
630
|
if "rotary_emb.inv_freq" in name:
|
569
631
|
continue
|
632
|
+
|
570
633
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
571
634
|
if weight_name not in name:
|
572
635
|
continue
|
636
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
637
|
+
# Since we handle the experts below in expert_params_mapping,
|
638
|
+
# we need to skip here BEFORE we update the name, otherwise
|
639
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
640
|
+
# will then be updated below in expert_params_mapping
|
641
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
642
|
+
if "mlp.experts" in name:
|
643
|
+
continue
|
573
644
|
name = name.replace(weight_name, param_name)
|
574
645
|
# Skip loading extra bias for GPTQ models.
|
575
646
|
if name.endswith(".bias") and name not in params_dict:
|
@@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module):
|
|
584
655
|
name = name.replace(r"attn.", r"attn.attn.")
|
585
656
|
name = name.replace(r"qkv.", r"qkv_proj.")
|
586
657
|
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
kv_groups = config.num_attention_heads // config.num_key_value_heads
|
594
|
-
head_dim = config.hidden_size // config.num_attention_heads
|
595
|
-
loaded_weight = loaded_weight.view(
|
596
|
-
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
597
|
-
)
|
598
|
-
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
|
599
|
-
wq = wq.reshape(-1, wq.shape[-1])
|
600
|
-
wk = wk.reshape(-1, wk.shape[-1])
|
601
|
-
wv = wv.reshape(-1, wv.shape[-1])
|
658
|
+
for mapping in expert_params_mapping:
|
659
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
660
|
+
if weight_name not in name:
|
661
|
+
continue
|
662
|
+
name = name.replace(weight_name, param_name)
|
663
|
+
param = params_dict[name]
|
602
664
|
weight_loader = param.weight_loader
|
603
|
-
weight_loader(
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
665
|
+
weight_loader(
|
666
|
+
param,
|
667
|
+
loaded_weight,
|
668
|
+
name,
|
669
|
+
shard_id=shard_id,
|
670
|
+
expert_id=expert_id,
|
609
671
|
)
|
610
|
-
|
672
|
+
break
|
673
|
+
else:
|
674
|
+
# Skip loading extra bias for GPTQ models.
|
675
|
+
if name.endswith(".bias") and name not in params_dict:
|
676
|
+
continue
|
677
|
+
param = params_dict[name]
|
678
|
+
if "wqkv" in name:
|
679
|
+
config = self.config
|
680
|
+
kv_groups = (
|
681
|
+
config.num_attention_heads // config.num_key_value_heads
|
682
|
+
)
|
683
|
+
head_dim = config.hidden_size // config.num_attention_heads
|
684
|
+
loaded_weight = loaded_weight.view(
|
685
|
+
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
686
|
+
)
|
687
|
+
wq, wk, wv = torch.split(
|
688
|
+
loaded_weight, [kv_groups, 1, 1], dim=1
|
689
|
+
)
|
690
|
+
wq = wq.reshape(-1, wq.shape[-1])
|
691
|
+
wk = wk.reshape(-1, wk.shape[-1])
|
692
|
+
wv = wv.reshape(-1, wv.shape[-1])
|
693
|
+
weight_loader = param.weight_loader
|
694
|
+
weight_loader(param, wq, "q")
|
695
|
+
weight_loader(param, wk, "k")
|
696
|
+
weight_loader(param, wv, "v")
|
697
|
+
else:
|
698
|
+
weight_loader = getattr(
|
699
|
+
param, "weight_loader", default_weight_loader
|
700
|
+
)
|
701
|
+
if "vision_model" in name:
|
702
|
+
loaded_weight = self._pad_vit_attn_dummy_heads(
|
703
|
+
name, loaded_weight
|
704
|
+
)
|
705
|
+
weight_loader(param, loaded_weight)
|
706
|
+
|
611
707
|
loaded_params.add(name)
|
612
708
|
unloaded_params = params_dict.keys() - loaded_params
|
613
709
|
if unloaded_params:
|
sglang/srt/models/llava.py
CHANGED
@@ -656,11 +656,15 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
|
656
656
|
self, auto_model_type: Type[AutoModel]
|
657
657
|
) -> Dict[str, str]:
|
658
658
|
mapping = {}
|
659
|
-
for config_cls
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
659
|
+
for config_cls in auto_model_type._model_mapping.keys():
|
660
|
+
archs = auto_model_type._model_mapping.get(config_cls, None)
|
661
|
+
if archs is not None:
|
662
|
+
if isinstance(archs, tuple):
|
663
|
+
mapping[config_cls.__name__] = tuple(
|
664
|
+
arch.__name__ for arch in archs
|
665
|
+
)
|
666
|
+
else:
|
667
|
+
mapping[config_cls.__name__] = archs.__name__
|
664
668
|
return mapping
|
665
669
|
|
666
670
|
def __init__(
|
sglang/srt/models/minicpmo.py
CHANGED
@@ -1134,7 +1134,10 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
|
1134
1134
|
"""
|
1135
1135
|
residual = hidden_states
|
1136
1136
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
1137
|
-
|
1137
|
+
# TODO (lifuhuang): confirmed with Mick that the logic for past_key_values is copied from minicpmo official code,
|
1138
|
+
# currently we are not using past_key_values at all. We need to redesign the caching logic when we support streaming
|
1139
|
+
# in the future.
|
1140
|
+
hidden_states, attn_weights = self.self_attn(
|
1138
1141
|
hidden_states=hidden_states,
|
1139
1142
|
attention_mask=attention_mask,
|
1140
1143
|
layer_head_mask=layer_head_mask,
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
147
147
|
# Additional args for FusedMoE
|
148
148
|
**(
|
149
149
|
dict(
|
150
|
-
|
150
|
+
enable_flashinfer_cutlass_moe=True,
|
151
151
|
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
152
152
|
)
|
153
|
-
if global_server_args_dict["
|
153
|
+
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
154
154
|
else {}
|
155
155
|
),
|
156
156
|
)
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
120
120
|
# Additional args for FusedMoE
|
121
121
|
**(
|
122
122
|
dict(
|
123
|
-
|
123
|
+
enable_flashinfer_cutlass_moe=True,
|
124
124
|
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
125
125
|
)
|
126
|
-
if global_server_args_dict["
|
126
|
+
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
127
127
|
else {}
|
128
128
|
),
|
129
129
|
)
|
@@ -707,6 +707,9 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
707
707
|
self.logits_processor = LogitsProcessor(config)
|
708
708
|
self.capture_aux_hidden_states = False
|
709
709
|
|
710
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
711
|
+
return self.model.embed_tokens
|
712
|
+
|
710
713
|
@torch.no_grad()
|
711
714
|
def forward(
|
712
715
|
self,
|
@@ -12,6 +12,7 @@ import torch
|
|
12
12
|
from PIL import Image
|
13
13
|
from transformers import BaseImageProcessorFast
|
14
14
|
|
15
|
+
from sglang.srt.managers.mm_utils import TransportProxyTensor
|
15
16
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
16
17
|
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
17
18
|
|
@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
|
|
142
143
|
class BaseMultimodalProcessor(ABC):
|
143
144
|
models = []
|
144
145
|
|
145
|
-
def __init__(
|
146
|
+
def __init__(
|
147
|
+
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
|
148
|
+
):
|
146
149
|
self.hf_config = hf_config
|
147
150
|
self._processor = _processor
|
148
151
|
self.arch = hf_config.architectures[0]
|
149
152
|
self.server_args = server_args
|
153
|
+
self.transport_mode = transport_mode
|
150
154
|
|
151
155
|
# FIXME: not accurate, model and image specific
|
152
156
|
self.NUM_TOKEN_PER_FRAME = 330
|
@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
|
|
217
221
|
return_tensors="pt",
|
218
222
|
**kwargs,
|
219
223
|
)
|
220
|
-
if "pixel_values" in result and isinstance(
|
221
|
-
result["pixel_values"], torch.Tensor
|
222
|
-
):
|
223
|
-
result["pixel_values"] = result["pixel_values"].to("cpu")
|
224
224
|
return result
|
225
225
|
|
226
226
|
@abstractmethod
|
@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
|
|
500
500
|
) -> List[MultimodalDataItem]:
|
501
501
|
"""Create mm_items directly from processor output."""
|
502
502
|
items: dict[Modality, MultimodalDataItem] = {}
|
503
|
-
|
504
503
|
for attr_name, value in data_dict.items():
|
505
504
|
if attr_name == "input_ids":
|
506
505
|
continue
|
@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
|
|
624
623
|
mm_token_id=mm_token_id,
|
625
624
|
)
|
626
625
|
|
626
|
+
# post-process
|
627
|
+
for item in all_collected_items:
|
628
|
+
# replace the feature tensor with a proxy
|
629
|
+
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
|
630
|
+
item.feature = TransportProxyTensor(
|
631
|
+
transport_mode=self.transport_mode, data=item.feature
|
632
|
+
)
|
633
|
+
elif (
|
634
|
+
isinstance(item.precomputed_embeddings, torch.Tensor)
|
635
|
+
and item.precomputed_embeddings.is_cuda
|
636
|
+
):
|
637
|
+
item.precomputed_embeddings = TransportProxyTensor(
|
638
|
+
transport_mode=self.transport_mode, data=item.precomputed_embeddings
|
639
|
+
)
|
640
|
+
|
627
641
|
return all_collected_items, input_ids, ret
|
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
10
10
|
class ClipImageProcessor(BaseMultimodalProcessor):
|
11
11
|
models = [CLIPModel]
|
12
12
|
|
13
|
-
def __init__(self, hf_config, server_args, _processor):
|
14
|
-
super().__init__(hf_config, server_args, _processor)
|
13
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
14
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
15
15
|
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
16
16
|
_processor
|
17
17
|
)
|
@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
31
31
|
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
32
32
|
models = [DeepseekVL2ForCausalLM]
|
33
33
|
|
34
|
-
def __init__(self, hf_config, server_args, _processor):
|
35
|
-
super().__init__(hf_config, server_args, _processor)
|
34
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
35
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
36
36
|
self.mm_tokens = MultimodalSpecialTokens(
|
37
37
|
image_token="<image>", image_token_id=self._processor.image_token_id
|
38
38
|
).build(_processor)
|
@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
|
14
14
|
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
15
15
|
models = [Gemma3ForConditionalGeneration]
|
16
16
|
|
17
|
-
def __init__(self, hf_config, server_args, _processor):
|
18
|
-
super().__init__(hf_config, server_args, _processor)
|
17
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
18
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
19
19
|
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
20
20
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
21
21
|
self.mm_tokens = MultimodalSpecialTokens(
|
@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
|
27
27
|
|
28
28
|
models = [Gemma3nForConditionalGeneration]
|
29
29
|
|
30
|
-
def __init__(self, hf_config, server_args, _processor):
|
31
|
-
super().__init__(hf_config, server_args, _processor)
|
30
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
31
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
32
32
|
|
33
33
|
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
34
34
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
@@ -6,6 +6,7 @@ from decord import VideoReader, cpu
|
|
6
6
|
from PIL import Image
|
7
7
|
|
8
8
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
9
|
+
from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
|
9
10
|
from sglang.srt.models.internvl import InternVLChatModel
|
10
11
|
from sglang.srt.multimodal.processors.base_processor import (
|
11
12
|
BaseMultimodalProcessor,
|
@@ -14,12 +15,19 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
14
15
|
|
15
16
|
|
16
17
|
class InternVLImageProcessor(BaseMultimodalProcessor):
|
17
|
-
models = [InternVLChatModel]
|
18
|
+
models = [InternVLChatModel, InternS1ForConditionalGeneration]
|
18
19
|
|
19
|
-
def __init__(self, hf_config, server_args, _image_processor):
|
20
|
-
super().__init__(hf_config, server_args, _image_processor)
|
21
|
-
image_size =
|
20
|
+
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
|
21
|
+
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
|
22
|
+
image_size = (
|
23
|
+
getattr(hf_config, "force_image_size", None)
|
24
|
+
or hf_config.vision_config.image_size
|
25
|
+
)
|
22
26
|
patch_size = hf_config.vision_config.patch_size
|
27
|
+
if isinstance(image_size, list):
|
28
|
+
image_size = image_size[0]
|
29
|
+
if isinstance(patch_size, list):
|
30
|
+
patch_size = patch_size[0]
|
23
31
|
|
24
32
|
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
25
33
|
self.IMG_START_TOKEN = "<img>"
|
@@ -27,8 +35,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
27
35
|
self.num_image_token = int(
|
28
36
|
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
|
29
37
|
)
|
38
|
+
if hasattr(self._processor, "tokenizer"):
|
39
|
+
tokenizer = self._processor.tokenizer
|
40
|
+
else:
|
41
|
+
tokenizer = self._processor
|
42
|
+
self.tokenizer = tokenizer
|
30
43
|
|
31
|
-
tokenizer = self._processor
|
32
44
|
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
33
45
|
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
34
46
|
self.mm_tokens = MultimodalSpecialTokens(
|
@@ -195,7 +207,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
195
207
|
try:
|
196
208
|
# TODO: video input
|
197
209
|
raw_image = process_image_internvl(image)
|
198
|
-
pixel_value = [raw_image.to(torch.bfloat16)
|
210
|
+
pixel_value = [raw_image.to(torch.bfloat16)]
|
199
211
|
pixel_values += pixel_value
|
200
212
|
num_patches = raw_image.shape[0]
|
201
213
|
num_patches_list += [num_patches]
|
@@ -214,8 +226,9 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
214
226
|
)
|
215
227
|
input_text = input_text.replace("<image>", image_tokens, 1)
|
216
228
|
|
217
|
-
|
218
|
-
|
229
|
+
input_ids = self.tokenizer(input_text, return_tensors="pt")[
|
230
|
+
"input_ids"
|
231
|
+
].flatten()
|
219
232
|
image_offsets = self.get_mm_items_offset(
|
220
233
|
input_ids=input_ids,
|
221
234
|
mm_token_id=self.mm_tokens.image_token_id,
|
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
11
11
|
class JanusProImageProcessor(BaseMultimodalProcessor):
|
12
12
|
models = [MultiModalityCausalLM]
|
13
13
|
|
14
|
-
def __init__(self, hf_config, server_args, _processor):
|
15
|
-
super().__init__(hf_config, server_args, _processor)
|
14
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
15
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
16
16
|
|
17
17
|
self.mm_tokens = MultimodalSpecialTokens(
|
18
18
|
image_token=_processor.image_token,
|
@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
|
12
12
|
class KimiVLImageProcessor(SGLangBaseProcessor):
|
13
13
|
models = [KimiVLForConditionalGeneration]
|
14
14
|
|
15
|
-
def __init__(self, hf_config, server_args, _processor):
|
16
|
-
super().__init__(hf_config, server_args, _processor)
|
15
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
16
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
17
17
|
self.mm_tokens = MultimodalSpecialTokens(
|
18
18
|
image_token="<|media_pad|>",
|
19
19
|
# TODO: could we convert in MultimodalSpecialTokens?
|
@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
30
30
|
LlavaMistralForCausalLM,
|
31
31
|
]
|
32
32
|
|
33
|
-
def __init__(self, hf_config, server_args, _processor):
|
34
|
-
super().__init__(hf_config, server_args, _processor)
|
33
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
34
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
35
35
|
|
36
36
|
@staticmethod
|
37
37
|
def _process_single_image_task(
|
@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
|
187
187
|
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
188
188
|
)
|
189
189
|
|
190
|
-
def __init__(self, hf_config, server_args, _processor):
|
190
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
191
191
|
assert hasattr(hf_config, "vision_config")
|
192
192
|
assert hasattr(hf_config, "text_config")
|
193
193
|
self.vision_config = hf_config.vision_config
|
@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
|
196
196
|
|
197
197
|
if vision_type := getattr(self.vision_config, "model_type"):
|
198
198
|
self.inner = self._get_sgl_processor_cls(vision_type)(
|
199
|
-
hf_config, server_args, _processor
|
199
|
+
hf_config, server_args, _processor, *args, **kwargs
|
200
200
|
)
|
201
201
|
else:
|
202
202
|
raise ValueError(
|
@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
15
15
|
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
16
16
|
models = [MiniCPMV, MiniCPMO]
|
17
17
|
|
18
|
-
def __init__(self, hf_config, server_args, _processor):
|
19
|
-
super().__init__(hf_config, server_args, _processor)
|
18
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
19
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
20
20
|
# Collect special token ids
|
21
21
|
tokenizer = self._processor.tokenizer
|
22
22
|
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
|
@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|
26
26
|
self.im_start_id = getattr(tokenizer, "im_start_id", None)
|
27
27
|
self.im_end_id = getattr(tokenizer, "im_end_id", None)
|
28
28
|
self.im_token_id = getattr(tokenizer, "unk_id", None)
|
29
|
-
|
30
29
|
self.mm_tokens = MultimodalSpecialTokens(
|
31
30
|
image_token="(<image>./</image>)",
|
32
31
|
audio_token="(<audio>./</audio>)",
|
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
10
10
|
class MllamaImageProcessor(BaseMultimodalProcessor):
|
11
11
|
models = [MllamaForConditionalGeneration]
|
12
12
|
|
13
|
-
def __init__(self, hf_config, server_args, _processor):
|
14
|
-
super().__init__(hf_config, server_args, _processor)
|
13
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
14
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
15
15
|
self.mm_tokens = MultimodalSpecialTokens(
|
16
16
|
image_token=self._processor.image_token,
|
17
17
|
image_token_id=self._processor.image_token_id,
|