sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,19 +20,18 @@ import copy
|
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
24
24
|
|
25
|
-
from sglang.srt.
|
25
|
+
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
|
+
from sglang.srt.multimodal.mm_utils import has_valid_data
|
27
|
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
26
28
|
|
27
|
-
#
|
29
|
+
# Handle serialization of Image for pydantic
|
28
30
|
if TYPE_CHECKING:
|
29
31
|
from PIL.Image import Image
|
30
32
|
else:
|
31
33
|
Image = Any
|
32
34
|
|
33
|
-
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
34
|
-
from sglang.srt.sampling.sampling_params import SamplingParams
|
35
|
-
|
36
35
|
|
37
36
|
@dataclass
|
38
37
|
class SessionParams:
|
@@ -40,6 +39,7 @@ class SessionParams:
|
|
40
39
|
rid: Optional[str] = None
|
41
40
|
offset: Optional[int] = None
|
42
41
|
replace: Optional[bool] = None
|
42
|
+
drop_previous_output: Optional[bool] = None
|
43
43
|
|
44
44
|
|
45
45
|
AudioDataItem = Union[str, Dict]
|
@@ -182,6 +182,7 @@ class GenerateReqInput:
|
|
182
182
|
# Determine parallel sample count
|
183
183
|
if self.sampling_params is None:
|
184
184
|
self.parallel_sample_num = 1
|
185
|
+
return
|
185
186
|
elif isinstance(self.sampling_params, dict):
|
186
187
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
187
188
|
else: # isinstance(self.sampling_params, list):
|
@@ -199,6 +200,8 @@ class GenerateReqInput:
|
|
199
200
|
self.text = [self.text]
|
200
201
|
if self.input_ids is not None:
|
201
202
|
self.input_ids = [self.input_ids]
|
203
|
+
if self.input_embeds is not None:
|
204
|
+
self.input_embeds = [self.input_embeds]
|
202
205
|
|
203
206
|
def _normalize_single_inputs(self):
|
204
207
|
"""Normalize inputs for a single example."""
|
@@ -323,7 +326,9 @@ class GenerateReqInput:
|
|
323
326
|
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
324
327
|
self.rid = new_rids
|
325
328
|
elif isinstance(self.rid, list):
|
326
|
-
|
329
|
+
# Note: the length of rid shall be the same as the batch_size,
|
330
|
+
# as the rid would be expanded for parallel sampling in tokenizer_manager
|
331
|
+
if len(self.rid) != self.batch_size:
|
327
332
|
raise ValueError(
|
328
333
|
"The specified rids length mismatch with the batch_size for batch processing."
|
329
334
|
)
|
@@ -399,6 +404,9 @@ class GenerateReqInput:
|
|
399
404
|
return GenerateReqInput(
|
400
405
|
text=self.text[i] if self.text is not None else None,
|
401
406
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
407
|
+
input_embeds=(
|
408
|
+
self.input_embeds[i] if self.input_embeds is not None else None
|
409
|
+
),
|
402
410
|
image_data=self.image_data[i],
|
403
411
|
audio_data=self.audio_data[i],
|
404
412
|
sampling_params=self.sampling_params[i],
|
@@ -516,9 +524,6 @@ class EmbeddingReqInput:
|
|
516
524
|
# For cross-encoder requests
|
517
525
|
is_cross_encoder_request: bool = False
|
518
526
|
|
519
|
-
def contains_mm_input(self) -> bool:
|
520
|
-
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
521
|
-
|
522
527
|
def normalize_batch_and_arguments(self):
|
523
528
|
# at least one of text, input_ids, or image should be provided
|
524
529
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
@@ -572,6 +577,9 @@ class EmbeddingReqInput:
|
|
572
577
|
self.rid = uuid.uuid4().hex
|
573
578
|
return self.rid
|
574
579
|
|
580
|
+
def contains_mm_input(self) -> bool:
|
581
|
+
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
582
|
+
|
575
583
|
def __getitem__(self, i):
|
576
584
|
if self.is_cross_encoder_request:
|
577
585
|
return EmbeddingReqInput(
|
@@ -740,6 +748,8 @@ class UpdateWeightFromDiskReqInput:
|
|
740
748
|
model_path: str
|
741
749
|
# The format to load the weights
|
742
750
|
load_format: Optional[str] = None
|
751
|
+
# Whether to abort all requests before updating weights
|
752
|
+
abort_all_requests: bool = False
|
743
753
|
|
744
754
|
|
745
755
|
@dataclass
|
@@ -752,9 +762,15 @@ class UpdateWeightFromDiskReqOutput:
|
|
752
762
|
|
753
763
|
@dataclass
|
754
764
|
class UpdateWeightsFromDistributedReqInput:
|
755
|
-
|
756
|
-
|
757
|
-
|
765
|
+
names: List[str]
|
766
|
+
dtypes: List[str]
|
767
|
+
shapes: List[List[int]]
|
768
|
+
# The group name
|
769
|
+
group_name: str = "weight_update_group"
|
770
|
+
# Whether to flush the cache after updating weights
|
771
|
+
flush_cache: bool = True
|
772
|
+
# Whether to abort all requests before updating weights
|
773
|
+
abort_all_requests: bool = False
|
758
774
|
|
759
775
|
|
760
776
|
@dataclass
|
@@ -776,6 +792,8 @@ class UpdateWeightsFromTensorReqInput:
|
|
776
792
|
load_format: Optional[str] = None
|
777
793
|
# Whether to flush the cache after updating weights
|
778
794
|
flush_cache: bool = True
|
795
|
+
# Whether to abort all requests before updating weights
|
796
|
+
abort_all_requests: bool = False
|
779
797
|
|
780
798
|
|
781
799
|
@dataclass
|
@@ -854,7 +872,9 @@ class SlowDownReqOutput:
|
|
854
872
|
@dataclass
|
855
873
|
class AbortReq:
|
856
874
|
# The request id
|
857
|
-
rid: str
|
875
|
+
rid: str = ""
|
876
|
+
# Whether to abort all requests
|
877
|
+
abort_all: bool = False
|
858
878
|
|
859
879
|
|
860
880
|
@dataclass
|
@@ -1002,3 +1022,27 @@ class RpcReqInput:
|
|
1002
1022
|
class RpcReqOutput:
|
1003
1023
|
success: bool
|
1004
1024
|
message: str
|
1025
|
+
|
1026
|
+
|
1027
|
+
@dataclass
|
1028
|
+
class LoadLoRAAdapterReqInput:
|
1029
|
+
# The name of the lora module to newly loaded.
|
1030
|
+
lora_name: str
|
1031
|
+
# The path of loading.
|
1032
|
+
lora_path: str
|
1033
|
+
|
1034
|
+
|
1035
|
+
@dataclass
|
1036
|
+
class UnloadLoRAAdapterReqInput:
|
1037
|
+
# The name of lora module to unload.
|
1038
|
+
lora_name: str
|
1039
|
+
|
1040
|
+
|
1041
|
+
@dataclass
|
1042
|
+
class LoRAUpdateResult:
|
1043
|
+
success: bool
|
1044
|
+
error_message: Optional[str] = None
|
1045
|
+
loaded_adapters: Dict[str, str] = field(default_factory=dict)
|
1046
|
+
|
1047
|
+
|
1048
|
+
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -2,14 +2,15 @@
|
|
2
2
|
Multi-modality utils
|
3
3
|
"""
|
4
4
|
|
5
|
-
import
|
6
|
-
import logging
|
5
|
+
import hashlib
|
7
6
|
from abc import abstractmethod
|
8
7
|
from typing import Callable, List, Optional, Tuple
|
9
8
|
|
9
|
+
import numpy as np
|
10
10
|
import torch
|
11
11
|
from torch import nn
|
12
12
|
|
13
|
+
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
13
14
|
from sglang.srt.managers.schedule_batch import (
|
14
15
|
Modality,
|
15
16
|
MultimodalDataItem,
|
@@ -124,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
124
125
|
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
125
126
|
"""
|
126
127
|
|
127
|
-
def __init__(self, token_ids: List[int]) -> None:
|
128
|
-
self.token_ids = token_ids
|
129
|
-
|
130
128
|
def pad_input_tokens(
|
131
129
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
132
130
|
) -> List[int]:
|
133
131
|
"""
|
134
|
-
|
135
|
-
|
132
|
+
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
|
133
|
+
Each modality (image, audio, video) is handled separately based on its token_id.
|
136
134
|
"""
|
137
|
-
|
138
|
-
if not pad_values:
|
139
|
-
# No multimodal items, return original input_ids
|
135
|
+
if not input_ids or not mm_inputs.mm_items:
|
140
136
|
return input_ids
|
141
|
-
if not input_ids:
|
142
|
-
return []
|
143
137
|
|
144
138
|
input_ids_tensor = torch.tensor(input_ids)
|
145
|
-
device = input_ids_tensor.device
|
146
|
-
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
147
|
-
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
148
139
|
|
149
|
-
|
150
|
-
|
151
|
-
return input_ids
|
140
|
+
# Create mapping of token_ids to pad_values for each modality
|
141
|
+
token_to_pad_mapping = {}
|
152
142
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
)
|
161
|
-
# Find indices where the mask value changes
|
162
|
-
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
163
|
-
|
164
|
-
# Start indices are where False changes to True
|
165
|
-
starts = diff_indices[::2]
|
166
|
-
# End indices are where True changes to False (exclusive index)
|
167
|
-
ends = diff_indices[1::2]
|
168
|
-
|
169
|
-
# Check if the number of regions matches the number of pad values
|
170
|
-
if len(starts) != len(pad_values):
|
171
|
-
# Maybe log a warning here?
|
172
|
-
num_regions = len(starts)
|
173
|
-
num_pad_values = len(pad_values)
|
174
|
-
if num_regions > 0 and num_pad_values > 0:
|
175
|
-
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
176
|
-
:num_regions
|
177
|
-
]
|
178
|
-
else: # If no regions or no pad_values, this loop won't run anyway.
|
179
|
-
pad_values = [] # Ensure pad_values is empty if starts is empty
|
180
|
-
|
181
|
-
# Create a copy to modify
|
182
|
-
output_ids_tensor = input_ids_tensor.clone()
|
183
|
-
|
184
|
-
# Replace tokens in each region with the corresponding pad value
|
185
|
-
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
186
|
-
for i in range(min(len(starts), len(pad_values))):
|
187
|
-
start_idx = starts[i]
|
188
|
-
end_idx = ends[i]
|
189
|
-
pad_value = pad_values[i]
|
190
|
-
if pad_value is not None: # Ensure pad_value is not None before assignment
|
191
|
-
output_ids_tensor[start_idx:end_idx] = pad_value
|
143
|
+
for item in mm_inputs.mm_items:
|
144
|
+
if item.is_image() and mm_inputs.im_token_id is not None:
|
145
|
+
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
|
146
|
+
elif item.is_audio() and mm_inputs.audio_token_id is not None:
|
147
|
+
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
|
148
|
+
elif item.is_video() and mm_inputs.video_token_id is not None:
|
149
|
+
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
|
192
150
|
else:
|
193
|
-
|
194
|
-
|
151
|
+
raise ValueError(f"No multimodal token id provided for {item.modality}")
|
152
|
+
|
153
|
+
# Apply replacements for all tokens at once
|
154
|
+
for token_id, pad_value in token_to_pad_mapping.items():
|
155
|
+
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
156
|
+
|
157
|
+
ret_input_ids = input_ids_tensor.tolist()
|
158
|
+
|
159
|
+
return ret_input_ids
|
195
160
|
|
196
161
|
|
197
162
|
embedding_cache = None
|
@@ -283,7 +248,9 @@ def _get_chunked_prefill_embedding(
|
|
283
248
|
) -> Optional[torch.Tensor]:
|
284
249
|
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
285
250
|
embedding_list = []
|
286
|
-
for
|
251
|
+
# FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
|
252
|
+
max_iterations = min(len(items_size) - 1, len(prefix_length))
|
253
|
+
for i in range(max_iterations):
|
287
254
|
if items_size[i] == items_size[i + 1]:
|
288
255
|
continue
|
289
256
|
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
@@ -304,7 +271,7 @@ def _get_chunked_prefill_embedding(
|
|
304
271
|
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
305
272
|
embedding=embedding_per_req,
|
306
273
|
extend_prefix_len=prefix_length[i],
|
307
|
-
extend_seq_len=extend_length[i],
|
274
|
+
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
308
275
|
items_offset=items_offset,
|
309
276
|
)
|
310
277
|
# remove this item from cache if chunk reaches to the end
|
@@ -680,3 +647,52 @@ def get_multimodal_data_bounds(
|
|
680
647
|
# Convert valid pairs to tensor
|
681
648
|
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
682
649
|
return valid_pairs_tensor
|
650
|
+
|
651
|
+
|
652
|
+
def data_hash(data) -> int:
|
653
|
+
hash_bytes = hashlib.sha256(data).digest()[:8]
|
654
|
+
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
655
|
+
|
656
|
+
|
657
|
+
def tensor_hash(tensor_list) -> int:
|
658
|
+
"""
|
659
|
+
hash a tensor or a tensor list
|
660
|
+
"""
|
661
|
+
tensor = tensor_list
|
662
|
+
if isinstance(tensor_list, list):
|
663
|
+
tensor_list = flatten_nested_list(tensor_list)
|
664
|
+
tensor_list = [
|
665
|
+
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
|
666
|
+
]
|
667
|
+
tensor = torch.concat(tensor_list)
|
668
|
+
if tensor.is_cuda:
|
669
|
+
return gpu_tensor_hash(tensor)
|
670
|
+
tensor = tensor.detach().contiguous()
|
671
|
+
|
672
|
+
if tensor.dtype == torch.bfloat16:
|
673
|
+
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
674
|
+
tensor = tensor.float()
|
675
|
+
|
676
|
+
assert isinstance(tensor, torch.Tensor)
|
677
|
+
if tensor.is_cuda:
|
678
|
+
# TODO: improve this
|
679
|
+
tensor_cpu = tensor.cpu()
|
680
|
+
else:
|
681
|
+
tensor_cpu = tensor
|
682
|
+
|
683
|
+
mv = memoryview(tensor_cpu.numpy())
|
684
|
+
return data_hash(mv.tobytes())
|
685
|
+
|
686
|
+
|
687
|
+
def hash_feature(f):
|
688
|
+
if isinstance(f, list):
|
689
|
+
if isinstance(f[0], torch.Tensor):
|
690
|
+
return tensor_hash(f)
|
691
|
+
return data_hash(tuple(flatten_nested_list(f)))
|
692
|
+
elif isinstance(f, np.ndarray):
|
693
|
+
arr = np.ascontiguousarray(f)
|
694
|
+
arr_bytes = arr.tobytes()
|
695
|
+
return data_hash(arr_bytes)
|
696
|
+
elif isinstance(f, torch.Tensor):
|
697
|
+
return tensor_hash([f])
|
698
|
+
return data_hash(f)
|
@@ -3,11 +3,8 @@ import importlib
|
|
3
3
|
import inspect
|
4
4
|
import logging
|
5
5
|
import pkgutil
|
6
|
-
from functools import lru_cache
|
7
6
|
|
8
|
-
from sglang.srt.
|
9
|
-
BaseMultimodalProcessor,
|
10
|
-
)
|
7
|
+
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
11
8
|
from sglang.srt.server_args import ServerArgs
|
12
9
|
|
13
10
|
logger = logging.getLogger(__name__)
|
@@ -27,9 +24,8 @@ def get_dummy_processor():
|
|
27
24
|
return DummyMultimodalProcessor()
|
28
25
|
|
29
26
|
|
30
|
-
@lru_cache()
|
31
27
|
def import_processors():
|
32
|
-
package_name = "sglang.srt.
|
28
|
+
package_name = "sglang.srt.multimodal.processors"
|
33
29
|
package = importlib.import_module(package_name)
|
34
30
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
35
31
|
if not ispkg:
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import re
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
7
|
+
BaseMultimodalProcessor,
|
8
|
+
MultimodalSpecialTokens,
|
9
|
+
)
|
10
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
11
|
+
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
|
12
|
+
|
13
|
+
|
14
|
+
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
15
|
+
models = [Qwen2AudioForConditionalGeneration]
|
16
|
+
|
17
|
+
def __init__(self, hf_config, server_args, _processor):
|
18
|
+
super().__init__(hf_config, server_args, _processor)
|
19
|
+
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
20
|
+
self.AUDIO_TOKEN_REGEX = re.compile(
|
21
|
+
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
22
|
+
)
|
23
|
+
|
24
|
+
async def process_mm_data_async(
|
25
|
+
self,
|
26
|
+
image_data: List[Union[str, bytes]],
|
27
|
+
input_text,
|
28
|
+
request_obj,
|
29
|
+
max_req_input_len,
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
audio_data = request_obj.audio_data
|
33
|
+
if not isinstance(audio_data, list):
|
34
|
+
audio_data = [audio_data]
|
35
|
+
|
36
|
+
base_output = self.load_mm_data(
|
37
|
+
prompt=input_text,
|
38
|
+
max_req_input_len=max_req_input_len,
|
39
|
+
audio_data=audio_data,
|
40
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
41
|
+
audio_token=self.AUDIO_TOKEN,
|
42
|
+
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
43
|
+
),
|
44
|
+
)
|
45
|
+
if base_output is None:
|
46
|
+
return None
|
47
|
+
|
48
|
+
res = self.process_mm_data(
|
49
|
+
input_text=base_output.input_text,
|
50
|
+
audio=base_output.audios,
|
51
|
+
)
|
52
|
+
|
53
|
+
# Collect special token ids
|
54
|
+
tokenizer = self._processor.tokenizer
|
55
|
+
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
|
56
|
+
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
|
57
|
+
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
|
58
|
+
|
59
|
+
items = []
|
60
|
+
input_ids = res["input_ids"].flatten()
|
61
|
+
|
62
|
+
if (
|
63
|
+
"input_features" in res
|
64
|
+
and res["input_features"] is not None
|
65
|
+
and len(res["input_features"]) != 0
|
66
|
+
):
|
67
|
+
if audio_start_id is not None and audio_end_id is not None:
|
68
|
+
audio_offsets = self.get_mm_items_offset_by_pair(
|
69
|
+
input_ids=input_ids,
|
70
|
+
mm_start_id=audio_start_id,
|
71
|
+
mm_end_id=audio_end_id,
|
72
|
+
)
|
73
|
+
else:
|
74
|
+
audio_offsets = None
|
75
|
+
|
76
|
+
input_lengths = res["feature_attention_mask"].sum(dim=-1)
|
77
|
+
input_lengths = (input_lengths - 1) // 2 + 1
|
78
|
+
output_lengths = (input_lengths - 2) // 2 + 1
|
79
|
+
|
80
|
+
item = MultimodalDataItem(
|
81
|
+
audio_features=res["input_features"],
|
82
|
+
audio_feature_lens=output_lengths,
|
83
|
+
audio_offsets=audio_offsets,
|
84
|
+
modality=Modality.AUDIO,
|
85
|
+
)
|
86
|
+
items += [item]
|
87
|
+
|
88
|
+
return {
|
89
|
+
"mm_items": items,
|
90
|
+
"input_ids": input_ids.tolist(),
|
91
|
+
"audio_start_id": audio_start_id,
|
92
|
+
"audio_token_id": audio_token_id,
|
93
|
+
"audio_end_id": audio_end_id,
|
94
|
+
}
|