sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,106 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import logging
15
+ from contextlib import contextmanager
16
+ from enum import Enum, auto
17
+ from typing import Any, List, Optional
18
+
19
+ from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
20
+ from sglang.srt.poll_based_barrier import PollBasedBarrier
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class SchedulerInputBlocker:
26
+ def __init__(self, noop: bool):
27
+ self._state = _State.UNBLOCKED
28
+ self._pending_reqs = []
29
+ self._noop = noop
30
+ self._global_unblock_barrier = PollBasedBarrier(noop=noop)
31
+
32
+ def handle(self, recv_reqs: Optional[List[Any]]):
33
+ assert (recv_reqs is None) == self._noop
34
+
35
+ if not self._noop:
36
+ output_reqs = []
37
+ for recv_req in recv_reqs:
38
+ output_reqs += self._handle_recv_req(recv_req)
39
+
40
+ global_arrived_unblock_barrier = (
41
+ self._global_unblock_barrier.poll_global_arrived()
42
+ )
43
+ if (
44
+ self._state == _State.GLOBAL_UNBLOCK_BARRIER
45
+ and global_arrived_unblock_barrier
46
+ ):
47
+ output_reqs += self._handle_arrive_unblock_barrier()
48
+
49
+ if not self._noop:
50
+ return output_reqs
51
+
52
+ def _handle_recv_req(self, recv_req):
53
+ if isinstance(recv_req, BlockReqInput):
54
+ if recv_req.type == BlockReqType.BLOCK:
55
+ self._execute_block_req()
56
+ return []
57
+ elif recv_req.type == BlockReqType.UNBLOCK:
58
+ self._execute_unblock_req()
59
+ return []
60
+ else:
61
+ raise NotImplementedError(f"{recv_req=}")
62
+ else:
63
+ if self._state == _State.UNBLOCKED:
64
+ return [recv_req]
65
+ else:
66
+ self._pending_reqs.append(recv_req)
67
+ return []
68
+
69
+ def _execute_block_req(self):
70
+ logger.info("Handle block req")
71
+ self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
72
+
73
+ def _execute_unblock_req(self):
74
+ logger.info("Handle unblock req")
75
+ self._change_state(
76
+ original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
77
+ )
78
+ self._global_unblock_barrier.local_arrive()
79
+
80
+ def _handle_arrive_unblock_barrier(self):
81
+ logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
82
+ self._change_state(
83
+ original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
84
+ )
85
+ output_reqs = [*self._pending_reqs]
86
+ self._pending_reqs.clear()
87
+ return output_reqs
88
+
89
+ def _change_state(self, original: "_State", target: "_State"):
90
+ assert self._state == original, f"{self._state=} {original=} {target=}"
91
+ self._state = target
92
+
93
+
94
+ class _State(Enum):
95
+ UNBLOCKED = auto()
96
+ BLOCKED = auto()
97
+ GLOBAL_UNBLOCK_BARRIER = auto()
98
+
99
+
100
+ @contextmanager
101
+ def input_blocker_guard_region(send_to_scheduler):
102
+ send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
103
+ try:
104
+ yield
105
+ finally:
106
+ send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
@@ -0,0 +1,229 @@
1
+ import logging
2
+ import time
3
+ from collections import defaultdict
4
+ from typing import List, Optional
5
+
6
+ from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
7
+ from sglang.srt.disaggregation.utils import DisaggregationMode
8
+ from sglang.srt.managers.schedule_policy import PrefillAdder
9
+ from sglang.srt.managers.scheduler import Req, ScheduleBatch
10
+ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
11
+ from sglang.srt.utils import get_bool_env_var
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
16
+
17
+
18
+ class KvMetrics:
19
+ def __init__(self):
20
+ self.request_active_slots = None
21
+ self.request_total_slots = None
22
+ self.kv_active_blocks = None
23
+ self.kv_total_blocks = None
24
+ self.num_requests_waiting = None
25
+ self.gpu_cache_usage_perc = None
26
+ self.gpu_prefix_cache_hit_rate = None
27
+ self.data_parallel_rank = None
28
+
29
+
30
+ class SchedulerMetricsMixin:
31
+ def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
32
+ self.last_gen_throughput: float = 0.0
33
+ self.last_input_throughput: float = 0.0
34
+ self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
35
+ self.spec_num_total_accepted_tokens = 0
36
+ self.spec_num_total_forward_ct = 0
37
+ self.cum_spec_accept_length = 0
38
+ self.cum_spec_accept_count = 0
39
+ self.total_retracted_reqs = 0
40
+ self.stats = SchedulerStats()
41
+ if self.enable_metrics:
42
+ engine_type = "unified"
43
+ labels = {
44
+ "model_name": self.server_args.served_model_name,
45
+ "engine_type": engine_type,
46
+ "tp_rank": tp_rank,
47
+ "pp_rank": pp_rank,
48
+ }
49
+ if dp_rank is not None:
50
+ labels["dp_rank"] = dp_rank
51
+ self.metrics_collector = SchedulerMetricsCollector(labels=labels)
52
+
53
+ def init_kv_events(self, kv_events_config: Optional[str]):
54
+ if self.enable_kv_cache_events:
55
+ self.kv_event_publisher = EventPublisherFactory.create(
56
+ kv_events_config, self.attn_dp_rank
57
+ )
58
+
59
+ def log_prefill_stats(
60
+ self,
61
+ adder: PrefillAdder,
62
+ can_run_list: List[Req],
63
+ running_bs: int,
64
+ ):
65
+ gap_latency = time.perf_counter() - self.last_prefill_stats_tic
66
+ self.last_prefill_stats_tic = time.perf_counter()
67
+ self.last_input_throughput = self.last_prefill_tokens / gap_latency
68
+ self.last_prefill_tokens = adder.log_input_tokens
69
+
70
+ if self.is_hybrid:
71
+ (
72
+ full_num_used,
73
+ swa_num_used,
74
+ full_token_usage,
75
+ swa_token_usage,
76
+ _,
77
+ _,
78
+ _,
79
+ _,
80
+ ) = self._get_swa_token_info()
81
+ num_used = max(full_num_used, swa_num_used)
82
+ token_usage = max(full_token_usage, swa_token_usage)
83
+ token_msg = (
84
+ f"full token usage: {full_token_usage:.2f}, "
85
+ f"swa token usage: {swa_token_usage:.2f}, "
86
+ )
87
+ else:
88
+ num_used, token_usage, _, _ = self._get_token_info()
89
+ token_msg = f"token usage: {token_usage:.2f}, "
90
+
91
+ num_new_seq = len(can_run_list)
92
+ f = (
93
+ f"Prefill batch. "
94
+ f"#new-seq: {num_new_seq}, "
95
+ f"#new-token: {adder.log_input_tokens}, "
96
+ f"#cached-token: {adder.log_hit_tokens}, "
97
+ f"{token_msg}"
98
+ )
99
+
100
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
101
+ f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
102
+ f += f"#queue-req: {len(self.waiting_queue)}, "
103
+ f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
104
+ f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
105
+ else:
106
+ f += f"#running-req: {running_bs}, "
107
+ f += f"#queue-req: {len(self.waiting_queue)}, "
108
+
109
+ logger.info(f)
110
+
111
+ if self.enable_metrics:
112
+ total_tokens = adder.log_input_tokens + adder.log_hit_tokens
113
+
114
+ cache_hit_rate = (
115
+ adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
116
+ )
117
+ self.stats.num_running_reqs = running_bs
118
+ self.stats.num_used_tokens = num_used
119
+ self.stats.token_usage = round(token_usage, 2)
120
+ self.stats.num_queue_reqs = len(self.waiting_queue)
121
+ self.stats.cache_hit_rate = cache_hit_rate
122
+
123
+ total_queue_latency = 0
124
+ for req in can_run_list:
125
+ total_queue_latency += req.queue_time_end - req.queue_time_start
126
+ self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
127
+
128
+ self.metrics_collector.log_stats(self.stats)
129
+ self._emit_kv_metrics()
130
+ self._publish_kv_events()
131
+
132
+ def log_decode_stats(
133
+ self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
134
+ ):
135
+ batch = running_batch or self.running_batch
136
+
137
+ gap_latency = time.perf_counter() - self.last_decode_stats_tic
138
+ self.last_decode_stats_tic = time.perf_counter()
139
+ self.last_gen_throughput = self.num_generated_tokens / gap_latency
140
+ self.num_generated_tokens = 0
141
+ num_running_reqs = len(batch.reqs)
142
+ if self.is_hybrid:
143
+ (
144
+ full_num_used,
145
+ swa_num_used,
146
+ full_token_usage,
147
+ swa_token_usage,
148
+ _,
149
+ _,
150
+ _,
151
+ _,
152
+ ) = self._get_swa_token_info()
153
+ num_used = max(full_num_used, swa_num_used)
154
+ token_usage = max(full_token_usage, swa_token_usage)
155
+ token_msg = (
156
+ f"#full token: {full_num_used}, "
157
+ f"full token usage: {full_token_usage:.2f}, "
158
+ f"#swa token: {swa_num_used}, "
159
+ f"swa token usage: {swa_token_usage:.2f}, "
160
+ )
161
+ else:
162
+ num_used, token_usage, _, _ = self._get_token_info()
163
+ token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
164
+
165
+ if RECORD_STEP_TIME:
166
+ self.step_time_dict[num_running_reqs].append(
167
+ gap_latency / self.server_args.decode_log_interval
168
+ )
169
+
170
+ msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
171
+
172
+ if self.spec_algorithm.is_none():
173
+ spec_accept_length = 0
174
+ else:
175
+ spec_accept_length = (
176
+ self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
177
+ )
178
+ self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
179
+ self.cum_spec_accept_count += self.spec_num_total_forward_ct
180
+ self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
181
+ msg += f"accept len: {spec_accept_length:.2f}, "
182
+
183
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
184
+ msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
185
+ msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
186
+
187
+ msg += (
188
+ f"cuda graph: {can_run_cuda_graph}, "
189
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
190
+ f"#queue-req: {len(self.waiting_queue)}, "
191
+ )
192
+
193
+ logger.info(msg)
194
+ if self.enable_metrics:
195
+ self.stats.num_running_reqs = num_running_reqs
196
+ self.stats.num_used_tokens = num_used
197
+ self.stats.token_usage = round(token_usage, 2)
198
+ self.stats.cache_hit_rate = 0.0
199
+ self.stats.gen_throughput = self.last_gen_throughput
200
+ self.stats.num_queue_reqs = len(self.waiting_queue)
201
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
202
+ self.stats.spec_accept_length = spec_accept_length
203
+ self.stats.total_retracted_reqs = self.total_retracted_reqs
204
+ self.metrics_collector.log_stats(self.stats)
205
+ self._emit_kv_metrics()
206
+ self._publish_kv_events()
207
+
208
+ def _emit_kv_metrics(self):
209
+ kv_metrics = KvMetrics()
210
+ kv_metrics.request_active_slots = self.stats.num_running_reqs
211
+ kv_metrics.request_total_slots = self.max_running_requests
212
+ kv_metrics.kv_active_blocks = int(
213
+ self.stats.token_usage * self.max_total_num_tokens
214
+ )
215
+ kv_metrics.kv_total_blocks = self.max_total_num_tokens
216
+ kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
217
+ kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
218
+ kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
219
+ kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
220
+
221
+ if not self.send_metrics_from_scheduler.closed:
222
+ self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
223
+
224
+ def _publish_kv_events(self):
225
+ if self.enable_kv_cache_events:
226
+ events = self.tree_cache.take_events()
227
+ if events:
228
+ batch = KVEventBatch(ts=time.time(), events=events)
229
+ self.kv_event_publisher.publish(batch)
@@ -0,0 +1,279 @@
1
+ import logging
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+
9
+ from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class SchedulerProfilerMixin:
16
+
17
+ def init_profier(self):
18
+ self.torch_profiler = None
19
+ self.torch_profiler_output_dir: Optional[str] = None
20
+ self.profiler_activities: Optional[List[str]] = None
21
+ self.profile_id: Optional[str] = None
22
+ self.profiler_start_forward_ct: Optional[int] = None
23
+ self.profiler_target_forward_ct: Optional[int] = None
24
+ self.profiler_target_prefill_ct: Optional[int] = None
25
+ self.profiler_target_decode_ct: Optional[int] = None
26
+ self.profiler_prefill_ct: Optional[int] = None
27
+ self.profiler_decode_ct: Optional[int] = None
28
+ self.profile_by_stage: bool = False
29
+ self.profile_steps: Optional[int] = None
30
+ self.profile_in_progress: bool = False
31
+ self.rpd_profiler = None
32
+
33
+ def init_profile(
34
+ self,
35
+ output_dir: Optional[str],
36
+ start_step: Optional[int],
37
+ num_steps: Optional[int],
38
+ activities: Optional[List[str]],
39
+ with_stack: Optional[bool],
40
+ record_shapes: Optional[bool],
41
+ profile_by_stage: bool,
42
+ profile_id: str,
43
+ ) -> ProfileReqOutput:
44
+ if self.profile_in_progress:
45
+ return ProfileReqOutput(
46
+ success=False,
47
+ message="Profiling is already in progress. Call /stop_profile first.",
48
+ )
49
+
50
+ self.profile_by_stage = profile_by_stage
51
+
52
+ if output_dir is None:
53
+ output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
54
+ if activities is None:
55
+ activities = ["CPU", "GPU"]
56
+
57
+ self.torch_profiler_output_dir = output_dir
58
+ self.torch_profiler_with_stack = with_stack
59
+ self.torch_profiler_record_shapes = record_shapes
60
+ self.profiler_activities = activities
61
+ self.profile_id = profile_id
62
+
63
+ if start_step:
64
+ self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
65
+
66
+ if num_steps:
67
+ self.profile_steps = num_steps
68
+ if self.profile_by_stage:
69
+ self.profiler_target_prefill_ct = num_steps
70
+ self.profiler_target_decode_ct = num_steps
71
+ self.profiler_prefill_ct = 0
72
+ self.profiler_decode_ct = 0
73
+ elif start_step:
74
+ self.profiler_target_forward_ct = (
75
+ self.profiler_start_forward_ct + num_steps
76
+ )
77
+ else:
78
+ self.profiler_target_forward_ct = self.forward_ct + num_steps
79
+ # The caller will be notified when reaching profiler_target_forward_ct
80
+ else:
81
+ self.profiler_target_forward_ct = None
82
+
83
+ return ProfileReqOutput(success=True, message="Succeeded")
84
+
85
+ def start_profile(
86
+ self, stage: Optional[ForwardMode] = None
87
+ ) -> ProfileReqOutput | None:
88
+ stage_str = f" for {stage.__str__()}" if stage else ""
89
+ logger.info(
90
+ f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
91
+ )
92
+
93
+ activities = self.profiler_activities
94
+ with_stack = self.torch_profiler_with_stack
95
+ record_shapes = self.torch_profiler_record_shapes
96
+
97
+ activity_map = {
98
+ "CPU": torch.profiler.ProfilerActivity.CPU,
99
+ "GPU": torch.profiler.ProfilerActivity.CUDA,
100
+ }
101
+ torchprof_activities = [
102
+ activity_map[a] for a in activities if a in activity_map
103
+ ]
104
+
105
+ if "RPD" in activities:
106
+ from rpdTracerControl import rpdTracerControl
107
+
108
+ rpdTracerControl.skipCreate()
109
+
110
+ self.rpd_profile_path = os.path.join(
111
+ self.torch_profiler_output_dir,
112
+ "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
113
+ )
114
+
115
+ if self.tp_rank == 0:
116
+ import sqlite3
117
+
118
+ from rocpd.schema import RocpdSchema
119
+
120
+ if os.path.exists("trace.rpd"):
121
+ os.unlink("trace.rpd")
122
+ schema = RocpdSchema()
123
+ connection = sqlite3.connect("trace.rpd")
124
+ schema.writeSchema(connection)
125
+ connection.commit()
126
+ del connection
127
+ torch.distributed.barrier(self.tp_cpu_group)
128
+
129
+ self.rpd_profiler = rpdTracerControl()
130
+ self.rpd_profiler.setPythonTrace(True)
131
+ self.rpd_profiler.start()
132
+ self.rpd_profiler.rangePush("", "rpd profile range", "")
133
+ self.profile_in_progress = True
134
+ elif torchprof_activities:
135
+ self.torch_profiler = torch.profiler.profile(
136
+ activities=torchprof_activities,
137
+ with_stack=with_stack if with_stack is not None else True,
138
+ record_shapes=record_shapes if record_shapes is not None else False,
139
+ )
140
+ self.torch_profiler.start()
141
+ self.profile_in_progress = True
142
+
143
+ if "MEM" in activities:
144
+ torch.cuda.memory._record_memory_history(max_entries=100000)
145
+ self.profile_in_progress = True
146
+
147
+ if "CUDA_PROFILER" in activities:
148
+ torch.cuda.cudart().cudaProfilerStart()
149
+ self.profile_in_progress = True
150
+
151
+ return ProfileReqOutput(success=True, message="Succeeded")
152
+
153
+ def stop_profile(
154
+ self, stage: Optional[ForwardMode] = None
155
+ ) -> ProfileReqOutput | None:
156
+ if not self.profile_in_progress:
157
+ return ProfileReqOutput(
158
+ success=False,
159
+ message="Profiling is not in progress. Call /start_profile first.",
160
+ )
161
+
162
+ if not Path(self.torch_profiler_output_dir).exists():
163
+ Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
164
+
165
+ stage_suffix = f"-{stage.__str__()}" if stage else ""
166
+ logger.info("Stop profiling" + stage_suffix + "...")
167
+ if self.torch_profiler is not None:
168
+ self.torch_profiler.stop()
169
+ self.torch_profiler.export_chrome_trace(
170
+ os.path.join(
171
+ self.torch_profiler_output_dir,
172
+ self.profile_id
173
+ + f"-TP-{self.tp_rank}"
174
+ + stage_suffix
175
+ + ".trace.json.gz",
176
+ )
177
+ )
178
+ torch.distributed.barrier(self.tp_cpu_group)
179
+
180
+ if self.rpd_profiler is not None:
181
+ self.rpd_profiler.rangePop()
182
+ self.rpd_profiler.stop()
183
+ self.rpd_profiler.flush()
184
+
185
+ torch.distributed.barrier(self.tp_cpu_group)
186
+ if self.tp_rank == 0:
187
+ from sglang.srt.utils import rpd_to_chrome_trace
188
+
189
+ rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
190
+ self.rpd_profiler = None
191
+ self.rpd_profiler_path = None
192
+
193
+ if self.profiler_activities is not None and "MEM" in self.profiler_activities:
194
+ memory_profile_path = os.path.join(
195
+ self.torch_profiler_output_dir,
196
+ str(time.time())
197
+ + f"-TP-{self.tp_rank}-memory"
198
+ + stage_suffix
199
+ + ".pickle",
200
+ )
201
+ torch.cuda.memory._dump_snapshot(memory_profile_path)
202
+ torch.cuda.memory._record_memory_history(enabled=None)
203
+
204
+ if "CUDA_PROFILER" in self.profiler_activities:
205
+ torch.cuda.cudart().cudaProfilerStop()
206
+
207
+ logger.info(
208
+ "Profiling done. Traces are saved to: %s",
209
+ self.torch_profiler_output_dir,
210
+ )
211
+ self.torch_profiler = None
212
+ self.profile_in_progress = False
213
+ self.profiler_start_forward_ct = None
214
+
215
+ return ProfileReqOutput(success=True, message="Succeeded.")
216
+
217
+ def _profile_batch_predicate(self, batch):
218
+ if self.profile_by_stage:
219
+ if batch.forward_mode.is_prefill():
220
+ if self.profiler_prefill_ct == 0:
221
+ self.start_profile(batch.forward_mode)
222
+ self.profiler_prefill_ct += 1
223
+ if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
224
+ if self.profile_in_progress:
225
+ self.stop_profile(stage=ForwardMode.EXTEND)
226
+ elif batch.forward_mode.is_decode():
227
+ if self.profiler_decode_ct == 0:
228
+ if self.profile_in_progress:
229
+ # force trace flush
230
+ self.stop_profile(ForwardMode.EXTEND)
231
+ self.start_profile(batch.forward_mode)
232
+ self.profiler_decode_ct += 1
233
+ if self.profiler_decode_ct > self.profiler_target_decode_ct:
234
+ if self.profile_in_progress:
235
+ self.stop_profile(stage=ForwardMode.DECODE)
236
+ elif batch.forward_mode.is_idle():
237
+ pass
238
+ else:
239
+ raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
240
+ else:
241
+ # Check profiler
242
+ if (
243
+ self.profiler_target_forward_ct
244
+ and self.profiler_target_forward_ct <= self.forward_ct
245
+ ):
246
+ self.stop_profile()
247
+ if (
248
+ self.profiler_start_forward_ct
249
+ and self.profiler_start_forward_ct == self.forward_ct
250
+ ):
251
+ self.start_profile()
252
+
253
+ def profile(self, recv_req: ProfileReq):
254
+ if recv_req.type == ProfileReqType.START_PROFILE:
255
+ if recv_req.profile_by_stage or recv_req.start_step:
256
+ return self.init_profile(
257
+ recv_req.output_dir,
258
+ recv_req.start_step,
259
+ recv_req.num_steps,
260
+ recv_req.activities,
261
+ recv_req.with_stack,
262
+ recv_req.record_shapes,
263
+ recv_req.profile_by_stage,
264
+ recv_req.profile_id,
265
+ )
266
+ else:
267
+ self.init_profile(
268
+ recv_req.output_dir,
269
+ recv_req.start_step,
270
+ recv_req.num_steps,
271
+ recv_req.activities,
272
+ recv_req.with_stack,
273
+ recv_req.record_shapes,
274
+ recv_req.profile_by_stage,
275
+ recv_req.profile_id,
276
+ )
277
+ return self.start_profile(True)
278
+ else:
279
+ return self.stop_profile()