sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,19 +20,18 @@ import copy
|
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
24
24
|
|
25
|
-
from sglang.srt.
|
25
|
+
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
|
+
from sglang.srt.multimodal.mm_utils import has_valid_data
|
27
|
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
26
28
|
|
27
|
-
#
|
29
|
+
# Handle serialization of Image for pydantic
|
28
30
|
if TYPE_CHECKING:
|
29
31
|
from PIL.Image import Image
|
30
32
|
else:
|
31
33
|
Image = Any
|
32
34
|
|
33
|
-
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
34
|
-
from sglang.srt.sampling.sampling_params import SamplingParams
|
35
|
-
|
36
35
|
|
37
36
|
@dataclass
|
38
37
|
class SessionParams:
|
@@ -40,6 +39,7 @@ class SessionParams:
|
|
40
39
|
rid: Optional[str] = None
|
41
40
|
offset: Optional[int] = None
|
42
41
|
replace: Optional[bool] = None
|
42
|
+
drop_previous_output: Optional[bool] = None
|
43
43
|
|
44
44
|
|
45
45
|
AudioDataItem = Union[str, Dict]
|
@@ -182,6 +182,7 @@ class GenerateReqInput:
|
|
182
182
|
# Determine parallel sample count
|
183
183
|
if self.sampling_params is None:
|
184
184
|
self.parallel_sample_num = 1
|
185
|
+
return
|
185
186
|
elif isinstance(self.sampling_params, dict):
|
186
187
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
187
188
|
else: # isinstance(self.sampling_params, list):
|
@@ -319,8 +320,16 @@ class GenerateReqInput:
|
|
319
320
|
"""Normalize request IDs for batch processing."""
|
320
321
|
if self.rid is None:
|
321
322
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
322
|
-
elif
|
323
|
-
|
323
|
+
elif isinstance(self.rid, str):
|
324
|
+
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
325
|
+
self.rid = new_rids
|
326
|
+
elif isinstance(self.rid, list):
|
327
|
+
if len(self.rid) != num:
|
328
|
+
raise ValueError(
|
329
|
+
"The specified rids length mismatch with the batch_size for batch processing."
|
330
|
+
)
|
331
|
+
else:
|
332
|
+
raise ValueError("The rid should be a string or a list of strings.")
|
324
333
|
|
325
334
|
def _normalize_logprob_params(self, num):
|
326
335
|
"""Normalize logprob-related parameters for batch processing."""
|
@@ -508,9 +517,6 @@ class EmbeddingReqInput:
|
|
508
517
|
# For cross-encoder requests
|
509
518
|
is_cross_encoder_request: bool = False
|
510
519
|
|
511
|
-
def contains_mm_input(self) -> bool:
|
512
|
-
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
513
|
-
|
514
520
|
def normalize_batch_and_arguments(self):
|
515
521
|
# at least one of text, input_ids, or image should be provided
|
516
522
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
@@ -564,6 +570,9 @@ class EmbeddingReqInput:
|
|
564
570
|
self.rid = uuid.uuid4().hex
|
565
571
|
return self.rid
|
566
572
|
|
573
|
+
def contains_mm_input(self) -> bool:
|
574
|
+
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
575
|
+
|
567
576
|
def __getitem__(self, i):
|
568
577
|
if self.is_cross_encoder_request:
|
569
578
|
return EmbeddingReqInput(
|
@@ -732,6 +741,8 @@ class UpdateWeightFromDiskReqInput:
|
|
732
741
|
model_path: str
|
733
742
|
# The format to load the weights
|
734
743
|
load_format: Optional[str] = None
|
744
|
+
# Whether to abort all requests before updating weights
|
745
|
+
abort_all_requests: bool = False
|
735
746
|
|
736
747
|
|
737
748
|
@dataclass
|
@@ -744,9 +755,15 @@ class UpdateWeightFromDiskReqOutput:
|
|
744
755
|
|
745
756
|
@dataclass
|
746
757
|
class UpdateWeightsFromDistributedReqInput:
|
747
|
-
|
748
|
-
|
749
|
-
|
758
|
+
names: List[str]
|
759
|
+
dtypes: List[str]
|
760
|
+
shapes: List[List[int]]
|
761
|
+
# The group name
|
762
|
+
group_name: str = "weight_update_group"
|
763
|
+
# Whether to flush the cache after updating weights
|
764
|
+
flush_cache: bool = True
|
765
|
+
# Whether to abort all requests before updating weights
|
766
|
+
abort_all_requests: bool = False
|
750
767
|
|
751
768
|
|
752
769
|
@dataclass
|
@@ -768,6 +785,8 @@ class UpdateWeightsFromTensorReqInput:
|
|
768
785
|
load_format: Optional[str] = None
|
769
786
|
# Whether to flush the cache after updating weights
|
770
787
|
flush_cache: bool = True
|
788
|
+
# Whether to abort all requests before updating weights
|
789
|
+
abort_all_requests: bool = False
|
771
790
|
|
772
791
|
|
773
792
|
@dataclass
|
@@ -846,7 +865,9 @@ class SlowDownReqOutput:
|
|
846
865
|
@dataclass
|
847
866
|
class AbortReq:
|
848
867
|
# The request id
|
849
|
-
rid: str
|
868
|
+
rid: str = ""
|
869
|
+
# Whether to abort all requests
|
870
|
+
abort_all: bool = False
|
850
871
|
|
851
872
|
|
852
873
|
@dataclass
|
@@ -994,3 +1015,27 @@ class RpcReqInput:
|
|
994
1015
|
class RpcReqOutput:
|
995
1016
|
success: bool
|
996
1017
|
message: str
|
1018
|
+
|
1019
|
+
|
1020
|
+
@dataclass
|
1021
|
+
class LoadLoRAAdapterReqInput:
|
1022
|
+
# The name of the lora module to newly loaded.
|
1023
|
+
lora_name: str
|
1024
|
+
# The path of loading.
|
1025
|
+
lora_path: str
|
1026
|
+
|
1027
|
+
|
1028
|
+
@dataclass
|
1029
|
+
class UnloadLoRAAdapterReqInput:
|
1030
|
+
# The name of lora module to unload.
|
1031
|
+
lora_name: str
|
1032
|
+
|
1033
|
+
|
1034
|
+
@dataclass
|
1035
|
+
class LoRAUpdateResult:
|
1036
|
+
success: bool
|
1037
|
+
error_message: Optional[str] = None
|
1038
|
+
loaded_adapters: Dict[str, str] = field(default_factory=dict)
|
1039
|
+
|
1040
|
+
|
1041
|
+
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -2,14 +2,15 @@
|
|
2
2
|
Multi-modality utils
|
3
3
|
"""
|
4
4
|
|
5
|
-
import
|
6
|
-
import logging
|
5
|
+
import hashlib
|
7
6
|
from abc import abstractmethod
|
8
7
|
from typing import Callable, List, Optional, Tuple
|
9
8
|
|
9
|
+
import numpy as np
|
10
10
|
import torch
|
11
11
|
from torch import nn
|
12
12
|
|
13
|
+
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
13
14
|
from sglang.srt.managers.schedule_batch import (
|
14
15
|
Modality,
|
15
16
|
MultimodalDataItem,
|
@@ -124,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
124
125
|
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
125
126
|
"""
|
126
127
|
|
127
|
-
def __init__(self, token_ids: List[int]) -> None:
|
128
|
-
self.token_ids = token_ids
|
129
|
-
|
130
128
|
def pad_input_tokens(
|
131
129
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
132
130
|
) -> List[int]:
|
133
131
|
"""
|
134
|
-
|
135
|
-
|
132
|
+
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
|
133
|
+
Each modality (image, audio, video) is handled separately based on its token_id.
|
136
134
|
"""
|
137
|
-
|
138
|
-
if not pad_values:
|
139
|
-
# No multimodal items, return original input_ids
|
135
|
+
if not input_ids or not mm_inputs.mm_items:
|
140
136
|
return input_ids
|
141
|
-
if not input_ids:
|
142
|
-
return []
|
143
137
|
|
144
138
|
input_ids_tensor = torch.tensor(input_ids)
|
145
|
-
device = input_ids_tensor.device
|
146
|
-
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
147
|
-
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
148
139
|
|
149
|
-
|
150
|
-
|
151
|
-
return input_ids
|
140
|
+
# Create mapping of token_ids to pad_values for each modality
|
141
|
+
token_to_pad_mapping = {}
|
152
142
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
)
|
161
|
-
# Find indices where the mask value changes
|
162
|
-
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
163
|
-
|
164
|
-
# Start indices are where False changes to True
|
165
|
-
starts = diff_indices[::2]
|
166
|
-
# End indices are where True changes to False (exclusive index)
|
167
|
-
ends = diff_indices[1::2]
|
168
|
-
|
169
|
-
# Check if the number of regions matches the number of pad values
|
170
|
-
if len(starts) != len(pad_values):
|
171
|
-
# Maybe log a warning here?
|
172
|
-
num_regions = len(starts)
|
173
|
-
num_pad_values = len(pad_values)
|
174
|
-
if num_regions > 0 and num_pad_values > 0:
|
175
|
-
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
176
|
-
:num_regions
|
177
|
-
]
|
178
|
-
else: # If no regions or no pad_values, this loop won't run anyway.
|
179
|
-
pad_values = [] # Ensure pad_values is empty if starts is empty
|
180
|
-
|
181
|
-
# Create a copy to modify
|
182
|
-
output_ids_tensor = input_ids_tensor.clone()
|
183
|
-
|
184
|
-
# Replace tokens in each region with the corresponding pad value
|
185
|
-
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
186
|
-
for i in range(min(len(starts), len(pad_values))):
|
187
|
-
start_idx = starts[i]
|
188
|
-
end_idx = ends[i]
|
189
|
-
pad_value = pad_values[i]
|
190
|
-
if pad_value is not None: # Ensure pad_value is not None before assignment
|
191
|
-
output_ids_tensor[start_idx:end_idx] = pad_value
|
143
|
+
for item in mm_inputs.mm_items:
|
144
|
+
if item.is_image() and mm_inputs.im_token_id is not None:
|
145
|
+
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
|
146
|
+
elif item.is_audio() and mm_inputs.audio_token_id is not None:
|
147
|
+
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
|
148
|
+
elif item.is_video() and mm_inputs.video_token_id is not None:
|
149
|
+
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
|
192
150
|
else:
|
193
|
-
|
194
|
-
|
151
|
+
raise ValueError(f"No multimodal token id provided for {item.modality}")
|
152
|
+
|
153
|
+
# Apply replacements for all tokens at once
|
154
|
+
for token_id, pad_value in token_to_pad_mapping.items():
|
155
|
+
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
156
|
+
|
157
|
+
ret_input_ids = input_ids_tensor.tolist()
|
158
|
+
|
159
|
+
return ret_input_ids
|
195
160
|
|
196
161
|
|
197
162
|
embedding_cache = None
|
@@ -680,3 +645,52 @@ def get_multimodal_data_bounds(
|
|
680
645
|
# Convert valid pairs to tensor
|
681
646
|
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
682
647
|
return valid_pairs_tensor
|
648
|
+
|
649
|
+
|
650
|
+
def data_hash(data) -> int:
|
651
|
+
hash_bytes = hashlib.sha256(data).digest()[:8]
|
652
|
+
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
653
|
+
|
654
|
+
|
655
|
+
def tensor_hash(tensor_list) -> int:
|
656
|
+
"""
|
657
|
+
hash a tensor or a tensor list
|
658
|
+
"""
|
659
|
+
tensor = tensor_list
|
660
|
+
if isinstance(tensor_list, list):
|
661
|
+
tensor_list = flatten_nested_list(tensor_list)
|
662
|
+
tensor_list = [
|
663
|
+
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
|
664
|
+
]
|
665
|
+
tensor = torch.concat(tensor_list)
|
666
|
+
if tensor.is_cuda:
|
667
|
+
return gpu_tensor_hash(tensor)
|
668
|
+
tensor = tensor.detach().contiguous()
|
669
|
+
|
670
|
+
if tensor.dtype == torch.bfloat16:
|
671
|
+
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
672
|
+
tensor = tensor.float()
|
673
|
+
|
674
|
+
assert isinstance(tensor, torch.Tensor)
|
675
|
+
if tensor.is_cuda:
|
676
|
+
# TODO: improve this
|
677
|
+
tensor_cpu = tensor.cpu()
|
678
|
+
else:
|
679
|
+
tensor_cpu = tensor
|
680
|
+
|
681
|
+
mv = memoryview(tensor_cpu.numpy())
|
682
|
+
return data_hash(mv.tobytes())
|
683
|
+
|
684
|
+
|
685
|
+
def hash_feature(f):
|
686
|
+
if isinstance(f, list):
|
687
|
+
if isinstance(f[0], torch.Tensor):
|
688
|
+
return tensor_hash(f)
|
689
|
+
return data_hash(tuple(flatten_nested_list(f)))
|
690
|
+
elif isinstance(f, np.ndarray):
|
691
|
+
arr = np.ascontiguousarray(f)
|
692
|
+
arr_bytes = arr.tobytes()
|
693
|
+
return data_hash(arr_bytes)
|
694
|
+
elif isinstance(f, torch.Tensor):
|
695
|
+
return tensor_hash([f])
|
696
|
+
return data_hash(f)
|
@@ -3,11 +3,8 @@ import importlib
|
|
3
3
|
import inspect
|
4
4
|
import logging
|
5
5
|
import pkgutil
|
6
|
-
from functools import lru_cache
|
7
6
|
|
8
|
-
from sglang.srt.
|
9
|
-
BaseMultimodalProcessor,
|
10
|
-
)
|
7
|
+
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
11
8
|
from sglang.srt.server_args import ServerArgs
|
12
9
|
|
13
10
|
logger = logging.getLogger(__name__)
|
@@ -27,9 +24,8 @@ def get_dummy_processor():
|
|
27
24
|
return DummyMultimodalProcessor()
|
28
25
|
|
29
26
|
|
30
|
-
@lru_cache()
|
31
27
|
def import_processors():
|
32
|
-
package_name = "sglang.srt.
|
28
|
+
package_name = "sglang.srt.multimodal.processors"
|
33
29
|
package = importlib.import_module(package_name)
|
34
30
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
35
31
|
if not ispkg:
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import re
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
7
|
+
BaseMultimodalProcessor,
|
8
|
+
MultimodalSpecialTokens,
|
9
|
+
)
|
10
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
11
|
+
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
|
12
|
+
|
13
|
+
|
14
|
+
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
15
|
+
models = [Qwen2AudioForConditionalGeneration]
|
16
|
+
|
17
|
+
def __init__(self, hf_config, server_args, _processor):
|
18
|
+
super().__init__(hf_config, server_args, _processor)
|
19
|
+
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
20
|
+
self.AUDIO_TOKEN_REGEX = re.compile(
|
21
|
+
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
22
|
+
)
|
23
|
+
|
24
|
+
async def process_mm_data_async(
|
25
|
+
self,
|
26
|
+
image_data: List[Union[str, bytes]],
|
27
|
+
input_text,
|
28
|
+
request_obj,
|
29
|
+
max_req_input_len,
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
audio_data = request_obj.audio_data
|
33
|
+
if not isinstance(audio_data, list):
|
34
|
+
audio_data = [audio_data]
|
35
|
+
|
36
|
+
base_output = self.load_mm_data(
|
37
|
+
prompt=input_text,
|
38
|
+
max_req_input_len=max_req_input_len,
|
39
|
+
audio_data=audio_data,
|
40
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
41
|
+
audio_token=self.AUDIO_TOKEN,
|
42
|
+
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
43
|
+
),
|
44
|
+
)
|
45
|
+
if base_output is None:
|
46
|
+
return None
|
47
|
+
|
48
|
+
res = self.process_mm_data(
|
49
|
+
input_text=base_output.input_text,
|
50
|
+
audio=base_output.audios,
|
51
|
+
)
|
52
|
+
|
53
|
+
# Collect special token ids
|
54
|
+
tokenizer = self._processor.tokenizer
|
55
|
+
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
|
56
|
+
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
|
57
|
+
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
|
58
|
+
|
59
|
+
items = []
|
60
|
+
input_ids = res["input_ids"].flatten()
|
61
|
+
|
62
|
+
if (
|
63
|
+
"input_features" in res
|
64
|
+
and res["input_features"] is not None
|
65
|
+
and len(res["input_features"]) != 0
|
66
|
+
):
|
67
|
+
if audio_start_id is not None and audio_end_id is not None:
|
68
|
+
audio_offsets = self.get_mm_items_offset_by_pair(
|
69
|
+
input_ids=input_ids,
|
70
|
+
mm_start_id=audio_start_id,
|
71
|
+
mm_end_id=audio_end_id,
|
72
|
+
)
|
73
|
+
else:
|
74
|
+
audio_offsets = None
|
75
|
+
|
76
|
+
input_lengths = res["feature_attention_mask"].sum(dim=-1)
|
77
|
+
input_lengths = (input_lengths - 1) // 2 + 1
|
78
|
+
output_lengths = (input_lengths - 2) // 2 + 1
|
79
|
+
|
80
|
+
item = MultimodalDataItem(
|
81
|
+
audio_features=res["input_features"],
|
82
|
+
audio_feature_lens=output_lengths,
|
83
|
+
audio_offsets=audio_offsets,
|
84
|
+
modality=Modality.AUDIO,
|
85
|
+
)
|
86
|
+
items += [item]
|
87
|
+
|
88
|
+
return {
|
89
|
+
"mm_items": items,
|
90
|
+
"input_ids": input_ids.tolist(),
|
91
|
+
"audio_start_id": audio_start_id,
|
92
|
+
"audio_token_id": audio_token_id,
|
93
|
+
"audio_end_id": audio_end_id,
|
94
|
+
}
|