sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +48 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +34 -0
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +36 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +11 -7
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +50 -13
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +77 -84
- sglang/srt/managers/scheduler.py +113 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +181 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +69 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +200 -27
- sglang/srt/utils.py +306 -146
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
|
|
35
35
|
get_normalized_lora_weight_names,
|
36
36
|
get_weight_name,
|
37
37
|
)
|
38
|
+
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
38
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
40
|
from sglang.srt.utils import replace_submodule
|
40
41
|
|
@@ -98,44 +99,96 @@ class LoRAManager:
|
|
98
99
|
],
|
99
100
|
)
|
100
101
|
|
101
|
-
def
|
102
|
+
def create_lora_update_result(
|
103
|
+
self, success: bool, error_message: str = ""
|
104
|
+
) -> LoRAUpdateResult:
|
105
|
+
return LoRAUpdateResult(
|
106
|
+
success=success,
|
107
|
+
error_message=error_message,
|
108
|
+
loaded_adapters={
|
109
|
+
name: config.path for name, config in self.configs.items()
|
110
|
+
},
|
111
|
+
)
|
112
|
+
|
113
|
+
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
|
102
114
|
"""
|
103
115
|
Load LoRA adapters from the specified paths.
|
104
|
-
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
105
116
|
|
106
117
|
Args:
|
107
118
|
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
108
119
|
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
109
120
|
"""
|
110
121
|
|
122
|
+
results = []
|
111
123
|
for lora_name, lora_path in lora_paths.items():
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
124
|
+
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
125
|
+
results.append(result)
|
126
|
+
|
127
|
+
self.update_state_from_configs()
|
128
|
+
|
129
|
+
return self.create_lora_update_result(
|
130
|
+
success=all(result.success for result in results),
|
131
|
+
error_message="\n".join(
|
132
|
+
result.error_message for result in results if not result.success
|
133
|
+
),
|
134
|
+
)
|
135
|
+
|
136
|
+
def load_lora_adapter(
|
137
|
+
self, lora_name: str, lora_path: str, update_state: bool = True
|
138
|
+
) -> LoRAUpdateResult:
|
139
|
+
"""
|
140
|
+
Load a single LoRA adapter from the specified path.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
lora_name (str): The name of the LoRA adapter.
|
144
|
+
lora_path (str): The file path to the LoRA adapter.
|
145
|
+
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
146
|
+
"""
|
118
147
|
|
148
|
+
success = True
|
149
|
+
error_message = ""
|
150
|
+
|
151
|
+
if lora_name in self.loras:
|
152
|
+
success = False
|
153
|
+
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
154
|
+
|
155
|
+
try:
|
119
156
|
self.configs[lora_name] = LoRAConfig(lora_path)
|
157
|
+
except Exception as e:
|
158
|
+
success = False
|
159
|
+
error_message = (
|
160
|
+
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
|
161
|
+
)
|
120
162
|
|
121
|
-
|
163
|
+
if update_state:
|
164
|
+
self.update_state_from_configs()
|
165
|
+
|
166
|
+
return self.create_lora_update_result(
|
167
|
+
success=success,
|
168
|
+
error_message=error_message,
|
169
|
+
)
|
122
170
|
|
123
|
-
def
|
171
|
+
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
124
172
|
"""
|
125
173
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
126
174
|
delete the corresponding LoRA modules.
|
127
|
-
|
128
|
-
Args:
|
129
|
-
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
130
175
|
"""
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
176
|
+
|
177
|
+
success = True
|
178
|
+
error_message = ""
|
179
|
+
if lora_name in self.loras:
|
180
|
+
del self.configs[lora_name]
|
181
|
+
else:
|
182
|
+
error_message = f"LoRA adapter {lora_name} is not loaded."
|
183
|
+
success = False
|
136
184
|
|
137
185
|
self.update_state_from_configs()
|
138
186
|
|
187
|
+
return self.create_lora_update_result(
|
188
|
+
success=success,
|
189
|
+
error_message=error_message,
|
190
|
+
)
|
191
|
+
|
139
192
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
140
193
|
# load active loras into lora memory pool
|
141
194
|
cur_uids = set(forward_batch.lora_paths)
|
@@ -372,8 +425,8 @@ class LoRAManager:
|
|
372
425
|
lora_adapter.initialize_weights()
|
373
426
|
self.loras[name] = lora_adapter
|
374
427
|
|
375
|
-
# Clean up unused LoRA adapters
|
376
|
-
for name in self.loras:
|
428
|
+
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
429
|
+
for name in list(self.loras):
|
377
430
|
if name not in self.configs:
|
378
431
|
logger.info(f"Unloading LoRA adapter {name}")
|
379
432
|
del self.loras[name]
|
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
|
28
28
|
parser = argparse.ArgumentParser()
|
29
29
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
30
|
parser.add_argument("--log-requests", action="store_true")
|
31
|
-
parser.add_argument("--log-requests-level", type=int, default=
|
31
|
+
parser.add_argument("--log-requests-level", type=int, default=3)
|
32
32
|
parser.add_argument(
|
33
33
|
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
34
34
|
)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,19 +20,18 @@ import copy
|
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
24
24
|
|
25
|
-
from sglang.srt.
|
25
|
+
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
|
+
from sglang.srt.multimodal.mm_utils import has_valid_data
|
27
|
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
26
28
|
|
27
|
-
#
|
29
|
+
# Handle serialization of Image for pydantic
|
28
30
|
if TYPE_CHECKING:
|
29
31
|
from PIL.Image import Image
|
30
32
|
else:
|
31
33
|
Image = Any
|
32
34
|
|
33
|
-
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
34
|
-
from sglang.srt.sampling.sampling_params import SamplingParams
|
35
|
-
|
36
35
|
|
37
36
|
@dataclass
|
38
37
|
class SessionParams:
|
@@ -40,6 +39,7 @@ class SessionParams:
|
|
40
39
|
rid: Optional[str] = None
|
41
40
|
offset: Optional[int] = None
|
42
41
|
replace: Optional[bool] = None
|
42
|
+
drop_previous_output: Optional[bool] = None
|
43
43
|
|
44
44
|
|
45
45
|
AudioDataItem = Union[str, Dict]
|
@@ -182,6 +182,7 @@ class GenerateReqInput:
|
|
182
182
|
# Determine parallel sample count
|
183
183
|
if self.sampling_params is None:
|
184
184
|
self.parallel_sample_num = 1
|
185
|
+
return
|
185
186
|
elif isinstance(self.sampling_params, dict):
|
186
187
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
187
188
|
else: # isinstance(self.sampling_params, list):
|
@@ -516,9 +517,6 @@ class EmbeddingReqInput:
|
|
516
517
|
# For cross-encoder requests
|
517
518
|
is_cross_encoder_request: bool = False
|
518
519
|
|
519
|
-
def contains_mm_input(self) -> bool:
|
520
|
-
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
521
|
-
|
522
520
|
def normalize_batch_and_arguments(self):
|
523
521
|
# at least one of text, input_ids, or image should be provided
|
524
522
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
@@ -572,6 +570,9 @@ class EmbeddingReqInput:
|
|
572
570
|
self.rid = uuid.uuid4().hex
|
573
571
|
return self.rid
|
574
572
|
|
573
|
+
def contains_mm_input(self) -> bool:
|
574
|
+
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
575
|
+
|
575
576
|
def __getitem__(self, i):
|
576
577
|
if self.is_cross_encoder_request:
|
577
578
|
return EmbeddingReqInput(
|
@@ -740,6 +741,8 @@ class UpdateWeightFromDiskReqInput:
|
|
740
741
|
model_path: str
|
741
742
|
# The format to load the weights
|
742
743
|
load_format: Optional[str] = None
|
744
|
+
# Whether to abort all requests before updating weights
|
745
|
+
abort_all_requests: bool = False
|
743
746
|
|
744
747
|
|
745
748
|
@dataclass
|
@@ -752,9 +755,15 @@ class UpdateWeightFromDiskReqOutput:
|
|
752
755
|
|
753
756
|
@dataclass
|
754
757
|
class UpdateWeightsFromDistributedReqInput:
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
+
names: List[str]
|
759
|
+
dtypes: List[str]
|
760
|
+
shapes: List[List[int]]
|
761
|
+
# The group name
|
762
|
+
group_name: str = "weight_update_group"
|
763
|
+
# Whether to flush the cache after updating weights
|
764
|
+
flush_cache: bool = True
|
765
|
+
# Whether to abort all requests before updating weights
|
766
|
+
abort_all_requests: bool = False
|
758
767
|
|
759
768
|
|
760
769
|
@dataclass
|
@@ -776,6 +785,8 @@ class UpdateWeightsFromTensorReqInput:
|
|
776
785
|
load_format: Optional[str] = None
|
777
786
|
# Whether to flush the cache after updating weights
|
778
787
|
flush_cache: bool = True
|
788
|
+
# Whether to abort all requests before updating weights
|
789
|
+
abort_all_requests: bool = False
|
779
790
|
|
780
791
|
|
781
792
|
@dataclass
|
@@ -854,7 +865,9 @@ class SlowDownReqOutput:
|
|
854
865
|
@dataclass
|
855
866
|
class AbortReq:
|
856
867
|
# The request id
|
857
|
-
rid: str
|
868
|
+
rid: str = ""
|
869
|
+
# Whether to abort all requests
|
870
|
+
abort_all: bool = False
|
858
871
|
|
859
872
|
|
860
873
|
@dataclass
|
@@ -1002,3 +1015,27 @@ class RpcReqInput:
|
|
1002
1015
|
class RpcReqOutput:
|
1003
1016
|
success: bool
|
1004
1017
|
message: str
|
1018
|
+
|
1019
|
+
|
1020
|
+
@dataclass
|
1021
|
+
class LoadLoRAAdapterReqInput:
|
1022
|
+
# The name of the lora module to newly loaded.
|
1023
|
+
lora_name: str
|
1024
|
+
# The path of loading.
|
1025
|
+
lora_path: str
|
1026
|
+
|
1027
|
+
|
1028
|
+
@dataclass
|
1029
|
+
class UnloadLoRAAdapterReqInput:
|
1030
|
+
# The name of lora module to unload.
|
1031
|
+
lora_name: str
|
1032
|
+
|
1033
|
+
|
1034
|
+
@dataclass
|
1035
|
+
class LoRAUpdateResult:
|
1036
|
+
success: bool
|
1037
|
+
error_message: Optional[str] = None
|
1038
|
+
loaded_adapters: Dict[str, str] = field(default_factory=dict)
|
1039
|
+
|
1040
|
+
|
1041
|
+
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -2,14 +2,15 @@
|
|
2
2
|
Multi-modality utils
|
3
3
|
"""
|
4
4
|
|
5
|
-
import
|
6
|
-
import logging
|
5
|
+
import hashlib
|
7
6
|
from abc import abstractmethod
|
8
7
|
from typing import Callable, List, Optional, Tuple
|
9
8
|
|
9
|
+
import numpy as np
|
10
10
|
import torch
|
11
11
|
from torch import nn
|
12
12
|
|
13
|
+
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
13
14
|
from sglang.srt.managers.schedule_batch import (
|
14
15
|
Modality,
|
15
16
|
MultimodalDataItem,
|
@@ -124,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
124
125
|
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
125
126
|
"""
|
126
127
|
|
127
|
-
def __init__(self, token_ids: List[int]) -> None:
|
128
|
-
self.token_ids = token_ids
|
129
|
-
|
130
128
|
def pad_input_tokens(
|
131
129
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
132
130
|
) -> List[int]:
|
133
131
|
"""
|
134
|
-
|
135
|
-
|
132
|
+
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
|
133
|
+
Each modality (image, audio, video) is handled separately based on its token_id.
|
136
134
|
"""
|
137
|
-
|
138
|
-
if not pad_values:
|
139
|
-
# No multimodal items, return original input_ids
|
135
|
+
if not input_ids or not mm_inputs.mm_items:
|
140
136
|
return input_ids
|
141
|
-
if not input_ids:
|
142
|
-
return []
|
143
137
|
|
144
138
|
input_ids_tensor = torch.tensor(input_ids)
|
145
|
-
device = input_ids_tensor.device
|
146
|
-
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
147
|
-
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
148
139
|
|
149
|
-
|
150
|
-
|
151
|
-
return input_ids
|
140
|
+
# Create mapping of token_ids to pad_values for each modality
|
141
|
+
token_to_pad_mapping = {}
|
152
142
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
)
|
161
|
-
# Find indices where the mask value changes
|
162
|
-
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
163
|
-
|
164
|
-
# Start indices are where False changes to True
|
165
|
-
starts = diff_indices[::2]
|
166
|
-
# End indices are where True changes to False (exclusive index)
|
167
|
-
ends = diff_indices[1::2]
|
168
|
-
|
169
|
-
# Check if the number of regions matches the number of pad values
|
170
|
-
if len(starts) != len(pad_values):
|
171
|
-
# Maybe log a warning here?
|
172
|
-
num_regions = len(starts)
|
173
|
-
num_pad_values = len(pad_values)
|
174
|
-
if num_regions > 0 and num_pad_values > 0:
|
175
|
-
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
176
|
-
:num_regions
|
177
|
-
]
|
178
|
-
else: # If no regions or no pad_values, this loop won't run anyway.
|
179
|
-
pad_values = [] # Ensure pad_values is empty if starts is empty
|
180
|
-
|
181
|
-
# Create a copy to modify
|
182
|
-
output_ids_tensor = input_ids_tensor.clone()
|
183
|
-
|
184
|
-
# Replace tokens in each region with the corresponding pad value
|
185
|
-
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
186
|
-
for i in range(min(len(starts), len(pad_values))):
|
187
|
-
start_idx = starts[i]
|
188
|
-
end_idx = ends[i]
|
189
|
-
pad_value = pad_values[i]
|
190
|
-
if pad_value is not None: # Ensure pad_value is not None before assignment
|
191
|
-
output_ids_tensor[start_idx:end_idx] = pad_value
|
143
|
+
for item in mm_inputs.mm_items:
|
144
|
+
if item.is_image() and mm_inputs.im_token_id is not None:
|
145
|
+
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
|
146
|
+
elif item.is_audio() and mm_inputs.audio_token_id is not None:
|
147
|
+
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
|
148
|
+
elif item.is_video() and mm_inputs.video_token_id is not None:
|
149
|
+
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
|
192
150
|
else:
|
193
|
-
|
194
|
-
|
151
|
+
raise ValueError(f"No multimodal token id provided for {item.modality}")
|
152
|
+
|
153
|
+
# Apply replacements for all tokens at once
|
154
|
+
for token_id, pad_value in token_to_pad_mapping.items():
|
155
|
+
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
156
|
+
|
157
|
+
ret_input_ids = input_ids_tensor.tolist()
|
158
|
+
|
159
|
+
return ret_input_ids
|
195
160
|
|
196
161
|
|
197
162
|
embedding_cache = None
|
@@ -680,3 +645,52 @@ def get_multimodal_data_bounds(
|
|
680
645
|
# Convert valid pairs to tensor
|
681
646
|
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
682
647
|
return valid_pairs_tensor
|
648
|
+
|
649
|
+
|
650
|
+
def data_hash(data) -> int:
|
651
|
+
hash_bytes = hashlib.sha256(data).digest()[:8]
|
652
|
+
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
653
|
+
|
654
|
+
|
655
|
+
def tensor_hash(tensor_list) -> int:
|
656
|
+
"""
|
657
|
+
hash a tensor or a tensor list
|
658
|
+
"""
|
659
|
+
tensor = tensor_list
|
660
|
+
if isinstance(tensor_list, list):
|
661
|
+
tensor_list = flatten_nested_list(tensor_list)
|
662
|
+
tensor_list = [
|
663
|
+
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
|
664
|
+
]
|
665
|
+
tensor = torch.concat(tensor_list)
|
666
|
+
if tensor.is_cuda:
|
667
|
+
return gpu_tensor_hash(tensor)
|
668
|
+
tensor = tensor.detach().contiguous()
|
669
|
+
|
670
|
+
if tensor.dtype == torch.bfloat16:
|
671
|
+
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
672
|
+
tensor = tensor.float()
|
673
|
+
|
674
|
+
assert isinstance(tensor, torch.Tensor)
|
675
|
+
if tensor.is_cuda:
|
676
|
+
# TODO: improve this
|
677
|
+
tensor_cpu = tensor.cpu()
|
678
|
+
else:
|
679
|
+
tensor_cpu = tensor
|
680
|
+
|
681
|
+
mv = memoryview(tensor_cpu.numpy())
|
682
|
+
return data_hash(mv.tobytes())
|
683
|
+
|
684
|
+
|
685
|
+
def hash_feature(f):
|
686
|
+
if isinstance(f, list):
|
687
|
+
if isinstance(f[0], torch.Tensor):
|
688
|
+
return tensor_hash(f)
|
689
|
+
return data_hash(tuple(flatten_nested_list(f)))
|
690
|
+
elif isinstance(f, np.ndarray):
|
691
|
+
arr = np.ascontiguousarray(f)
|
692
|
+
arr_bytes = arr.tobytes()
|
693
|
+
return data_hash(arr_bytes)
|
694
|
+
elif isinstance(f, torch.Tensor):
|
695
|
+
return tensor_hash([f])
|
696
|
+
return data_hash(f)
|
@@ -3,11 +3,8 @@ import importlib
|
|
3
3
|
import inspect
|
4
4
|
import logging
|
5
5
|
import pkgutil
|
6
|
-
from functools import lru_cache
|
7
6
|
|
8
|
-
from sglang.srt.
|
9
|
-
BaseMultimodalProcessor,
|
10
|
-
)
|
7
|
+
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
11
8
|
from sglang.srt.server_args import ServerArgs
|
12
9
|
|
13
10
|
logger = logging.getLogger(__name__)
|
@@ -27,9 +24,8 @@ def get_dummy_processor():
|
|
27
24
|
return DummyMultimodalProcessor()
|
28
25
|
|
29
26
|
|
30
|
-
@lru_cache()
|
31
27
|
def import_processors():
|
32
|
-
package_name = "sglang.srt.
|
28
|
+
package_name = "sglang.srt.multimodal.processors"
|
33
29
|
package = importlib.import_module(package_name)
|
34
30
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
35
31
|
if not ispkg:
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import re
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
7
|
+
BaseMultimodalProcessor,
|
8
|
+
MultimodalSpecialTokens,
|
9
|
+
)
|
10
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
11
|
+
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
|
12
|
+
|
13
|
+
|
14
|
+
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
15
|
+
models = [Qwen2AudioForConditionalGeneration]
|
16
|
+
|
17
|
+
def __init__(self, hf_config, server_args, _processor):
|
18
|
+
super().__init__(hf_config, server_args, _processor)
|
19
|
+
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
20
|
+
self.AUDIO_TOKEN_REGEX = re.compile(
|
21
|
+
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
22
|
+
)
|
23
|
+
|
24
|
+
async def process_mm_data_async(
|
25
|
+
self,
|
26
|
+
image_data: List[Union[str, bytes]],
|
27
|
+
input_text,
|
28
|
+
request_obj,
|
29
|
+
max_req_input_len,
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
audio_data = request_obj.audio_data
|
33
|
+
if not isinstance(audio_data, list):
|
34
|
+
audio_data = [audio_data]
|
35
|
+
|
36
|
+
base_output = self.load_mm_data(
|
37
|
+
prompt=input_text,
|
38
|
+
max_req_input_len=max_req_input_len,
|
39
|
+
audio_data=audio_data,
|
40
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
41
|
+
audio_token=self.AUDIO_TOKEN,
|
42
|
+
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
43
|
+
),
|
44
|
+
)
|
45
|
+
if base_output is None:
|
46
|
+
return None
|
47
|
+
|
48
|
+
res = self.process_mm_data(
|
49
|
+
input_text=base_output.input_text,
|
50
|
+
audio=base_output.audios,
|
51
|
+
)
|
52
|
+
|
53
|
+
# Collect special token ids
|
54
|
+
tokenizer = self._processor.tokenizer
|
55
|
+
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
|
56
|
+
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
|
57
|
+
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
|
58
|
+
|
59
|
+
items = []
|
60
|
+
input_ids = res["input_ids"].flatten()
|
61
|
+
|
62
|
+
if (
|
63
|
+
"input_features" in res
|
64
|
+
and res["input_features"] is not None
|
65
|
+
and len(res["input_features"]) != 0
|
66
|
+
):
|
67
|
+
if audio_start_id is not None and audio_end_id is not None:
|
68
|
+
audio_offsets = self.get_mm_items_offset_by_pair(
|
69
|
+
input_ids=input_ids,
|
70
|
+
mm_start_id=audio_start_id,
|
71
|
+
mm_end_id=audio_end_id,
|
72
|
+
)
|
73
|
+
else:
|
74
|
+
audio_offsets = None
|
75
|
+
|
76
|
+
input_lengths = res["feature_attention_mask"].sum(dim=-1)
|
77
|
+
input_lengths = (input_lengths - 1) // 2 + 1
|
78
|
+
output_lengths = (input_lengths - 2) // 2 + 1
|
79
|
+
|
80
|
+
item = MultimodalDataItem(
|
81
|
+
audio_features=res["input_features"],
|
82
|
+
audio_feature_lens=output_lengths,
|
83
|
+
audio_offsets=audio_offsets,
|
84
|
+
modality=Modality.AUDIO,
|
85
|
+
)
|
86
|
+
items += [item]
|
87
|
+
|
88
|
+
return {
|
89
|
+
"mm_items": items,
|
90
|
+
"input_ids": input_ids.tolist(),
|
91
|
+
"audio_start_id": audio_start_id,
|
92
|
+
"audio_token_id": audio_token_id,
|
93
|
+
"audio_end_id": audio_end_id,
|
94
|
+
}
|