sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
sglang/srt/managers/mm_utils.py
CHANGED
@@ -3,8 +3,9 @@ Multi-modality utils
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
import hashlib
|
6
|
+
import pickle
|
6
7
|
from abc import abstractmethod
|
7
|
-
from typing import Callable, Dict, List, Optional, Tuple
|
8
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import torch
|
@@ -27,6 +28,128 @@ from sglang.utils import logger
|
|
27
28
|
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
28
29
|
# in the console when multimodal support is enabled.
|
29
30
|
|
31
|
+
# TODO(mick): nccl
|
32
|
+
# cuda_ipc: for intranode tensor sharing
|
33
|
+
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
|
34
|
+
|
35
|
+
|
36
|
+
class TransportProxyTensor(torch.Tensor):
|
37
|
+
"""
|
38
|
+
A convenient torch.Tensor subclass that carries extra metadata and supports
|
39
|
+
efficient inter-process communications
|
40
|
+
"""
|
41
|
+
|
42
|
+
@staticmethod
|
43
|
+
def __new__(
|
44
|
+
cls,
|
45
|
+
data: torch.Tensor,
|
46
|
+
name: Optional[str] = None,
|
47
|
+
fields: Optional[Dict[str, Any]] = None,
|
48
|
+
transport_mode: TensorTransportMode = "default",
|
49
|
+
*args,
|
50
|
+
**kwargs,
|
51
|
+
):
|
52
|
+
|
53
|
+
if not isinstance(data, torch.Tensor):
|
54
|
+
raise TypeError(
|
55
|
+
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
|
56
|
+
)
|
57
|
+
|
58
|
+
instance = data.as_subclass(cls)
|
59
|
+
|
60
|
+
instance._metadata = {
|
61
|
+
"name": name,
|
62
|
+
"fields": fields if fields is not None else {},
|
63
|
+
"transport_mode": transport_mode,
|
64
|
+
}
|
65
|
+
|
66
|
+
return instance
|
67
|
+
|
68
|
+
def __getstate__(self):
|
69
|
+
"""
|
70
|
+
Called during pickling. Implements the serialization logic.
|
71
|
+
"""
|
72
|
+
# acquire all serialize metadata from _metadata
|
73
|
+
state = {
|
74
|
+
"metadata": self._metadata,
|
75
|
+
"tensor_data": None,
|
76
|
+
"ipc_extra": None,
|
77
|
+
}
|
78
|
+
|
79
|
+
transport_mode = self._metadata.get("transport_mode", "default")
|
80
|
+
|
81
|
+
if transport_mode == "cuda_ipc" and self.is_cuda:
|
82
|
+
try:
|
83
|
+
storage = self.untyped_storage()
|
84
|
+
handle = storage._share_cuda_()
|
85
|
+
|
86
|
+
state["ipc_extra"] = {
|
87
|
+
"handle": handle,
|
88
|
+
"shape": self.shape,
|
89
|
+
"dtype": self.dtype,
|
90
|
+
"stride": self.stride(),
|
91
|
+
"device_index": self.device.index,
|
92
|
+
}
|
93
|
+
state["tensor_data"] = None
|
94
|
+
except Exception as e:
|
95
|
+
# Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
|
96
|
+
state["metadata"]["transport_mode"] = "default"
|
97
|
+
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
98
|
+
else:
|
99
|
+
state["metadata"]["transport_mode"] = "default"
|
100
|
+
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
101
|
+
|
102
|
+
return state
|
103
|
+
|
104
|
+
def __setstate__(self, state: Dict[str, Any]):
|
105
|
+
"""
|
106
|
+
Called during unpickling. Implements the deserialization logic.
|
107
|
+
"""
|
108
|
+
self._metadata = state["metadata"]
|
109
|
+
|
110
|
+
transport_mode = self._metadata.get("transport_mode", "default")
|
111
|
+
|
112
|
+
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
|
113
|
+
ipc_extra = state["ipc_extra"]
|
114
|
+
handle, shape, dtype, stride, source_device_index = (
|
115
|
+
ipc_extra["handle"],
|
116
|
+
ipc_extra["shape"],
|
117
|
+
ipc_extra["dtype"],
|
118
|
+
ipc_extra["stride"],
|
119
|
+
ipc_extra["device_index"],
|
120
|
+
)
|
121
|
+
|
122
|
+
try:
|
123
|
+
target_device = torch.device(f"cuda:{source_device_index}")
|
124
|
+
with torch.cuda.device(target_device):
|
125
|
+
storage = torch.UntypedStorage._new_shared_cuda(*handle)
|
126
|
+
reconstructed_tensor = torch.empty(
|
127
|
+
0, dtype=dtype, device=target_device
|
128
|
+
).set_(storage, storage_offset=0, size=shape, stride=stride)
|
129
|
+
self.set_(reconstructed_tensor)
|
130
|
+
except Exception as e:
|
131
|
+
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
|
132
|
+
raise e
|
133
|
+
|
134
|
+
elif state["tensor_data"] is not None:
|
135
|
+
self.set_(state["tensor_data"])
|
136
|
+
else:
|
137
|
+
raise pickle.UnpicklingError(
|
138
|
+
"Invalid state for TransportProxyTensor: no tensor data found."
|
139
|
+
)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def name(self) -> Optional[str]:
|
143
|
+
return self._metadata.get("name")
|
144
|
+
|
145
|
+
@property
|
146
|
+
def fields(self) -> Dict[str, Any]:
|
147
|
+
return self._metadata.get("fields", {})
|
148
|
+
|
149
|
+
@property
|
150
|
+
def transport_mode(self) -> TensorTransportMode:
|
151
|
+
return self._metadata.get("transport_mode", "default")
|
152
|
+
|
30
153
|
|
31
154
|
class MultiModalityDataPaddingPattern:
|
32
155
|
"""
|
@@ -85,8 +208,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
85
208
|
"No data_token_pairs provided, RadixAttention might be influenced."
|
86
209
|
)
|
87
210
|
return input_ids
|
88
|
-
start_token_ids =
|
89
|
-
end_tokens_ids =
|
211
|
+
start_token_ids = {s for s, _e in data_token_pairs}
|
212
|
+
end_tokens_ids = {e for _s, e in data_token_pairs}
|
90
213
|
|
91
214
|
padded_ids = []
|
92
215
|
last_idx = 0
|
@@ -135,7 +258,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
135
258
|
if not input_ids or not mm_inputs.mm_items:
|
136
259
|
return input_ids
|
137
260
|
|
138
|
-
input_ids_tensor = torch.
|
261
|
+
input_ids_tensor = torch.as_tensor(input_ids)
|
139
262
|
|
140
263
|
# Create mapping of token_ids to pad_values for each modality
|
141
264
|
token_to_pad_mapping = {}
|
@@ -211,7 +334,7 @@ def get_embedding_chunk(
|
|
211
334
|
end_index += extend_end_index - start + 1
|
212
335
|
elif extend_end_index > end:
|
213
336
|
end_index += end - start + 1
|
214
|
-
# some models embedding is 3-dim, reshape it to 2-dim
|
337
|
+
# some models' embedding is 3-dim, reshape it to 2-dim
|
215
338
|
embedding = embedding.reshape(-1, embedding.shape[-1])
|
216
339
|
embedding_chunk = embedding[start_index:end_index]
|
217
340
|
return embedding_chunk, start_index, end_index
|
@@ -428,7 +551,7 @@ def embed_mm_inputs(
|
|
428
551
|
modality_id = modality.name.lower()
|
429
552
|
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
430
553
|
if len(items) != 0 and embedder is not None:
|
431
|
-
placeholder_tensor = torch.
|
554
|
+
placeholder_tensor = torch.as_tensor(
|
432
555
|
[item.pad_value for item in items],
|
433
556
|
device=input_ids.device,
|
434
557
|
)
|
@@ -473,11 +596,9 @@ def embed_mm_inputs(
|
|
473
596
|
for embedding, mask in zip(embeddings, masks):
|
474
597
|
if embedding is None or mask is None:
|
475
598
|
continue
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
480
|
-
)
|
599
|
+
# in-place update
|
600
|
+
indices = torch.where(mask.squeeze(dim=-1))[0]
|
601
|
+
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
481
602
|
return inputs_embeds
|
482
603
|
|
483
604
|
|
@@ -561,34 +682,36 @@ def get_multimodal_data_bounds(
|
|
561
682
|
[bounds_count, 2]
|
562
683
|
"""
|
563
684
|
# All the multimodal data in the batch should share the same special bound token ids.
|
564
|
-
start_tokens =
|
565
|
-
end_tokens =
|
685
|
+
start_tokens = {s for s, _e in token_pairs}
|
686
|
+
end_tokens = {e for _s, e in token_pairs}
|
566
687
|
|
567
688
|
assert all(isinstance(t, int) for t in start_tokens)
|
568
689
|
assert all(isinstance(t, int) for t in end_tokens)
|
569
690
|
|
570
691
|
start_cond = torch.isin(
|
571
|
-
input_ids, torch.
|
692
|
+
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
|
693
|
+
)
|
694
|
+
end_cond = torch.isin(
|
695
|
+
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
|
572
696
|
)
|
573
|
-
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
574
697
|
|
575
698
|
(data_start_tokens,) = torch.where(start_cond)
|
576
699
|
(data_end_tokens,) = torch.where(end_cond)
|
577
700
|
|
701
|
+
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
|
702
|
+
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
|
703
|
+
|
578
704
|
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
579
|
-
if len(
|
705
|
+
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
|
580
706
|
if (
|
581
|
-
len(
|
582
|
-
and input_ids[0] in pad_values
|
583
|
-
and
|
707
|
+
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
|
708
|
+
and input_ids[0].item() in pad_values
|
709
|
+
and data_end_tokens_cpu
|
710
|
+
and data_start_tokens_cpu
|
711
|
+
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
|
584
712
|
):
|
585
|
-
|
586
|
-
|
587
|
-
torch.tensor([0], device=data_start_tokens.device),
|
588
|
-
data_start_tokens,
|
589
|
-
]
|
590
|
-
)
|
591
|
-
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
713
|
+
data_start_tokens_cpu.insert(0, 0)
|
714
|
+
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
|
592
715
|
|
593
716
|
if valid_mm_data_nums == 0:
|
594
717
|
return torch.zeros((0, 2), device=input_ids.device)
|
@@ -596,8 +719,8 @@ def get_multimodal_data_bounds(
|
|
596
719
|
# Filter out pairs where start_token >= end_token
|
597
720
|
valid_pairs = []
|
598
721
|
for i in range(valid_mm_data_nums):
|
599
|
-
start_token =
|
600
|
-
end_token =
|
722
|
+
start_token = data_start_tokens_cpu[i]
|
723
|
+
end_token = data_end_tokens_cpu[i]
|
601
724
|
if start_token < end_token:
|
602
725
|
valid_pairs.append((start_token + 1, end_token - 1))
|
603
726
|
|
@@ -605,7 +728,7 @@ def get_multimodal_data_bounds(
|
|
605
728
|
return torch.zeros((0, 2), device=input_ids.device)
|
606
729
|
|
607
730
|
# Convert valid pairs to tensor
|
608
|
-
valid_pairs_tensor = torch.
|
731
|
+
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
|
609
732
|
return valid_pairs_tensor
|
610
733
|
|
611
734
|
|
@@ -626,7 +749,7 @@ def tensor_hash(tensor_list) -> int:
|
|
626
749
|
]
|
627
750
|
tensor = torch.concat(tensor_list)
|
628
751
|
if tensor.is_cuda:
|
629
|
-
return gpu_tensor_hash(tensor)
|
752
|
+
return gpu_tensor_hash(tensor.cuda())
|
630
753
|
tensor = tensor.detach().contiguous()
|
631
754
|
|
632
755
|
if tensor.dtype == torch.bfloat16:
|
@@ -634,11 +757,7 @@ def tensor_hash(tensor_list) -> int:
|
|
634
757
|
tensor = tensor.float()
|
635
758
|
|
636
759
|
assert isinstance(tensor, torch.Tensor)
|
637
|
-
|
638
|
-
# TODO: improve this
|
639
|
-
tensor_cpu = tensor.cpu()
|
640
|
-
else:
|
641
|
-
tensor_cpu = tensor
|
760
|
+
tensor_cpu = tensor.cpu()
|
642
761
|
|
643
762
|
mv = memoryview(tensor_cpu.numpy())
|
644
763
|
return data_hash(mv.tobytes())
|
@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
|
|
12
12
|
PROCESSOR_MAPPING = {}
|
13
13
|
|
14
14
|
|
15
|
-
class DummyMultimodalProcessor(BaseMultimodalProcessor):
|
16
|
-
def __init__(self):
|
17
|
-
pass
|
18
|
-
|
19
|
-
async def process_mm_data_async(self, *args, **kwargs):
|
20
|
-
return None
|
21
|
-
|
22
|
-
|
23
|
-
def get_dummy_processor():
|
24
|
-
return DummyMultimodalProcessor()
|
25
|
-
|
26
|
-
|
27
15
|
def import_processors():
|
28
16
|
package_name = "sglang.srt.multimodal.processors"
|
29
17
|
package = importlib.import_module(package_name)
|
@@ -49,11 +37,12 @@ def import_processors():
|
|
49
37
|
|
50
38
|
|
51
39
|
def get_mm_processor(
|
52
|
-
hf_config, server_args: ServerArgs, processor
|
40
|
+
hf_config, server_args: ServerArgs, processor, transport_mode
|
53
41
|
) -> BaseMultimodalProcessor:
|
54
42
|
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
55
43
|
if model_cls.__name__ in hf_config.architectures:
|
56
|
-
return processor_cls(hf_config, server_args, processor)
|
44
|
+
return processor_cls(hf_config, server_args, processor, transport_mode)
|
45
|
+
|
57
46
|
raise ValueError(
|
58
47
|
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
59
48
|
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
88
88
|
"enable_deepep_moe",
|
89
89
|
"deepep_mode",
|
90
90
|
"enable_ep_moe",
|
91
|
-
"
|
91
|
+
"enable_flashinfer_cutlass_moe",
|
92
|
+
"enable_flashinfer_trtllm_moe",
|
92
93
|
"enable_flashinfer_allreduce_fusion",
|
93
94
|
"moe_dense_tp_size",
|
94
95
|
"ep_dispatch_algorithm",
|
@@ -209,10 +210,11 @@ class MultimodalDataItem:
|
|
209
210
|
hash: int = None
|
210
211
|
pad_value: int = None
|
211
212
|
offsets: Optional[list] = None
|
213
|
+
|
212
214
|
# the raw features returned by processor, e.g. pixel_values or audio_features
|
213
215
|
feature: Union[torch.Tensor, np.ndarray] = None
|
214
|
-
|
215
|
-
#
|
216
|
+
# the precomputed embeddings, passed as final encoder embeddings
|
217
|
+
# One and only one of the feature and precomputed_embeddings will be empty
|
216
218
|
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
217
219
|
|
218
220
|
# Model-specific data stored in a dictionary
|
@@ -1688,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1688
1690
|
extend_prefix_lens = self.prefix_lens
|
1689
1691
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
1690
1692
|
|
1693
|
+
if self.forward_mode.is_decode_or_idle():
|
1694
|
+
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
1695
|
+
else:
|
1696
|
+
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
1691
1697
|
# Create seq_lens_cpu when needed
|
1692
1698
|
if (
|
1693
|
-
|
1699
|
+
attention_backend_str == "fa3"
|
1694
1700
|
or (
|
1695
1701
|
global_server_args_dict["use_mla_backend"]
|
1696
|
-
and
|
1702
|
+
and attention_backend_str == "flashinfer"
|
1697
1703
|
)
|
1698
|
-
or
|
1699
|
-
or
|
1700
|
-
or
|
1704
|
+
or attention_backend_str == "flashmla"
|
1705
|
+
or attention_backend_str == "cutlass_mla"
|
1706
|
+
or attention_backend_str == "ascend"
|
1701
1707
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1702
1708
|
):
|
1703
1709
|
seq_lens_cpu = (
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -24,6 +24,7 @@ import time
|
|
24
24
|
from collections import defaultdict, deque
|
25
25
|
from concurrent import futures
|
26
26
|
from dataclasses import dataclass
|
27
|
+
from http import HTTPStatus
|
27
28
|
from pathlib import Path
|
28
29
|
from types import SimpleNamespace
|
29
30
|
from typing import Dict, List, Optional, Tuple, Union
|
@@ -122,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
|
|
122
123
|
PrefillAdder,
|
123
124
|
SchedulePolicy,
|
124
125
|
)
|
126
|
+
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
125
127
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
126
128
|
SchedulerOutputProcessorMixin,
|
127
129
|
)
|
@@ -370,6 +372,7 @@ class Scheduler(
|
|
370
372
|
self.max_total_num_tokens,
|
371
373
|
self.max_prefill_tokens,
|
372
374
|
self.max_running_requests,
|
375
|
+
self.max_queued_requests,
|
373
376
|
self.max_req_len,
|
374
377
|
self.max_req_input_len,
|
375
378
|
self.random_seed,
|
@@ -458,7 +461,10 @@ class Scheduler(
|
|
458
461
|
self.grammar_queue: List[Req] = []
|
459
462
|
if not server_args.skip_tokenizer_init:
|
460
463
|
self.grammar_backend = create_grammar_backend(
|
461
|
-
server_args,
|
464
|
+
server_args,
|
465
|
+
self.tokenizer,
|
466
|
+
self.model_config.vocab_size,
|
467
|
+
self.model_config.hf_eos_token_id,
|
462
468
|
)
|
463
469
|
else:
|
464
470
|
self.grammar_backend = None
|
@@ -499,6 +505,12 @@ class Scheduler(
|
|
499
505
|
)
|
500
506
|
self.init_profier()
|
501
507
|
|
508
|
+
self.input_blocker = (
|
509
|
+
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
510
|
+
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
511
|
+
else None
|
512
|
+
)
|
513
|
+
|
502
514
|
# Init metrics stats
|
503
515
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
504
516
|
self.init_kv_events(server_args.kv_events_config)
|
@@ -1030,6 +1042,9 @@ class Scheduler(
|
|
1030
1042
|
else:
|
1031
1043
|
recv_reqs = None
|
1032
1044
|
|
1045
|
+
if self.input_blocker is not None:
|
1046
|
+
recv_reqs = self.input_blocker.handle(recv_reqs)
|
1047
|
+
|
1033
1048
|
if self.server_args.enable_dp_attention:
|
1034
1049
|
if self.attn_tp_rank == 0:
|
1035
1050
|
work_reqs = [
|
@@ -1083,6 +1098,19 @@ class Scheduler(
|
|
1083
1098
|
self.return_health_check_ct += 1
|
1084
1099
|
continue
|
1085
1100
|
|
1101
|
+
# If it is a work request, accept or reject the request based on the request queue size.
|
1102
|
+
if is_work_request(recv_req):
|
1103
|
+
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
1104
|
+
abort_req = AbortReq(
|
1105
|
+
recv_req.rid,
|
1106
|
+
finished_reason={
|
1107
|
+
"type": "abort",
|
1108
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1109
|
+
"message": "The request queue is full.",
|
1110
|
+
},
|
1111
|
+
)
|
1112
|
+
self.send_to_tokenizer.send_pyobj(abort_req)
|
1113
|
+
continue
|
1086
1114
|
output = self._request_dispatcher(recv_req)
|
1087
1115
|
if output is not None:
|
1088
1116
|
if isinstance(output, RpcReqOutput):
|
@@ -2437,6 +2465,37 @@ class Scheduler(
|
|
2437
2465
|
req.grammar.cancel()
|
2438
2466
|
req.set_finish_with_abort("Aborted by AbortReq.")
|
2439
2467
|
|
2468
|
+
# Delete requests not in the waiting queue when PD disaggregation is enabled
|
2469
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2470
|
+
# Abort requests that have not yet been bootstrapped
|
2471
|
+
for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
|
2472
|
+
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2473
|
+
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2474
|
+
if hasattr(req.disagg_kv_sender, "abort"):
|
2475
|
+
req.disagg_kv_sender.abort()
|
2476
|
+
|
2477
|
+
# Abort in-flight requests
|
2478
|
+
for i, req in enumerate(self.disagg_prefill_inflight_queue):
|
2479
|
+
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2480
|
+
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2481
|
+
if hasattr(req.disagg_kv_sender, "abort"):
|
2482
|
+
req.disagg_kv_sender.abort()
|
2483
|
+
|
2484
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2485
|
+
# Abort requests that have not yet finished preallocation
|
2486
|
+
for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
|
2487
|
+
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2488
|
+
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2489
|
+
if hasattr(decode_req.kv_receiver, "abort"):
|
2490
|
+
decode_req.kv_receiver.abort()
|
2491
|
+
|
2492
|
+
# Abort requests waiting for kvcache to release tree cache
|
2493
|
+
for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
|
2494
|
+
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2495
|
+
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2496
|
+
if hasattr(decode_req.kv_receiver, "abort"):
|
2497
|
+
decode_req.kv_receiver.abort()
|
2498
|
+
|
2440
2499
|
# Delete requests in the running batch
|
2441
2500
|
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
2442
2501
|
reqs = self.running_batch.reqs
|
@@ -2868,6 +2927,10 @@ def is_health_check_generate_req(recv_req):
|
|
2868
2927
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
2869
2928
|
|
2870
2929
|
|
2930
|
+
def is_work_request(recv_req):
|
2931
|
+
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
2932
|
+
|
2933
|
+
|
2871
2934
|
def _export_static_state(model):
|
2872
2935
|
return dict(
|
2873
2936
|
buffers=[
|
@@ -0,0 +1,106 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
import logging
|
15
|
+
from contextlib import contextmanager
|
16
|
+
from enum import Enum, auto
|
17
|
+
from typing import Any, List, Optional
|
18
|
+
|
19
|
+
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
|
20
|
+
from sglang.srt.poll_based_barrier import PollBasedBarrier
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
class SchedulerInputBlocker:
|
26
|
+
def __init__(self, noop: bool):
|
27
|
+
self._state = _State.UNBLOCKED
|
28
|
+
self._pending_reqs = []
|
29
|
+
self._noop = noop
|
30
|
+
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
|
31
|
+
|
32
|
+
def handle(self, recv_reqs: Optional[List[Any]]):
|
33
|
+
assert (recv_reqs is None) == self._noop
|
34
|
+
|
35
|
+
if not self._noop:
|
36
|
+
output_reqs = []
|
37
|
+
for recv_req in recv_reqs:
|
38
|
+
output_reqs += self._handle_recv_req(recv_req)
|
39
|
+
|
40
|
+
global_arrived_unblock_barrier = (
|
41
|
+
self._global_unblock_barrier.poll_global_arrived()
|
42
|
+
)
|
43
|
+
if (
|
44
|
+
self._state == _State.GLOBAL_UNBLOCK_BARRIER
|
45
|
+
and global_arrived_unblock_barrier
|
46
|
+
):
|
47
|
+
output_reqs += self._handle_arrive_unblock_barrier()
|
48
|
+
|
49
|
+
if not self._noop:
|
50
|
+
return output_reqs
|
51
|
+
|
52
|
+
def _handle_recv_req(self, recv_req):
|
53
|
+
if isinstance(recv_req, BlockReqInput):
|
54
|
+
if recv_req.type == BlockReqType.BLOCK:
|
55
|
+
self._execute_block_req()
|
56
|
+
return []
|
57
|
+
elif recv_req.type == BlockReqType.UNBLOCK:
|
58
|
+
self._execute_unblock_req()
|
59
|
+
return []
|
60
|
+
else:
|
61
|
+
raise NotImplementedError(f"{recv_req=}")
|
62
|
+
else:
|
63
|
+
if self._state == _State.UNBLOCKED:
|
64
|
+
return [recv_req]
|
65
|
+
else:
|
66
|
+
self._pending_reqs.append(recv_req)
|
67
|
+
return []
|
68
|
+
|
69
|
+
def _execute_block_req(self):
|
70
|
+
logger.info("Handle block req")
|
71
|
+
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
|
72
|
+
|
73
|
+
def _execute_unblock_req(self):
|
74
|
+
logger.info("Handle unblock req")
|
75
|
+
self._change_state(
|
76
|
+
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
|
77
|
+
)
|
78
|
+
self._global_unblock_barrier.local_arrive()
|
79
|
+
|
80
|
+
def _handle_arrive_unblock_barrier(self):
|
81
|
+
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
|
82
|
+
self._change_state(
|
83
|
+
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
|
84
|
+
)
|
85
|
+
output_reqs = [*self._pending_reqs]
|
86
|
+
self._pending_reqs.clear()
|
87
|
+
return output_reqs
|
88
|
+
|
89
|
+
def _change_state(self, original: "_State", target: "_State"):
|
90
|
+
assert self._state == original, f"{self._state=} {original=} {target=}"
|
91
|
+
self._state = target
|
92
|
+
|
93
|
+
|
94
|
+
class _State(Enum):
|
95
|
+
UNBLOCKED = auto()
|
96
|
+
BLOCKED = auto()
|
97
|
+
GLOBAL_UNBLOCK_BARRIER = auto()
|
98
|
+
|
99
|
+
|
100
|
+
@contextmanager
|
101
|
+
def input_blocker_guard_region(send_to_scheduler):
|
102
|
+
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
|
103
|
+
try:
|
104
|
+
yield
|
105
|
+
finally:
|
106
|
+
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
|