sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
|
|
1
|
+
"""
|
2
|
+
Multimodality utils
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import abstractmethod
|
6
|
+
from typing import Callable, List, Optional, Tuple
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from torch import nn
|
10
|
+
|
11
|
+
from sglang.srt.managers.schedule_batch import (
|
12
|
+
MultimodalInputs,
|
13
|
+
global_server_args_dict,
|
14
|
+
logger,
|
15
|
+
)
|
16
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
17
|
+
from sglang.utils import logger
|
18
|
+
|
19
|
+
|
20
|
+
class MultiModalityDataPaddingPattern:
|
21
|
+
"""
|
22
|
+
Data tokens (like image tokens) often need special handling during padding
|
23
|
+
to maintain model compatibility. This class provides the interface for
|
24
|
+
implementing different padding strategies for data tokens
|
25
|
+
"""
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def pad_input_tokens(
|
29
|
+
self, input_ids: List[int], image_inputs: MultimodalInputs
|
30
|
+
) -> List[int]:
|
31
|
+
"""
|
32
|
+
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
33
|
+
"""
|
34
|
+
pass
|
35
|
+
|
36
|
+
|
37
|
+
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
38
|
+
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
39
|
+
|
40
|
+
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
44
|
+
self.data_token_id_pairs = data_token_pairs
|
45
|
+
|
46
|
+
def pad_input_tokens(
|
47
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
48
|
+
) -> List[int]:
|
49
|
+
"""
|
50
|
+
This function will replace the data-tokens inbetween with pad_values accordingly
|
51
|
+
"""
|
52
|
+
pad_values = mm_inputs.pad_values
|
53
|
+
data_token_pairs = self.data_token_id_pairs
|
54
|
+
mm_inputs.image_offsets = []
|
55
|
+
if data_token_pairs is None:
|
56
|
+
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
57
|
+
if data_token_pairs is None:
|
58
|
+
logger.warning(
|
59
|
+
"No data_token_pairs provided, RadixAttention might be influenced."
|
60
|
+
)
|
61
|
+
return input_ids
|
62
|
+
start_token_ids = [s for s, _e in data_token_pairs]
|
63
|
+
end_tokens_ids = [e for _s, e in data_token_pairs]
|
64
|
+
|
65
|
+
padded_ids = []
|
66
|
+
last_idx = 0
|
67
|
+
data_idx = -1
|
68
|
+
|
69
|
+
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
70
|
+
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
71
|
+
|
72
|
+
if len(start_indices) != len(end_indices):
|
73
|
+
return input_ids
|
74
|
+
|
75
|
+
for start_idx, end_idx in zip(start_indices, end_indices):
|
76
|
+
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
77
|
+
|
78
|
+
if input_ids[start_idx] in start_token_ids:
|
79
|
+
data_idx += 1
|
80
|
+
mm_inputs.image_offsets += [start_idx]
|
81
|
+
|
82
|
+
if data_idx >= len(mm_inputs.pad_values):
|
83
|
+
data_idx = len(mm_inputs.pad_values) - 1
|
84
|
+
|
85
|
+
num_tokens = end_idx - start_idx - 1
|
86
|
+
pad_value = pad_values[data_idx]
|
87
|
+
padded_ids.extend([pad_value] * num_tokens)
|
88
|
+
|
89
|
+
last_idx = end_idx
|
90
|
+
|
91
|
+
padded_ids.extend(input_ids[last_idx:])
|
92
|
+
|
93
|
+
assert len(input_ids) == len(padded_ids), "Length validation fails"
|
94
|
+
return padded_ids
|
95
|
+
|
96
|
+
|
97
|
+
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
98
|
+
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
99
|
+
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
100
|
+
|
101
|
+
This strategy should be used when a single data token represents content that should
|
102
|
+
be expanded to multiple tokens during processing.
|
103
|
+
"""
|
104
|
+
|
105
|
+
def __init__(
|
106
|
+
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
107
|
+
) -> None:
|
108
|
+
self.num_data_token_calc_func = num_data_token_calc_func
|
109
|
+
|
110
|
+
def pad_input_tokens(
|
111
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
112
|
+
) -> List[int]:
|
113
|
+
"""
|
114
|
+
This function will follow the procedure of:
|
115
|
+
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
116
|
+
2. the padded data tokens will be replaced with their pad_values
|
117
|
+
"""
|
118
|
+
image_grid_thws = mm_inputs.image_grid_thws
|
119
|
+
pad_values = mm_inputs.pad_values
|
120
|
+
|
121
|
+
image_indices = [
|
122
|
+
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
|
123
|
+
]
|
124
|
+
|
125
|
+
mm_inputs.image_offsets = []
|
126
|
+
|
127
|
+
input_ids_with_image = []
|
128
|
+
for image_cnt, _ in enumerate(image_grid_thws):
|
129
|
+
# print(f"image_cnt {image_cnt}")
|
130
|
+
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
131
|
+
if image_cnt == 0:
|
132
|
+
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
133
|
+
else:
|
134
|
+
non_image_tokens = input_ids[
|
135
|
+
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
136
|
+
]
|
137
|
+
input_ids_with_image.extend(non_image_tokens)
|
138
|
+
mm_inputs.image_offsets.append(len(input_ids_with_image))
|
139
|
+
pad_ids = pad_values * (
|
140
|
+
(num_image_tokens + len(pad_values)) // len(pad_values)
|
141
|
+
)
|
142
|
+
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
143
|
+
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
144
|
+
|
145
|
+
return input_ids_with_image
|
146
|
+
|
147
|
+
|
148
|
+
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
|
149
|
+
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
|
150
|
+
|
151
|
+
def __init__(self, image_token_id: torch.Tensor) -> None:
|
152
|
+
self.image_token_id = image_token_id
|
153
|
+
|
154
|
+
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
|
155
|
+
"""
|
156
|
+
This function will replace the data-tokens in between with pad_values accordingly
|
157
|
+
"""
|
158
|
+
pad_values = image_inputs.pad_values
|
159
|
+
assert len(pad_values) != 0
|
160
|
+
|
161
|
+
input_ids_tensor = torch.tensor(input_ids)
|
162
|
+
mask = torch.isin(input_ids_tensor, self.image_token_id)
|
163
|
+
|
164
|
+
num_image_tokens = mask.sum().item()
|
165
|
+
repeated_pad_values = torch.tensor(pad_values).repeat(
|
166
|
+
num_image_tokens // len(pad_values) + 1
|
167
|
+
)[:num_image_tokens]
|
168
|
+
|
169
|
+
input_ids_tensor[mask] = repeated_pad_values
|
170
|
+
return input_ids_tensor.tolist()
|
171
|
+
|
172
|
+
|
173
|
+
def embed_mm_inputs(
|
174
|
+
mm_input: MultimodalInputs,
|
175
|
+
input_ids: torch.Tensor,
|
176
|
+
input_embedding: nn.Embedding,
|
177
|
+
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
178
|
+
placeholder_token_ids: List[int] = None,
|
179
|
+
) -> Optional[torch.Tensor]:
|
180
|
+
"""
|
181
|
+
Calculate the image embeddings if necessary, then scatter the result with
|
182
|
+
the help of a boolean mask denoting the embed locations
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
final embedding: Optional[torch.Tensor]
|
186
|
+
"""
|
187
|
+
if mm_input is None:
|
188
|
+
return None
|
189
|
+
|
190
|
+
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
|
191
|
+
|
192
|
+
# boolean masking the special tokens
|
193
|
+
special_image_mask = torch.isin(
|
194
|
+
input_ids,
|
195
|
+
torch.tensor(placeholder_token_ids, device=input_ids.device),
|
196
|
+
).unsqueeze(-1)
|
197
|
+
|
198
|
+
num_image_tokens_in_input_ids = special_image_mask.sum()
|
199
|
+
# print(f"{num_image_tokens_in_input_ids}")
|
200
|
+
# print(f"{input_ids}")
|
201
|
+
|
202
|
+
# return
|
203
|
+
if num_image_tokens_in_input_ids == 0:
|
204
|
+
# unexpected
|
205
|
+
inputs_embeds = input_embedding(input_ids)
|
206
|
+
else:
|
207
|
+
# print(f"Getting image feature")
|
208
|
+
image_embedding = mm_data_embedding_func(mm_input)
|
209
|
+
|
210
|
+
# print(f"image_embedding: {image_embedding.shape}")
|
211
|
+
|
212
|
+
if image_embedding.dim() == 2:
|
213
|
+
num_image_tokens_in_embedding = image_embedding.shape[0]
|
214
|
+
else:
|
215
|
+
num_image_tokens_in_embedding = (
|
216
|
+
image_embedding.shape[0] * image_embedding.shape[1]
|
217
|
+
)
|
218
|
+
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
219
|
+
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
|
220
|
+
image_embedding = image_embedding[:num_image, :]
|
221
|
+
logger.warning(
|
222
|
+
f"Number of images does not match number of special image tokens in the input text. "
|
223
|
+
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
224
|
+
"tokens from image embeddings."
|
225
|
+
)
|
226
|
+
|
227
|
+
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
228
|
+
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
|
229
|
+
# extend_start_loc and extend_seq_lens
|
230
|
+
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
|
231
|
+
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
232
|
+
if chunked_prefill_size != -1:
|
233
|
+
logger.warning(
|
234
|
+
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
|
235
|
+
)
|
236
|
+
|
237
|
+
vocab_size = input_embedding.num_embeddings
|
238
|
+
# Important: clamp after getting original image regions
|
239
|
+
# Clamp input ids. This is because the input_ids for the image tokens are
|
240
|
+
# filled with the hash values of the image for the prefix matching in the radix attention.
|
241
|
+
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
242
|
+
input_ids.clamp_(min=0, max=vocab_size - 1)
|
243
|
+
inputs_embeds = input_embedding(input_ids)
|
244
|
+
|
245
|
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
246
|
+
inputs_embeds.device
|
247
|
+
)
|
248
|
+
|
249
|
+
inputs_embeds = inputs_embeds.masked_scatter(
|
250
|
+
special_image_mask,
|
251
|
+
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
252
|
+
)
|
253
|
+
return inputs_embeds
|
254
|
+
|
255
|
+
|
256
|
+
def embed_image_embedding(
|
257
|
+
inputs_embeds: torch.Tensor,
|
258
|
+
image_embedding: torch.Tensor,
|
259
|
+
image_bounds: torch.Tensor,
|
260
|
+
) -> torch.Tensor:
|
261
|
+
"""
|
262
|
+
scatter image_embedding into inputs_embeds according to image_bounds
|
263
|
+
"""
|
264
|
+
if len(image_bounds) > 0:
|
265
|
+
image_indices = torch.stack(
|
266
|
+
[
|
267
|
+
torch.arange(start, end, dtype=torch.long)
|
268
|
+
for start, end in image_bounds.tolist()
|
269
|
+
]
|
270
|
+
).to(inputs_embeds.device)
|
271
|
+
|
272
|
+
inputs_embeds.scatter_(
|
273
|
+
0,
|
274
|
+
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
|
275
|
+
image_embedding.view(-1, image_embedding.shape[-1]),
|
276
|
+
)
|
277
|
+
return inputs_embeds
|
278
|
+
|
279
|
+
|
280
|
+
def general_mm_embed_routine(
|
281
|
+
input_ids: torch.Tensor,
|
282
|
+
forward_batch: ForwardBatch,
|
283
|
+
embed_tokens: nn.Embedding,
|
284
|
+
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
285
|
+
placeholder_token_ids: List[int] = None,
|
286
|
+
):
|
287
|
+
"""
|
288
|
+
a general wrapper function to get final input embeds from multimodal models
|
289
|
+
with a language model as causal model
|
290
|
+
|
291
|
+
Args:
|
292
|
+
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
293
|
+
|
294
|
+
"""
|
295
|
+
if (
|
296
|
+
not forward_batch.forward_mode.is_decode()
|
297
|
+
and forward_batch.contains_mm_inputs()
|
298
|
+
):
|
299
|
+
image = forward_batch.merge_mm_inputs()
|
300
|
+
inputs_embeds = embed_mm_inputs(
|
301
|
+
mm_input=image,
|
302
|
+
input_ids=input_ids,
|
303
|
+
input_embedding=embed_tokens,
|
304
|
+
mm_data_embedding_func=mm_data_embedding_func,
|
305
|
+
placeholder_token_ids=placeholder_token_ids,
|
306
|
+
)
|
307
|
+
# once used, mm_inputs is useless
|
308
|
+
# just being defensive here
|
309
|
+
forward_batch.mm_inputs = None
|
310
|
+
else:
|
311
|
+
inputs_embeds = embed_tokens(input_ids)
|
312
|
+
|
313
|
+
return inputs_embeds
|
314
|
+
|
315
|
+
|
316
|
+
def get_multimodal_data_bounds(
|
317
|
+
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
|
318
|
+
) -> torch.Tensor:
|
319
|
+
"""
|
320
|
+
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
[bounds_count, 2]
|
324
|
+
"""
|
325
|
+
# All the images in the batch should share the same special image
|
326
|
+
# bound token ids.
|
327
|
+
start_tokens = [s for s, _e in token_pairs]
|
328
|
+
end_tokens = [e for _s, e in token_pairs]
|
329
|
+
|
330
|
+
assert all(isinstance(t, int) for t in start_tokens)
|
331
|
+
assert all(isinstance(t, int) for t in end_tokens)
|
332
|
+
|
333
|
+
# print(input_ids)
|
334
|
+
start_cond = torch.isin(
|
335
|
+
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
336
|
+
)
|
337
|
+
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
338
|
+
|
339
|
+
(data_start_tokens,) = torch.where(start_cond)
|
340
|
+
(data_end_tokens,) = torch.where(end_cond)
|
341
|
+
|
342
|
+
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
|
343
|
+
if len(data_start_tokens) != len(data_end_tokens):
|
344
|
+
if (
|
345
|
+
len(data_start_tokens) + 1 == len(data_end_tokens)
|
346
|
+
and input_ids[0] in pad_values
|
347
|
+
and data_end_tokens[0] < data_start_tokens[0]
|
348
|
+
):
|
349
|
+
data_start_tokens = torch.cat(
|
350
|
+
[
|
351
|
+
torch.tensor([0], device=data_start_tokens.device),
|
352
|
+
data_start_tokens,
|
353
|
+
]
|
354
|
+
)
|
355
|
+
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
|
356
|
+
|
357
|
+
if valid_image_nums == 0:
|
358
|
+
return torch.zeros((0, 2), device=input_ids.device)
|
359
|
+
|
360
|
+
# Filter out pairs where start_token >= end_token
|
361
|
+
valid_pairs = []
|
362
|
+
for i in range(valid_image_nums):
|
363
|
+
start_token = data_start_tokens[i]
|
364
|
+
end_token = data_end_tokens[i]
|
365
|
+
if start_token < end_token:
|
366
|
+
valid_pairs.append((start_token + 1, end_token - 1))
|
367
|
+
|
368
|
+
if not valid_pairs:
|
369
|
+
return torch.zeros((0, 2), device=input_ids.device)
|
370
|
+
|
371
|
+
# Convert valid pairs to tensor
|
372
|
+
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
373
|
+
return valid_pairs_tensor
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# TODO: also move pad_input_ids into this module
|
2
|
+
import importlib
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
import pkgutil
|
6
|
+
from functools import lru_cache
|
7
|
+
|
8
|
+
from transformers import PROCESSOR_MAPPING
|
9
|
+
|
10
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
11
|
+
BaseMultimodalProcessor,
|
12
|
+
)
|
13
|
+
from sglang.srt.server_args import ServerArgs
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
PROCESSOR_MAPPING = {}
|
18
|
+
|
19
|
+
|
20
|
+
class DummyMultimodalProcessor(BaseMultimodalProcessor):
|
21
|
+
def __init__(self):
|
22
|
+
pass
|
23
|
+
|
24
|
+
async def process_mm_data_async(self, *args, **kwargs):
|
25
|
+
return None
|
26
|
+
|
27
|
+
|
28
|
+
def get_dummy_processor():
|
29
|
+
return DummyMultimodalProcessor()
|
30
|
+
|
31
|
+
|
32
|
+
@lru_cache()
|
33
|
+
def import_processors():
|
34
|
+
package_name = "sglang.srt.managers.multimodal_processors"
|
35
|
+
package = importlib.import_module(package_name)
|
36
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
37
|
+
if not ispkg:
|
38
|
+
try:
|
39
|
+
module = importlib.import_module(name)
|
40
|
+
except Exception as e:
|
41
|
+
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
|
42
|
+
continue
|
43
|
+
all_members = inspect.getmembers(module, inspect.isclass)
|
44
|
+
classes = [
|
45
|
+
member
|
46
|
+
for name, member in all_members
|
47
|
+
if member.__module__ == module.__name__
|
48
|
+
]
|
49
|
+
for cls in (
|
50
|
+
cls for cls in classes if issubclass(cls, BaseMultimodalProcessor)
|
51
|
+
):
|
52
|
+
assert hasattr(cls, "models")
|
53
|
+
for arch in getattr(cls, "models"):
|
54
|
+
PROCESSOR_MAPPING[arch] = cls
|
55
|
+
|
56
|
+
|
57
|
+
def get_mm_processor(
|
58
|
+
hf_config, server_args: ServerArgs, processor
|
59
|
+
) -> BaseMultimodalProcessor:
|
60
|
+
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
61
|
+
if model_cls.__name__ in hf_config.architectures:
|
62
|
+
return processor_cls(hf_config, server_args, processor)
|
63
|
+
raise ValueError(
|
64
|
+
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
65
|
+
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
66
|
+
)
|
67
|
+
|
68
|
+
self.image_proce
|