sglang 0.4.9.post5__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sglang/srt/configs/model_config.py +3 -0
  2. sglang/srt/entrypoints/http_server.py +13 -1
  3. sglang/srt/entrypoints/openai/protocol.py +3 -1
  4. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  5. sglang/srt/layers/moe/ep_moe/layer.py +152 -37
  6. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  7. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  8. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  9. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  10. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  11. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  12. sglang/srt/layers/moe/topk.py +6 -2
  13. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  14. sglang/srt/managers/data_parallel_controller.py +4 -0
  15. sglang/srt/managers/io_struct.py +12 -0
  16. sglang/srt/managers/scheduler.py +29 -0
  17. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  18. sglang/srt/managers/tokenizer_manager.py +43 -9
  19. sglang/srt/managers/tp_worker.py +5 -0
  20. sglang/srt/model_executor/model_runner.py +15 -13
  21. sglang/srt/models/deepseek_v2.py +13 -56
  22. sglang/srt/models/qwen3_moe.py +12 -69
  23. sglang/srt/poll_based_barrier.py +31 -0
  24. sglang/srt/server_args.py +8 -0
  25. sglang/srt/two_batch_overlap.py +8 -3
  26. sglang/test/test_utils.py +53 -0
  27. sglang/version.py +1 -1
  28. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +2 -1
  29. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +32 -25
  30. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  31. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  32. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 256,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
File without changes
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from enum import Enum, auto
5
+ from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
6
+
7
+ import torch
8
+
9
+
10
+ class DispatchOutputFormat(Enum):
11
+ standard = auto()
12
+ deepep_normal = auto()
13
+ deepep_ll = auto()
14
+
15
+ def is_standard(self) -> bool:
16
+ return self == DispatchOutputFormat.standard
17
+
18
+ def is_deepep_normal(self) -> bool:
19
+ return self == DispatchOutputFormat.deepep_normal
20
+
21
+ def is_deepep_ll(self) -> bool:
22
+ return self == DispatchOutputFormat.deepep_ll
23
+
24
+
25
+ @runtime_checkable
26
+ class DispatchOutput(Protocol):
27
+ """Protocol for dispatch outputs in different formats."""
28
+
29
+ @property
30
+ def format(self) -> DispatchOutputFormat: ...
31
+
32
+
33
+ class BaseDispatcherConfig(ABC):
34
+ """Base class for dispatcher configs."""
35
+
36
+ pass
37
+
38
+
39
+ class BaseDispatcher(ABC):
40
+ """Base class for dispatchers."""
41
+
42
+ @abstractmethod
43
+ def dispatch(self, *args, **kwargs) -> DispatchOutput:
44
+ pass
45
+
46
+ @abstractmethod
47
+ def combine(self, *args, **kwargs) -> torch.Tensor:
48
+ pass
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import NamedTuple
4
+
5
+ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
6
+ DispatchOutput,
7
+ DispatchOutputFormat,
8
+ )
9
+
10
+
11
+ class StandardDispatchOutput(NamedTuple):
12
+ """Standard dispatch output."""
13
+
14
+ @property
15
+ def format(self) -> DispatchOutputFormat:
16
+ return DispatchOutputFormat.standard
17
+
18
+
19
+ assert isinstance(StandardDispatchOutput, DispatchOutput)
@@ -397,7 +397,9 @@ def grouped_topk_gpu(
397
397
  .reshape(num_token, -1)
398
398
  ) # [n, e]
399
399
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
400
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
400
+ topk_weights, topk_ids = torch.topk(
401
+ tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
402
+ )
401
403
  if num_fused_shared_experts:
