sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +23 -3
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +98 -603
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +6 -0
- sglang/srt/managers/io_struct.py +12 -2
- sglang/srt/managers/scheduler.py +116 -669
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +166 -83
- sglang/srt/managers/tp_worker.py +9 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +20 -13
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +15 -56
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +18 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
import logging
|
15
|
+
from contextlib import contextmanager
|
16
|
+
from enum import Enum, auto
|
17
|
+
from typing import Any, List, Optional
|
18
|
+
|
19
|
+
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
|
20
|
+
from sglang.srt.poll_based_barrier import PollBasedBarrier
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
class SchedulerInputBlocker:
|
26
|
+
def __init__(self, noop: bool):
|
27
|
+
self._state = _State.UNBLOCKED
|
28
|
+
self._pending_reqs = []
|
29
|
+
self._noop = noop
|
30
|
+
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
|
31
|
+
|
32
|
+
def handle(self, recv_reqs: Optional[List[Any]]):
|
33
|
+
assert (recv_reqs is None) == self._noop
|
34
|
+
|
35
|
+
if not self._noop:
|
36
|
+
output_reqs = []
|
37
|
+
for recv_req in recv_reqs:
|
38
|
+
output_reqs += self._handle_recv_req(recv_req)
|
39
|
+
|
40
|
+
global_arrived_unblock_barrier = (
|
41
|
+
self._global_unblock_barrier.poll_global_arrived()
|
42
|
+
)
|
43
|
+
if (
|
44
|
+
self._state == _State.GLOBAL_UNBLOCK_BARRIER
|
45
|
+
and global_arrived_unblock_barrier
|
46
|
+
):
|
47
|
+
output_reqs += self._handle_arrive_unblock_barrier()
|
48
|
+
|
49
|
+
if not self._noop:
|
50
|
+
return output_reqs
|
51
|
+
|
52
|
+
def _handle_recv_req(self, recv_req):
|
53
|
+
if isinstance(recv_req, BlockReqInput):
|
54
|
+
if recv_req.type == BlockReqType.BLOCK:
|
55
|
+
self._execute_block_req()
|
56
|
+
return []
|
57
|
+
elif recv_req.type == BlockReqType.UNBLOCK:
|
58
|
+
self._execute_unblock_req()
|
59
|
+
return []
|
60
|
+
else:
|
61
|
+
raise NotImplementedError(f"{recv_req=}")
|
62
|
+
else:
|
63
|
+
if self._state == _State.UNBLOCKED:
|
64
|
+
return [recv_req]
|
65
|
+
else:
|
66
|
+
self._pending_reqs.append(recv_req)
|
67
|
+
return []
|
68
|
+
|
69
|
+
def _execute_block_req(self):
|
70
|
+
logger.info("Handle block req")
|
71
|
+
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
|
72
|
+
|
73
|
+
def _execute_unblock_req(self):
|
74
|
+
logger.info("Handle unblock req")
|
75
|
+
self._change_state(
|
76
|
+
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
|
77
|
+
)
|
78
|
+
self._global_unblock_barrier.local_arrive()
|
79
|
+
|
80
|
+
def _handle_arrive_unblock_barrier(self):
|
81
|
+
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
|
82
|
+
self._change_state(
|
83
|
+
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
|
84
|
+
)
|
85
|
+
output_reqs = [*self._pending_reqs]
|
86
|
+
self._pending_reqs.clear()
|
87
|
+
return output_reqs
|
88
|
+
|
89
|
+
def _change_state(self, original: "_State", target: "_State"):
|
90
|
+
assert self._state == original, f"{self._state=} {original=} {target=}"
|
91
|
+
self._state = target
|
92
|
+
|
93
|
+
|
94
|
+
class _State(Enum):
|
95
|
+
UNBLOCKED = auto()
|
96
|
+
BLOCKED = auto()
|
97
|
+
GLOBAL_UNBLOCK_BARRIER = auto()
|
98
|
+
|
99
|
+
|
100
|
+
@contextmanager
|
101
|
+
def input_blocker_guard_region(send_to_scheduler):
|
102
|
+
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
|
103
|
+
try:
|
104
|
+
yield
|
105
|
+
finally:
|
106
|
+
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
|
@@ -0,0 +1,229 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from collections import defaultdict
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
7
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
8
|
+
from sglang.srt.managers.schedule_policy import PrefillAdder
|
9
|
+
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
10
|
+
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
11
|
+
from sglang.srt.utils import get_bool_env_var
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
16
|
+
|
17
|
+
|
18
|
+
class KvMetrics:
|
19
|
+
def __init__(self):
|
20
|
+
self.request_active_slots = None
|
21
|
+
self.request_total_slots = None
|
22
|
+
self.kv_active_blocks = None
|
23
|
+
self.kv_total_blocks = None
|
24
|
+
self.num_requests_waiting = None
|
25
|
+
self.gpu_cache_usage_perc = None
|
26
|
+
self.gpu_prefix_cache_hit_rate = None
|
27
|
+
self.data_parallel_rank = None
|
28
|
+
|
29
|
+
|
30
|
+
class SchedulerMetricsMixin:
|
31
|
+
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
|
32
|
+
self.last_gen_throughput: float = 0.0
|
33
|
+
self.last_input_throughput: float = 0.0
|
34
|
+
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
35
|
+
self.spec_num_total_accepted_tokens = 0
|
36
|
+
self.spec_num_total_forward_ct = 0
|
37
|
+
self.cum_spec_accept_length = 0
|
38
|
+
self.cum_spec_accept_count = 0
|
39
|
+
self.total_retracted_reqs = 0
|
40
|
+
self.stats = SchedulerStats()
|
41
|
+
if self.enable_metrics:
|
42
|
+
engine_type = "unified"
|
43
|
+
labels = {
|
44
|
+
"model_name": self.server_args.served_model_name,
|
45
|
+
"engine_type": engine_type,
|
46
|
+
"tp_rank": tp_rank,
|
47
|
+
"pp_rank": pp_rank,
|
48
|
+
}
|
49
|
+
if dp_rank is not None:
|
50
|
+
labels["dp_rank"] = dp_rank
|
51
|
+
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
52
|
+
|
53
|
+
def init_kv_events(self, kv_events_config: Optional[str]):
|
54
|
+
if self.enable_kv_cache_events:
|
55
|
+
self.kv_event_publisher = EventPublisherFactory.create(
|
56
|
+
kv_events_config, self.attn_dp_rank
|
57
|
+
)
|
58
|
+
|
59
|
+
def log_prefill_stats(
|
60
|
+
self,
|
61
|
+
adder: PrefillAdder,
|
62
|
+
can_run_list: List[Req],
|
63
|
+
running_bs: int,
|
64
|
+
):
|
65
|
+
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
66
|
+
self.last_prefill_stats_tic = time.perf_counter()
|
67
|
+
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
68
|
+
self.last_prefill_tokens = adder.log_input_tokens
|
69
|
+
|
70
|
+
if self.is_hybrid:
|
71
|
+
(
|
72
|
+
full_num_used,
|
73
|
+
swa_num_used,
|
74
|
+
full_token_usage,
|
75
|
+
swa_token_usage,
|
76
|
+
_,
|
77
|
+
_,
|
78
|
+
_,
|
79
|
+
_,
|
80
|
+
) = self._get_swa_token_info()
|
81
|
+
num_used = max(full_num_used, swa_num_used)
|
82
|
+
token_usage = max(full_token_usage, swa_token_usage)
|
83
|
+
token_msg = (
|
84
|
+
f"full token usage: {full_token_usage:.2f}, "
|
85
|
+
f"swa token usage: {swa_token_usage:.2f}, "
|
86
|
+
)
|
87
|
+
else:
|
88
|
+
num_used, token_usage, _, _ = self._get_token_info()
|
89
|
+
token_msg = f"token usage: {token_usage:.2f}, "
|
90
|
+
|
91
|
+
num_new_seq = len(can_run_list)
|
92
|
+
f = (
|
93
|
+
f"Prefill batch. "
|
94
|
+
f"#new-seq: {num_new_seq}, "
|
95
|
+
f"#new-token: {adder.log_input_tokens}, "
|
96
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
97
|
+
f"{token_msg}"
|
98
|
+
)
|
99
|
+
|
100
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
101
|
+
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
102
|
+
f += f"#queue-req: {len(self.waiting_queue)}, "
|
103
|
+
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
104
|
+
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
105
|
+
else:
|
106
|
+
f += f"#running-req: {running_bs}, "
|
107
|
+
f += f"#queue-req: {len(self.waiting_queue)}, "
|
108
|
+
|
109
|
+
logger.info(f)
|
110
|
+
|
111
|
+
if self.enable_metrics:
|
112
|
+
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
113
|
+
|
114
|
+
cache_hit_rate = (
|
115
|
+
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
116
|
+
)
|
117
|
+
self.stats.num_running_reqs = running_bs
|
118
|
+
self.stats.num_used_tokens = num_used
|
119
|
+
self.stats.token_usage = round(token_usage, 2)
|
120
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
121
|
+
self.stats.cache_hit_rate = cache_hit_rate
|
122
|
+
|
123
|
+
total_queue_latency = 0
|
124
|
+
for req in can_run_list:
|
125
|
+
total_queue_latency += req.queue_time_end - req.queue_time_start
|
126
|
+
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
127
|
+
|
128
|
+
self.metrics_collector.log_stats(self.stats)
|
129
|
+
self._emit_kv_metrics()
|
130
|
+
self._publish_kv_events()
|
131
|
+
|
132
|
+
def log_decode_stats(
|
133
|
+
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
134
|
+
):
|
135
|
+
batch = running_batch or self.running_batch
|
136
|
+
|
137
|
+
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
138
|
+
self.last_decode_stats_tic = time.perf_counter()
|
139
|
+
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
140
|
+
self.num_generated_tokens = 0
|
141
|
+
num_running_reqs = len(batch.reqs)
|
142
|
+
if self.is_hybrid:
|
143
|
+
(
|
144
|
+
full_num_used,
|
145
|
+
swa_num_used,
|
146
|
+
full_token_usage,
|
147
|
+
swa_token_usage,
|
148
|
+
_,
|
149
|
+
_,
|
150
|
+
_,
|
151
|
+
_,
|
152
|
+
) = self._get_swa_token_info()
|
153
|
+
num_used = max(full_num_used, swa_num_used)
|
154
|
+
token_usage = max(full_token_usage, swa_token_usage)
|
155
|
+
token_msg = (
|
156
|
+
f"#full token: {full_num_used}, "
|
157
|
+
f"full token usage: {full_token_usage:.2f}, "
|
158
|
+
f"#swa token: {swa_num_used}, "
|
159
|
+
f"swa token usage: {swa_token_usage:.2f}, "
|
160
|
+
)
|
161
|
+
else:
|
162
|
+
num_used, token_usage, _, _ = self._get_token_info()
|
163
|
+
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
164
|
+
|
165
|
+
if RECORD_STEP_TIME:
|
166
|
+
self.step_time_dict[num_running_reqs].append(
|
167
|
+
gap_latency / self.server_args.decode_log_interval
|
168
|
+
)
|
169
|
+
|
170
|
+
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
171
|
+
|
172
|
+
if self.spec_algorithm.is_none():
|
173
|
+
spec_accept_length = 0
|
174
|
+
else:
|
175
|
+
spec_accept_length = (
|
176
|
+
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
177
|
+
)
|
178
|
+
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
179
|
+
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
180
|
+
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
181
|
+
msg += f"accept len: {spec_accept_length:.2f}, "
|
182
|
+
|
183
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
184
|
+
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
185
|
+
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
186
|
+
|
187
|
+
msg += (
|
188
|
+
f"cuda graph: {can_run_cuda_graph}, "
|
189
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
190
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
191
|
+
)
|
192
|
+
|
193
|
+
logger.info(msg)
|
194
|
+
if self.enable_metrics:
|
195
|
+
self.stats.num_running_reqs = num_running_reqs
|
196
|
+
self.stats.num_used_tokens = num_used
|
197
|
+
self.stats.token_usage = round(token_usage, 2)
|
198
|
+
self.stats.cache_hit_rate = 0.0
|
199
|
+
self.stats.gen_throughput = self.last_gen_throughput
|
200
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
201
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
202
|
+
self.stats.spec_accept_length = spec_accept_length
|
203
|
+
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
204
|
+
self.metrics_collector.log_stats(self.stats)
|
205
|
+
self._emit_kv_metrics()
|
206
|
+
self._publish_kv_events()
|
207
|
+
|
208
|
+
def _emit_kv_metrics(self):
|
209
|
+
kv_metrics = KvMetrics()
|
210
|
+
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
211
|
+
kv_metrics.request_total_slots = self.max_running_requests
|
212
|
+
kv_metrics.kv_active_blocks = int(
|
213
|
+
self.stats.token_usage * self.max_total_num_tokens
|
214
|
+
)
|
215
|
+
kv_metrics.kv_total_blocks = self.max_total_num_tokens
|
216
|
+
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
|
217
|
+
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
|
218
|
+
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
|
219
|
+
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
|
220
|
+
|
221
|
+
if not self.send_metrics_from_scheduler.closed:
|
222
|
+
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
223
|
+
|
224
|
+
def _publish_kv_events(self):
|
225
|
+
if self.enable_kv_cache_events:
|
226
|
+
events = self.tree_cache.take_events()
|
227
|
+
if events:
|
228
|
+
batch = KVEventBatch(ts=time.time(), events=events)
|
229
|
+
self.kv_event_publisher.publish(batch)
|
@@ -0,0 +1,279 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import List, Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class SchedulerProfilerMixin:
|
16
|
+
|
17
|
+
def init_profier(self):
|
18
|
+
self.torch_profiler = None
|
19
|
+
self.torch_profiler_output_dir: Optional[str] = None
|
20
|
+
self.profiler_activities: Optional[List[str]] = None
|
21
|
+
self.profile_id: Optional[str] = None
|
22
|
+
self.profiler_start_forward_ct: Optional[int] = None
|
23
|
+
self.profiler_target_forward_ct: Optional[int] = None
|
24
|
+
self.profiler_target_prefill_ct: Optional[int] = None
|
25
|
+
self.profiler_target_decode_ct: Optional[int] = None
|
26
|
+
self.profiler_prefill_ct: Optional[int] = None
|
27
|
+
self.profiler_decode_ct: Optional[int] = None
|
28
|
+
self.profile_by_stage: bool = False
|
29
|
+
self.profile_steps: Optional[int] = None
|
30
|
+
self.profile_in_progress: bool = False
|
31
|
+
self.rpd_profiler = None
|
32
|
+
|
33
|
+
def init_profile(
|
34
|
+
self,
|
35
|
+
output_dir: Optional[str],
|
36
|
+
start_step: Optional[int],
|
37
|
+
num_steps: Optional[int],
|
38
|
+
activities: Optional[List[str]],
|
39
|
+
with_stack: Optional[bool],
|
40
|
+
record_shapes: Optional[bool],
|
41
|
+
profile_by_stage: bool,
|
42
|
+
profile_id: str,
|
43
|
+
) -> ProfileReqOutput:
|
44
|
+
if self.profile_in_progress:
|
45
|
+
return ProfileReqOutput(
|
46
|
+
success=False,
|
47
|
+
message="Profiling is already in progress. Call /stop_profile first.",
|
48
|
+
)
|
49
|
+
|
50
|
+
self.profile_by_stage = profile_by_stage
|
51
|
+
|
52
|
+
if output_dir is None:
|
53
|
+
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
54
|
+
if activities is None:
|
55
|
+
activities = ["CPU", "GPU"]
|
56
|
+
|
57
|
+
self.torch_profiler_output_dir = output_dir
|
58
|
+
self.torch_profiler_with_stack = with_stack
|
59
|
+
self.torch_profiler_record_shapes = record_shapes
|
60
|
+
self.profiler_activities = activities
|
61
|
+
self.profile_id = profile_id
|
62
|
+
|
63
|
+
if start_step:
|
64
|
+
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
|
65
|
+
|
66
|
+
if num_steps:
|
67
|
+
self.profile_steps = num_steps
|
68
|
+
if self.profile_by_stage:
|
69
|
+
self.profiler_target_prefill_ct = num_steps
|
70
|
+
self.profiler_target_decode_ct = num_steps
|
71
|
+
self.profiler_prefill_ct = 0
|
72
|
+
self.profiler_decode_ct = 0
|
73
|
+
elif start_step:
|
74
|
+
self.profiler_target_forward_ct = (
|
75
|
+
self.profiler_start_forward_ct + num_steps
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
79
|
+
# The caller will be notified when reaching profiler_target_forward_ct
|
80
|
+
else:
|
81
|
+
self.profiler_target_forward_ct = None
|
82
|
+
|
83
|
+
return ProfileReqOutput(success=True, message="Succeeded")
|
84
|
+
|
85
|
+
def start_profile(
|
86
|
+
self, stage: Optional[ForwardMode] = None
|
87
|
+
) -> ProfileReqOutput | None:
|
88
|
+
stage_str = f" for {stage.__str__()}" if stage else ""
|
89
|
+
logger.info(
|
90
|
+
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
91
|
+
)
|
92
|
+
|
93
|
+
activities = self.profiler_activities
|
94
|
+
with_stack = self.torch_profiler_with_stack
|
95
|
+
record_shapes = self.torch_profiler_record_shapes
|
96
|
+
|
97
|
+
activity_map = {
|
98
|
+
"CPU": torch.profiler.ProfilerActivity.CPU,
|
99
|
+
"GPU": torch.profiler.ProfilerActivity.CUDA,
|
100
|
+
}
|
101
|
+
torchprof_activities = [
|
102
|
+
activity_map[a] for a in activities if a in activity_map
|
103
|
+
]
|
104
|
+
|
105
|
+
if "RPD" in activities:
|
106
|
+
from rpdTracerControl import rpdTracerControl
|
107
|
+
|
108
|
+
rpdTracerControl.skipCreate()
|
109
|
+
|
110
|
+
self.rpd_profile_path = os.path.join(
|
111
|
+
self.torch_profiler_output_dir,
|
112
|
+
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
113
|
+
)
|
114
|
+
|
115
|
+
if self.tp_rank == 0:
|
116
|
+
import sqlite3
|
117
|
+
|
118
|
+
from rocpd.schema import RocpdSchema
|
119
|
+
|
120
|
+
if os.path.exists("trace.rpd"):
|
121
|
+
os.unlink("trace.rpd")
|
122
|
+
schema = RocpdSchema()
|
123
|
+
connection = sqlite3.connect("trace.rpd")
|
124
|
+
schema.writeSchema(connection)
|
125
|
+
connection.commit()
|
126
|
+
del connection
|
127
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
128
|
+
|
129
|
+
self.rpd_profiler = rpdTracerControl()
|
130
|
+
self.rpd_profiler.setPythonTrace(True)
|
131
|
+
self.rpd_profiler.start()
|
132
|
+
self.rpd_profiler.rangePush("", "rpd profile range", "")
|
133
|
+
self.profile_in_progress = True
|
134
|
+
elif torchprof_activities:
|
135
|
+
self.torch_profiler = torch.profiler.profile(
|
136
|
+
activities=torchprof_activities,
|
137
|
+
with_stack=with_stack if with_stack is not None else True,
|
138
|
+
record_shapes=record_shapes if record_shapes is not None else False,
|
139
|
+
)
|
140
|
+
self.torch_profiler.start()
|
141
|
+
self.profile_in_progress = True
|
142
|
+
|
143
|
+
if "MEM" in activities:
|
144
|
+
torch.cuda.memory._record_memory_history(max_entries=100000)
|
145
|
+
self.profile_in_progress = True
|
146
|
+
|
147
|
+
if "CUDA_PROFILER" in activities:
|
148
|
+
torch.cuda.cudart().cudaProfilerStart()
|
149
|
+
self.profile_in_progress = True
|
150
|
+
|
151
|
+
return ProfileReqOutput(success=True, message="Succeeded")
|
152
|
+
|
153
|
+
def stop_profile(
|
154
|
+
self, stage: Optional[ForwardMode] = None
|
155
|
+
) -> ProfileReqOutput | None:
|
156
|
+
if not self.profile_in_progress:
|
157
|
+
return ProfileReqOutput(
|
158
|
+
success=False,
|
159
|
+
message="Profiling is not in progress. Call /start_profile first.",
|
160
|
+
)
|
161
|
+
|
162
|
+
if not Path(self.torch_profiler_output_dir).exists():
|
163
|
+
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
164
|
+
|
165
|
+
stage_suffix = f"-{stage.__str__()}" if stage else ""
|
166
|
+
logger.info("Stop profiling" + stage_suffix + "...")
|
167
|
+
if self.torch_profiler is not None:
|
168
|
+
self.torch_profiler.stop()
|
169
|
+
self.torch_profiler.export_chrome_trace(
|
170
|
+
os.path.join(
|
171
|
+
self.torch_profiler_output_dir,
|
172
|
+
self.profile_id
|
173
|
+
+ f"-TP-{self.tp_rank}"
|
174
|
+
+ stage_suffix
|
175
|
+
+ ".trace.json.gz",
|
176
|
+
)
|
177
|
+
)
|
178
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
179
|
+
|
180
|
+
if self.rpd_profiler is not None:
|
181
|
+
self.rpd_profiler.rangePop()
|
182
|
+
self.rpd_profiler.stop()
|
183
|
+
self.rpd_profiler.flush()
|
184
|
+
|
185
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
186
|
+
if self.tp_rank == 0:
|
187
|
+
from sglang.srt.utils import rpd_to_chrome_trace
|
188
|
+
|
189
|
+
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
190
|
+
self.rpd_profiler = None
|
191
|
+
self.rpd_profiler_path = None
|
192
|
+
|
193
|
+
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
|
194
|
+
memory_profile_path = os.path.join(
|
195
|
+
self.torch_profiler_output_dir,
|
196
|
+
str(time.time())
|
197
|
+
+ f"-TP-{self.tp_rank}-memory"
|
198
|
+
+ stage_suffix
|
199
|
+
+ ".pickle",
|
200
|
+
)
|
201
|
+
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
202
|
+
torch.cuda.memory._record_memory_history(enabled=None)
|
203
|
+
|
204
|
+
if "CUDA_PROFILER" in self.profiler_activities:
|
205
|
+
torch.cuda.cudart().cudaProfilerStop()
|
206
|
+
|
207
|
+
logger.info(
|
208
|
+
"Profiling done. Traces are saved to: %s",
|
209
|
+
self.torch_profiler_output_dir,
|
210
|
+
)
|
211
|
+
self.torch_profiler = None
|
212
|
+
self.profile_in_progress = False
|
213
|
+
self.profiler_start_forward_ct = None
|
214
|
+
|
215
|
+
return ProfileReqOutput(success=True, message="Succeeded.")
|
216
|
+
|
217
|
+
def _profile_batch_predicate(self, batch):
|
218
|
+
if self.profile_by_stage:
|
219
|
+
if batch.forward_mode.is_prefill():
|
220
|
+
if self.profiler_prefill_ct == 0:
|
221
|
+
self.start_profile(batch.forward_mode)
|
222
|
+
self.profiler_prefill_ct += 1
|
223
|
+
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
|
224
|
+
if self.profile_in_progress:
|
225
|
+
self.stop_profile(stage=ForwardMode.EXTEND)
|
226
|
+
elif batch.forward_mode.is_decode():
|
227
|
+
if self.profiler_decode_ct == 0:
|
228
|
+
if self.profile_in_progress:
|
229
|
+
# force trace flush
|
230
|
+
self.stop_profile(ForwardMode.EXTEND)
|
231
|
+
self.start_profile(batch.forward_mode)
|
232
|
+
self.profiler_decode_ct += 1
|
233
|
+
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
234
|
+
if self.profile_in_progress:
|
235
|
+
self.stop_profile(stage=ForwardMode.DECODE)
|
236
|
+
elif batch.forward_mode.is_idle():
|
237
|
+
pass
|
238
|
+
else:
|
239
|
+
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
|
240
|
+
else:
|
241
|
+
# Check profiler
|
242
|
+
if (
|
243
|
+
self.profiler_target_forward_ct
|
244
|
+
and self.profiler_target_forward_ct <= self.forward_ct
|
245
|
+
):
|
246
|
+
self.stop_profile()
|
247
|
+
if (
|
248
|
+
self.profiler_start_forward_ct
|
249
|
+
and self.profiler_start_forward_ct == self.forward_ct
|
250
|
+
):
|
251
|
+
self.start_profile()
|
252
|
+
|
253
|
+
def profile(self, recv_req: ProfileReq):
|
254
|
+
if recv_req.type == ProfileReqType.START_PROFILE:
|
255
|
+
if recv_req.profile_by_stage or recv_req.start_step:
|
256
|
+
return self.init_profile(
|
257
|
+
recv_req.output_dir,
|
258
|
+
recv_req.start_step,
|
259
|
+
recv_req.num_steps,
|
260
|
+
recv_req.activities,
|
261
|
+
recv_req.with_stack,
|
262
|
+
recv_req.record_shapes,
|
263
|
+
recv_req.profile_by_stage,
|
264
|
+
recv_req.profile_id,
|
265
|
+
)
|
266
|
+
else:
|
267
|
+
self.init_profile(
|
268
|
+
recv_req.output_dir,
|
269
|
+
recv_req.start_step,
|
270
|
+
recv_req.num_steps,
|
271
|
+
recv_req.activities,
|
272
|
+
recv_req.with_stack,
|
273
|
+
recv_req.record_shapes,
|
274
|
+
recv_req.profile_by_stage,
|
275
|
+
recv_req.profile_id,
|
276
|
+
)
|
277
|
+
return self.start_profile(True)
|
278
|
+
else:
|
279
|
+
return self.stop_profile()
|