sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -1,81 +1,729 @@
1
- import json
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
2
14
  import logging
15
+ import os
3
16
  import time
4
- from collections import defaultdict
5
- from typing import Dict, List, Tuple
17
+ from abc import ABC
18
+ from collections import deque
19
+ from contextlib import contextmanager
20
+ from pathlib import Path
21
+ from typing import Dict, List, Literal, Optional, Tuple, Type
6
22
 
23
+ import einops
7
24
  import torch
25
+ import torch.distributed
26
+
27
+ from sglang.srt.managers.expert_location import ExpertLocationMetadata
28
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
29
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
+ from sglang.srt.server_args import ServerArgs
31
+ from sglang.srt.utils import Withable, get_bool_env_var
8
32
 
9
33
  logger = logging.getLogger(__name__)
10
34
 
35
+ # --------------------------------------- Entrypoint -----------------------------------------
36
+
37
+ _OutputMode = Literal["file", "object"]
38
+
39
+
40
+ class ExpertDistributionRecorder(ABC):
41
+ """Global expert distribution recording"""
42
+
43
+ @staticmethod
44
+ def init_new(
45
+ server_args: ServerArgs,
46
+ expert_location_metadata: "ExpertLocationMetadata",
47
+ rank: int,
48
+ ):
49
+ if server_args.expert_distribution_recorder_mode is not None:
50
+ return _ExpertDistributionRecorderReal(
51
+ server_args, expert_location_metadata, rank
52
+ )
53
+ else:
54
+ return _ExpertDistributionRecorderNoop()
55
+
56
+ @contextmanager
57
+ def with_current_layer(self, layer_idx):
58
+ yield
59
+
60
+ @contextmanager
61
+ def with_debug_name(self, debug_name):
62
+ yield
63
+
64
+ @contextmanager
65
+ def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
66
+ yield
67
+
68
+ def on_select_experts(self, topk_ids: torch.Tensor):
69
+ pass
70
+
71
+ def on_deepep_dispatch_normal(
72
+ self,
73
+ local_physical_count_of_layer: List[int],
74
+ num_tokens_per_rank,
75
+ num_tokens_per_rdma_rank,
76
+ num_tokens_per_expert,
77
+ ):
78
+ pass
79
+
80
+ def on_deepep_dispatch_low_latency(
81
+ self, local_physical_count_of_layer: torch.Tensor
82
+ ):
83
+ pass
84
+
85
+ def start_record(self):
86
+ self._on_not_implemented()
87
+
88
+ def stop_record(self):
89
+ self._on_not_implemented()
90
+
91
+ def dump_record(self, output_mode: _OutputMode = "file"):
92
+ self._on_not_implemented()
93
+
94
+ def _on_not_implemented(self):
95
+ raise Exception(
96
+ "Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder."
97
+ )
98
+
99
+
100
+ class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):
101
+ pass
11
102
 
12
- # global expert distribution recording
13
- class ExpertDistributionRecorder:
14
- # This class is a singleton class
15
- def __new__(cls):
16
- if not hasattr(cls, "instance"):
17
- cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
18
- return cls.instance
19
103
 
20
- def __init__(self):
21
- # the length of the dictionary is the number of layers
22
- # the length of the list is the number of tokens
23
- # the length of the tuple is topk's k value
24
- self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
25
- list
104
+ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
105
+ def __init__(
106
+ self,
107
+ server_args: ServerArgs,
108
+ expert_location_metadata: "ExpertLocationMetadata",
109
+ rank: int,
110
+ ):
111
+ self._server_args = server_args
112
+ self._expert_location_metadata = expert_location_metadata
113
+
114
+ self._recording = False
115
+ self._current_forward_pass_id = Withable()
116
+ self._current_layer_idx = Withable()
117
+ self._current_debug_name = Withable()
118
+ self._accumulator = _Accumulator.init_new(
119
+ server_args, expert_location_metadata, rank
26
120
  )
27
- self._record = False
28
- self._current_layer_id = "UNKNOWN"
121
+ self._single_pass_gatherers = {
122
+ k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)
123
+ for k in self._accumulator.get_single_pass_gatherer_keys()
124
+ }
125
+
126
+ def with_current_layer(self, layer_idx):
127
+ return self._current_layer_idx.with_value(layer_idx)
128
+
129
+ def with_debug_name(self, debug_name):
130
+ return self._current_debug_name.with_value(debug_name)
29
131
 