402
404
  topk_ids[:, -1] = torch.randint(
403
405
  low=num_experts,
@@ -486,7 +488,9 @@ def biased_grouped_topk_impl(
486
488
  tmp_scores = scores_for_choice.masked_fill(
487
489
  ~score_mask.bool(), float("-inf")
488
490
  ) # [n, e]
489
- _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
491
+ _, topk_ids = torch.topk(
492
+ tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
493
+ )
490
494
  topk_weights = scores.gather(1, topk_ids)
491
495
 
492
496
  if num_fused_shared_experts:
@@ -900,6 +900,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
900
900
  layer.w13_blockscale_swizzled = Parameter(
901
901
  w13_blockscale_swizzled, requires_grad=False
902
902
  )
903
+ del layer.w13_weight_scale
903
904
 
904
905
  # This is for quantization, so we need to invert it.
905
906
  layer.w13_input_scale_quant = Parameter(
@@ -935,6 +936,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
935
936
  layer.w2_blockscale_swizzled = Parameter(
936
937
  w2_blockscale_swizzled, requires_grad=False
937
938
  )
939
+ del layer.w2_weight_scale
938
940
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
939
941
 
940
942
  device = layer.w13_weight.device
@@ -26,6 +26,7 @@ import zmq
26
26
 
27
27
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
28
28
  from sglang.srt.managers.io_struct import (
29
+ BlockReqInput,
29
30
  TokenizedEmbeddingReqInput,
30
31
  TokenizedGenerateReqInput,
31
32
  )
@@ -282,6 +283,9 @@ class DataParallelController:
282
283
  ),
283
284
  ):
284
285
  self.dispatching(recv_req)
286
+ elif isinstance(recv_req, BlockReqInput):
287
+ for worker in self.workers:
288
+ worker.send_pyobj(recv_req)
285
289
  else:
286
290
  # Send other control messages to first worker of tp group
287
291
  for worker in self.workers[:: self.control_message_step]:
@@ -911,6 +911,8 @@ class AbortReq:
911
911
  rid: str = ""
912
912
  # Whether to abort all requests
913
913
  abort_all: bool = False
914
+ # The finished reason data
915
+ finished_reason: Optional[Dict[str, Any]] = None
914
916
 
915
917
 
916
918
  @dataclass
@@ -1101,3 +1103,13 @@ class LoRAUpdateResult:
1101
1103
 
1102
1104
 
1103
1105
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
1106
+
1107
+
1108
+ class BlockReqType(Enum):
1109
+ BLOCK = 1
1110
+ UNBLOCK = 2
1111
+
1112
+
1113
+ @dataclass
1114
+ class BlockReqInput:
1115
+ type: BlockReqType
@@ -24,6 +24,7 @@ import time
24
24
  from collections import defaultdict, deque
25
25
  from concurrent import futures
26
26
  from dataclasses import dataclass
27
+ from http import HTTPStatus
27
28
  from pathlib import Path
28
29
  from types import SimpleNamespace
29
30
  from typing import Dict, List, Optional, Tuple, Union
@@ -122,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
122
123
  PrefillAdder,
123
124
  SchedulePolicy,
124
125
  )
126
+ from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
125
127
  from sglang.srt.managers.scheduler_output_processor_mixin import (
126
128
  SchedulerOutputProcessorMixin,
127
129
  )
