sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,229 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from collections import defaultdict
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
7
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
8
|
+
from sglang.srt.managers.schedule_policy import PrefillAdder
|
9
|
+
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
10
|
+
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
11
|
+
from sglang.srt.utils import get_bool_env_var
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
16
|
+
|
17
|
+
|
18
|
+
class KvMetrics:
|
19
|
+
def __init__(self):
|
20
|
+
self.request_active_slots = None
|
21
|
+
self.request_total_slots = None
|
22
|
+
self.kv_active_blocks = None
|
23
|
+
self.kv_total_blocks = None
|
24
|
+
self.num_requests_waiting = None
|
25
|
+
self.gpu_cache_usage_perc = None
|
26
|
+
self.gpu_prefix_cache_hit_rate = None
|
27
|
+
self.data_parallel_rank = None
|
28
|
+
|
29
|
+
|
30
|
+
class SchedulerMetricsMixin:
|
31
|
+
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
|
32
|
+
self.last_gen_throughput: float = 0.0
|
33
|
+
self.last_input_throughput: float = 0.0
|
34
|
+
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
35
|
+
self.spec_num_total_accepted_tokens = 0
|
36
|
+
self.spec_num_total_forward_ct = 0
|
37
|
+
self.cum_spec_accept_length = 0
|
38
|
+
self.cum_spec_accept_count = 0
|
39
|
+
self.total_retracted_reqs = 0
|
40
|
+
self.stats = SchedulerStats()
|
41
|
+
if self.enable_metrics:
|
42
|
+
engine_type = "unified"
|
43
|
+
labels = {
|
44
|
+
"model_name": self.server_args.served_model_name,
|
45
|
+
"engine_type": engine_type,
|
46
|
+
"tp_rank": tp_rank,
|
47
|
+
"pp_rank": pp_rank,
|
48
|
+
}
|
49
|
+
if dp_rank is not None:
|
50
|
+
labels["dp_rank"] = dp_rank
|
51
|
+
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
52
|
+
|
53
|
+
def init_kv_events(self, kv_events_config: Optional[str]):
|
54
|
+
if self.enable_kv_cache_events:
|
55
|
+
self.kv_event_publisher = EventPublisherFactory.create(
|
56
|
+
kv_events_config, self.attn_dp_rank
|
57
|
+
)
|
58
|
+
|
59
|
+
def log_prefill_stats(
|
60
|
+
self,
|
61
|
+
adder: PrefillAdder,
|
62
|
+
can_run_list: List[Req],
|
63
|
+
running_bs: int,
|
64
|
+
):
|
65
|
+
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
66
|
+
self.last_prefill_stats_tic = time.perf_counter()
|
67
|
+
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
68
|
+
self.last_prefill_tokens = adder.log_input_tokens
|
69
|
+
|
70
|
+
if self.is_hybrid:
|
71
|
+
(
|
72
|
+
full_num_used,
|
73
|
+
swa_num_used,
|
74
|
+
full_token_usage,
|
75
|
+
swa_token_usage,
|
76
|
+
_,
|
77
|
+
_,
|
78
|
+
_,
|
79
|
+
_,
|
80
|
+
) = self._get_swa_token_info()
|
81
|
+
num_used = max(full_num_used, swa_num_used)
|
82
|
+
token_usage = max(full_token_usage, swa_token_usage)
|
83
|
+
token_msg = (
|
84
|
+
f"full token usage: {full_token_usage:.2f}, "
|
85
|
+
f"swa token usage: {swa_token_usage:.2f}, "
|
86
|
+
)
|
87
|
+
else:
|
88
|
+
num_used, token_usage, _, _ = self._get_token_info()
|
89
|
+
token_msg = f"token usage: {token_usage:.2f}, "
|
90
|
+
|
91
|
+
num_new_seq = len(can_run_list)
|
92
|
+
f = (
|
93
|
+
f"Prefill batch. "
|
94
|
+
f"#new-seq: {num_new_seq}, "
|
95
|
+
f"#new-token: {adder.log_input_tokens}, "
|
96
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
97
|
+
f"{token_msg}"
|
98
|
+
)
|
99
|
+
|
100
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
101
|
+
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
102
|
+
f += f"#queue-req: {len(self.waiting_queue)}, "
|
103
|
+
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
104
|
+
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
105
|
+
else:
|
106
|
+
f += f"#running-req: {running_bs}, "
|
107
|
+
f += f"#queue-req: {len(self.waiting_queue)}, "
|
108
|
+
|
109
|
+
logger.info(f)
|
110
|
+
|
111
|
+
if self.enable_metrics:
|
112
|
+
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
113
|
+
|
114
|
+
cache_hit_rate = (
|
115
|
+
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
116
|
+
)
|
117
|
+
self.stats.num_running_reqs = running_bs
|
118
|
+
self.stats.num_used_tokens = num_used
|
119
|
+
self.stats.token_usage = round(token_usage, 2)
|
120
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
121
|
+
self.stats.cache_hit_rate = cache_hit_rate
|
122
|
+
|
123
|
+
total_queue_latency = 0
|
124
|
+
for req in can_run_list:
|
125
|
+
total_queue_latency += req.queue_time_end - req.queue_time_start
|
126
|
+
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
127
|
+
|
128
|
+
self.metrics_collector.log_stats(self.stats)
|
129
|
+
self._emit_kv_metrics()
|
130
|
+
self._publish_kv_events()
|
131
|
+
|
132
|
+
def log_decode_stats(
|
133
|
+
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
134
|
+
):
|
135
|
+
batch = running_batch or self.running_batch
|
136
|
+
|
137
|
+
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
138
|
+
self.last_decode_stats_tic = time.perf_counter()
|
139
|
+
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
140
|
+
self.num_generated_tokens = 0
|
141
|
+
num_running_reqs = len(batch.reqs)
|
142
|
+
if self.is_hybrid:
|
143
|
+
(
|
144
|
+
full_num_used,
|
145
|
+
swa_num_used,
|
146
|
+
full_token_usage,
|
147
|
+
swa_token_usage,
|
148
|
+
_,
|
149
|
+
_,
|
150
|
+
_,
|
151
|
+
_,
|
152
|
+
) = self._get_swa_token_info()
|
153
|
+
num_used = max(full_num_used, swa_num_used)
|
154
|
+
token_usage = max(full_token_usage, swa_token_usage)
|
155
|
+
token_msg = (
|
156
|
+
f"#full token: {full_num_used}, "
|
157
|
+
f"full token usage: {full_token_usage:.2f}, "
|
158
|
+
f"#swa token: {swa_num_used}, "
|
159
|
+
f"swa token usage: {swa_token_usage:.2f}, "
|
160
|
+
)
|
161
|
+
else:
|
162
|
+
num_used, token_usage, _, _ = self._get_token_info()
|
163
|
+
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
164
|
+
|
165
|
+
if RECORD_STEP_TIME:
|
166
|
+
self.step_time_dict[num_running_reqs].append(
|
167
|
+
gap_latency / self.server_args.decode_log_interval
|
168
|
+
)
|
169
|
+
|
170
|
+
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
171
|
+
|
172
|
+
if self.spec_algorithm.is_none():
|
173
|
+
spec_accept_length = 0
|
174
|
+
else:
|
175
|
+
spec_accept_length = (
|
176
|
+
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
177
|
+
)
|
178
|
+
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
179
|
+
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
180
|
+
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
181
|
+
msg += f"accept len: {spec_accept_length:.2f}, "
|
182
|
+
|
183
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
184
|
+
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
185
|
+
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
186
|
+
|
187
|
+
msg += (
|
188
|
+
f"cuda graph: {can_run_cuda_graph}, "
|
189
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
190
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
191
|
+
)
|
192
|
+
|
193
|
+
logger.info(msg)
|
194
|
+
if self.enable_metrics:
|
195
|
+
self.stats.num_running_reqs = num_running_reqs
|
196
|
+
self.stats.num_used_tokens = num_used
|
197
|
+
self.stats.token_usage = round(token_usage, 2)
|
198
|
+
self.stats.cache_hit_rate = 0.0
|
199
|
+
self.stats.gen_throughput = self.last_gen_throughput
|
200
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
201
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
202
|
+
self.stats.spec_accept_length = spec_accept_length
|
203
|
+
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
204
|
+
self.metrics_collector.log_stats(self.stats)
|
205
|
+
self._emit_kv_metrics()
|
206
|
+
self._publish_kv_events()
|
207
|
+
|
208
|
+
def _emit_kv_metrics(self):
|
209
|
+
kv_metrics = KvMetrics()
|
210
|
+
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
211
|
+
kv_metrics.request_total_slots = self.max_running_requests
|
212
|
+
kv_metrics.kv_active_blocks = int(
|
213
|
+
self.stats.token_usage * self.max_total_num_tokens
|
214
|
+
)
|
215
|
+
kv_metrics.kv_total_blocks = self.max_total_num_tokens
|
216
|
+
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
|
217
|
+
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
|
218
|
+
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
|
219
|
+
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
|
220
|
+
|
221
|
+
if not self.send_metrics_from_scheduler.closed:
|
222
|
+
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
223
|
+
|
224
|
+
def _publish_kv_events(self):
|
225
|
+
if self.enable_kv_cache_events:
|
226
|
+
events = self.tree_cache.take_events()
|
227
|
+
if events:
|
228
|
+
batch = KVEventBatch(ts=time.time(), events=events)
|
229
|
+
self.kv_event_publisher.publish(batch)
|
@@ -0,0 +1,279 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import List, Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class SchedulerProfilerMixin:
|
16
|
+
|
17
|
+
def init_profier(self):
|
18
|
+
self.torch_profiler = None
|
19
|
+
self.torch_profiler_output_dir: Optional[str] = None
|
20
|
+
self.profiler_activities: Optional[List[str]] = None
|
21
|
+
self.profile_id: Optional[str] = None
|
22
|
+
self.profiler_start_forward_ct: Optional[int] = None
|
23
|
+
self.profiler_target_forward_ct: Optional[int] = None
|
24
|
+
self.profiler_target_prefill_ct: Optional[int] = None
|
25
|
+
self.profiler_target_decode_ct: Optional[int] = None
|
26
|
+
self.profiler_prefill_ct: Optional[int] = None
|
27
|
+
self.profiler_decode_ct: Optional[int] = None
|
28
|
+
self.profile_by_stage: bool = False
|
29
|
+
self.profile_steps: Optional[int] = None
|
30
|
+
self.profile_in_progress: bool = False
|
31
|
+
self.rpd_profiler = None
|
32
|
+
|
33
|
+
def init_profile(
|
34
|
+
self,
|
35
|
+
output_dir: Optional[str],
|
36
|
+
start_step: Optional[int],
|
37
|
+
num_steps: Optional[int],
|
38
|
+
activities: Optional[List[str]],
|
39
|
+
with_stack: Optional[bool],
|
40
|
+
record_shapes: Optional[bool],
|
41
|
+
profile_by_stage: bool,
|
42
|
+
profile_id: str,
|
43
|
+
) -> ProfileReqOutput:
|
44
|
+
if self.profile_in_progress:
|
45
|
+
return ProfileReqOutput(
|
46
|
+
success=False,
|
47
|
+
message="Profiling is already in progress. Call /stop_profile first.",
|
48
|
+
)
|
49
|
+
|
50
|
+
self.profile_by_stage = profile_by_stage
|
51
|
+
|
52
|
+
if output_dir is None:
|
53
|
+
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
54
|
+
if activities is None:
|
55
|
+
activities = ["CPU", "GPU"]
|
56
|
+
|
57
|
+
self.torch_profiler_output_dir = output_dir
|
58
|
+
self.torch_profiler_with_stack = with_stack
|
59
|
+
self.torch_profiler_record_shapes = record_shapes
|
60
|
+
self.profiler_activities = activities
|
61
|
+
self.profile_id = profile_id
|
62
|
+
|
63
|
+
if start_step:
|
64
|
+
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
|
65
|
+
|
66
|
+
if num_steps:
|
67
|
+
self.profile_steps = num_steps
|
68
|
+
if self.profile_by_stage:
|
69
|
+
self.profiler_target_prefill_ct = num_steps
|
70
|
+
self.profiler_target_decode_ct = num_steps
|
71
|
+
self.profiler_prefill_ct = 0
|
72
|
+
self.profiler_decode_ct = 0
|
73
|
+
elif start_step:
|
74
|
+
self.profiler_target_forward_ct = (
|
75
|
+
self.profiler_start_forward_ct + num_steps
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
79
|
+
# The caller will be notified when reaching profiler_target_forward_ct
|
80
|
+
else:
|
81
|
+
self.profiler_target_forward_ct = None
|
82
|
+
|
83
|
+
return ProfileReqOutput(success=True, message="Succeeded")
|
84
|
+
|
85
|
+
def start_profile(
|
86
|
+
self, stage: Optional[ForwardMode] = None
|
87
|
+
) -> ProfileReqOutput | None:
|
88
|
+
stage_str = f" for {stage.__str__()}" if stage else ""
|
89
|
+
logger.info(
|
90
|
+
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
91
|
+
)
|
92
|
+
|
93
|
+
activities = self.profiler_activities
|
94
|
+
with_stack = self.torch_profiler_with_stack
|
95
|
+
record_shapes = self.torch_profiler_record_shapes
|
96
|
+
|
97
|
+
activity_map = {
|
98
|
+
"CPU": torch.profiler.ProfilerActivity.CPU,
|
99
|
+
"GPU": torch.profiler.ProfilerActivity.CUDA,
|
100
|
+
}
|
101
|
+
torchprof_activities = [
|
102
|
+
activity_map[a] for a in activities if a in activity_map
|
103
|
+
]
|
104
|
+
|
105
|
+
if "RPD" in activities:
|
106
|
+
from rpdTracerControl import rpdTracerControl
|
107
|
+
|
108
|
+
rpdTracerControl.skipCreate()
|
109
|
+
|
110
|
+
self.rpd_profile_path = os.path.join(
|
111
|
+
self.torch_profiler_output_dir,
|
112
|
+
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
113
|
+
)
|
114
|
+
|
115
|
+
if self.tp_rank == 0:
|
116
|
+
import sqlite3
|
117
|
+
|
118
|
+
from rocpd.schema import RocpdSchema
|
119
|
+
|
120
|
+
if os.path.exists("trace.rpd"):
|
121
|
+
os.unlink("trace.rpd")
|
122
|
+
schema = RocpdSchema()
|
123
|
+
connection = sqlite3.connect("trace.rpd")
|
124
|
+
schema.writeSchema(connection)
|
125
|
+
connection.commit()
|
126
|
+
del connection
|
127
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
128
|
+
|
129
|
+
self.rpd_profiler = rpdTracerControl()
|
130
|
+
self.rpd_profiler.setPythonTrace(True)
|
131
|
+
self.rpd_profiler.start()
|
132
|
+
self.rpd_profiler.rangePush("", "rpd profile range", "")
|
133
|
+
self.profile_in_progress = True
|
134
|
+
elif torchprof_activities:
|
135
|
+
self.torch_profiler = torch.profiler.profile(
|
136
|
+
activities=torchprof_activities,
|
137
|
+
with_stack=with_stack if with_stack is not None else True,
|
138
|
+
record_shapes=record_shapes if record_shapes is not None else False,
|
139
|
+
)
|
140
|
+
self.torch_profiler.start()
|
141
|
+
self.profile_in_progress = True
|
142
|
+
|
143
|
+
if "MEM" in activities:
|
144
|
+
torch.cuda.memory._record_memory_history(max_entries=100000)
|
145
|
+
self.profile_in_progress = True
|
146
|
+
|
147
|
+
if "CUDA_PROFILER" in activities:
|
148
|
+
torch.cuda.cudart().cudaProfilerStart()
|
149
|
+
self.profile_in_progress = True
|
150
|
+
|
151
|
+
return ProfileReqOutput(success=True, message="Succeeded")
|
152
|
+
|
153
|
+
def stop_profile(
|
154
|
+
self, stage: Optional[ForwardMode] = None
|
155
|
+
) -> ProfileReqOutput | None:
|
156
|
+
if not self.profile_in_progress:
|
157
|
+
return ProfileReqOutput(
|
158
|
+
success=False,
|
159
|
+
message="Profiling is not in progress. Call /start_profile first.",
|
160
|
+
)
|
161
|
+
|
162
|
+
if not Path(self.torch_profiler_output_dir).exists():
|
163
|
+
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
164
|
+
|
165
|
+
stage_suffix = f"-{stage.__str__()}" if stage else ""
|
166
|
+
logger.info("Stop profiling" + stage_suffix + "...")
|
167
|
+
if self.torch_profiler is not None:
|
168
|
+
self.torch_profiler.stop()
|
169
|
+
self.torch_profiler.export_chrome_trace(
|
170
|
+
os.path.join(
|
171
|
+
self.torch_profiler_output_dir,
|
172
|
+
self.profile_id
|
173
|
+
+ f"-TP-{self.tp_rank}"
|
174
|
+
+ stage_suffix
|
175
|
+
+ ".trace.json.gz",
|
176
|
+
)
|
177
|
+
)
|
178
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
179
|
+
|
180
|
+
if self.rpd_profiler is not None:
|
181
|
+
self.rpd_profiler.rangePop()
|
182
|
+
self.rpd_profiler.stop()
|
183
|
+
self.rpd_profiler.flush()
|
184
|
+
|
185
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
186
|
+
if self.tp_rank == 0:
|
187
|
+
from sglang.srt.utils import rpd_to_chrome_trace
|
188
|
+
|
189
|
+
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
190
|
+
self.rpd_profiler = None
|
191
|
+
self.rpd_profiler_path = None
|
192
|
+
|
193
|
+
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
|
194
|
+
memory_profile_path = os.path.join(
|
195
|
+
self.torch_profiler_output_dir,
|
196
|
+
str(time.time())
|
197
|
+
+ f"-TP-{self.tp_rank}-memory"
|
198
|
+
+ stage_suffix
|
199
|
+
+ ".pickle",
|
200
|
+
)
|
201
|
+
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
202
|
+
torch.cuda.memory._record_memory_history(enabled=None)
|
203
|
+
|
204
|
+
if "CUDA_PROFILER" in self.profiler_activities:
|
205
|
+
torch.cuda.cudart().cudaProfilerStop()
|
206
|
+
|
207
|
+
logger.info(
|
208
|
+
"Profiling done. Traces are saved to: %s",
|
209
|
+
self.torch_profiler_output_dir,
|
210
|
+
)
|
211
|
+
self.torch_profiler = None
|
212
|
+
self.profile_in_progress = False
|
213
|
+
self.profiler_start_forward_ct = None
|
214
|
+
|
215
|
+
return ProfileReqOutput(success=True, message="Succeeded.")
|
216
|
+
|
217
|
+
def _profile_batch_predicate(self, batch):
|
218
|
+
if self.profile_by_stage:
|
219
|
+
if batch.forward_mode.is_prefill():
|
220
|
+
if self.profiler_prefill_ct == 0:
|
221
|
+
self.start_profile(batch.forward_mode)
|
222
|
+
self.profiler_prefill_ct += 1
|
223
|
+
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
|
224
|
+
if self.profile_in_progress:
|
225
|
+
self.stop_profile(stage=ForwardMode.EXTEND)
|
226
|
+
elif batch.forward_mode.is_decode():
|
227
|
+
if self.profiler_decode_ct == 0:
|
228
|
+
if self.profile_in_progress:
|
229
|
+
# force trace flush
|
230
|
+
self.stop_profile(ForwardMode.EXTEND)
|
231
|
+
self.start_profile(batch.forward_mode)
|
232
|
+
self.profiler_decode_ct += 1
|
233
|
+
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
234
|
+
if self.profile_in_progress:
|
235
|
+
self.stop_profile(stage=ForwardMode.DECODE)
|
236
|
+
elif batch.forward_mode.is_idle():
|
237
|
+
pass
|
238
|
+
else:
|
239
|
+
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
|
240
|
+
else:
|
241
|
+
# Check profiler
|
242
|
+
if (
|
243
|
+
self.profiler_target_forward_ct
|
244
|
+
and self.profiler_target_forward_ct <= self.forward_ct
|
245
|
+
):
|
246
|
+
self.stop_profile()
|
247
|
+
if (
|
248
|
+
self.profiler_start_forward_ct
|
249
|
+
and self.profiler_start_forward_ct == self.forward_ct
|
250
|
+
):
|
251
|
+
self.start_profile()
|
252
|
+
|
253
|
+
def profile(self, recv_req: ProfileReq):
|
254
|
+
if recv_req.type == ProfileReqType.START_PROFILE:
|
255
|
+
if recv_req.profile_by_stage or recv_req.start_step:
|
256
|
+
return self.init_profile(
|
257
|
+
recv_req.output_dir,
|
258
|
+
recv_req.start_step,
|
259
|
+
recv_req.num_steps,
|
260
|
+
recv_req.activities,
|
261
|
+
recv_req.with_stack,
|
262
|
+
recv_req.record_shapes,
|
263
|
+
recv_req.profile_by_stage,
|
264
|
+
recv_req.profile_id,
|
265
|
+
)
|
266
|
+
else:
|
267
|
+
self.init_profile(
|
268
|
+
recv_req.output_dir,
|
269
|
+
recv_req.start_step,
|
270
|
+
recv_req.num_steps,
|
271
|
+
recv_req.activities,
|
272
|
+
recv_req.with_stack,
|
273
|
+
recv_req.record_shapes,
|
274
|
+
recv_req.profile_by_stage,
|
275
|
+
recv_req.profile_id,
|
276
|
+
)
|
277
|
+
return self.start_profile(True)
|
278
|
+
else:
|
279
|
+
return self.stop_profile()
|
@@ -0,0 +1,142 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
7
|
+
from sglang.srt.managers.io_struct import (
|
8
|
+
GetWeightsByNameReqInput,
|
9
|
+
GetWeightsByNameReqOutput,
|
10
|
+
InitWeightsUpdateGroupReqInput,
|
11
|
+
InitWeightsUpdateGroupReqOutput,
|
12
|
+
ReleaseMemoryOccupationReqInput,
|
13
|
+
ReleaseMemoryOccupationReqOutput,
|
14
|
+
ResumeMemoryOccupationReqInput,
|
15
|
+
ResumeMemoryOccupationReqOutput,
|
16
|
+
UpdateWeightFromDiskReqInput,
|
17
|
+
UpdateWeightFromDiskReqOutput,
|
18
|
+
UpdateWeightsFromDistributedReqInput,
|
19
|
+
UpdateWeightsFromDistributedReqOutput,
|
20
|
+
UpdateWeightsFromTensorReqInput,
|
21
|
+
UpdateWeightsFromTensorReqOutput,
|
22
|
+
)
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class SchedulerUpdateWeightsMixin:
|
28
|
+
|
29
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
30
|
+
"""In-place update of the weights from disk."""
|
31
|
+
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
32
|
+
if success:
|
33
|
+
flush_cache_success = self.flush_cache()
|
34
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
35
|
+
else:
|
36
|
+
logger.error(message)
|
37
|
+
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
38
|
+
|
39
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
40
|
+
"""Initialize the online model parameter update group."""
|
41
|
+
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
42
|
+
return InitWeightsUpdateGroupReqOutput(success, message)
|
43
|
+
|
44
|
+
def update_weights_from_distributed(
|
45
|
+
self,
|
46
|
+
recv_req: UpdateWeightsFromDistributedReqInput,
|
47
|
+
) -> Tuple[bool, str]:
|
48
|
+
"""Update the online model parameter."""
|
49
|
+
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
50
|
+
if success:
|
51
|
+
if recv_req.flush_cache:
|
52
|
+
flush_cache_success = self.flush_cache()
|
53
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
54
|
+
else:
|
55
|
+
logger.error(message)
|
56
|
+
return UpdateWeightsFromDistributedReqOutput(success, message)
|
57
|
+
|
58
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
59
|
+
"""Update the online model parameter from tensors."""
|
60
|
+
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
61
|
+
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
62
|
+
if success:
|
63
|
+
if recv_req.flush_cache:
|
64
|
+
flush_cache_success = self.flush_cache()
|
65
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
66
|
+
else:
|
67
|
+
logger.error(message)
|
68
|
+
torch.distributed.barrier(group=self.tp_cpu_group)
|
69
|
+
return UpdateWeightsFromTensorReqOutput(success, message)
|
70
|
+
|
71
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
72
|
+
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
73
|
+
return GetWeightsByNameReqOutput(parameter)
|
74
|
+
|
75
|
+
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
76
|
+
tags = recv_req.tags
|
77
|
+
|
78
|
+
if tags is None or len(tags) == 0:
|
79
|
+
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
80
|
+
|
81
|
+
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
82
|
+
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
83
|
+
self.flush_cache()
|
84
|
+
|
85
|
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
86
|
+
self.stashed_model_static_state = _export_static_state(
|
87
|
+
self.tp_worker.worker.model_runner.model
|
88
|
+
)
|
89
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
90
|
+
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
91
|
+
|
92
|
+
return ReleaseMemoryOccupationReqOutput()
|
93
|
+
|
94
|
+
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
95
|
+
tags = recv_req.tags
|
96
|
+
|
97
|
+
if tags is None or len(tags) == 0:
|
98
|
+
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
99
|
+
|
100
|
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
101
|
+
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
102
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
103
|
+
_import_static_state(
|
104
|
+
self.tp_worker.worker.model_runner.model,
|
105
|
+
self.stashed_model_static_state,
|
106
|
+
)
|
107
|
+
del self.stashed_model_static_state
|
108
|
+
|
109
|
+
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
110
|
+
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
111
|
+
|
112
|
+
return ResumeMemoryOccupationReqOutput()
|
113
|
+
|
114
|
+
def save_remote_model(self, params):
|
115
|
+
url = params["url"]
|
116
|
+
|
117
|
+
worker = self.tp_worker.worker
|
118
|
+
|
119
|
+
worker.model_runner.save_remote_model(url)
|
120
|
+
|
121
|
+
def save_sharded_model(self, params):
|
122
|
+
worker = self.tp_worker.worker
|
123
|
+
|
124
|
+
worker.model_runner.save_sharded_model(
|
125
|
+
path=params["path"],
|
126
|
+
pattern=params["pattern"],
|
127
|
+
max_size=params["max_size"],
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
def _export_static_state(model):
|
132
|
+
return dict(
|
133
|
+
buffers=[
|
134
|
+
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
135
|
+
]
|
136
|
+
)
|
137
|
+
|
138
|
+
|
139
|
+
def _import_static_state(model, static_params):
|
140
|
+
self_named_buffers = dict(model.named_buffers())
|
141
|
+
for name, tensor in static_params["buffers"]:
|
142
|
+
self_named_buffers[name][...] = tensor
|