sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +84 -22
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +25 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +37 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +68 -14
- sglang/srt/models/deepseek_v2.py +62 -28
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +5 -2
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +57 -6
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/mm_utils.py
CHANGED
@@ -3,8 +3,9 @@ Multi-modality utils
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
import hashlib
|
6
|
+
import pickle
|
6
7
|
from abc import abstractmethod
|
7
|
-
from typing import Callable, Dict, List, Optional, Tuple
|
8
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import torch
|
@@ -27,6 +28,128 @@ from sglang.utils import logger
|
|
27
28
|
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
28
29
|
# in the console when multimodal support is enabled.
|
29
30
|
|
31
|
+
# TODO(mick): nccl
|
32
|
+
# cuda_ipc: for intranode tensor sharing
|
33
|
+
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
|
34
|
+
|
35
|
+
|
36
|
+
class TransportProxyTensor(torch.Tensor):
|
37
|
+
"""
|
38
|
+
A convenient torch.Tensor subclass that carries extra metadata and supports
|
39
|
+
efficient inter-process communications
|
40
|
+
"""
|
41
|
+
|
42
|
+
@staticmethod
|
43
|
+
def __new__(
|
44
|
+
cls,
|
45
|
+
data: torch.Tensor,
|
46
|
+
name: Optional[str] = None,
|
47
|
+
fields: Optional[Dict[str, Any]] = None,
|
48
|
+
transport_mode: TensorTransportMode = "default",
|
49
|
+
*args,
|
50
|
+
**kwargs,
|
51
|
+
):
|
52
|
+
|
53
|
+
if not isinstance(data, torch.Tensor):
|
54
|
+
raise TypeError(
|
55
|
+
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
|
56
|
+
)
|
57
|
+
|
58
|
+
instance = data.as_subclass(cls)
|
59
|
+
|
60
|
+
instance._metadata = {
|
61
|
+
"name": name,
|
62
|
+
"fields": fields if fields is not None else {},
|
63
|
+
"transport_mode": transport_mode,
|
64
|
+
}
|
65
|
+
|
66
|
+
return instance
|
67
|
+
|
68
|
+
def __getstate__(self):
|
69
|
+
"""
|
70
|
+
Called during pickling. Implements the serialization logic.
|
71
|
+
"""
|
72
|
+
# acquire all serialize metadata from _metadata
|
73
|
+
state = {
|
74
|
+
"metadata": self._metadata,
|
75
|
+
"tensor_data": None,
|
76
|
+
"ipc_extra": None,
|
77
|
+
}
|
78
|
+
|
79
|
+
transport_mode = self._metadata.get("transport_mode", "default")
|
80
|
+
|
81
|
+
if transport_mode == "cuda_ipc" and self.is_cuda:
|
82
|
+
try:
|
83
|
+
storage = self.untyped_storage()
|
84
|
+
handle = storage._share_cuda_()
|
85
|
+
|
86
|
+
state["ipc_extra"] = {
|
87
|
+
"handle": handle,
|
88
|
+
"shape": self.shape,
|
89
|
+
"dtype": self.dtype,
|
90
|
+
"stride": self.stride(),
|
91
|
+
"device_index": self.device.index,
|
92
|
+
}
|
93
|
+
state["tensor_data"] = None
|
94
|
+
except Exception as e:
|
95
|
+
# Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
|
96
|
+
state["metadata"]["transport_mode"] = "default"
|
97
|
+
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
98
|
+
else:
|
99
|
+
state["metadata"]["transport_mode"] = "default"
|
100
|
+
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
101
|
+
|
102
|
+
return state
|
103
|
+
|
104
|
+
def __setstate__(self, state: Dict[str, Any]):
|
105
|
+
"""
|
106
|
+
Called during unpickling. Implements the deserialization logic.
|
107
|
+
"""
|
108
|
+
self._metadata = state["metadata"]
|
109
|
+
|
110
|
+
transport_mode = self._metadata.get("transport_mode", "default")
|
111
|
+
|
112
|
+
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
|
113
|
+
ipc_extra = state["ipc_extra"]
|
114
|
+
handle, shape, dtype, stride, source_device_index = (
|
115
|
+
ipc_extra["handle"],
|
116
|
+
ipc_extra["shape"],
|
117
|
+
ipc_extra["dtype"],
|
118
|
+
ipc_extra["stride"],
|
119
|
+
ipc_extra["device_index"],
|
120
|
+
)
|
121
|
+
|
122
|
+
try:
|
123
|
+
target_device = torch.device(f"cuda:{source_device_index}")
|
124
|
+
with torch.cuda.device(target_device):
|
125
|
+
storage = torch.UntypedStorage._new_shared_cuda(*handle)
|
126
|
+
reconstructed_tensor = torch.empty(
|
127
|
+
0, dtype=dtype, device=target_device
|
128
|
+
).set_(storage, storage_offset=0, size=shape, stride=stride)
|
129
|
+
self.set_(reconstructed_tensor)
|
130
|
+
except Exception as e:
|
131
|
+
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
|
132
|
+
raise e
|
133
|
+
|
134
|
+
elif state["tensor_data"] is not None:
|
135
|
+
self.set_(state["tensor_data"])
|
136
|
+
else:
|
137
|
+
raise pickle.UnpicklingError(
|
138
|
+
"Invalid state for TransportProxyTensor: no tensor data found."
|
139
|
+
)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def name(self) -> Optional[str]:
|
143
|
+
return self._metadata.get("name")
|
144
|
+
|
145
|
+
@property
|
146
|
+
def fields(self) -> Dict[str, Any]:
|
147
|
+
return self._metadata.get("fields", {})
|
148
|
+
|
149
|
+
@property
|
150
|
+
def transport_mode(self) -> TensorTransportMode:
|
151
|
+
return self._metadata.get("transport_mode", "default")
|
152
|
+
|
30
153
|
|
31
154
|
class MultiModalityDataPaddingPattern:
|
32
155
|
"""
|
@@ -85,8 +208,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
85
208
|
"No data_token_pairs provided, RadixAttention might be influenced."
|
86
209
|
)
|
87
210
|
return input_ids
|
88
|
-
start_token_ids =
|
89
|
-
end_tokens_ids =
|
211
|
+
start_token_ids = {s for s, _e in data_token_pairs}
|
212
|
+
end_tokens_ids = {e for _s, e in data_token_pairs}
|
90
213
|
|
91
214
|
padded_ids = []
|
92
215
|
last_idx = 0
|
@@ -135,7 +258,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
135
258
|
if not input_ids or not mm_inputs.mm_items:
|
136
259
|
return input_ids
|
137
260
|
|
138
|
-
input_ids_tensor = torch.
|
261
|
+
input_ids_tensor = torch.as_tensor(input_ids)
|
139
262
|
|
140
263
|
# Create mapping of token_ids to pad_values for each modality
|
141
264
|
token_to_pad_mapping = {}
|
@@ -211,7 +334,7 @@ def get_embedding_chunk(
|
|
211
334
|
end_index += extend_end_index - start + 1
|
212
335
|
elif extend_end_index > end:
|
213
336
|
end_index += end - start + 1
|
214
|
-
# some models embedding is 3-dim, reshape it to 2-dim
|
337
|
+
# some models' embedding is 3-dim, reshape it to 2-dim
|
215
338
|
embedding = embedding.reshape(-1, embedding.shape[-1])
|
216
339
|
embedding_chunk = embedding[start_index:end_index]
|
217
340
|
return embedding_chunk, start_index, end_index
|
@@ -428,7 +551,7 @@ def embed_mm_inputs(
|
|
428
551
|
modality_id = modality.name.lower()
|
429
552
|
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
430
553
|
if len(items) != 0 and embedder is not None:
|
431
|
-
placeholder_tensor = torch.
|
554
|
+
placeholder_tensor = torch.as_tensor(
|
432
555
|
[item.pad_value for item in items],
|
433
556
|
device=input_ids.device,
|
434
557
|
)
|
@@ -473,11 +596,9 @@ def embed_mm_inputs(
|
|
473
596
|
for embedding, mask in zip(embeddings, masks):
|
474
597
|
if embedding is None or mask is None:
|
475
598
|
continue
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
480
|
-
)
|
599
|
+
# in-place update
|
600
|
+
indices = torch.where(mask.squeeze(dim=-1))[0]
|
601
|
+
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
481
602
|
return inputs_embeds
|
482
603
|
|
483
604
|
|
@@ -561,34 +682,36 @@ def get_multimodal_data_bounds(
|
|
561
682
|
[bounds_count, 2]
|
562
683
|
"""
|
563
684
|
# All the multimodal data in the batch should share the same special bound token ids.
|
564
|
-
start_tokens =
|
565
|
-
end_tokens =
|
685
|
+
start_tokens = {s for s, _e in token_pairs}
|
686
|
+
end_tokens = {e for _s, e in token_pairs}
|
566
687
|
|
567
688
|
assert all(isinstance(t, int) for t in start_tokens)
|
568
689
|
assert all(isinstance(t, int) for t in end_tokens)
|
569
690
|
|
570
691
|
start_cond = torch.isin(
|
571
|
-
input_ids, torch.
|
692
|
+
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
|
693
|
+
)
|
694
|
+
end_cond = torch.isin(
|
695
|
+
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
|
572
696
|
)
|
573
|
-
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
574
697
|
|
575
698
|
(data_start_tokens,) = torch.where(start_cond)
|
576
699
|
(data_end_tokens,) = torch.where(end_cond)
|
577
700
|
|
701
|
+
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
|
702
|
+
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
|
703
|
+
|
578
704
|
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
579
|
-
if len(
|
705
|
+
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
|
580
706
|
if (
|
581
|
-
len(
|
582
|
-
and input_ids[0] in pad_values
|
583
|
-
and
|
707
|
+
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
|
708
|
+
and input_ids[0].item() in pad_values
|
709
|
+
and data_end_tokens_cpu
|
710
|
+
and data_start_tokens_cpu
|
711
|
+
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
|
584
712
|
):
|
585
|
-
|
586
|
-
|
587
|
-
torch.tensor([0], device=data_start_tokens.device),
|
588
|
-
data_start_tokens,
|
589
|
-
]
|
590
|
-
)
|
591
|
-
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
713
|
+
data_start_tokens_cpu.insert(0, 0)
|
714
|
+
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
|
592
715
|
|
593
716
|
if valid_mm_data_nums == 0:
|
594
717
|
return torch.zeros((0, 2), device=input_ids.device)
|
@@ -596,8 +719,8 @@ def get_multimodal_data_bounds(
|
|
596
719
|
# Filter out pairs where start_token >= end_token
|
597
720
|
valid_pairs = []
|
598
721
|
for i in range(valid_mm_data_nums):
|
599
|
-
start_token =
|
600
|
-
end_token =
|
722
|
+
start_token = data_start_tokens_cpu[i]
|
723
|
+
end_token = data_end_tokens_cpu[i]
|
601
724
|
if start_token < end_token:
|
602
725
|
valid_pairs.append((start_token + 1, end_token - 1))
|
603
726
|
|
@@ -605,7 +728,7 @@ def get_multimodal_data_bounds(
|
|
605
728
|
return torch.zeros((0, 2), device=input_ids.device)
|
606
729
|
|
607
730
|
# Convert valid pairs to tensor
|
608
|
-
valid_pairs_tensor = torch.
|
731
|
+
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
|
609
732
|
return valid_pairs_tensor
|
610
733
|
|
611
734
|
|
@@ -626,7 +749,7 @@ def tensor_hash(tensor_list) -> int:
|
|
626
749
|
]
|
627
750
|
tensor = torch.concat(tensor_list)
|
628
751
|
if tensor.is_cuda:
|
629
|
-
return gpu_tensor_hash(tensor)
|
752
|
+
return gpu_tensor_hash(tensor.cuda())
|
630
753
|
tensor = tensor.detach().contiguous()
|
631
754
|
|
632
755
|
if tensor.dtype == torch.bfloat16:
|
@@ -634,11 +757,7 @@ def tensor_hash(tensor_list) -> int:
|
|
634
757
|
tensor = tensor.float()
|
635
758
|
|
636
759
|
assert isinstance(tensor, torch.Tensor)
|
637
|
-
|
638
|
-
# TODO: improve this
|
639
|
-
tensor_cpu = tensor.cpu()
|
640
|
-
else:
|
641
|
-
tensor_cpu = tensor
|
760
|
+
tensor_cpu = tensor.cpu()
|
642
761
|
|
643
762
|
mv = memoryview(tensor_cpu.numpy())
|
644
763
|
return data_hash(mv.tobytes())
|
@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
|
|
12
12
|
PROCESSOR_MAPPING = {}
|
13
13
|
|
14
14
|
|
15
|
-
class DummyMultimodalProcessor(BaseMultimodalProcessor):
|
16
|
-
def __init__(self):
|
17
|
-
pass
|
18
|
-
|
19
|
-
async def process_mm_data_async(self, *args, **kwargs):
|
20
|
-
return None
|
21
|
-
|
22
|
-
|
23
|
-
def get_dummy_processor():
|
24
|
-
return DummyMultimodalProcessor()
|
25
|
-
|
26
|
-
|
27
15
|
def import_processors():
|
28
16
|
package_name = "sglang.srt.multimodal.processors"
|
29
17
|
package = importlib.import_module(package_name)
|
@@ -49,11 +37,12 @@ def import_processors():
|
|
49
37
|
|
50
38
|
|
51
39
|
def get_mm_processor(
|
52
|
-
hf_config, server_args: ServerArgs, processor
|
40
|
+
hf_config, server_args: ServerArgs, processor, transport_mode
|
53
41
|
) -> BaseMultimodalProcessor:
|
54
42
|
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
55
43
|
if model_cls.__name__ in hf_config.architectures:
|
56
|
-
return processor_cls(hf_config, server_args, processor)
|
44
|
+
return processor_cls(hf_config, server_args, processor, transport_mode)
|
45
|
+
|
57
46
|
raise ValueError(
|
58
47
|
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
59
48
|
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
88
88
|
"enable_deepep_moe",
|
89
89
|
"deepep_mode",
|
90
90
|
"enable_ep_moe",
|
91
|
-
"
|
91
|
+
"enable_flashinfer_cutlass_moe",
|
92
|
+
"enable_flashinfer_trtllm_moe",
|
92
93
|
"enable_flashinfer_allreduce_fusion",
|
93
94
|
"moe_dense_tp_size",
|
94
95
|
"ep_dispatch_algorithm",
|
@@ -209,10 +210,11 @@ class MultimodalDataItem:
|
|
209
210
|
hash: int = None
|
210
211
|
pad_value: int = None
|
211
212
|
offsets: Optional[list] = None
|
213
|
+
|
212
214
|
# the raw features returned by processor, e.g. pixel_values or audio_features
|
213
215
|
feature: Union[torch.Tensor, np.ndarray] = None
|
214
|
-
|
215
|
-
#
|
216
|
+
# the precomputed embeddings, passed as final encoder embeddings
|
217
|
+
# One and only one of the feature and precomputed_embeddings will be empty
|
216
218
|
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
217
219
|
|
218
220
|
# Model-specific data stored in a dictionary
|
@@ -1688,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1688
1690
|
extend_prefix_lens = self.prefix_lens
|
1689
1691
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
1690
1692
|
|
1693
|
+
if self.forward_mode.is_decode_or_idle():
|
1694
|
+
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
1695
|
+
else:
|
1696
|
+
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
1691
1697
|
# Create seq_lens_cpu when needed
|
1692
1698
|
if (
|
1693
|
-
|
1699
|
+
attention_backend_str == "fa3"
|
1694
1700
|
or (
|
1695
1701
|
global_server_args_dict["use_mla_backend"]
|
1696
|
-
and
|
1702
|
+
and attention_backend_str == "flashinfer"
|
1697
1703
|
)
|
1698
|
-
or
|
1699
|
-
or
|
1700
|
-
or
|
1704
|
+
or attention_backend_str == "flashmla"
|
1705
|
+
or attention_backend_str == "cutlass_mla"
|
1706
|
+
or attention_backend_str == "ascend"
|
1701
1707
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1702
1708
|
):
|
1703
1709
|
seq_lens_cpu = (
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -458,7 +458,10 @@ class Scheduler(
|
|
458
458
|
self.grammar_queue: List[Req] = []
|
459
459
|
if not server_args.skip_tokenizer_init:
|
460
460
|
self.grammar_backend = create_grammar_backend(
|
461
|
-
server_args,
|
461
|
+
server_args,
|
462
|
+
self.tokenizer,
|
463
|
+
self.model_config.vocab_size,
|
464
|
+
self.model_config.hf_eos_token_id,
|
462
465
|
)
|
463
466
|
else:
|
464
467
|
self.grammar_backend = None
|
@@ -2437,6 +2440,37 @@ class Scheduler(
|
|
2437
2440
|
req.grammar.cancel()
|
2438
2441
|
req.set_finish_with_abort("Aborted by AbortReq.")
|
2439
2442
|
|
2443
|
+
# Delete requests not in the waiting queue when PD disaggregation is enabled
|
2444
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2445
|
+
# Abort requests that have not yet been bootstrapped
|
2446
|
+
for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
|
2447
|
+
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2448
|
+
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2449
|
+
if hasattr(req.disagg_kv_sender, "abort"):
|
2450
|
+
req.disagg_kv_sender.abort()
|
2451
|
+
|
2452
|
+
# Abort in-flight requests
|
2453
|
+
for i, req in enumerate(self.disagg_prefill_inflight_queue):
|
2454
|
+
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2455
|
+
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2456
|
+
if hasattr(req.disagg_kv_sender, "abort"):
|
2457
|
+
req.disagg_kv_sender.abort()
|
2458
|
+
|
2459
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2460
|
+
# Abort requests that have not yet finished preallocation
|
2461
|
+
for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
|
2462
|
+
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2463
|
+
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2464
|
+
if hasattr(decode_req.kv_receiver, "abort"):
|
2465
|
+
decode_req.kv_receiver.abort()
|
2466
|
+
|
2467
|
+
# Abort requests waiting for kvcache to release tree cache
|
2468
|
+
for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
|
2469
|
+
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2470
|
+
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2471
|
+
if hasattr(decode_req.kv_receiver, "abort"):
|
2472
|
+
decode_req.kv_receiver.abort()
|
2473
|
+
|
2440
2474
|
# Delete requests in the running batch
|
2441
2475
|
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
2442
2476
|
reqs = self.running_batch.reqs
|
@@ -112,6 +112,7 @@ from sglang.srt.managers.io_struct import (
|
|
112
112
|
UpdateWeightsFromTensorReqInput,
|
113
113
|
UpdateWeightsFromTensorReqOutput,
|
114
114
|
)
|
115
|
+
from sglang.srt.managers.mm_utils import TensorTransportMode
|
115
116
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
116
117
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
117
118
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -166,6 +167,16 @@ class ReqState:
|
|
166
167
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
167
168
|
|
168
169
|
|
170
|
+
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
171
|
+
is_cross_node = server_args.dist_init_addr
|
172
|
+
|
173
|
+
if is_cross_node:
|
174
|
+
# Fallback to default CPU transport for multi-node
|
175
|
+
return "default"
|
176
|
+
else:
|
177
|
+
return "cuda_ipc"
|
178
|
+
|
179
|
+
|
169
180
|
class TokenizerManager:
|
170
181
|
"""TokenizerManager is a process that tokenizes the text."""
|
171
182
|
|
@@ -216,12 +227,13 @@ class TokenizerManager:
|
|
216
227
|
revision=server_args.revision,
|
217
228
|
use_fast=not server_args.disable_fast_image_processor,
|
218
229
|
)
|
230
|
+
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
219
231
|
|
220
232
|
# We want to parallelize the image pre-processing so we create an executor for it
|
221
233
|
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
222
234
|
# images even with skip_tokenizer_init=False.
|
223
235
|
self.mm_processor = get_mm_processor(
|
224
|
-
self.model_config.hf_config, server_args, _processor
|
236
|
+
self.model_config.hf_config, server_args, _processor, transport_mode
|
225
237
|
)
|
226
238
|
|
227
239
|
if server_args.skip_tokenizer_init:
|
@@ -270,6 +282,11 @@ class TokenizerManager:
|
|
270
282
|
None
|
271
283
|
)
|
272
284
|
|
285
|
+
# Lock to serialize LoRA update operations.
|
286
|
+
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
287
|
+
# LoRA updates and inference to overlap.
|
288
|
+
self.lora_update_lock = asyncio.Lock()
|
289
|
+
|
273
290
|
# For pd disaggregtion
|
274
291
|
self.disaggregation_mode = DisaggregationMode(
|
275
292
|
self.server_args.disaggregation_mode
|
@@ -525,7 +542,8 @@ class TokenizerManager:
|
|
525
542
|
mm_inputs = None
|
526
543
|
|
527
544
|
if self.server_args.enable_lora and obj.lora_path:
|
528
|
-
#
|
545
|
+
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
546
|
+
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
529
547
|
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
530
548
|
|
531
549
|
self._validate_one_request(obj, input_ids)
|
@@ -735,6 +753,10 @@ class TokenizerManager:
|
|
735
753
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
736
754
|
logger.info(msg)
|
737
755
|
|
756
|
+
# Mark ongoing LoRA request as finished.
|
757
|
+
if self.server_args.enable_lora and obj.lora_path:
|
758
|
+
await self.lora_registry.release(obj.lora_path)
|
759
|
+
|
738
760
|
# Check if this was an abort/error created by scheduler
|
739
761
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
740
762
|
finish_reason = out["meta_info"]["finish_reason"]
|
@@ -1041,16 +1063,18 @@ class TokenizerManager:
|
|
1041
1063
|
obj.lora_path,
|
1042
1064
|
)
|
1043
1065
|
|
1044
|
-
async with self.
|
1066
|
+
async with self.lora_update_lock:
|
1045
1067
|
# Generate new uniquely identifiable LoRARef object.
|
1046
1068
|
new_adapter = LoRARef(
|
1047
1069
|
lora_name=obj.lora_name,
|
1048
1070
|
lora_path=obj.lora_path,
|
1049
1071
|
)
|
1050
1072
|
|
1051
|
-
#
|
1073
|
+
# Trigger the actual loading operation at the backend processes.
|
1052
1074
|
obj.lora_id = new_adapter.lora_id
|
1053
1075
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1076
|
+
|
1077
|
+
# Register the LoRA adapter only after loading is successful.
|
1054
1078
|
if result.success:
|
1055
1079
|
await self.lora_registry.register(new_adapter)
|
1056
1080
|
|
@@ -1081,8 +1105,15 @@ class TokenizerManager:
|
|
1081
1105
|
obj.lora_name,
|
1082
1106
|
)
|
1083
1107
|
|
1084
|
-
async with self.
|
1085
|
-
|
1108
|
+
async with self.lora_update_lock:
|
1109
|
+
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1110
|
+
# from being started.
|
1111
|
+
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1112
|
+
obj.lora_id = lora_id
|
1113
|
+
|
1114
|
+
# Initiate the actual unloading operation at the backend processes only after all
|
1115
|
+
# ongoing requests using this LoRA adapter are finished.
|
1116
|
+
await self.lora_registry.wait_for_unload(lora_id)
|
1086
1117
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1087
1118
|
|
1088
1119
|
return result
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
41
41
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
43
43
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
44
45
|
from sglang.srt.server_args import ServerArgs
|
45
46
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
46
47
|
|
@@ -278,6 +279,8 @@ class TpModelWorker:
|
|
278
279
|
return success, message
|
279
280
|
|
280
281
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
282
|
+
|
283
|
+
monkey_patch_torch_reductions()
|
281
284
|
success, message = self.model_runner.update_weights_from_tensor(
|
282
285
|
named_tensors=MultiprocessingSerializer.deserialize(
|
283
286
|
recv_req.serialized_named_tensors[self.tp_rank]
|
@@ -365,10 +365,12 @@ class HiRadixCache(RadixCache):
|
|
365
365
|
for _ in range(queue_size.item()):
|
366
366
|
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
367
367
|
if req_id in self.ongoing_prefetch:
|
368
|
-
last_host_node, _,
|
368
|
+
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
|
369
369
|
last_host_node.release_host()
|
370
|
-
self.cache_controller.mem_pool_host.free(host_indices)
|
371
370
|
del self.ongoing_prefetch[req_id]
|
371
|
+
else:
|
372
|
+
# the revoked operation already got terminated
|
373
|
+
pass
|
372
374
|
|
373
375
|
def check_backup_progress(self):
|
374
376
|
queue_size = torch.tensor(
|
@@ -403,6 +405,7 @@ class HiRadixCache(RadixCache):
|
|
403
405
|
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
404
406
|
req_id
|
405
407
|
]
|
408
|
+
|
406
409
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
407
410
|
operation
|
408
411
|
)
|