sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,229 @@
1
+ import logging
2
+ import time
3
+ from collections import defaultdict
4
+ from typing import List, Optional
5
+
6
+ from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
7
+ from sglang.srt.disaggregation.utils import DisaggregationMode
8
+ from sglang.srt.managers.schedule_policy import PrefillAdder
9
+ from sglang.srt.managers.scheduler import Req, ScheduleBatch
10
+ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
11
+ from sglang.srt.utils import get_bool_env_var
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
16
+
17
+
18
+ class KvMetrics:
19
+ def __init__(self):
20
+ self.request_active_slots = None
21
+ self.request_total_slots = None
22
+ self.kv_active_blocks = None
23
+ self.kv_total_blocks = None
24
+ self.num_requests_waiting = None
25
+ self.gpu_cache_usage_perc = None
26
+ self.gpu_prefix_cache_hit_rate = None
27
+ self.data_parallel_rank = None
28
+
29
+
30
+ class SchedulerMetricsMixin:
31
+ def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
32
+ self.last_gen_throughput: float = 0.0
33
+ self.last_input_throughput: float = 0.0
34
+ self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
35
+ self.spec_num_total_accepted_tokens = 0
36
+ self.spec_num_total_forward_ct = 0
37
+ self.cum_spec_accept_length = 0
38
+ self.cum_spec_accept_count = 0
39
+ self.total_retracted_reqs = 0
40
+ self.stats = SchedulerStats()
41
+ if self.enable_metrics:
42
+ engine_type = "unified"
43
+ labels = {
44
+ "model_name": self.server_args.served_model_name,
45
+ "engine_type": engine_type,
46
+ "tp_rank": tp_rank,
47
+ "pp_rank": pp_rank,
48
+ }
49
+ if dp_rank is not None:
50
+ labels["dp_rank"] = dp_rank
51
+ self.metrics_collector = SchedulerMetricsCollector(labels=labels)
52
+
53
+ def init_kv_events(self, kv_events_config: Optional[str]):
54
+ if self.enable_kv_cache_events:
55
+ self.kv_event_publisher = EventPublisherFactory.create(
56
+ kv_events_config, self.attn_dp_rank
57
+ )
58
+
59
+ def log_prefill_stats(
60
+ self,
61
+ adder: PrefillAdder,
62
+ can_run_list: List[Req],
63
+ running_bs: int,
64
+ ):
65
+ gap_latency = time.perf_counter() - self.last_prefill_stats_tic
66
+ self.last_prefill_stats_tic = time.perf_counter()
67
+ self.last_input_throughput = self.last_prefill_tokens / gap_latency
68
+ self.last_prefill_tokens = adder.log_input_tokens
69
+
70
+ if self.is_hybrid:
71
+ (
72
+ full_num_used,
73
+ swa_num_used,
74
+ full_token_usage,
75
+ swa_token_usage,
76
+ _,
77
+ _,
78
+ _,
79
+ _,
80
+ ) = self._get_swa_token_info()
81
+ num_used = max(full_num_used, swa_num_used)
82
+ token_usage = max(full_token_usage, swa_token_usage)
83
+ token_msg = (
84
+ f"full token usage: {full_token_usage:.2f}, "
85
+ f"swa token usage: {swa_token_usage:.2f}, "
86
+ )
87
+ else:
88
+ num_used, token_usage, _, _ = self._get_token_info()
89
+ token_msg = f"token usage: {token_usage:.2f}, "
90
+
91
+ num_new_seq = len(can_run_list)
92
+ f = (
93
+ f"Prefill batch. "
94
+ f"#new-seq: {num_new_seq}, "
95
+ f"#new-token: {adder.log_input_tokens}, "
96
+ f"#cached-token: {adder.log_hit_tokens}, "
97
+ f"{token_msg}"
98
+ )
99
+
100
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
101
+ f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
102
+ f += f"#queue-req: {len(self.waiting_queue)}, "
103
+ f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
104
+ f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
105
+ else:
106
+ f += f"#running-req: {running_bs}, "
107
+ f += f"#queue-req: {len(self.waiting_queue)}, "
108
+
109
+ logger.info(f)
110
+
111
+ if self.enable_metrics:
112
+ total_tokens = adder.log_input_tokens + adder.log_hit_tokens
113
+
114
+ cache_hit_rate = (
115
+ adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
116
+ )
117
+ self.stats.num_running_reqs = running_bs
118
+ self.stats.num_used_tokens = num_used
119
+ self.stats.token_usage = round(token_usage, 2)
120
+ self.stats.num_queue_reqs = len(self.waiting_queue)
121
+ self.stats.cache_hit_rate = cache_hit_rate
122
+
123
+ total_queue_latency = 0
124
+ for req in can_run_list:
125
+ total_queue_latency += req.queue_time_end - req.queue_time_start
126
+ self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
127
+
128
+ self.metrics_collector.log_stats(self.stats)
129
+ self._emit_kv_metrics()
130
+ self._publish_kv_events()
131
+
132
+ def log_decode_stats(
133
+ self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
134
+ ):
135
+ batch = running_batch or self.running_batch
136
+
137
+ gap_latency = time.perf_counter() - self.last_decode_stats_tic
138
+ self.last_decode_stats_tic = time.perf_counter()
139
+ self.last_gen_throughput = self.num_generated_tokens / gap_latency
140
+ self.num_generated_tokens = 0
141
+ num_running_reqs = len(batch.reqs)
142
+ if self.is_hybrid:
143
+ (
144
+ full_num_used,
145
+ swa_num_used,
146
+ full_token_usage,
147
+ swa_token_usage,
148
+ _,
149
+ _,
150
+ _,
151
+ _,
152
+ ) = self._get_swa_token_info()
153
+ num_used = max(full_num_used, swa_num_used)
154
+ token_usage = max(full_token_usage, swa_token_usage)
155
+ token_msg = (
156
+ f"#full token: {full_num_used}, "
157
+ f"full token usage: {full_token_usage:.2f}, "
158
+ f"#swa token: {swa_num_used}, "
159
+ f"swa token usage: {swa_token_usage:.2f}, "
160
+ )
161
+ else:
162
+ num_used, token_usage, _, _ = self._get_token_info()
163
+ token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
164
+
165
+ if RECORD_STEP_TIME:
166
+ self.step_time_dict[num_running_reqs].append(
167
+ gap_latency / self.server_args.decode_log_interval
168
+ )
169
+
170
+ msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
171
+
172
+ if self.spec_algorithm.is_none():
173
+ spec_accept_length = 0
174
+ else:
175
+ spec_accept_length = (
176
+ self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
177
+ )
178
+ self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
179
+ self.cum_spec_accept_count += self.spec_num_total_forward_ct
180
+ self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
181
+ msg += f"accept len: {spec_accept_length:.2f}, "
182
+
183
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
184
+ msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
185
+ msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
186
+
187
+ msg += (
188
+ f"cuda graph: {can_run_cuda_graph}, "
189
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
190
+ f"#queue-req: {len(self.waiting_queue)}, "
191
+ )
192
+
193
+ logger.info(msg)
194
+ if self.enable_metrics:
195
+ self.stats.num_running_reqs = num_running_reqs
196
+ self.stats.num_used_tokens = num_used
197
+ self.stats.token_usage = round(token_usage, 2)
198
+ self.stats.cache_hit_rate = 0.0
199
+ self.stats.gen_throughput = self.last_gen_throughput
200
+ self.stats.num_queue_reqs = len(self.waiting_queue)
201
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
202
+ self.stats.spec_accept_length = spec_accept_length
203
+ self.stats.total_retracted_reqs = self.total_retracted_reqs
204
+ self.metrics_collector.log_stats(self.stats)
205
+ self._emit_kv_metrics()
206
+ self._publish_kv_events()
207
+
208
+ def _emit_kv_metrics(self):
209
+ kv_metrics = KvMetrics()
210
+ kv_metrics.request_active_slots = self.stats.num_running_reqs
211
+ kv_metrics.request_total_slots = self.max_running_requests
212
+ kv_metrics.kv_active_blocks = int(
213
+ self.stats.token_usage * self.max_total_num_tokens
214
+ )
215
+ kv_metrics.kv_total_blocks = self.max_total_num_tokens
216
+ kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
217
+ kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
218
+ kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
219
+ kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
220
+
221
+ if not self.send_metrics_from_scheduler.closed:
222
+ self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
223
+
224
+ def _publish_kv_events(self):
225
+ if self.enable_kv_cache_events:
226
+ events = self.tree_cache.take_events()
227
+ if events:
228
+ batch = KVEventBatch(ts=time.time(), events=events)
229
+ self.kv_event_publisher.publish(batch)
@@ -0,0 +1,279 @@
1
+ import logging
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+
9
+ from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class SchedulerProfilerMixin:
16
+
17
+ def init_profier(self):
18
+ self.torch_profiler = None
19
+ self.torch_profiler_output_dir: Optional[str] = None
20
+ self.profiler_activities: Optional[List[str]] = None
21
+ self.profile_id: Optional[str] = None
22
+ self.profiler_start_forward_ct: Optional[int] = None
23
+ self.profiler_target_forward_ct: Optional[int] = None
24
+ self.profiler_target_prefill_ct: Optional[int] = None
25
+ self.profiler_target_decode_ct: Optional[int] = None
26
+ self.profiler_prefill_ct: Optional[int] = None
27
+ self.profiler_decode_ct: Optional[int] = None
28
+ self.profile_by_stage: bool = False
29
+ self.profile_steps: Optional[int] = None
30
+ self.profile_in_progress: bool = False
31
+ self.rpd_profiler = None
32
+
33
+ def init_profile(
34
+ self,
35
+ output_dir: Optional[str],
36
+ start_step: Optional[int],
37
+ num_steps: Optional[int],
38
+ activities: Optional[List[str]],
39
+ with_stack: Optional[bool],
40
+ record_shapes: Optional[bool],
41
+ profile_by_stage: bool,
42
+ profile_id: str,
43
+ ) -> ProfileReqOutput:
44
+ if self.profile_in_progress:
45
+ return ProfileReqOutput(
46
+ success=False,
47
+ message="Profiling is already in progress. Call /stop_profile first.",
48
+ )
49
+
50
+ self.profile_by_stage = profile_by_stage
51
+
52
+ if output_dir is None:
53
+ output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
54
+ if activities is None:
55
+ activities = ["CPU", "GPU"]
56
+
57
+ self.torch_profiler_output_dir = output_dir
58
+ self.torch_profiler_with_stack = with_stack
59
+ self.torch_profiler_record_shapes = record_shapes
60
+ self.profiler_activities = activities
61
+ self.profile_id = profile_id
62
+
63
+ if start_step:
64
+ self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
65
+
66
+ if num_steps:
67
+ self.profile_steps = num_steps
68
+ if self.profile_by_stage:
69
+ self.profiler_target_prefill_ct = num_steps
70
+ self.profiler_target_decode_ct = num_steps
71
+ self.profiler_prefill_ct = 0
72
+ self.profiler_decode_ct = 0
73
+ elif start_step:
74
+ self.profiler_target_forward_ct = (
75
+ self.profiler_start_forward_ct + num_steps
76
+ )
77
+ else:
78
+ self.profiler_target_forward_ct = self.forward_ct + num_steps
79
+ # The caller will be notified when reaching profiler_target_forward_ct
80
+ else:
81
+ self.profiler_target_forward_ct = None
82
+
83
+ return ProfileReqOutput(success=True, message="Succeeded")
84
+
85
+ def start_profile(
86
+ self, stage: Optional[ForwardMode] = None
87
+ ) -> ProfileReqOutput | None:
88
+ stage_str = f" for {stage.__str__()}" if stage else ""
89
+ logger.info(
90
+ f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
91
+ )
92
+
93
+ activities = self.profiler_activities
94
+ with_stack = self.torch_profiler_with_stack
95
+ record_shapes = self.torch_profiler_record_shapes
96
+
97
+ activity_map = {
98
+ "CPU": torch.profiler.ProfilerActivity.CPU,
99
+ "GPU": torch.profiler.ProfilerActivity.CUDA,
100
+ }
101
+ torchprof_activities = [
102
+ activity_map[a] for a in activities if a in activity_map
103
+ ]
104
+
105
+ if "RPD" in activities:
106
+ from rpdTracerControl import rpdTracerControl
107
+
108
+ rpdTracerControl.skipCreate()
109
+
110
+ self.rpd_profile_path = os.path.join(
111
+ self.torch_profiler_output_dir,
112
+ "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
113
+ )
114
+
115
+ if self.tp_rank == 0:
116
+ import sqlite3
117
+
118
+ from rocpd.schema import RocpdSchema
119
+
120
+ if os.path.exists("trace.rpd"):
121
+ os.unlink("trace.rpd")
122
+ schema = RocpdSchema()
123
+ connection = sqlite3.connect("trace.rpd")
124
+ schema.writeSchema(connection)
125
+ connection.commit()
126
+ del connection
127
+ torch.distributed.barrier(self.tp_cpu_group)
128
+
129
+ self.rpd_profiler = rpdTracerControl()
130
+ self.rpd_profiler.setPythonTrace(True)
131
+ self.rpd_profiler.start()
132
+ self.rpd_profiler.rangePush("", "rpd profile range", "")
133
+ self.profile_in_progress = True
134
+ elif torchprof_activities:
135
+ self.torch_profiler = torch.profiler.profile(
136
+ activities=torchprof_activities,
137
+ with_stack=with_stack if with_stack is not None else True,
138
+ record_shapes=record_shapes if record_shapes is not None else False,
139
+ )
140
+ self.torch_profiler.start()
141
+ self.profile_in_progress = True
142
+
143
+ if "MEM" in activities:
144
+ torch.cuda.memory._record_memory_history(max_entries=100000)
145
+ self.profile_in_progress = True
146
+
147
+ if "CUDA_PROFILER" in activities:
148
+ torch.cuda.cudart().cudaProfilerStart()
149
+ self.profile_in_progress = True
150
+
151
+ return ProfileReqOutput(success=True, message="Succeeded")
152
+
153
+ def stop_profile(
154
+ self, stage: Optional[ForwardMode] = None
155
+ ) -> ProfileReqOutput | None:
156
+ if not self.profile_in_progress:
157
+ return ProfileReqOutput(
158
+ success=False,
159
+ message="Profiling is not in progress. Call /start_profile first.",
160
+ )
161
+
162
+ if not Path(self.torch_profiler_output_dir).exists():
163
+ Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
164
+
165
+ stage_suffix = f"-{stage.__str__()}" if stage else ""
166
+ logger.info("Stop profiling" + stage_suffix + "...")
167
+ if self.torch_profiler is not None:
168
+ self.torch_profiler.stop()
169
+ self.torch_profiler.export_chrome_trace(
170
+ os.path.join(
171
+ self.torch_profiler_output_dir,
172
+ self.profile_id
173
+ + f"-TP-{self.tp_rank}"
174
+ + stage_suffix
175
+ + ".trace.json.gz",
176
+ )
177
+ )
178
+ torch.distributed.barrier(self.tp_cpu_group)
179
+
180
+ if self.rpd_profiler is not None:
181
+ self.rpd_profiler.rangePop()
182
+ self.rpd_profiler.stop()
183
+ self.rpd_profiler.flush()
184
+
185
+ torch.distributed.barrier(self.tp_cpu_group)
186
+ if self.tp_rank == 0:
187
+ from sglang.srt.utils import rpd_to_chrome_trace
188
+
189
+ rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
190
+ self.rpd_profiler = None
191
+ self.rpd_profiler_path = None
192
+
193
+ if self.profiler_activities is not None and "MEM" in self.profiler_activities:
194
+ memory_profile_path = os.path.join(
195
+ self.torch_profiler_output_dir,
196
+ str(time.time())
197
+ + f"-TP-{self.tp_rank}-memory"
198
+ + stage_suffix
199
+ + ".pickle",
200
+ )
201
+ torch.cuda.memory._dump_snapshot(memory_profile_path)
202
+ torch.cuda.memory._record_memory_history(enabled=None)
203
+
204
+ if "CUDA_PROFILER" in self.profiler_activities:
205
+ torch.cuda.cudart().cudaProfilerStop()
206
+
207
+ logger.info(
208
+ "Profiling done. Traces are saved to: %s",
209
+ self.torch_profiler_output_dir,
210
+ )
211
+ self.torch_profiler = None
212
+ self.profile_in_progress = False
213
+ self.profiler_start_forward_ct = None
214
+
215
+ return ProfileReqOutput(success=True, message="Succeeded.")
216
+
217
+ def _profile_batch_predicate(self, batch):
218
+ if self.profile_by_stage:
219
+ if batch.forward_mode.is_prefill():
220
+ if self.profiler_prefill_ct == 0:
221
+ self.start_profile(batch.forward_mode)
222
+ self.profiler_prefill_ct += 1
223
+ if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
224
+ if self.profile_in_progress:
225
+ self.stop_profile(stage=ForwardMode.EXTEND)
226
+ elif batch.forward_mode.is_decode():
227
+ if self.profiler_decode_ct == 0:
228
+ if self.profile_in_progress:
229
+ # force trace flush
230
+ self.stop_profile(ForwardMode.EXTEND)
231
+ self.start_profile(batch.forward_mode)
232
+ self.profiler_decode_ct += 1
233
+ if self.profiler_decode_ct > self.profiler_target_decode_ct:
234
+ if self.profile_in_progress:
235
+ self.stop_profile(stage=ForwardMode.DECODE)
236
+ elif batch.forward_mode.is_idle():
237
+ pass
238
+ else:
239
+ raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
240
+ else:
241
+ # Check profiler
242
+ if (
243
+ self.profiler_target_forward_ct
244
+ and self.profiler_target_forward_ct <= self.forward_ct
245
+ ):
246
+ self.stop_profile()
247
+ if (
248
+ self.profiler_start_forward_ct
249
+ and self.profiler_start_forward_ct == self.forward_ct
250
+ ):
251
+ self.start_profile()
252
+
253
+ def profile(self, recv_req: ProfileReq):
254
+ if recv_req.type == ProfileReqType.START_PROFILE:
255
+ if recv_req.profile_by_stage or recv_req.start_step:
256
+ return self.init_profile(
257
+ recv_req.output_dir,
258
+ recv_req.start_step,
259
+ recv_req.num_steps,
260
+ recv_req.activities,
261
+ recv_req.with_stack,
262
+ recv_req.record_shapes,
263
+ recv_req.profile_by_stage,
264
+ recv_req.profile_id,
265
+ )
266
+ else:
267
+ self.init_profile(
268
+ recv_req.output_dir,
269
+ recv_req.start_step,
270
+ recv_req.num_steps,
271
+ recv_req.activities,
272
+ recv_req.with_stack,
273
+ recv_req.record_shapes,
274
+ recv_req.profile_by_stage,
275
+ recv_req.profile_id,
276
+ )
277
+ return self.start_profile(True)
278
+ else:
279
+ return self.stop_profile()
@@ -0,0 +1,142 @@
1
+ import logging
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
6
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
7
+ from sglang.srt.managers.io_struct import (
8
+ GetWeightsByNameReqInput,
9
+ GetWeightsByNameReqOutput,
10
+ InitWeightsUpdateGroupReqInput,
11
+ InitWeightsUpdateGroupReqOutput,
12
+ ReleaseMemoryOccupationReqInput,
13
+ ReleaseMemoryOccupationReqOutput,
14
+ ResumeMemoryOccupationReqInput,
15
+ ResumeMemoryOccupationReqOutput,
16
+ UpdateWeightFromDiskReqInput,
17
+ UpdateWeightFromDiskReqOutput,
18
+ UpdateWeightsFromDistributedReqInput,
19
+ UpdateWeightsFromDistributedReqOutput,
20
+ UpdateWeightsFromTensorReqInput,
21
+ UpdateWeightsFromTensorReqOutput,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class SchedulerUpdateWeightsMixin:
28
+
29
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
30
+ """In-place update of the weights from disk."""
31
+ success, message = self.tp_worker.update_weights_from_disk(recv_req)
32
+ if success:
33
+ flush_cache_success = self.flush_cache()
34
+ assert flush_cache_success, "Cache flush failed after updating weights"
35
+ else:
36
+ logger.error(message)
37
+ return UpdateWeightFromDiskReqOutput(success, message, 0)
38
+
39
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
40
+ """Initialize the online model parameter update group."""
41
+ success, message = self.tp_worker.init_weights_update_group(recv_req)
42
+ return InitWeightsUpdateGroupReqOutput(success, message)
43
+
44
+ def update_weights_from_distributed(
45
+ self,
46
+ recv_req: UpdateWeightsFromDistributedReqInput,
47
+ ) -> Tuple[bool, str]:
48
+ """Update the online model parameter."""
49
+ success, message = self.tp_worker.update_weights_from_distributed(recv_req)
50
+ if success:
51
+ if recv_req.flush_cache:
52
+ flush_cache_success = self.flush_cache()
53
+ assert flush_cache_success, "Cache flush failed after updating weights"
54
+ else:
55
+ logger.error(message)
56
+ return UpdateWeightsFromDistributedReqOutput(success, message)
57
+
58
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
59
+ """Update the online model parameter from tensors."""
60
+ success, message = self.tp_worker.update_weights_from_tensor(recv_req)
61
+ # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
62
+ if success:
63
+ if recv_req.flush_cache:
64
+ flush_cache_success = self.flush_cache()
65
+ assert flush_cache_success, "Cache flush failed after updating weights"
66
+ else:
67
+ logger.error(message)
68
+ torch.distributed.barrier(group=self.tp_cpu_group)
69
+ return UpdateWeightsFromTensorReqOutput(success, message)
70
+
71
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
72
+ parameter = self.tp_worker.get_weights_by_name(recv_req)
73
+ return GetWeightsByNameReqOutput(parameter)
74
+
75
+ def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
76
+ tags = recv_req.tags
77
+
78
+ if tags is None or len(tags) == 0:
79
+ tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
80
+
81
+ if GPU_MEMORY_TYPE_KV_CACHE in tags:
82
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
83
+ self.flush_cache()
84
+
85
+ if GPU_MEMORY_TYPE_WEIGHTS in tags:
86
+ self.stashed_model_static_state = _export_static_state(
87
+ self.tp_worker.worker.model_runner.model
88
+ )
89
+ torch.distributed.barrier(self.tp_cpu_group)
90
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
91
+
92
+ return ReleaseMemoryOccupationReqOutput()
93
+
94
+ def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
95
+ tags = recv_req.tags
96
+
97
+ if tags is None or len(tags) == 0:
98
+ tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
99
+
100
+ if GPU_MEMORY_TYPE_WEIGHTS in tags:
101
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
102
+ torch.distributed.barrier(self.tp_cpu_group)
103
+ _import_static_state(
104
+ self.tp_worker.worker.model_runner.model,
105
+ self.stashed_model_static_state,
106
+ )
107
+ del self.stashed_model_static_state
108
+
109
+ if GPU_MEMORY_TYPE_KV_CACHE in tags:
110
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
111
+
112
+ return ResumeMemoryOccupationReqOutput()
113
+
114
+ def save_remote_model(self, params):
115
+ url = params["url"]
116
+
117
+ worker = self.tp_worker.worker
118
+
119
+ worker.model_runner.save_remote_model(url)
120
+
121
+ def save_sharded_model(self, params):
122
+ worker = self.tp_worker.worker
123
+
124
+ worker.model_runner.save_sharded_model(
125
+ path=params["path"],
126
+ pattern=params["pattern"],
127
+ max_size=params["max_size"],
128
+ )
129
+
130
+
131
+ def _export_static_state(model):
132
+ return dict(
133
+ buffers=[
134
+ (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
135
+ ]
136
+ )
137
+
138
+
139
+ def _import_static_state(model, static_params):
140
+ self_named_buffers = dict(model.named_buffers())
141
+ for name, tensor in static_params["buffers"]:
142
+ self_named_buffers[name][...] = tensor