sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,7 +20,7 @@ import copy
|
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import Any, Dict, List, Optional, Union
|
23
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
24
24
|
|
25
25
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
26
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -45,6 +45,8 @@ class GenerateReqInput:
|
|
45
45
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
46
46
|
# See also python/sglang/srt/utils.py:load_image.
|
47
47
|
image_data: Optional[Union[List[str], str]] = None
|
48
|
+
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
|
49
|
+
audio_data: Optional[Union[List[str], str]] = None
|
48
50
|
# The sampling_params. See descriptions below.
|
49
51
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
50
52
|
# The request id.
|
@@ -103,6 +105,8 @@ class GenerateReqInput:
|
|
103
105
|
self.batch_size = len(self.text)
|
104
106
|
self.input_embeds = None
|
105
107
|
elif self.input_ids is not None:
|
108
|
+
if len(self.input_ids) == 0:
|
109
|
+
raise ValueError("input_ids cannot be empty.")
|
106
110
|
if isinstance(self.input_ids[0], int):
|
107
111
|
self.is_single = True
|
108
112
|
self.batch_size = 1
|
@@ -165,6 +169,13 @@ class GenerateReqInput:
|
|
165
169
|
elif isinstance(self.image_data, list):
|
166
170
|
pass
|
167
171
|
|
172
|
+
if self.audio_data is None:
|
173
|
+
self.audio_data = [None] * num
|
174
|
+
elif not isinstance(self.audio_data, list):
|
175
|
+
self.audio_data = [self.audio_data] * num
|
176
|
+
elif isinstance(self.audio_data, list):
|
177
|
+
pass
|
178
|
+
|
168
179
|
if self.sampling_params is None:
|
169
180
|
self.sampling_params = [{}] * num
|
170
181
|
elif not isinstance(self.sampling_params, list):
|
@@ -229,6 +240,7 @@ class GenerateReqInput:
|
|
229
240
|
text=self.text[i] if self.text is not None else None,
|
230
241
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
231
242
|
image_data=self.image_data[i],
|
243
|
+
audio_data=self.audio_data[i],
|
232
244
|
sampling_params=self.sampling_params[i],
|
233
245
|
rid=self.rid[i],
|
234
246
|
return_logprob=self.return_logprob[i],
|
@@ -257,8 +269,8 @@ class TokenizedGenerateReqInput:
|
|
257
269
|
input_text: str
|
258
270
|
# The input token ids
|
259
271
|
input_ids: List[int]
|
260
|
-
# The
|
261
|
-
|
272
|
+
# The multimodal inputs
|
273
|
+
mm_inputs: dict
|
262
274
|
# The sampling parameters
|
263
275
|
sampling_params: SamplingParams
|
264
276
|
# Whether to return the logprobs
|
@@ -538,7 +550,8 @@ class UpdateWeightsFromDistributedReqOutput:
|
|
538
550
|
|
539
551
|
@dataclass
|
540
552
|
class UpdateWeightsFromTensorReqInput:
|
541
|
-
|
553
|
+
# List containing one serialized Dict[str, torch.Tensor] per TP worker
|
554
|
+
serialized_named_tensors: List[bytes]
|
542
555
|
load_format: Optional[str]
|
543
556
|
flush_cache: bool
|
544
557
|
|
@@ -637,7 +650,7 @@ class ProfileReqInput:
|
|
637
650
|
# If it is set, profiling is automatically stopped after this step, and
|
638
651
|
# the caller doesn't need to run stop_profile.
|
639
652
|
num_steps: Optional[int] = None
|
640
|
-
activities: Optional[List[
|
653
|
+
activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
|
641
654
|
|
642
655
|
|
643
656
|
class ProfileReqType(Enum):
|
@@ -645,12 +658,25 @@ class ProfileReqType(Enum):
|
|
645
658
|
STOP_PROFILE = 2
|
646
659
|
|
647
660
|
|
661
|
+
class ExpertDistributionReq(Enum):
|
662
|
+
START_RECORD = 1
|
663
|
+
STOP_RECORD = 2
|
664
|
+
DUMP_RECORD = 3
|
665
|
+
|
666
|
+
|
667
|
+
@dataclass
|
668
|
+
class ExpertDistributionReqOutput:
|
669
|
+
pass
|
670
|
+
|
671
|
+
|
648
672
|
@dataclass
|
649
673
|
class ProfileReq:
|
650
674
|
type: ProfileReqType
|
651
675
|
output_dir: Optional[str] = None
|
652
676
|
num_steps: Optional[int] = None
|
653
677
|
activities: Optional[List[str]] = None
|
678
|
+
with_stack: Optional[bool] = None
|
679
|
+
record_shapes: Optional[bool] = None
|
654
680
|
|
655
681
|
|
656
682
|
@dataclass
|
@@ -723,3 +749,15 @@ class SeparateReasoningReqInput:
|
|
723
749
|
class VertexGenerateReqInput:
|
724
750
|
instances: List[dict]
|
725
751
|
parameters: Optional[dict] = None
|
752
|
+
|
753
|
+
|
754
|
+
@dataclass
|
755
|
+
class RpcReqInput:
|
756
|
+
method: str
|
757
|
+
parameters: Optional[Dict] = None
|
758
|
+
|
759
|
+
|
760
|
+
@dataclass
|
761
|
+
class RpcReqOutput:
|
762
|
+
success: bool
|
763
|
+
message: str
|
@@ -0,0 +1,373 @@
|
|
1
|
+
"""
|
2
|
+
Multimodality utils
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import abstractmethod
|
6
|
+
from typing import Callable, List, Optional, Tuple
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from torch import nn
|
10
|
+
|
11
|
+
from sglang.srt.managers.schedule_batch import (
|
12
|
+
MultimodalInputs,
|
13
|
+
global_server_args_dict,
|
14
|
+
logger,
|
15
|
+
)
|
16
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
17
|
+
from sglang.utils import logger
|
18
|
+
|
19
|
+
|
20
|
+
class MultiModalityDataPaddingPattern:
|
21
|
+
"""
|
22
|
+
Data tokens (like image tokens) often need special handling during padding
|
23
|
+
to maintain model compatibility. This class provides the interface for
|
24
|
+
implementing different padding strategies for data tokens
|
25
|
+
"""
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def pad_input_tokens(
|
29
|
+
self, input_ids: List[int], image_inputs: MultimodalInputs
|
30
|
+
) -> List[int]:
|
31
|
+
"""
|
32
|
+
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
33
|
+
"""
|
34
|
+
pass
|
35
|
+
|
36
|
+
|
37
|
+
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
38
|
+
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
39
|
+
|
40
|
+
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
44
|
+
self.data_token_id_pairs = data_token_pairs
|
45
|
+
|
46
|
+
def pad_input_tokens(
|
47
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
48
|
+
) -> List[int]:
|
49
|
+
"""
|
50
|
+
This function will replace the data-tokens inbetween with pad_values accordingly
|
51
|
+
"""
|
52
|
+
pad_values = mm_inputs.pad_values
|
53
|
+
data_token_pairs = self.data_token_id_pairs
|
54
|
+
mm_inputs.image_offsets = []
|
55
|
+
if data_token_pairs is None:
|
56
|
+
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
57
|
+
if data_token_pairs is None:
|
58
|
+
logger.warning(
|
59
|
+
"No data_token_pairs provided, RadixAttention might be influenced."
|
60
|
+
)
|
61
|
+
return input_ids
|
62
|
+
start_token_ids = [s for s, _e in data_token_pairs]
|
63
|
+
end_tokens_ids = [e for _s, e in data_token_pairs]
|
64
|
+
|
65
|
+
padded_ids = []
|
66
|
+
last_idx = 0
|
67
|
+
data_idx = -1
|
68
|
+
|
69
|
+
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
70
|
+
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
71
|
+
|
72
|
+
if len(start_indices) != len(end_indices):
|
73
|
+
return input_ids
|
74
|
+
|
75
|
+
for start_idx, end_idx in zip(start_indices, end_indices):
|
76
|
+
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
77
|
+
|
78
|
+
if input_ids[start_idx] in start_token_ids:
|
79
|
+
data_idx += 1
|
80
|
+
mm_inputs.image_offsets += [start_idx]
|
81
|
+
|
82
|
+
if data_idx >= len(mm_inputs.pad_values):
|
83
|
+
data_idx = len(mm_inputs.pad_values) - 1
|
84
|
+
|
85
|
+
num_tokens = end_idx - start_idx - 1
|
86
|
+
pad_value = pad_values[data_idx]
|
87
|
+
padded_ids.extend([pad_value] * num_tokens)
|
88
|
+
|
89
|
+
last_idx = end_idx
|
90
|
+
|
91
|
+
padded_ids.extend(input_ids[last_idx:])
|
92
|
+
|
93
|
+
assert len(input_ids) == len(padded_ids), "Length validation fails"
|
94
|
+
return padded_ids
|
95
|
+
|
96
|
+
|
97
|
+
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
98
|
+
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
99
|
+
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
100
|
+
|
101
|
+
This strategy should be used when a single data token represents content that should
|
102
|
+
be expanded to multiple tokens during processing.
|
103
|
+
"""
|
104
|
+
|
105
|
+
def __init__(
|
106
|
+
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
107
|
+
) -> None:
|
108
|
+
self.num_data_token_calc_func = num_data_token_calc_func
|
109
|
+
|
110
|
+
def pad_input_tokens(
|
111
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
112
|
+
) -> List[int]:
|
113
|
+
"""
|
114
|
+
This function will follow the procedure of:
|
115
|
+
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
116
|
+
2. the padded data tokens will be replaced with their pad_values
|
117
|
+
"""
|
118
|
+
image_grid_thws = mm_inputs.image_grid_thws
|
119
|
+
pad_values = mm_inputs.pad_values
|
120
|
+
|
121
|
+
image_indices = [
|
122
|
+
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
|
123
|
+
]
|
124
|
+
|
125
|
+
mm_inputs.image_offsets = []
|
126
|
+
|
127
|
+
input_ids_with_image = []
|
128
|
+
for image_cnt, _ in enumerate(image_grid_thws):
|
129
|
+
# print(f"image_cnt {image_cnt}")
|
130
|
+
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
131
|
+
if image_cnt == 0:
|
132
|
+
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
133
|
+
else:
|
134
|
+
non_image_tokens = input_ids[
|
135
|
+
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
136
|
+
]
|
137
|
+
input_ids_with_image.extend(non_image_tokens)
|
138
|
+
mm_inputs.image_offsets.append(len(input_ids_with_image))
|
139
|
+
pad_ids = pad_values * (
|
140
|
+
(num_image_tokens + len(pad_values)) // len(pad_values)
|
141
|
+
)
|
142
|
+
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
143
|
+
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
144
|
+
|
145
|
+
return input_ids_with_image
|
146
|
+
|
147
|
+
|
148
|
+
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
|
149
|
+
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
|
150
|
+
|
151
|
+
def __init__(self, image_token_id: torch.Tensor) -> None:
|
152
|
+
self.image_token_id = image_token_id
|
153
|
+
|
154
|
+
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
|
155
|
+
"""
|
156
|
+
This function will replace the data-tokens in between with pad_values accordingly
|
157
|
+
"""
|
158
|
+
pad_values = image_inputs.pad_values
|
159
|
+
assert len(pad_values) != 0
|
160
|
+
|
161
|
+
input_ids_tensor = torch.tensor(input_ids)
|
162
|
+
mask = torch.isin(input_ids_tensor, self.image_token_id)
|
163
|
+
|
164
|
+
num_image_tokens = mask.sum().item()
|
165
|
+
repeated_pad_values = torch.tensor(pad_values).repeat(
|
166
|
+
num_image_tokens // len(pad_values) + 1
|
167
|
+
)[:num_image_tokens]
|
168
|
+
|
169
|
+
input_ids_tensor[mask] = repeated_pad_values
|
170
|
+
return input_ids_tensor.tolist()
|
171
|
+
|
172
|
+
|
173
|
+
def embed_mm_inputs(
|
174
|
+
mm_input: MultimodalInputs,
|
175
|
+
input_ids: torch.Tensor,
|
176
|
+
input_embedding: nn.Embedding,
|
177
|
+
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
178
|
+
placeholder_token_ids: List[int] = None,
|
179
|
+
) -> Optional[torch.Tensor]:
|
180
|
+
"""
|
181
|
+
Calculate the image embeddings if necessary, then scatter the result with
|
182
|
+
the help of a boolean mask denoting the embed locations
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
final embedding: Optional[torch.Tensor]
|
186
|
+
"""
|
187
|
+
if mm_input is None:
|
188
|
+
return None
|
189
|
+
|
190
|
+
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
|
191
|
+
|
192
|
+
# boolean masking the special tokens
|
193
|
+
special_image_mask = torch.isin(
|
194
|
+
input_ids,
|
195
|
+
torch.tensor(placeholder_token_ids, device=input_ids.device),
|
196
|
+
).unsqueeze(-1)
|
197
|
+
|
198
|
+
num_image_tokens_in_input_ids = special_image_mask.sum()
|
199
|
+
# print(f"{num_image_tokens_in_input_ids}")
|
200
|
+
# print(f"{input_ids}")
|
201
|
+
|
202
|
+
# return
|
203
|
+
if num_image_tokens_in_input_ids == 0:
|
204
|
+
# unexpected
|
205
|
+
inputs_embeds = input_embedding(input_ids)
|
206
|
+
else:
|
207
|
+
# print(f"Getting image feature")
|
208
|
+
image_embedding = mm_data_embedding_func(mm_input)
|
209
|
+
|
210
|
+
# print(f"image_embedding: {image_embedding.shape}")
|
211
|
+
|
212
|
+
if image_embedding.dim() == 2:
|
213
|
+
num_image_tokens_in_embedding = image_embedding.shape[0]
|
214
|
+
else:
|
215
|
+
num_image_tokens_in_embedding = (
|
216
|
+
image_embedding.shape[0] * image_embedding.shape[1]
|
217
|
+
)
|
218
|
+
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
219
|
+
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
|
220
|
+
image_embedding = image_embedding[:num_image, :]
|
221
|
+
logger.warning(
|
222
|
+
f"Number of images does not match number of special image tokens in the input text. "
|
223
|
+
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
224
|
+
"tokens from image embeddings."
|
225
|
+
)
|
226
|
+
|
227
|
+
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
228
|
+
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
|
229
|
+
# extend_start_loc and extend_seq_lens
|
230
|
+
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
|
231
|
+
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
232
|
+
if chunked_prefill_size != -1:
|
233
|
+
logger.warning(
|
234
|
+
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
|
235
|
+
)
|
236
|
+
|
237
|
+
vocab_size = input_embedding.num_embeddings
|
238
|
+
# Important: clamp after getting original image regions
|
239
|
+
# Clamp input ids. This is because the input_ids for the image tokens are
|
240
|
+
# filled with the hash values of the image for the prefix matching in the radix attention.
|
241
|
+
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
242
|
+
input_ids.clamp_(min=0, max=vocab_size - 1)
|
243
|
+
inputs_embeds = input_embedding(input_ids)
|
244
|
+
|
245
|
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
246
|
+
inputs_embeds.device
|
247
|
+
)
|
248
|
+
|
249
|
+
inputs_embeds = inputs_embeds.masked_scatter(
|
250
|
+
special_image_mask,
|
251
|
+
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
252
|
+
)
|
253
|
+
return inputs_embeds
|
254
|
+
|
255
|
+
|
256
|
+
def embed_image_embedding(
|
257
|
+
inputs_embeds: torch.Tensor,
|
258
|
+
image_embedding: torch.Tensor,
|
259
|
+
image_bounds: torch.Tensor,
|
260
|
+
) -> torch.Tensor:
|
261
|
+
"""
|
262
|
+
scatter image_embedding into inputs_embeds according to image_bounds
|
263
|
+
"""
|
264
|
+
if len(image_bounds) > 0:
|
265
|
+
image_indices = torch.stack(
|
266
|
+
[
|
267
|
+
torch.arange(start, end, dtype=torch.long)
|
268
|
+
for start, end in image_bounds.tolist()
|
269
|
+
]
|
270
|
+
).to(inputs_embeds.device)
|
271
|
+
|
272
|
+
inputs_embeds.scatter_(
|
273
|
+
0,
|
274
|
+
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
|
275
|
+
image_embedding.view(-1, image_embedding.shape[-1]),
|
276
|
+
)
|
277
|
+
return inputs_embeds
|
278
|
+
|
279
|
+
|
280
|
+
def general_mm_embed_routine(
|
281
|
+
input_ids: torch.Tensor,
|
282
|
+
forward_batch: ForwardBatch,
|
283
|
+
embed_tokens: nn.Embedding,
|
284
|
+
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
285
|
+
placeholder_token_ids: List[int] = None,
|
286
|
+
):
|
287
|
+
"""
|
288
|
+
a general wrapper function to get final input embeds from multimodal models
|
289
|
+
with a language model as causal model
|
290
|
+
|
291
|
+
Args:
|
292
|
+
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
293
|
+
|
294
|
+
"""
|
295
|
+
if (
|
296
|
+
not forward_batch.forward_mode.is_decode()
|
297
|
+
and forward_batch.contains_mm_inputs()
|
298
|
+
):
|
299
|
+
image = forward_batch.merge_mm_inputs()
|
300
|
+
inputs_embeds = embed_mm_inputs(
|
301
|
+
mm_input=image,
|
302
|
+
input_ids=input_ids,
|
303
|
+
input_embedding=embed_tokens,
|
304
|
+
mm_data_embedding_func=mm_data_embedding_func,
|
305
|
+
placeholder_token_ids=placeholder_token_ids,
|
306
|
+
)
|
307
|
+
# once used, mm_inputs is useless
|
308
|
+
# just being defensive here
|
309
|
+
forward_batch.mm_inputs = None
|
310
|
+
else:
|
311
|
+
inputs_embeds = embed_tokens(input_ids)
|
312
|
+
|
313
|
+
return inputs_embeds
|
314
|
+
|
315
|
+
|
316
|
+
def get_multimodal_data_bounds(
|
317
|
+
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
|
318
|
+
) -> torch.Tensor:
|
319
|
+
"""
|
320
|
+
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
[bounds_count, 2]
|
324
|
+
"""
|
325
|
+
# All the images in the batch should share the same special image
|
326
|
+
# bound token ids.
|
327
|
+
start_tokens = [s for s, _e in token_pairs]
|
328
|
+
end_tokens = [e for _s, e in token_pairs]
|
329
|
+
|
330
|
+
assert all(isinstance(t, int) for t in start_tokens)
|
331
|
+
assert all(isinstance(t, int) for t in end_tokens)
|
332
|
+
|
333
|
+
# print(input_ids)
|
334
|
+
start_cond = torch.isin(
|
335
|
+
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
336
|
+
)
|
337
|
+
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
338
|
+
|
339
|
+
(data_start_tokens,) = torch.where(start_cond)
|
340
|
+
(data_end_tokens,) = torch.where(end_cond)
|
341
|
+
|
342
|
+
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
|
343
|
+
if len(data_start_tokens) != len(data_end_tokens):
|
344
|
+
if (
|
345
|
+
len(data_start_tokens) + 1 == len(data_end_tokens)
|
346
|
+
and input_ids[0] in pad_values
|
347
|
+
and data_end_tokens[0] < data_start_tokens[0]
|
348
|
+
):
|
349
|
+
data_start_tokens = torch.cat(
|
350
|
+
[
|
351
|
+
torch.tensor([0], device=data_start_tokens.device),
|
352
|
+
data_start_tokens,
|
353
|
+
]
|
354
|
+
)
|
355
|
+
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
|
356
|
+
|
357
|
+
if valid_image_nums == 0:
|
358
|
+
return torch.zeros((0, 2), device=input_ids.device)
|
359
|
+
|
360
|
+
# Filter out pairs where start_token >= end_token
|
361
|
+
valid_pairs = []
|
362
|
+
for i in range(valid_image_nums):
|
363
|
+
start_token = data_start_tokens[i]
|
364
|
+
end_token = data_end_tokens[i]
|
365
|
+
if start_token < end_token:
|
366
|
+
valid_pairs.append((start_token + 1, end_token - 1))
|
367
|
+
|
368
|
+
if not valid_pairs:
|
369
|
+
return torch.zeros((0, 2), device=input_ids.device)
|
370
|
+
|
371
|
+
# Convert valid pairs to tensor
|
372
|
+
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
373
|
+
return valid_pairs_tensor
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# TODO: also move pad_input_ids into this module
|
2
|
+
import importlib
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
import pkgutil
|
6
|
+
from functools import lru_cache
|
7
|
+
|
8
|
+
from transformers import PROCESSOR_MAPPING
|
9
|
+
|
10
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
11
|
+
BaseMultimodalProcessor,
|
12
|
+
)
|
13
|
+
from sglang.srt.server_args import ServerArgs
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
PROCESSOR_MAPPING = {}
|
18
|
+
|
19
|
+
|
20
|
+
class DummyMultimodalProcessor(BaseMultimodalProcessor):
|
21
|
+
def __init__(self):
|
22
|
+
pass
|
23
|
+
|
24
|
+
async def process_mm_data_async(self, *args, **kwargs):
|
25
|
+
return None
|
26
|
+
|
27
|
+
|
28
|
+
def get_dummy_processor():
|
29
|
+
return DummyMultimodalProcessor()
|
30
|
+
|
31
|
+
|
32
|
+
@lru_cache()
|
33
|
+
def import_processors():
|
34
|
+
package_name = "sglang.srt.managers.multimodal_processors"
|
35
|
+
package = importlib.import_module(package_name)
|
36
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
37
|
+
if not ispkg:
|
38
|
+
try:
|
39
|
+
module = importlib.import_module(name)
|
40
|
+
except Exception as e:
|
41
|
+
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
|
42
|
+
continue
|
43
|
+
all_members = inspect.getmembers(module, inspect.isclass)
|
44
|
+
classes = [
|
45
|
+
member
|
46
|
+
for name, member in all_members
|
47
|
+
if member.__module__ == module.__name__
|
48
|
+
]
|
49
|
+
for cls in (
|
50
|
+
cls for cls in classes if issubclass(cls, BaseMultimodalProcessor)
|
51
|
+
):
|
52
|
+
assert hasattr(cls, "models")
|
53
|
+
for arch in getattr(cls, "models"):
|
54
|
+
PROCESSOR_MAPPING[arch] = cls
|
55
|
+
|
56
|
+
|
57
|
+
def get_mm_processor(
|
58
|
+
hf_config, server_args: ServerArgs, processor
|
59
|
+
) -> BaseMultimodalProcessor:
|
60
|
+
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
61
|
+
if model_cls.__name__ in hf_config.architectures:
|
62
|
+
return processor_cls(hf_config, server_args, processor)
|
63
|
+
raise ValueError(
|
64
|
+
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
65
|
+
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
66
|
+
)
|
67
|
+
|
68
|
+
self.image_proce
|