30
- def set_current_layer(self, layer_idx):
31
- self._current_layer_id = layer_idx
132
+ @contextmanager
133
+ def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
134
+ with self._current_forward_pass_id.with_value(forward_pass_id):
135
+ self._on_forward_pass_start(forward_batch)
136
+ try:
137
+ yield
138
+ finally:
139
+ self._on_forward_pass_end(forward_pass_id)
32
140
 
33
- def record_new_token(self, topk_ids):
34
- if not self._record:
141
+ def _on_forward_pass_start(self, forward_batch: ForwardBatch):
142
+ if not self._recording:
35
143
  return
36
- topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
37
- torch.cuda.synchronize()
38
- for i in topk_ids_list:
39
- self._expert_distribution_record[self._current_layer_id].append(tuple(i))
144
+ for gatherer_key, gatherer in self._single_pass_gatherers.items():
145
+ gatherer.reset()
146
+ gatherer.on_forward_pass_start(forward_batch)
40
147
 
41
- def reset(self):
148
+ def _on_forward_pass_end(self, forward_pass_id: int):
149
+ if not self._recording:
150
+ return
151
+ for gatherer_key, gatherer in self._single_pass_gatherers.items():
152
+ single_pass_data = gatherer.collect()
153
+ self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
154
+
155
+ def on_select_experts(self, topk_ids: torch.Tensor):
156
+ self._on_hook("on_select_experts", topk_ids=topk_ids)
157
+
158
+ def on_deepep_dispatch_normal(
159
+ self,
160
+ local_physical_count_of_layer: List[int],
161
+ num_tokens_per_rank,
162
+ num_tokens_per_rdma_rank,
163
+ num_tokens_per_expert,
164
+ ):
165
+ self._on_hook(
166
+ "on_deepep_dispatch_normal",
167
+ local_physical_count_of_layer=local_physical_count_of_layer,
168
+ num_tokens_per_rank=num_tokens_per_rank,
169
+ num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
170
+ num_tokens_per_expert=num_tokens_per_expert,
171
+ )
172
+
173
+ def on_deepep_dispatch_low_latency(
174
+ self, local_physical_count_of_layer: torch.Tensor
175
+ ):
176
+ self._on_hook(
177
+ "on_deepep_dispatch_low_latency",
178
+ local_physical_count_of_layer=local_physical_count_of_layer,
179
+ )
180
+
181
+ def _on_hook(self, hook_name: str, **kwargs):
182
+ if not (self._recording or torch.cuda.is_current_stream_capturing()):
183
+ return
184
+ gatherer = self._single_pass_gatherers[
185
+ self._accumulator.get_single_pass_gatherer_key(
186
+ self._current_debug_name.value
187
+ )
188
+ ]
189
+ getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)
190
+
191
+ def _reset(self):
42
192
  """Reset the expert distribution recorder."""
43
- logger.info("Resetting expert distribution record...")
44
- self._record = False
45
- self._expert_distribution_record.clear()
46
- self._current_layer_id = "UNKNOWN"
193
+ logger.info("Resetting ExpertDistributionRecorder...")
194
+ assert (
195
+ self._current_layer_idx.value is None
196
+ ), f"{self._current_layer_idx.value=}"
197
+ for gatherer in self._single_pass_gatherers.values():
198
+ gatherer.reset()
199
+ self._accumulator.reset()
47
200
 
48
201
  def start_record(self):
