sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
1
+ # This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
2
+
3
+ from typing import Literal, Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def pack_groups(tokens_per_group: torch.Tensor, num_nodes: int) -> torch.Tensor:
9
+ num_layers, num_groups = tokens_per_group.shape
10
+ assert num_groups % num_nodes == 0
11
+ groups_per_rank = num_groups // num_nodes
12
+
13
+ indices = tokens_per_group.float().sort(-1, descending=True).indices.cpu()
14
+ ret = torch.full_like(
15
+ tokens_per_group, fill_value=-1, dtype=torch.int64, device="cpu"
16
+ )
17
+ for layer in range(num_layers):
18
+ node_tokens = [0] * num_nodes
19
+ node_groups = [0] * num_nodes
20
+ for group in indices[layer]:
21
+
22
+ def key_func(rank: int) -> int:
23
+ if node_groups[rank] >= groups_per_rank:
24
+ return 1, 0
25
+ else:
26
+ return 0, node_tokens[rank]
27
+
28
+ rank = min(range(num_nodes), key=key_func)
29
+ assert node_groups[rank] < groups_per_rank
30
+ ret[layer, group] = rank * groups_per_rank + node_groups[rank]
31
+ node_tokens[rank] += tokens_per_group[layer, group]
32
+ node_groups[rank] += 1
33
+ return ret
34
+
35
+
36
+ def make_redundant_experts_chunkwise(
37
+ tokens_per_expert: torch.Tensor,
38
+ num_physical_experts: int,
39
+ num_local_physical_experts: int,
40
+ num_physical_experts_per_chunk: int,
41
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
42
+ num_steps, num_moe_layers, num_logical_experts = tokens_per_expert.shape
43
+ num_redundancy_experts = num_physical_experts - num_logical_experts
44
+
45
+ physical_to_logical_map = torch.empty(
46
+ num_moe_layers,
47
+ num_physical_experts,
48
+ dtype=torch.int,
49
+ device=tokens_per_expert.device,
50
+ )
51
+ logical_to_physical_map = torch.full(
52
+ (num_moe_layers, num_logical_experts, num_redundancy_experts + 1),
53
+ -1,
54
+ dtype=torch.int,
55
+ device=tokens_per_expert.device,
56
+ )
57
+ logical_count = torch.ones(
58
+ num_moe_layers,
59
+ num_logical_experts,
60
+ dtype=torch.int,
61
+ device=tokens_per_expert.device,
62
+ )
63
+
64
+ assert num_physical_experts % num_physical_experts_per_chunk == 0
65
+ num_chunks = num_physical_experts // num_physical_experts_per_chunk
66
+ assert num_logical_experts % num_chunks == 0
67
+ num_logical_experts_per_group = num_logical_experts // num_chunks
68
+ assert num_redundancy_experts % num_chunks == 0
69
+ num_redundancy_experts_per_group = num_redundancy_experts // num_chunks
70
+
71
+ arange_num_moe_layers_num_groups = torch.arange(
72
+ num_moe_layers * num_chunks, dtype=torch.int, device=tokens_per_expert.device
73
+ )
74
+ arange_num_logical_experts = torch.arange(
75
+ num_logical_experts, dtype=torch.int, device=tokens_per_expert.device
76
+ )
77
+ arange_num_logical_experts_per_group = torch.arange(
78
+ num_logical_experts_per_group, dtype=torch.int, device=tokens_per_expert.device
79
+ )
80
+ arange_num_groups = torch.arange(
81
+ num_chunks, dtype=torch.int, device=tokens_per_expert.device
82
+ )
83
+ physical_to_logical_map.view(
84
+ num_moe_layers, num_chunks, num_physical_experts_per_chunk
85
+ )[:, :, :num_logical_experts_per_group] = arange_num_logical_experts.view(
86
+ num_chunks, num_logical_experts_per_group
87
+ )
88
+ logical_to_physical_map[:, :, 0] = (
89
+ arange_num_logical_experts_per_group.expand(
90
+ num_chunks, num_logical_experts_per_group
91
+ )
92
+ + arange_num_groups[:, None] * num_physical_experts_per_chunk
93
+ ).view(num_logical_experts)
94
+
95
+ tokens_per_expert_all_diff = tokens_per_expert + arange_num_logical_experts * 1e-4
96
+ for i in range(num_redundancy_experts_per_group):
97
+ score = (
98
+ tokens_per_expert_all_diff / logical_count
99
+ ) # NOTE: Values in score must be different from each other
100
+ score1 = tokens_per_expert / (logical_count + 1)
101
+ score = score.view(
102
+ num_steps, num_moe_layers, num_chunks, num_logical_experts_per_group
103
+ )
104
+ score1 = score1.view_as(score)
105
+ values, indices = score.max(-1, keepdim=True)
106
+ values = values.expand_as(score).contiguous()
107
+ score.scatter_(-1, indices, score1.gather(-1, indices))
108
+ values.scatter_(-1, indices, score.max(-1, keepdim=True).values)
109
+ redundancy_indices = values.sum(0).argmin(-1)
110
+ physical_to_logical_map.view(
111
+ num_moe_layers, num_chunks, num_physical_experts_per_chunk
112
+ )[:, :, num_logical_experts_per_group + i] = (
113
+ redundancy_indices + arange_num_groups * num_logical_experts_per_group
114
+ )
115
+ redundancy_count = (
116
+ logical_count.view(
117
+ num_moe_layers * num_chunks, num_logical_experts_per_group
118
+ )
119
+ .gather(-1, redundancy_indices.view(num_moe_layers * num_chunks, 1))
120
+ .squeeze(1)
121
+ )
122
+ physical_redundancy_indices = (
123
+ (
124
+ arange_num_groups * num_physical_experts_per_chunk
125
+ + num_logical_experts_per_group
126
+ + i
127
+ )
128
+ .expand(num_moe_layers, num_chunks)
129
+ .flatten()
130
+ )
131
+ logical_to_physical_map.view(
132
+ num_moe_layers * num_chunks,
133
+ num_logical_experts_per_group,
134
+ num_redundancy_experts + 1,
135
+ )[
136
+ arange_num_moe_layers_num_groups,
137
+ redundancy_indices.view(num_moe_layers * num_chunks),
138
+ redundancy_count,
139
+ ] = physical_redundancy_indices
140
+ logical_count.view(num_moe_layers * num_chunks, num_logical_experts_per_group)[
141
+ arange_num_moe_layers_num_groups,
142
+ redundancy_indices.view(num_moe_layers * num_chunks),
143
+ ] += 1
144
+
145
+ if num_local_physical_experts > 1:
146
+ # Load-balancing between GPUs
147
+ physical_to_logical_map_int64 = physical_to_logical_map.to(torch.int64)
148
+ counts = logical_count.gather(-1, physical_to_logical_map_int64)
149
+ score = tokens_per_expert.sum(0).gather(-1, physical_to_logical_map_int64)
150
+ score = score / counts
151
+ score = score.view(num_moe_layers, num_chunks, num_physical_experts_per_chunk)
152
+ indices = score.argsort(-1, descending=True)
153
+ indices += torch.arange(
154
+ 0,
155
+ num_physical_experts,
156
+ num_physical_experts_per_chunk,
157
+ dtype=indices.dtype,
158
+ device=indices.device,
159
+ )[None, :, None]
160
+
161
+ assert num_physical_experts_per_chunk % num_local_physical_experts == 0
162
+ num_local_groups = num_physical_experts_per_chunk // num_local_physical_experts
163
+ indices = indices.view(
164
+ num_moe_layers, num_chunks, num_local_physical_experts, num_local_groups
165
+ )
166
+ indices[:, :, 1::2, :] = indices[:, :, 1::2, :].flip(-1)
167
+ indices = indices.transpose(2, 3)
168
+ indices = indices.reshape(num_moe_layers, num_physical_experts)
169
+ physical_to_logical_map = physical_to_logical_map.gather(-1, indices)
170
+ mask = logical_to_physical_map == -1
171
+ logical_to_physical_map[mask] = 0
172
+ logical_to_physical_map = (
173
+ indices.argsort(-1)
174
+ .gather(
175
+ -1, logical_to_physical_map.view(num_moe_layers, -1).to(torch.int64)
176
+ )
177
+ .view_as(logical_to_physical_map)
178
+ .to(torch.int)
179
+ )
180
+ logical_to_physical_map[mask] = -1
181
+
182
+ return physical_to_logical_map, logical_to_physical_map, logical_count
183
+
184
+
185
+ def decode_rebalance_experts(
186
+ tokens_per_expert: torch.Tensor,
187
+ num_physical_experts: int,
188
+ num_local_physical_experts: int,
189
+ ):
190
+ return make_redundant_experts_chunkwise(
191
+ tokens_per_expert,
192
+ num_physical_experts,
193
+ num_local_physical_experts,
194
+ num_physical_experts,
195
+ )
196
+
197
+
198
+ def prefill_rebalance_experts(
199
+ tokens_per_expert: torch.Tensor,
200
+ num_physical_experts: int,
201
+ num_local_physical_experts: int,
202
+ num_groups: int,
203
+ num_nodes: int,
204
+ ):
205
+ tokens_per_expert = tokens_per_expert.float().cpu()
206
+
207
+ num_steps, _, num_logical_experts = tokens_per_expert.shape
208
+ assert num_logical_experts % num_groups == 0
209
+ group_size = num_logical_experts // num_groups
210
+ assert num_groups % num_nodes == 0, f"{num_groups=} {num_nodes=}"
211
+
212
+ tokens_per_group = tokens_per_expert.sum(0).unflatten(-1, (num_groups, -1)).sum(-1)
213
+ group_perm = pack_groups(
214
+ tokens_per_group, num_nodes
215
+ ) # [num_moe_layers, num_groups] => [num_moe_layers, num_nodes]
216
+
217
+ # log2mlog [layers, #logexp] -> [layers, #logexp]
218
+ log2mlog = (
219
+ (group_perm * group_size).unsqueeze(-1)
220
+ + torch.arange(group_size, dtype=torch.int64, device=group_perm.device)
221
+ ).flatten(-2)
222
+
223
+ # mlog2log [layers, #logexp] -> [layers, #logexp], inverse of log2mlog
224
+ mlog2log = torch.empty_like(log2mlog)
225
+ arange = torch.arange(
226
+ num_logical_experts, dtype=torch.int64, device=mlog2log.device
227
+ )
228
+ mlog2log.scatter_(1, log2mlog, arange.expand(log2mlog.size(0), -1))
229
+
230
+ # tokens_per_mlog[i][j][k] = tokens_per_expert[i][j][mlog2log[j][k]]
231
+ tokens_per_mlog = tokens_per_expert.gather(
232
+ 2, mlog2log.unsqueeze(0).expand(num_steps, -1, -1)
233
+ )
234
+
235
+ phy2mlog, mlog2phy, mlog_count = make_redundant_experts_chunkwise(
236
+ tokens_per_mlog,
237
+ num_physical_experts,
238
+ num_local_physical_experts,
239
+ num_physical_experts // num_nodes,
240
+ )
241
+
242
+ # phy2log[i][j] = mlog2log[i][phy2mlog[i][j]]
243
+ phy2log = mlog2log.gather(1, phy2mlog.to(torch.int64))
244
+
245
+ # mlog2phy: [num_moe_layers, num_logical_experts, ...]
246
+ # log2phy[i][j][k] = mlog2phy[i][log2mlog[i][j]][k]
247
+ log2phy = mlog2phy.gather(
248
+ 1, log2mlog.unsqueeze(-1).expand(-1, -1, mlog2phy.size(-1)).to(torch.int64)
249
+ )
250
+
251
+ # log_count[i][j] = mlog_count[i][log2mlog[i][j]]
252
+ log_count = mlog_count.gather(1, log2mlog)
253
+ return phy2log, log2phy, log_count
254
+
255
+
256
+ def rebalance_experts(
257
+ tokens_per_expert: torch.Tensor,
258
+ num_physical_experts: int,
259
+ num_local_physical_experts: int,
260
+ num_groups: int,
261
+ num_nodes: int,
262
+ phase: Literal["prefill", "decode"],
263
+ ):
264
+ if phase == "prefill":
265
+ return prefill_rebalance_experts(
266
+ tokens_per_expert=tokens_per_expert,
267
+ num_physical_experts=num_physical_experts,
268
+ num_local_physical_experts=num_local_physical_experts,
269
+ num_groups=num_groups,
270
+ num_nodes=num_nodes,
271
+ )
272
+ if phase == "decode":
273
+ return decode_rebalance_experts(
274
+ tokens_per_expert=tokens_per_expert,
275
+ num_physical_experts=num_physical_experts,
276
+ num_local_physical_experts=num_local_physical_experts,
277
+ )
278
+ raise NotImplementedError
@@ -0,0 +1,55 @@
1
+ import logging
2
+ import time
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch.cuda
6
+
7
+ from sglang.srt.managers.expert_distribution import (
8
+ get_global_expert_distribution_recorder,
9
+ )
10
+ from sglang.srt.managers.expert_location import ExpertLocationMetadata
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.model_executor.model_runner import ModelRunner
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class EPLBManager:
19
+ def __init__(self, model_runner: "ModelRunner"):
20
+ super().__init__()
21
+ self._model_runner = model_runner
22
+ self._server_args = model_runner.server_args
23
+
24
+ # Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
25
+ assert (
26
+ self._server_args.eplb_rebalance_num_iterations
27
+ <= self._server_args.expert_distribution_recorder_buffer_size
28
+ ), "eplb_rebalance_num_iterations must be less than expert_distribution_recorder_buffer_size"
29
+
30
+ get_global_expert_distribution_recorder().start_record()
31
+
32
+ logger.info(
33
+ f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations."
34
+ )
35
+
36
+ def on_forward_pass_end(self, forward_pass_id: int):
37
+ if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0:
38
+ self.rebalance()
39
+
40
+ def rebalance(self):
41
+ logger.info("[EPLBManager] rebalance start")
42
+ torch.cuda.synchronize()
43
+ time_start = time.time()
44
+
45
+ logical_count = get_global_expert_distribution_recorder().dump_record(
46
+ output_mode="object"
47
+ )["logical_count"]
48
+ expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
49
+ self._server_args, self._model_runner.model_config, logical_count
50
+ )
51
+ self._model_runner.update_expert_location(expert_location_metadata)
52
+
53
+ torch.cuda.synchronize()
54
+ time_end = time.time()
55
+ logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s")