sglang 0.4.9.post5__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +3 -0
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/layers/moe/ep_moe/layer.py +152 -37
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/scheduler.py +29 -0
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +43 -9
- sglang/srt/managers/tp_worker.py +5 -0
- sglang/srt/model_executor/model_runner.py +15 -13
- sglang/srt/models/deepseek_v2.py +13 -56
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/server_args.py +8 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +2 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +32 -25
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 256,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 256,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 256,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 256,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 2
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
File without changes
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from enum import Enum, auto
|
5
|
+
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
class DispatchOutputFormat(Enum):
|
11
|
+
standard = auto()
|
12
|
+
deepep_normal = auto()
|
13
|
+
deepep_ll = auto()
|
14
|
+
|
15
|
+
def is_standard(self) -> bool:
|
16
|
+
return self == DispatchOutputFormat.standard
|
17
|
+
|
18
|
+
def is_deepep_normal(self) -> bool:
|
19
|
+
return self == DispatchOutputFormat.deepep_normal
|
20
|
+
|
21
|
+
def is_deepep_ll(self) -> bool:
|
22
|
+
return self == DispatchOutputFormat.deepep_ll
|
23
|
+
|
24
|
+
|
25
|
+
@runtime_checkable
|
26
|
+
class DispatchOutput(Protocol):
|
27
|
+
"""Protocol for dispatch outputs in different formats."""
|
28
|
+
|
29
|
+
@property
|
30
|
+
def format(self) -> DispatchOutputFormat: ...
|
31
|
+
|
32
|
+
|
33
|
+
class BaseDispatcherConfig(ABC):
|
34
|
+
"""Base class for dispatcher configs."""
|
35
|
+
|
36
|
+
pass
|
37
|
+
|
38
|
+
|
39
|
+
class BaseDispatcher(ABC):
|
40
|
+
"""Base class for dispatchers."""
|
41
|
+
|
42
|
+
@abstractmethod
|
43
|
+
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
44
|
+
pass
|
45
|
+
|
46
|
+
@abstractmethod
|
47
|
+
def combine(self, *args, **kwargs) -> torch.Tensor:
|
48
|
+
pass
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import NamedTuple
|
4
|
+
|
5
|
+
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
6
|
+
DispatchOutput,
|
7
|
+
DispatchOutputFormat,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
class StandardDispatchOutput(NamedTuple):
|
12
|
+
"""Standard dispatch output."""
|
13
|
+
|
14
|
+
@property
|
15
|
+
def format(self) -> DispatchOutputFormat:
|
16
|
+
return DispatchOutputFormat.standard
|
17
|
+
|
18
|
+
|
19
|
+
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -397,7 +397,9 @@ def grouped_topk_gpu(
|
|
397
397
|
.reshape(num_token, -1)
|
398
398
|
) # [n, e]
|
399
399
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
400
|
-
topk_weights, topk_ids = torch.topk(
|
400
|
+
topk_weights, topk_ids = torch.topk(
|
401
|
+
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
|
402
|
+
)
|
401
403
|
if num_fused_shared_experts:
|
402
404
|
topk_ids[:, -1] = torch.randint(
|
403
405
|
low=num_experts,
|
@@ -486,7 +488,9 @@ def biased_grouped_topk_impl(
|
|
486
488
|
tmp_scores = scores_for_choice.masked_fill(
|
487
489
|
~score_mask.bool(), float("-inf")
|
488
490
|
) # [n, e]
|
489
|
-
_, topk_ids = torch.topk(
|
491
|
+
_, topk_ids = torch.topk(
|
492
|
+
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
|
493
|
+
)
|
490
494
|
topk_weights = scores.gather(1, topk_ids)
|
491
495
|
|
492
496
|
if num_fused_shared_experts:
|
@@ -900,6 +900,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
900
900
|
layer.w13_blockscale_swizzled = Parameter(
|
901
901
|
w13_blockscale_swizzled, requires_grad=False
|
902
902
|
)
|
903
|
+
del layer.w13_weight_scale
|
903
904
|
|
904
905
|
# This is for quantization, so we need to invert it.
|
905
906
|
layer.w13_input_scale_quant = Parameter(
|
@@ -935,6 +936,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
935
936
|
layer.w2_blockscale_swizzled = Parameter(
|
936
937
|
w2_blockscale_swizzled, requires_grad=False
|
937
938
|
)
|
939
|
+
del layer.w2_weight_scale
|
938
940
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
939
941
|
|
940
942
|
device = layer.w13_weight.device
|
@@ -26,6 +26,7 @@ import zmq
|
|
26
26
|
|
27
27
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
28
28
|
from sglang.srt.managers.io_struct import (
|
29
|
+
BlockReqInput,
|
29
30
|
TokenizedEmbeddingReqInput,
|
30
31
|
TokenizedGenerateReqInput,
|
31
32
|
)
|
@@ -282,6 +283,9 @@ class DataParallelController:
|
|
282
283
|
),
|
283
284
|
):
|
284
285
|
self.dispatching(recv_req)
|
286
|
+
elif isinstance(recv_req, BlockReqInput):
|
287
|
+
for worker in self.workers:
|
288
|
+
worker.send_pyobj(recv_req)
|
285
289
|
else:
|
286
290
|
# Send other control messages to first worker of tp group
|
287
291
|
for worker in self.workers[:: self.control_message_step]:
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -911,6 +911,8 @@ class AbortReq:
|
|
911
911
|
rid: str = ""
|
912
912
|
# Whether to abort all requests
|
913
913
|
abort_all: bool = False
|
914
|
+
# The finished reason data
|
915
|
+
finished_reason: Optional[Dict[str, Any]] = None
|
914
916
|
|
915
917
|
|
916
918
|
@dataclass
|
@@ -1101,3 +1103,13 @@ class LoRAUpdateResult:
|
|
1101
1103
|
|
1102
1104
|
|
1103
1105
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
1106
|
+
|
1107
|
+
|
1108
|
+
class BlockReqType(Enum):
|
1109
|
+
BLOCK = 1
|
1110
|
+
UNBLOCK = 2
|
1111
|
+
|
1112
|
+
|
1113
|
+
@dataclass
|
1114
|
+
class BlockReqInput:
|
1115
|
+
type: BlockReqType
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -24,6 +24,7 @@ import time
|
|
24
24
|
from collections import defaultdict, deque
|
25
25
|
from concurrent import futures
|
26
26
|
from dataclasses import dataclass
|
27
|
+
from http import HTTPStatus
|
27
28
|
from pathlib import Path
|
28
29
|
from types import SimpleNamespace
|
29
30
|
from typing import Dict, List, Optional, Tuple, Union
|
@@ -122,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
|
|
122
123
|
PrefillAdder,
|
123
124
|
SchedulePolicy,
|
124
125
|
)
|
126
|
+
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
125
127
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
126
128
|
SchedulerOutputProcessorMixin,
|
127
129
|
)
|
@@ -370,6 +372,7 @@ class Scheduler(
|
|
370
372
|
self.max_total_num_tokens,
|
371
373
|
self.max_prefill_tokens,
|
372
374
|
self.max_running_requests,
|
375
|
+
self.max_queued_requests,
|
373
376
|
self.max_req_len,
|
374
377
|
self.max_req_input_len,
|
375
378
|
self.random_seed,
|
@@ -502,6 +505,12 @@ class Scheduler(
|
|
502
505
|
)
|
503
506
|
self.init_profier()
|
504
507
|
|
508
|
+
self.input_blocker = (
|
509
|
+
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
510
|
+
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
511
|
+
else None
|
512
|
+
)
|
513
|
+
|
505
514
|
# Init metrics stats
|
506
515
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
507
516
|
self.init_kv_events(server_args.kv_events_config)
|
@@ -1033,6 +1042,9 @@ class Scheduler(
|
|
1033
1042
|
else:
|
1034
1043
|
recv_reqs = None
|
1035
1044
|
|
1045
|
+
if self.input_blocker is not None:
|
1046
|
+
recv_reqs = self.input_blocker.handle(recv_reqs)
|
1047
|
+
|
1036
1048
|
if self.server_args.enable_dp_attention:
|
1037
1049
|
if self.attn_tp_rank == 0:
|
1038
1050
|
work_reqs = [
|
@@ -1086,6 +1098,19 @@ class Scheduler(
|
|
1086
1098
|
self.return_health_check_ct += 1
|
1087
1099
|
continue
|
1088
1100
|
|
1101
|
+
# If it is a work request, accept or reject the request based on the request queue size.
|
1102
|
+
if is_work_request(recv_req):
|
1103
|
+
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
1104
|
+
abort_req = AbortReq(
|
1105
|
+
recv_req.rid,
|
1106
|
+
finished_reason={
|
1107
|
+
"type": "abort",
|
1108
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1109
|
+
"message": "The request queue is full.",
|
1110
|
+
},
|
1111
|
+
)
|
1112
|
+
self.send_to_tokenizer.send_pyobj(abort_req)
|
1113
|
+
continue
|
1089
1114
|
output = self._request_dispatcher(recv_req)
|
1090
1115
|
if output is not None:
|
1091
1116
|
if isinstance(output, RpcReqOutput):
|
@@ -2902,6 +2927,10 @@ def is_health_check_generate_req(recv_req):
|
|
2902
2927
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
2903
2928
|
|
2904
2929
|
|
2930
|
+
def is_work_request(recv_req):
|
2931
|
+
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
2932
|
+
|
2933
|
+
|
2905
2934
|
def _export_static_state(model):
|
2906
2935
|
return dict(
|
2907
2936
|
buffers=[
|
@@ -0,0 +1,106 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
import logging
|
15
|
+
from contextlib import contextmanager
|
16
|
+
from enum import Enum, auto
|
17
|
+
from typing import Any, List, Optional
|
18
|
+
|
19
|
+
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
|
20
|
+
from sglang.srt.poll_based_barrier import PollBasedBarrier
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
class SchedulerInputBlocker:
|
26
|
+
def __init__(self, noop: bool):
|
27
|
+
self._state = _State.UNBLOCKED
|
28
|
+
self._pending_reqs = []
|
29
|
+
self._noop = noop
|
30
|
+
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
|
31
|
+
|
32
|
+
def handle(self, recv_reqs: Optional[List[Any]]):
|
33
|
+
assert (recv_reqs is None) == self._noop
|
34
|
+
|
35
|
+
if not self._noop:
|
36
|
+
output_reqs = []
|
37
|
+
for recv_req in recv_reqs:
|
38
|
+
output_reqs += self._handle_recv_req(recv_req)
|
39
|
+
|
40
|
+
global_arrived_unblock_barrier = (
|
41
|
+
self._global_unblock_barrier.poll_global_arrived()
|
42
|
+
)
|
43
|
+
if (
|
44
|
+
self._state == _State.GLOBAL_UNBLOCK_BARRIER
|
45
|
+
and global_arrived_unblock_barrier
|
46
|
+
):
|
47
|
+
output_reqs += self._handle_arrive_unblock_barrier()
|
48
|
+
|
49
|
+
if not self._noop:
|
50
|
+
return output_reqs
|
51
|
+
|
52
|
+
def _handle_recv_req(self, recv_req):
|
53
|
+
if isinstance(recv_req, BlockReqInput):
|
54
|
+
if recv_req.type == BlockReqType.BLOCK:
|
55
|
+
self._execute_block_req()
|
56
|
+
return []
|
57
|
+
elif recv_req.type == BlockReqType.UNBLOCK:
|
58
|
+
self._execute_unblock_req()
|
59
|
+
return []
|
60
|
+
else:
|
61
|
+
raise NotImplementedError(f"{recv_req=}")
|
62
|
+
else:
|
63
|
+
if self._state == _State.UNBLOCKED:
|
64
|
+
return [recv_req]
|
65
|
+
else:
|
66
|
+
self._pending_reqs.append(recv_req)
|
67
|
+
return []
|
68
|
+
|
69
|
+
def _execute_block_req(self):
|
70
|
+
logger.info("Handle block req")
|
71
|
+
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
|
72
|
+
|
73
|
+
def _execute_unblock_req(self):
|
74
|
+
logger.info("Handle unblock req")
|
75
|
+
self._change_state(
|
76
|
+
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
|
77
|
+
)
|
78
|
+
self._global_unblock_barrier.local_arrive()
|
79
|
+
|
80
|
+
def _handle_arrive_unblock_barrier(self):
|
81
|
+
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
|
82
|
+
self._change_state(
|
83
|
+
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
|
84
|
+
)
|
85
|
+
output_reqs = [*self._pending_reqs]
|
86
|
+
self._pending_reqs.clear()
|
87
|
+
return output_reqs
|
88
|
+
|
89
|
+
def _change_state(self, original: "_State", target: "_State"):
|
90
|
+
assert self._state == original, f"{self._state=} {original=} {target=}"
|
91
|
+
self._state = target
|
92
|
+
|
93
|
+
|
94
|
+
class _State(Enum):
|
95
|
+
UNBLOCKED = auto()
|
96
|
+
BLOCKED = auto()
|
97
|
+
GLOBAL_UNBLOCK_BARRIER = auto()
|
98
|
+
|
99
|
+
|
100
|
+
@contextmanager
|
101
|
+
def input_blocker_guard_region(send_to_scheduler):
|
102
|
+
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
|
103
|
+
try:
|
104
|
+
yield
|
105
|
+
finally:
|
106
|
+
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
|