@@ -370,6 +372,7 @@ class Scheduler(
370
372
  self.max_total_num_tokens,
371
373
  self.max_prefill_tokens,
372
374
  self.max_running_requests,
375
+ self.max_queued_requests,
373
376
  self.max_req_len,
374
377
  self.max_req_input_len,
375
378
  self.random_seed,
@@ -502,6 +505,12 @@ class Scheduler(
502
505
  )
503
506
  self.init_profier()
504
507
 
508
+ self.input_blocker = (
509
+ SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
510
+ if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
511
+ else None
512
+ )
513
+
505
514
  # Init metrics stats
506
515
  self.init_metrics(tp_rank, pp_rank, dp_rank)
507
516
  self.init_kv_events(server_args.kv_events_config)
@@ -1033,6 +1042,9 @@ class Scheduler(
1033
1042
  else:
1034
1043
  recv_reqs = None
1035
1044
 
1045
+ if self.input_blocker is not None:
1046
+ recv_reqs = self.input_blocker.handle(recv_reqs)
1047
+
1036
1048
  if self.server_args.enable_dp_attention:
1037
1049
  if self.attn_tp_rank == 0:
1038
1050
  work_reqs = [
@@ -1086,6 +1098,19 @@ class Scheduler(
1086
1098
  self.return_health_check_ct += 1
1087
1099
  continue
1088
1100
 
1101
+ # If it is a work request, accept or reject the request based on the request queue size.
1102
+ if is_work_request(recv_req):
1103
+ if len(self.waiting_queue) + 1 > self.max_queued_requests:
1104
+ abort_req = AbortReq(
1105
+ recv_req.rid,
1106
+ finished_reason={
1107
+ "type": "abort",
1108
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1109
+ "message": "The request queue is full.",
1110
+ },
1111
+ )
1112
+ self.send_to_tokenizer.send_pyobj(abort_req)
1113
+ continue
1089
1114
  output = self._request_dispatcher(recv_req)
1090
1115
  if output is not None:
1091
1116
  if isinstance(output, RpcReqOutput):
@@ -2902,6 +2927,10 @@ def is_health_check_generate_req(recv_req):
2902
2927
  return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2903
2928
 
2904
2929
 
2930
+ def is_work_request(recv_req):
2931
+ return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2932
+
2933
+
2905
2934
  def _export_static_state(model):
2906
2935
  return dict(
2907
2936
  buffers=[
@@ -0,0 +1,106 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import logging
15
+ from contextlib import contextmanager
16
+ from enum import Enum, auto
17
+ from typing import Any, List, Optional
18
+
19
+ from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
20
+ from sglang.srt.poll_based_barrier import PollBasedBarrier
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class SchedulerInputBlocker:
26
+ def __init__(self, noop: bool):
27
+ self._state = _State.UNBLOCKED
28
+ self._pending_reqs = []
29
+ self._noop = noop
30
+ self._global_unblock_barrier = PollBasedBarrier(noop=noop)
31
+
32
+ def handle(self, recv_reqs: Optional[List[Any]]):
33
+ assert (recv_reqs is None) == self._noop
34
+
35
+ if not self._noop:
36
+ output_reqs = []
37
+ for recv_req in recv_reqs:
38
+ output_reqs += self._handle_recv_req(recv_req)
39
+
40
+ global_arrived_unblock_barrier = (
41
+ self._global_unblock_barrier.poll_global_arrived()
42
+ )
43
+ if (
44
+ self._state == _State.GLOBAL_UNBLOCK_BARRIER
45
+ and global_arrived_unblock_barrier
46
+ ):
47
+ output_reqs += self._handle_arrive_unblock_barrier()
48
+
49
+ if not self._noop:
50
+ return output_reqs
51
+
52
+ def _handle_recv_req(self, recv_req):
53
+ if isinstance(recv_req, BlockReqInput):
54
+ if recv_req.type == BlockReqType.BLOCK:
55
+ self._execute_block_req()
56
+ return []
57
+ elif recv_req.type == BlockReqType.UNBLOCK:
58
+ self._execute_unblock_req()
59
+ return []
60
+ else:
61
+ raise NotImplementedError(f"{recv_req=}")
62
+ else:
63
+ if self._state == _State.UNBLOCKED:
64
+ return [recv_req]
65
+ else:
66
+ self._pending_reqs.append(recv_req)
67
+ return []
68
+
69
+ def _execute_block_req(self):
70
+ logger.info("Handle block req")
71
+ self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
72
+
73
+ def _execute_unblock_req(self):
74
+ logger.info("Handle unblock req")
75
+ self._change_state(
76
+ original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
77
+ )
78
+ self._global_unblock_barrier.local_arrive()
79
+
80
+ def _handle_arrive_unblock_barrier(self):
81
+ logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
82
+ self._change_state(
83
+ original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
84
+ )
85
+ output_reqs = [*self._pending_reqs]
86
+ self._pending_reqs.clear()
87
+ return output_reqs
88
+
89
+ def _change_state(self, original: "_State", target: "_State"):
90
+ assert self._state == original, f"{self._state=} {original=} {target=}"
91
+ self._state = target
92
+
93
+
94
+ class _State(Enum):
95
+ UNBLOCKED = auto()
96
+ BLOCKED = auto()
97
+ GLOBAL_UNBLOCK_BARRIER = auto()
98
+
99
+
100
+ @contextmanager
101
+ def input_blocker_guard_region(send_to_scheduler):
102
+ send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
103
+ try:
104
+ yield
105
+ finally:
106
+ send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))