sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -1,81 +1,729 @@
|
|
1
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
2
14
|
import logging
|
15
|
+
import os
|
3
16
|
import time
|
4
|
-
from
|
5
|
-
from
|
17
|
+
from abc import ABC
|
18
|
+
from collections import deque
|
19
|
+
from contextlib import contextmanager
|
20
|
+
from pathlib import Path
|
21
|
+
from typing import Dict, List, Literal, Optional, Tuple, Type
|
6
22
|
|
23
|
+
import einops
|
7
24
|
import torch
|
25
|
+
import torch.distributed
|
26
|
+
|
27
|
+
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
28
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
|
+
from sglang.srt.server_args import ServerArgs
|
31
|
+
from sglang.srt.utils import Withable, get_bool_env_var
|
8
32
|
|
9
33
|
logger = logging.getLogger(__name__)
|
10
34
|
|
35
|
+
# --------------------------------------- Entrypoint -----------------------------------------
|
36
|
+
|
37
|
+
_OutputMode = Literal["file", "object"]
|
38
|
+
|
39
|
+
|
40
|
+
class ExpertDistributionRecorder(ABC):
|
41
|
+
"""Global expert distribution recording"""
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def init_new(
|
45
|
+
server_args: ServerArgs,
|
46
|
+
expert_location_metadata: "ExpertLocationMetadata",
|
47
|
+
rank: int,
|
48
|
+
):
|
49
|
+
if server_args.expert_distribution_recorder_mode is not None:
|
50
|
+
return _ExpertDistributionRecorderReal(
|
51
|
+
server_args, expert_location_metadata, rank
|
52
|
+
)
|
53
|
+
else:
|
54
|
+
return _ExpertDistributionRecorderNoop()
|
55
|
+
|
56
|
+
@contextmanager
|
57
|
+
def with_current_layer(self, layer_idx):
|
58
|
+
yield
|
59
|
+
|
60
|
+
@contextmanager
|
61
|
+
def with_debug_name(self, debug_name):
|
62
|
+
yield
|
63
|
+
|
64
|
+
@contextmanager
|
65
|
+
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
66
|
+
yield
|
67
|
+
|
68
|
+
def on_select_experts(self, topk_ids: torch.Tensor):
|
69
|
+
pass
|
70
|
+
|
71
|
+
def on_deepep_dispatch_normal(
|
72
|
+
self,
|
73
|
+
local_physical_count_of_layer: List[int],
|
74
|
+
num_tokens_per_rank,
|
75
|
+
num_tokens_per_rdma_rank,
|
76
|
+
num_tokens_per_expert,
|
77
|
+
):
|
78
|
+
pass
|
79
|
+
|
80
|
+
def on_deepep_dispatch_low_latency(
|
81
|
+
self, local_physical_count_of_layer: torch.Tensor
|
82
|
+
):
|
83
|
+
pass
|
84
|
+
|
85
|
+
def start_record(self):
|
86
|
+
self._on_not_implemented()
|
87
|
+
|
88
|
+
def stop_record(self):
|
89
|
+
self._on_not_implemented()
|
90
|
+
|
91
|
+
def dump_record(self, output_mode: _OutputMode = "file"):
|
92
|
+
self._on_not_implemented()
|
93
|
+
|
94
|
+
def _on_not_implemented(self):
|
95
|
+
raise Exception(
|
96
|
+
"Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder."
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):
|
101
|
+
pass
|
11
102
|
|
12
|
-
# global expert distribution recording
|
13
|
-
class ExpertDistributionRecorder:
|
14
|
-
# This class is a singleton class
|
15
|
-
def __new__(cls):
|
16
|
-
if not hasattr(cls, "instance"):
|
17
|
-
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
18
|
-
return cls.instance
|
19
103
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
104
|
+
class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
105
|
+
def __init__(
|
106
|
+
self,
|
107
|
+
server_args: ServerArgs,
|
108
|
+
expert_location_metadata: "ExpertLocationMetadata",
|
109
|
+
rank: int,
|
110
|
+
):
|
111
|
+
self._server_args = server_args
|
112
|
+
self._expert_location_metadata = expert_location_metadata
|
113
|
+
|
114
|
+
self._recording = False
|
115
|
+
self._current_forward_pass_id = Withable()
|
116
|
+
self._current_layer_idx = Withable()
|
117
|
+
self._current_debug_name = Withable()
|
118
|
+
self._accumulator = _Accumulator.init_new(
|
119
|
+
server_args, expert_location_metadata, rank
|
26
120
|
)
|
27
|
-
self.
|
28
|
-
|
121
|
+
self._single_pass_gatherers = {
|
122
|
+
k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)
|
123
|
+
for k in self._accumulator.get_single_pass_gatherer_keys()
|
124
|
+
}
|
125
|
+
|
126
|
+
def with_current_layer(self, layer_idx):
|
127
|
+
return self._current_layer_idx.with_value(layer_idx)
|
128
|
+
|
129
|
+
def with_debug_name(self, debug_name):
|
130
|
+
return self._current_debug_name.with_value(debug_name)
|
29
131
|
|
30
|
-
|
31
|
-
|
132
|
+
@contextmanager
|
133
|
+
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
134
|
+
with self._current_forward_pass_id.with_value(forward_pass_id):
|
135
|
+
self._on_forward_pass_start(forward_batch)
|
136
|
+
try:
|
137
|
+
yield
|
138
|
+
finally:
|
139
|
+
self._on_forward_pass_end(forward_pass_id)
|
32
140
|
|
33
|
-
def
|
34
|
-
if not self.
|
141
|
+
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
|
142
|
+
if not self._recording:
|
35
143
|
return
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
144
|
+
for gatherer_key, gatherer in self._single_pass_gatherers.items():
|
145
|
+
gatherer.reset()
|
146
|
+
gatherer.on_forward_pass_start(forward_batch)
|
40
147
|
|
41
|
-
def
|
148
|
+
def _on_forward_pass_end(self, forward_pass_id: int):
|
149
|
+
if not self._recording:
|
150
|
+
return
|
151
|
+
for gatherer_key, gatherer in self._single_pass_gatherers.items():
|
152
|
+
single_pass_data = gatherer.collect()
|
153
|
+
self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
|
154
|
+
|
155
|
+
def on_select_experts(self, topk_ids: torch.Tensor):
|
156
|
+
self._on_hook("on_select_experts", topk_ids=topk_ids)
|
157
|
+
|
158
|
+
def on_deepep_dispatch_normal(
|
159
|
+
self,
|
160
|
+
local_physical_count_of_layer: List[int],
|
161
|
+
num_tokens_per_rank,
|
162
|
+
num_tokens_per_rdma_rank,
|
163
|
+
num_tokens_per_expert,
|
164
|
+
):
|
165
|
+
self._on_hook(
|
166
|
+
"on_deepep_dispatch_normal",
|
167
|
+
local_physical_count_of_layer=local_physical_count_of_layer,
|
168
|
+
num_tokens_per_rank=num_tokens_per_rank,
|
169
|
+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
170
|
+
num_tokens_per_expert=num_tokens_per_expert,
|
171
|
+
)
|
172
|
+
|
173
|
+
def on_deepep_dispatch_low_latency(
|
174
|
+
self, local_physical_count_of_layer: torch.Tensor
|
175
|
+
):
|
176
|
+
self._on_hook(
|
177
|
+
"on_deepep_dispatch_low_latency",
|
178
|
+
local_physical_count_of_layer=local_physical_count_of_layer,
|
179
|
+
)
|
180
|
+
|
181
|
+
def _on_hook(self, hook_name: str, **kwargs):
|
182
|
+
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
183
|
+
return
|
184
|
+
gatherer = self._single_pass_gatherers[
|
185
|
+
self._accumulator.get_single_pass_gatherer_key(
|
186
|
+
self._current_debug_name.value
|
187
|
+
)
|
188
|
+
]
|
189
|
+
getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)
|
190
|
+
|
191
|
+
def _reset(self):
|
42
192
|
"""Reset the expert distribution recorder."""
|
43
|
-
logger.info("Resetting
|
44
|
-
|
45
|
-
|
46
|
-
self.
|
193
|
+
logger.info("Resetting ExpertDistributionRecorder...")
|
194
|
+
assert (
|
195
|
+
self._current_layer_idx.value is None
|
196
|
+
), f"{self._current_layer_idx.value=}"
|
197
|
+
for gatherer in self._single_pass_gatherers.values():
|
198
|
+
gatherer.reset()
|
199
|
+
self._accumulator.reset()
|
47
200
|
|
48
201
|
def start_record(self):
|
49
|
-
"""Start recording the expert distribution.
|
50
|
-
if self.
|
202
|
+
"""Start recording the expert distribution."""
|
203
|
+
if self._recording:
|
51
204
|
logger.warning(
|
52
205
|
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
53
206
|
)
|
54
|
-
self.
|
55
|
-
self.
|
207
|
+
self._reset()
|
208
|
+
self._recording = True
|
56
209
|
|
57
210
|
def stop_record(self):
|
58
|
-
"""Stop recording the expert distribution.
|
59
|
-
if self.
|
211
|
+
"""Stop recording the expert distribution."""
|
212
|
+
if not self._recording:
|
60
213
|
logger.warning(
|
61
214
|
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
62
215
|
)
|
63
|
-
self.
|
64
|
-
|
65
|
-
def dump_record(self):
|
66
|
-
"""Dump the expert distribution record
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
216
|
+
self._recording = False
|
217
|
+
|
218
|
+
def dump_record(self, output_mode: _OutputMode = "file"):
|
219
|
+
"""Dump the expert distribution record and reset the recorder after dumping."""
|
220
|
+
output = self._accumulator.dump(output_mode=output_mode)
|
221
|
+
self._reset()
|
222
|
+
return output
|
223
|
+
|
224
|
+
|
225
|
+
_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (
|
226
|
+
_ExpertDistributionRecorderNoop()
|
227
|
+
)
|
228
|
+
|
229
|
+
|
230
|
+
def get_global_expert_distribution_recorder():
|
231
|
+
return _global_expert_distribution_recorder
|
232
|
+
|
233
|
+
|
234
|
+
def set_global_expert_distribution_recorder(value):
|
235
|
+
global _global_expert_distribution_recorder
|
236
|
+
_global_expert_distribution_recorder = value
|
237
|
+
|
238
|
+
|
239
|
+
# --------------------------------------- SinglePassGatherer -----------------------------------------
|
240
|
+
|
241
|
+
|
242
|
+
class _SinglePassGatherer(ABC):
|
243
|
+
@staticmethod
|
244
|
+
def init_new(
|
245
|
+
server_args: ServerArgs,
|
246
|
+
expert_location_metadata: "ExpertLocationMetadata",
|
247
|
+
rank: int,
|
248
|
+
) -> "_SinglePassGatherer":
|
249
|
+
if server_args.expert_distribution_recorder_mode == "per_token":
|
250
|
+
return _DetailSinglePassGatherer(
|
251
|
+
server_args, expert_location_metadata, rank
|
252
|
+
)
|
253
|
+
if server_args.enable_deepep_moe:
|
254
|
+
if server_args.deepep_mode == "normal":
|
255
|
+
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
256
|
+
elif server_args.deepep_mode == "low_latency":
|
257
|
+
return _DeepepLowLatencySinglePassGatherer(
|
258
|
+
expert_location_metadata, rank
|
259
|
+
)
|
260
|
+
else:
|
261
|
+
raise NotImplementedError
|
262
|
+
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
263
|
+
|
264
|
+
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
|
265
|
+
self._expert_location_metadata = expert_location_metadata
|
266
|
+
self._rank = rank
|
267
|
+
|
268
|
+
def on_forward_pass_start(self, forward_batch: ForwardBatch):
|
269
|
+
pass
|
270
|
+
|
271
|
+
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
272
|
+
pass
|
273
|
+
|
274
|
+
def on_deepep_dispatch_normal(
|
275
|
+
self,
|
276
|
+
layer_idx: int,
|
277
|
+
local_physical_count_of_layer: List[int],
|
278
|
+
num_tokens_per_rank,
|
279
|
+
num_tokens_per_rdma_rank,
|
280
|
+
num_tokens_per_expert,
|
281
|
+
):
|
282
|
+
pass
|
283
|
+
|
284
|
+
def on_deepep_dispatch_low_latency(
|
285
|
+
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
|
286
|
+
):
|
287
|
+
pass
|
288
|
+
|
289
|
+
def reset(self):
|
290
|
+
raise NotImplementedError
|
291
|
+
|
292
|
+
def collect(self) -> Dict:
|
293
|
+
raise NotImplementedError
|
294
|
+
|
295
|
+
|
296
|
+
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
|
297
|
+
def __init__(self, *args, **kwargs):
|
298
|
+
super().__init__(*args, **kwargs)
|
299
|
+
self._objects_of_layer = {}
|
300
|
+
|
301
|
+
def _on_layer_data(self, layer_idx: int, objects: List[int]):
|
302
|
+
assert 0 <= layer_idx < self._expert_location_metadata.num_layers
|
303
|
+
if layer_idx in self._objects_of_layer:
|
304
|
+
self._objects_of_layer[layer_idx] = _list_sum(
|
305
|
+
self._objects_of_layer[layer_idx], objects
|
306
|
+
)
|
307
|
+
else:
|
308
|
+
self._objects_of_layer[layer_idx] = objects
|
309
|
+
|
310
|
+
def reset(self):
|
311
|
+
self._objects_of_layer.clear()
|
312
|
+
|
313
|
+
def _collect_objects(self, pad_len: int) -> torch.Tensor:
|
314
|
+
data = [
|
315
|
+
self._objects_of_layer.get(layer_index) or ([0] * pad_len)
|
316
|
+
for layer_index in range(self._expert_location_metadata.num_layers)
|
317
|
+
]
|
318
|
+
return torch.tensor(data)
|
319
|
+
|
320
|
+
|
321
|
+
def _list_sum(a: List, b: List) -> List:
|
322
|
+
return [x + y for x, y in zip(a, b, strict=True)]
|
323
|
+
|
324
|
+
|
325
|
+
class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
326
|
+
# pretty slow, but we will use the DeepEP Gatherer in production
|
327
|
+
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
328
|
+
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
329
|
+
torch.cuda.synchronize()
|
330
|
+
|
331
|
+
global_physical_count = [
|
332
|
+
0
|
333
|
+
] * self._expert_location_metadata.num_physical_experts
|
334
|
+
for token_record in topk_ids_list:
|
335
|
+
for global_physical_expert_idx in token_record:
|
336
|
+
global_physical_count[global_physical_expert_idx] += 1
|
337
|
+
|
338
|
+
self._on_layer_data(layer_idx, global_physical_count)
|
339
|
+
|
340
|
+
def collect(self) -> Dict:
|
341
|
+
global_physical_count = super()._collect_objects(
|
342
|
+
pad_len=self._expert_location_metadata.num_physical_experts
|
343
|
+
)
|
344
|
+
return dict(global_physical_count=global_physical_count)
|
345
|
+
|
346
|
+
|
347
|
+
class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
348
|
+
def on_deepep_dispatch_normal(
|
349
|
+
self,
|
350
|
+
layer_idx: int,
|
351
|
+
local_physical_count_of_layer: List[int],
|
352
|
+
num_tokens_per_rank,
|
353
|
+
num_tokens_per_rdma_rank,
|
354
|
+
num_tokens_per_expert,
|
355
|
+
):
|
356
|
+
assert isinstance(local_physical_count_of_layer, list)
|
357
|
+
self._on_layer_data(layer_idx, local_physical_count_of_layer)
|
358
|
+
|
359
|
+
def collect(self) -> Dict:
|
360
|
+
local_physical_count = super()._collect_objects(
|
361
|
+
pad_len=self._expert_location_metadata.num_local_physical_experts
|
362
|
+
)
|
363
|
+
global_physical_count = _convert_local_to_global_physical_count(
|
364
|
+
local_physical_count,
|
365
|
+
rank=self._rank,
|
366
|
+
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
367
|
+
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
368
|
+
)
|
369
|
+
return dict(global_physical_count=global_physical_count)
|
370
|
+
|
371
|
+
|
372
|
+
class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
|
373
|
+
def __init__(self, *args, **kwargs):
|
374
|
+
super().__init__(*args, **kwargs)
|
375
|
+
self._data = torch.zeros(
|
376
|
+
(
|
377
|
+
self._expert_location_metadata.num_layers,
|
378
|
+
self._expert_location_metadata.num_local_physical_experts,
|
379
|
+
),
|
380
|
+
dtype=torch.int,
|
381
|
+
device="cuda",
|
382
|
+
)
|
383
|
+
|
384
|
+
def on_deepep_dispatch_low_latency(
|
385
|
+
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
|
386
|
+
):
|
387
|
+
# Most naive implementation, can optimize later
|
388
|
+
self._data[layer_idx, :] += local_physical_count_of_layer
|
389
|
+
|
390
|
+
def reset(self):
|
391
|
+
self._data[...] = 0
|
392
|
+
|
393
|
+
def collect(self) -> Dict:
|
394
|
+
# Can optimize if bottleneck
|
395
|
+
global_physical_count = _convert_local_to_global_physical_count(
|
396
|
+
self._data,
|
397
|
+
rank=self._rank,
|
398
|
+
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
399
|
+
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
400
|
+
)
|
401
|
+
return dict(global_physical_count=global_physical_count)
|
402
|
+
|
403
|
+
|
404
|
+
def _convert_local_to_global_physical_count(
|
405
|
+
local_physical_count: torch.Tensor,
|
406
|
+
rank: int,
|
407
|
+
num_local_physical_experts: int,
|
408
|
+
num_physical_experts: int,
|
409
|
+
) -> torch.Tensor:
|
410
|
+
dtype = local_physical_count.dtype
|
411
|
+
device = local_physical_count.device
|
412
|
+
num_layers, _ = local_physical_count.shape
|
413
|
+
|
414
|
+
ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)
|
415
|
+
ans[
|
416
|
+
:, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)
|
417
|
+
] = local_physical_count
|
418
|
+
return ans
|
419
|
+
|
420
|
+
|
421
|
+
# --------------------------------------- Accumulator -----------------------------------------
|
422
|
+
|
423
|
+
_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary"
|
424
|
+
|
425
|
+
|
426
|
+
class _Accumulator(ABC):
|
427
|
+
@staticmethod
|
428
|
+
def init_new(
|
429
|
+
server_args: ServerArgs,
|
430
|
+
expert_location_metadata: "ExpertLocationMetadata",
|
431
|
+
rank: int,
|
432
|
+
) -> "_Accumulator":
|
433
|
+
return _Accumulator.get_class(server_args)(
|
434
|
+
server_args, expert_location_metadata, rank
|
435
|
+
)
|
436
|
+
|
437
|
+
@staticmethod
|
438
|
+
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
|
439
|
+
return {
|
440
|
+
"stat": _StatAccumulator,
|
441
|
+
# TODO pr-chain: enable this later
|
442
|
+
# "per_pass": _DetailAccumulator,
|
443
|
+
# "per_token": _DetailAccumulator,
|
444
|
+
}[server_args.expert_distribution_recorder_mode]
|
445
|
+
|
446
|
+
def __init__(
|
447
|
+
self,
|
448
|
+
server_args: ServerArgs,
|
449
|
+
expert_location_metadata: "ExpertLocationMetadata",
|
450
|
+
rank: int,
|
451
|
+
):
|
452
|
+
self._server_args = server_args
|
453
|
+
self._expert_location_metadata = expert_location_metadata
|
454
|
+
self._rank = rank
|
455
|
+
|
456
|
+
def get_single_pass_gatherer_keys(self):
|
457
|
+
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]
|
458
|
+
|
459
|
+
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
|
460
|
+
return _SINGLE_PASS_GATHERER_KEY_PRIMARY
|
461
|
+
|
462
|
+
def append(
|
463
|
+
self,
|
464
|
+
forward_pass_id: int,
|
465
|
+
gatherer_key: str,
|
466
|
+
single_pass_data: Dict,
|
467
|
+
):
|
468
|
+
pass
|
469
|
+
|
470
|
+
def reset(self):
|
471
|
+
pass
|
472
|
+
|
473
|
+
def dump(self, output_mode: _OutputMode):
|
474
|
+
pass
|
475
|
+
|
476
|
+
|
477
|
+
class _UtilizationRateAccumulatorMixin(_Accumulator):
|
478
|
+
def __init__(self, *args, **kwargs):
|
479
|
+
super().__init__(*args, **kwargs)
|
480
|
+
|
481
|
+
self._enable = self._server_args.enable_expert_distribution_metrics
|
482
|
+
|
483
|
+
if self._enable:
|
484
|
+
window_sizes = [10, 100, 1000]
|
485
|
+
self._history = _DequeCollection(maxlens=window_sizes)
|
486
|
+
self._rank = torch.distributed.get_rank()
|
487
|
+
|
488
|
+
def append(
|
489
|
+
self,
|
490
|
+
forward_pass_id: int,
|
491
|
+
gatherer_key: str,
|
492
|
+
single_pass_data: Dict,
|
493
|
+
):
|
494
|
+
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
495
|
+
if self._enable:
|
496
|
+
self._append_utilization_rate(
|
497
|
+
forward_pass_id, single_pass_data["global_physical_count"]
|
498
|
+
)
|
499
|
+
|
500
|
+
def reset(self):
|
501
|
+
super().reset()
|
502
|
+
if self._enable:
|
503
|
+
self._history.clear()
|
504
|
+
|
505
|
+
def _append_utilization_rate(
|
506
|
+
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
|
507
|
+
):
|
508
|
+
gpu_physical_count = compute_gpu_physical_count(
|
509
|
+
single_pass_global_physical_count,
|
510
|
+
num_gpu=self._expert_location_metadata.ep_size,
|
511
|
+
)
|
512
|
+
gpu_physical_count = gpu_physical_count.to(self._server_args.device)
|
513
|
+
torch.distributed.reduce(
|
514
|
+
gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
|
515
|
+
)
|
516
|
+
|
517
|
+
if self._rank == 0:
|
518
|
+
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
|
519
|
+
utilization_rate = torch.mean(utilization_rate_tensor).item()
|
520
|
+
self._history.append(utilization_rate)
|
521
|
+
|
522
|
+
gpu_physical_count_sum = gpu_physical_count.sum().item()
|
523
|
+
|
524
|
+
logger.info(
|
525
|
+
f"[Expert Balancedness] "
|
526
|
+
f"forward_pass_id={forward_pass_id} "
|
527
|
+
f"current_pass_balancedness={utilization_rate:.03f} "
|
528
|
+
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
|
529
|
+
f"gpu_physical_count_sum={gpu_physical_count_sum}"
|
530
|
+
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
|
531
|
+
)
|
532
|
+
|
533
|
+
|
534
|
+
class _DequeCollection:
|
535
|
+
def __init__(self, maxlens: List[int]):
|
536
|
+
self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
|
537
|
+
|
538
|
+
def append(self, value):
|
539
|
+
for d in self._dequeues:
|
540
|
+
d.append(value)
|
541
|
+
|
542
|
+
def clear(self):
|
543
|
+
for d in self._dequeues:
|
544
|
+
d.clear()
|
545
|
+
|
546
|
+
def mean(self) -> Dict[int, float]:
|
547
|
+
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
|
548
|
+
|
549
|
+
|
550
|
+
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
551
|
+
def __init__(self, *args, **kwargs):
|
552
|
+
super().__init__(*args, **kwargs)
|
553
|
+
self._global_physical_count_of_buffered_step = _Buffer.init_new(
|
554
|
+
item_shape=(
|
555
|
+
self._expert_location_metadata.num_layers,
|
556
|
+
# Cannot use local_physical_count to support select_experts
|
557
|
+
self._expert_location_metadata.num_physical_experts,
|
558
|
+
),
|
559
|
+
buffer_size=self._server_args.expert_distribution_recorder_buffer_size,
|
560
|
+
dtype=torch.int32,
|
561
|
+
device=self._server_args.device,
|
562
|
+
)
|
563
|
+
|
564
|
+
def append(
|
565
|
+
self,
|
566
|
+
forward_pass_id: int,
|
567
|
+
gatherer_key: str,
|
568
|
+
single_pass_data: Dict,
|
569
|
+
):
|
570
|
+
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
571
|
+
# Can optimize if overhead here is large
|
572
|
+
self._global_physical_count_of_buffered_step.append(
|
573
|
+
single_pass_data["global_physical_count"]
|
574
|
+
)
|
575
|
+
|
576
|
+
def reset(self):
|
577
|
+
super().reset()
|
578
|
+
self._global_physical_count_of_buffered_step.reset()
|
579
|
+
|
580
|
+
def dump(self, output_mode: _OutputMode):
|
581
|
+
logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(
|
582
|
+
self._global_physical_count_of_buffered_step.get_all(),
|
583
|
+
num_layers=self._expert_location_metadata.num_layers,
|
584
|
+
num_logical_experts=self._expert_location_metadata.num_logical_experts,
|
585
|
+
physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
|
586
|
+
)
|
587
|
+
torch.distributed.all_reduce(
|
588
|
+
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
|
589
|
+
)
|
590
|
+
output = dict(
|
591
|
+
rank=self._rank,
|
592
|
+
logical_count=logical_count_of_buffered_step,
|
593
|
+
)
|
594
|
+
|
595
|
+
if output_mode == "file":
|
596
|
+
if self._rank == 0:
|
597
|
+
_dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output)
|
598
|
+
elif output_mode == "object":
|
599
|
+
return output
|
600
|
+
else:
|
601
|
+
raise NotImplementedError
|
602
|
+
|
603
|
+
|
604
|
+
def _dump_to_file(name, data):
|
605
|
+
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
|
606
|
+
path_output = save_dir / name
|
607
|
+
logger.info(f"Write expert distribution to {path_output}")
|
608
|
+
if not save_dir.exists():
|
609
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
610
|
+
torch.save(data, str(path_output))
|
611
|
+
|
612
|
+
|
613
|
+
class _Buffer:
|
614
|
+
@staticmethod
|
615
|
+
def init_new(item_shape: Tuple, buffer_size: int, dtype, device):
|
616
|
+
if buffer_size < 0:
|
617
|
+
return _InfiniteBuffer(item_shape, dtype=dtype, device=device)
|
618
|
+
else:
|
619
|
+
return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)
|
620
|
+
|
621
|
+
def append(self, value: torch.Tensor):
|
622
|
+
raise NotImplementedError
|
623
|
+
|
624
|
+
def get_all(self) -> torch.Tensor:
|
625
|
+
raise NotImplementedError
|
626
|
+
|
627
|
+
def reset(self):
|
628
|
+
raise NotImplementedError
|
629
|
+
|
630
|
+
|
631
|
+
class _CircularBuffer(_Buffer):
|
632
|
+
def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):
|
633
|
+
self._buffer = torch.zeros(
|
634
|
+
(buffer_size, *item_shape), dtype=dtype, device=device
|
635
|
+
)
|
636
|
+
self._curr_index = 0
|
637
|
+
|
638
|
+
def append(self, value: torch.Tensor):
|
639
|
+
self._buffer[self._curr_index] = value
|
640
|
+
self._curr_index = (self._curr_index + 1) % len(self._buffer)
|
641
|
+
|
642
|
+
def get_all(self) -> torch.Tensor:
|
643
|
+
return self._buffer
|
644
|
+
|
645
|
+
def reset(self):
|
646
|
+
self._buffer[...] = 0
|
647
|
+
|
648
|
+
|
649
|
+
class _InfiniteBuffer(_Buffer):
|
650
|
+
def __init__(self, item_shape: Tuple, dtype, device):
|
651
|
+
self._item_shape = item_shape
|
652
|
+
self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)
|
653
|
+
self._size = 0
|
654
|
+
|
655
|
+
def append(self, value: torch.Tensor):
|
656
|
+
curr_buffer_size = len(self._buffer)
|
657
|
+
dtype = self._buffer.dtype
|
658
|
+
device = self._buffer.device
|
659
|
+
|
660
|
+
if self._size == curr_buffer_size:
|
661
|
+
new_buffer = torch.zeros(
|
662
|
+
(2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device
|
663
|
+
)
|
664
|
+
new_buffer[:curr_buffer_size] = self._buffer
|
665
|
+
self._buffer = new_buffer
|
666
|
+
|
667
|
+
self._buffer[self._size] = value
|
668
|
+
self._size += 1
|
669
|
+
|
670
|
+
def get_all(self) -> torch.Tensor:
|
671
|
+
return self._buffer[: self._size]
|
672
|
+
|
673
|
+
def reset(self):
|
674
|
+
self._buffer[...] = 0
|
675
|
+
self._size = 0
|
676
|
+
|
677
|
+
|
678
|
+
def _convert_global_physical_count_to_logical_count(
|
679
|
+
# (whatever, num_layers, num_physical_experts)
|
680
|
+
global_physical_count: torch.Tensor,
|
681
|
+
num_layers: int,
|
682
|
+
num_logical_experts: int,
|
683
|
+
physical_to_logical_map: torch.Tensor,
|
684
|
+
):
|
685
|
+
dim_extra, _, _ = global_physical_count.shape
|
686
|
+
dtype = global_physical_count.dtype
|
687
|
+
device = global_physical_count.device
|
688
|
+
logical_count = torch.zeros(
|
689
|
+
(dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device
|
690
|
+
)
|
691
|
+
logical_count.scatter_add_(
|
692
|
+
dim=2,
|
693
|
+
index=physical_to_logical_map.unsqueeze(0)
|
694
|
+
.expand(dim_extra, -1, -1)
|
695
|
+
.to(torch.int64),
|
696
|
+
src=global_physical_count,
|
697
|
+
)
|
698
|
+
return logical_count
|
699
|
+
|
700
|
+
|
701
|
+
def compute_gpu_physical_count(
|
702
|
+
physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
|
703
|
+
num_gpu: int,
|
704
|
+
):
|
705
|
+
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
|
706
|
+
return einops.reduce(
|
707
|
+
physical_count_of_whatever,
|
708
|
+
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
|
709
|
+
"sum",
|
710
|
+
num_gpu=num_gpu,
|
711
|
+
)
|
712
|
+
|
713
|
+
|
714
|
+
def compute_utilization_rate(
|
715
|
+
gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
|
716
|
+
):
|
717
|
+
"""output: utilization_rate (..., num_layer)"""
|
718
|
+
gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
|
719
|
+
max_gpu_physical_count = einops.reduce(
|
720
|
+
gpu_physical_count_of_batch,
|
721
|
+
"... num_layer num_gpu -> ... num_layer",
|
722
|
+
"max",
|
723
|
+
)
|
724
|
+
avg_gpu_physical_count = einops.reduce(
|
725
|
+
gpu_physical_count_of_batch,
|
726
|
+
"... num_layer num_gpu -> ... num_layer",
|
727
|
+
"mean",
|
728
|
+
)
|
729
|
+
return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)
|