49
- """Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
50
- if self._record == True:
202
+ """Start recording the expert distribution."""
203
+ if self._recording:
51
204
  logger.warning(
52
205
  "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
53
206
  )
54
- self.reset()
55
- self._record = True
207
+ self._reset()
208
+ self._recording = True
56
209
 
57
210
  def stop_record(self):
58
- """Stop recording the expert distribution. Set the recording flag to False."""
59
- if self._record == False:
211
+ """Stop recording the expert distribution."""
212
+ if not self._recording:
60
213
  logger.warning(
61
214
  "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
62
215
  )
63
- self._record = False
64
-
65
- def dump_record(self):
66
- """Dump the expert distribution record to a file. Reset the recorder after dumping."""
67
- results = {}
68
- for layer_idx, layer_record in self._expert_distribution_record.items():
69
- results[layer_idx] = defaultdict(int)
70
- for token_record in layer_record:
71
- for expert_idx in token_record:
72
- results[layer_idx][expert_idx] += 1
73
- with open(
74
- f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
75
- "w",
76
- ) as fd:
77
- fd.write("layer_id,expert_id,count\n")
78
- for layer_idx, layer_results in results.items():
79
- for expert_idx, count in layer_results.items():
80
- fd.write(f"{layer_idx},{expert_idx},{count}\n")
81
- self.reset()
216
+ self._recording = False
217
+
218
+ def dump_record(self, output_mode: _OutputMode = "file"):
219
+ """Dump the expert distribution record and reset the recorder after dumping."""
220
+ output = self._accumulator.dump(output_mode=output_mode)
221
+ self._reset()
222
+ return output
223
+
224
+
225
+ _global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (
226
+ _ExpertDistributionRecorderNoop()
227
+ )
228
+
229
+
230
+ def get_global_expert_distribution_recorder():
231
+ return _global_expert_distribution_recorder
232
+
233
+
234
+ def set_global_expert_distribution_recorder(value):
235
+ global _global_expert_distribution_recorder
236
+ _global_expert_distribution_recorder = value
237
+
238
+
239
+ # --------------------------------------- SinglePassGatherer -----------------------------------------
240
+
241
+
242
+ class _SinglePassGatherer(ABC):
243
+ @staticmethod
244
+ def init_new(
245
+ server_args: ServerArgs,
246
+ expert_location_metadata: "ExpertLocationMetadata",
247
+ rank: int,
248
+ ) -> "_SinglePassGatherer":
249
+ if server_args.expert_distribution_recorder_mode == "per_token":
250
+ return _DetailSinglePassGatherer(
251
+ server_args, expert_location_metadata, rank
252
+ )
253
+ if server_args.enable_deepep_moe:
254
+ if server_args.deepep_mode == "normal":
255
+ return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
256
+ elif server_args.deepep_mode == "low_latency":
257
+ return _DeepepLowLatencySinglePassGatherer(
258
+ expert_location_metadata, rank
259
+ )
260
+ else:
261
+ raise NotImplementedError
262
+ return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
263
+
264
+ def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
265
+ self._expert_location_metadata = expert_location_metadata
266
+ self._rank = rank
267
+
268
+ def on_forward_pass_start(self, forward_batch: ForwardBatch):
269
+ pass
270
+
271
+ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
272
+ pass
273
+
274
+ def on_deepep_dispatch_normal(
275
+ self,
276
+ layer_idx: int,
277
+ local_physical_count_of_layer: List[int],
278
+ num_tokens_per_rank,
279
+ num_tokens_per_rdma_rank,
280
+ num_tokens_per_expert,
281
+ ):
282
+ pass
283
+
284
+ def on_deepep_dispatch_low_latency(
285
+ self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
286
+ ):
287
+ pass
288
+
289
+ def reset(self):
290
+ raise NotImplementedError
291
+
292
+ def collect(self) -> Dict:
293
+ raise NotImplementedError
294
+
295
+
296
+ class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
297
+ def __init__(self, *args, **kwargs):
298
+ super().__init__(*args, **kwargs)
299
+ self._objects_of_layer = {}
300
+
301
+ def _on_layer_data(self, layer_idx: int, objects: List[int]):
302
+ assert 0 <= layer_idx < self._expert_location_metadata.num_layers
303
+ if layer_idx in self._objects_of_layer:
304
+ self._objects_of_layer[layer_idx] = _list_sum(
305
+ self._objects_of_layer[layer_idx], objects
306
+ )
307
+ else:
308
+ self._objects_of_layer[layer_idx] = objects
309
+
310
+ def reset(self):
311
+ self._objects_of_layer.clear()
312
+
313
+ def _collect_objects(self, pad_len: int) -> torch.Tensor:
314
+ data = [
315
+ self._objects_of_layer.get(layer_index) or ([0] * pad_len)
316
+ for layer_index in range(self._expert_location_metadata.num_layers)
317
+ ]
318
+ return torch.tensor(data)
319
+
320
+
321
+ def _list_sum(a: List, b: List) -> List:
322
+ return [x + y for x, y in zip(a, b, strict=True)]
323
+
324
+
325
+ class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer):
326
+ # pretty slow, but we will use the DeepEP Gatherer in production
327
+ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
328
+ topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
329
+ torch.cuda.synchronize()
330
+
331
+ global_physical_count = [
332
+ 0
333
+ ] * self._expert_location_metadata.num_physical_experts
334
+ for token_record in topk_ids_list:
335
+ for global_physical_expert_idx in token_record:
336
+ global_physical_count[global_physical_expert_idx] += 1
337
+
338
+ self._on_layer_data(layer_idx, global_physical_count)
339
+
340
+ def collect(self) -> Dict:
341
+ global_physical_count = super()._collect_objects(
342
+ pad_len=self._expert_location_metadata.num_physical_experts
343
+ )
344
+ return dict(global_physical_count=global_physical_count)
345
+
346
+
347
+ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
348
+ def on_deepep_dispatch_normal(
349
+ self,
350
+ layer_idx: int,
351
+ local_physical_count_of_layer: List[int],
352
+ num_tokens_per_rank,
353
+ num_tokens_per_rdma_rank,
354
+ num_tokens_per_expert,
355
+ ):
356
+ assert isinstance(local_physical_count_of_layer, list)
357
+ self._on_layer_data(layer_idx, local_physical_count_of_layer)
358
+
359
+ def collect(self) -> Dict:
360
+ local_physical_count = super()._collect_objects(
361
+ pad_len=self._expert_location_metadata.num_local_physical_experts
362
+ )
363
+ global_physical_count = _convert_local_to_global_physical_count(
364
+ local_physical_count,
365
+ rank=self._rank,
366
+ num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
367
+ num_physical_experts=self._expert_location_metadata.num_physical_experts,
368
+ )
369
+ return dict(global_physical_count=global_physical_count)
370
+
371
+
372
+ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
373
+ def __init__(self, *args, **kwargs):
374
+ super().__init__(*args, **kwargs)
375
+ self._data = torch.zeros(
376
+ (
377
+ self._expert_location_metadata.num_layers,
378
+ self._expert_location_metadata.num_local_physical_experts,
379
+ ),
380
+ dtype=torch.int,
381
+ device="cuda",
382
+ )
383
+
384
+ def on_deepep_dispatch_low_latency(
385
+ self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
386
+ ):
387
+ # Most naive implementation, can optimize later
388
+ self._data[layer_idx, :] += local_physical_count_of_layer
389
+
390
+ def reset(self):
391
+ self._data[...] = 0
392
+
393
+ def collect(self) -> Dict:
394
+ # Can optimize if bottleneck
395
+ global_physical_count = _convert_local_to_global_physical_count(
396
+ self._data,
397
+ rank=self._rank,
398
+ num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
399
+ num_physical_experts=self._expert_location_metadata.num_physical_experts,
400
+ )
401
+ return dict(global_physical_count=global_physical_count)
402
+
403
+
404
+ def _convert_local_to_global_physical_count(
405
+ local_physical_count: torch.Tensor,
406
+ rank: int,
407
+ num_local_physical_experts: int,
408
+ num_physical_experts: int,
409
+ ) -> torch.Tensor:
410
+ dtype = local_physical_count.dtype
411
+ device = local_physical_count.device
412
+ num_layers, _ = local_physical_count.shape
413
+
414
+ ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)
415
+ ans[
416
+ :, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)
417
+ ] = local_physical_count
418
+ return ans
419
+
420
+
421
+ # --------------------------------------- Accumulator -----------------------------------------
422
+
423
+ _SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary"
424
+
425
+
426
+ class _Accumulator(ABC):
427
+ @staticmethod
428
+ def init_new(
429
+ server_args: ServerArgs,
430
+ expert_location_metadata: "ExpertLocationMetadata",
431
+ rank: int,
432
+ ) -> "_Accumulator":
433
+ return _Accumulator.get_class(server_args)(
434
+ server_args, expert_location_metadata, rank
435
+ )
436
+
437
+ @staticmethod
438
+ def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
439
+ return {
440
+ "stat": _StatAccumulator,
441
+ # TODO pr-chain: enable this later
442
+ # "per_pass": _DetailAccumulator,
443
+ # "per_token": _DetailAccumulator,
444
+ }[server_args.expert_distribution_recorder_mode]
445
+
446
+ def __init__(
447
+ self,
448
+ server_args: ServerArgs,
449
+ expert_location_metadata: "ExpertLocationMetadata",
450
+ rank: int,
451
+ ):
452
+ self._server_args = server_args
453
+ self._expert_location_metadata = expert_location_metadata
454
+ self._rank = rank
455
+
456
+ def get_single_pass_gatherer_keys(self):
457
+ return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]
458
+
459
+ def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
460
+ return _SINGLE_PASS_GATHERER_KEY_PRIMARY
461
+
462
+ def append(
463
+ self,
464
+ forward_pass_id: int,
465
+ gatherer_key: str,
466
+ single_pass_data: Dict,
467
+ ):
468
+ pass
469
+
470
+ def reset(self):
471
+ pass
472
+
473
+ def dump(self, output_mode: _OutputMode):
474
+ pass
475
+
476
+
477
+ class _UtilizationRateAccumulatorMixin(_Accumulator):
478
+ def __init__(self, *args, **kwargs):
479
+ super().__init__(*args, **kwargs)
480
+
481
+ self._enable = self._server_args.enable_expert_distribution_metrics
482
+
483
+ if self._enable:
484
+ window_sizes = [10, 100, 1000]
485
+ self._history = _DequeCollection(maxlens=window_sizes)
486
+ self._rank = torch.distributed.get_rank()
487
+
488
+ def append(
489
+ self,
490
+ forward_pass_id: int,
491
+ gatherer_key: str,
492
+ single_pass_data: Dict,
493
+ ):
494
+ super().append(forward_pass_id, gatherer_key, single_pass_data)
495
+ if self._enable:
496
+ self._append_utilization_rate(
497
+ forward_pass_id, single_pass_data["global_physical_count"]
498
+ )
499
+
500
+ def reset(self):
501
+ super().reset()
502
+ if self._enable:
503
+ self._history.clear()
504
+
505
+ def _append_utilization_rate(
506
+ self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
507
+ ):
508
+ gpu_physical_count = compute_gpu_physical_count(
509
+ single_pass_global_physical_count,
510
+ num_gpu=self._expert_location_metadata.ep_size,
511
+ )
512
+ gpu_physical_count = gpu_physical_count.to(self._server_args.device)
513
+ torch.distributed.reduce(
514
+ gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
515
+ )
516
+
517
+ if self._rank == 0:
518
+ utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
519
+ utilization_rate = torch.mean(utilization_rate_tensor).item()
520
+ self._history.append(utilization_rate)
521
+
522
+ gpu_physical_count_sum = gpu_physical_count.sum().item()
523
+
524
+ logger.info(
525
+ f"[Expert Balancedness] "
526
+ f"forward_pass_id={forward_pass_id} "
527
+ f"current_pass_balancedness={utilization_rate:.03f} "
528
+ f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
529
+ f"gpu_physical_count_sum={gpu_physical_count_sum}"
530
+ # f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
531
+ )
532
+
533
+
534
+ class _DequeCollection:
535
+ def __init__(self, maxlens: List[int]):
536
+ self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
537
+
538
+ def append(self, value):
539
+ for d in self._dequeues:
540
+ d.append(value)
541
+
542
+ def clear(self):
543
+ for d in self._dequeues:
544
+ d.clear()
545
+
546
+ def mean(self) -> Dict[int, float]:
547
+ return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
548
+
549
+
550
+ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
551
+ def __init__(self, *args, **kwargs):
552
+ super().__init__(*args, **kwargs)
553
+ self._global_physical_count_of_buffered_step = _Buffer.init_new(
554
+ item_shape=(
555
+ self._expert_location_metadata.num_layers,
556
+ # Cannot use local_physical_count to support select_experts
557
+ self._expert_location_metadata.num_physical_experts,
558
+ ),
559
+ buffer_size=self._server_args.expert_distribution_recorder_buffer_size,
560
+ dtype=torch.int32,
561
+ device=self._server_args.device,
562
+ )
563
+
564
+ def append(
565
+ self,
566
+ forward_pass_id: int,
567
+ gatherer_key: str,
568
+ single_pass_data: Dict,
569
+ ):
570
+ super().append(forward_pass_id, gatherer_key, single_pass_data)
571
+ # Can optimize if overhead here is large
572
+ self._global_physical_count_of_buffered_step.append(
573
+ single_pass_data["global_physical_count"]
574
+ )
575
+
576
+ def reset(self):
577
+ super().reset()
578
+ self._global_physical_count_of_buffered_step.reset()
579
+
580
+ def dump(self, output_mode: _OutputMode):
581
+ logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(
582
+ self._global_physical_count_of_buffered_step.get_all(),
583
+ num_layers=self._expert_location_metadata.num_layers,
584
+ num_logical_experts=self._expert_location_metadata.num_logical_experts,
585
+ physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
586
+ )
587
+ torch.distributed.all_reduce(
588
+ logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
589
+ )
590
+ output = dict(
591
+ rank=self._rank,
592
+ logical_count=logical_count_of_buffered_step,
593
+ )
594
+
595
+ if output_mode == "file":
596
+ if self._rank == 0:
597
+ _dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output)
598
+ elif output_mode == "object":
599
+ return output
600
+ else:
601
+ raise NotImplementedError
602
+
603
+
604
+ def _dump_to_file(name, data):
605
+ save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
606
+ path_output = save_dir / name
607
+ logger.info(f"Write expert distribution to {path_output}")
608
+ if not save_dir.exists():
609
+ save_dir.mkdir(parents=True, exist_ok=True)
610
+ torch.save(data, str(path_output))
611
+
612
+
613
+ class _Buffer:
614
+ @staticmethod
615
+ def init_new(item_shape: Tuple, buffer_size: int, dtype, device):
616
+ if buffer_size < 0:
617
+ return _InfiniteBuffer(item_shape, dtype=dtype, device=device)
618
+ else:
619
+ return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)
620
+
621
+ def append(self, value: torch.Tensor):
622
+ raise NotImplementedError
623
+
624
+ def get_all(self) -> torch.Tensor:
625
+ raise NotImplementedError
626
+
627
+ def reset(self):
628
+ raise NotImplementedError
629
+
630
+
631
+ class _CircularBuffer(_Buffer):
632
+ def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):
633
+ self._buffer = torch.zeros(
634
+ (buffer_size, *item_shape), dtype=dtype, device=device
635
+ )
636
+ self._curr_index = 0
637
+
638
+ def append(self, value: torch.Tensor):
639
+ self._buffer[self._curr_index] = value
640
+ self._curr_index = (self._curr_index + 1) % len(self._buffer)
641
+
642
+ def get_all(self) -> torch.Tensor:
643
+ return self._buffer
644
+
645
+ def reset(self):
646
+ self._buffer[...] = 0
647
+
648
+
649
+ class _InfiniteBuffer(_Buffer):
650
+ def __init__(self, item_shape: Tuple, dtype, device):
651
+ self._item_shape = item_shape
652
+ self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)
653
+ self._size = 0
654
+
655
+ def append(self, value: torch.Tensor):
656
+ curr_buffer_size = len(self._buffer)
657
+ dtype = self._buffer.dtype
658
+ device = self._buffer.device
659
+
660
+ if self._size == curr_buffer_size:
661
+ new_buffer = torch.zeros(
662
+ (2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device
663
+ )
664
+ new_buffer[:curr_buffer_size] = self._buffer
665
+ self._buffer = new_buffer
666
+
667
+ self._buffer[self._size] = value
668
+ self._size += 1
669
+
670
+ def get_all(self) -> torch.Tensor:
671
+ return self._buffer[: self._size]
672
+
673
+ def reset(self):
674
+ self._buffer[...] = 0
675
+ self._size = 0
676
+
677
+
678
+ def _convert_global_physical_count_to_logical_count(
679
+ # (whatever, num_layers, num_physical_experts)
680
+ global_physical_count: torch.Tensor,
681
+ num_layers: int,
682
+ num_logical_experts: int,
683
+ physical_to_logical_map: torch.Tensor,
684
+ ):
685
+ dim_extra, _, _ = global_physical_count.shape
686
+ dtype = global_physical_count.dtype
687
+ device = global_physical_count.device
688
+ logical_count = torch.zeros(
689
+ (dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device
690
+ )
691
+ logical_count.scatter_add_(
692
+ dim=2,
693
+ index=physical_to_logical_map.unsqueeze(0)
694
+ .expand(dim_extra, -1, -1)
695
+ .to(torch.int64),
696
+ src=global_physical_count,
697
+ )
698
+ return logical_count
699
+
700
+
701
+ def compute_gpu_physical_count(
702
+ physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
703
+ num_gpu: int,
704
+ ):
705
+ """output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
706
+ return einops.reduce(
707
+ physical_count_of_whatever,
708
+ "... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
709
+ "sum",
710
+ num_gpu=num_gpu,
711
+ )
712
+
713
+
714
+ def compute_utilization_rate(
715
+ gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
716
+ ):
717
+ """output: utilization_rate (..., num_layer)"""
718
+ gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
719
+ max_gpu_physical_count = einops.reduce(
720
+ gpu_physical_count_of_batch,
721
+ "... num_layer num_gpu -> ... num_layer",
722
+ "max",
723
+ )
724
+ avg_gpu_physical_count = einops.reduce(
725
+ gpu_physical_count_of_batch,
726
+ "... num_layer num_gpu -> ... num_layer",
727
+ "mean",
728
+ )
729
+ return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)