sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,394 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import json
15
+ import logging
16
+ from dataclasses import dataclass
17
+ from pathlib import Path
18
+ from typing import List, Optional
19
+
20
+ import torch
21
+ import torch.distributed
22
+ import torch.nn.functional as F
23
+
24
+ from sglang.srt.configs.model_config import ModelConfig
25
+ from sglang.srt.managers import deepseek_eplb
26
+ from sglang.srt.model_loader import get_model_architecture
27
+ from sglang.srt.server_args import ServerArgs
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class ExpertLocationMetadata:
34
+ physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
35
+ logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
36
+ logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
37
+ logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
38
+
39
+ # -------------------------------- properties ------------------------------------
40
+
41
+ @property
42
+ def num_layers(self) -> int:
43
+ return self.physical_to_logical_map.shape[0]
44
+
45
+ @property
46
+ def num_physical_experts(self) -> int:
47
+ return self.physical_to_logical_map.shape[1]
48
+
49
+ @property
50
+ def num_local_physical_experts(self) -> int:
51
+ ans, remainder = divmod(self.num_physical_experts, self.ep_size)
52
+ assert remainder == 0
53
+ return ans
54
+
55
+ @property
56
+ def num_logical_experts(self) -> int:
57
+ return self.logical_to_all_physical_map.shape[1]
58
+
59
+ @property
60
+ def ep_size(self):
61
+ # TODO change when EP size != world size
62
+ return torch.distributed.get_world_size()
63
+
64
+ def __post_init__(self):
65
+ num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
66
+ num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
67
+ self.logical_to_all_physical_map.shape
68
+ )
69
+ num_layers_2, num_logical_experts_1 = (
70
+ self.logical_to_all_physical_map_num_valid.shape
71
+ )
72
+ num_layers_3, num_logical_experts_2 = (
73
+ self.logical_to_rank_dispatch_physical_map.shape
74
+ )
75
+ assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
76
+ assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
77
+ assert num_physical_experts_0 == num_physical_experts_1
78
+
79
+ # -------------------------------- construction ------------------------------------
80
+
81
+ @staticmethod
82
+ def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
83
+ """Trivial location - logical expert i corresponds to physical expert i"""
84
+ common = ExpertLocationMetadata._init_common(server_args, model_config)
85
+ num_physical_experts = common["num_physical_experts"]
86
+ model_config_for_expert_location = common["model_config_for_expert_location"]
87
+ num_layers = model_config_for_expert_location.num_layers
88
+ num_logical_experts = model_config_for_expert_location.num_logical_experts
89
+
90
+ physical_to_logical_map = (
91
+ torch.arange(0, num_physical_experts).repeat(num_layers, 1)
92
+ % num_logical_experts
93
+ )
94
+
95
+ return ExpertLocationMetadata.init_by_mapping(
96
+ server_args,
97
+ model_config,
98
+ physical_to_logical_map=physical_to_logical_map,
99
+ )
100
+
101
+ @staticmethod
102
+ def init_by_mapping(
103
+ server_args: ServerArgs,
104
+ model_config: ModelConfig,
105
+ physical_to_logical_map,
106
+ ):
107
+ if not isinstance(physical_to_logical_map, torch.Tensor):
108
+ physical_to_logical_map = torch.tensor(physical_to_logical_map)
109
+ physical_to_logical_map = physical_to_logical_map.to(server_args.device)
110
+
111
+ common = ExpertLocationMetadata._init_common(server_args, model_config)
112
+ model_config_for_expert_location = common["model_config_for_expert_location"]
113
+ logical_to_all_physical_map = _compute_logical_to_all_physical_map(
114
+ physical_to_logical_map,
115
+ num_logical_experts=model_config_for_expert_location.num_logical_experts,
116
+ )
117
+
118
+ return ExpertLocationMetadata._init_raw(
119
+ ep_size=common["ep_size"],
120
+ physical_to_logical_map=physical_to_logical_map,
121
+ logical_to_all_physical_map=logical_to_all_physical_map,
122
+ )
123
+
124
+ @staticmethod
125
+ def init_by_eplb(
126
+ server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor
127
+ ):
128
+ if not isinstance(logical_count, torch.Tensor):
129
+ logical_count = torch.tensor(logical_count)
130
+ if len(logical_count.shape) == 2:
131
+ logical_count = logical_count.unsqueeze(0)
132
+ logical_count = logical_count.to(server_args.device)
133
+
134
+ common = ExpertLocationMetadata._init_common(server_args, model_config)
135
+ model_config_for_expert_location = common["model_config_for_expert_location"]
136
+ num_physical_experts = common["num_physical_experts"]
137
+
138
+ phase = server_args.disaggregation_mode
139
+ if phase == "null":
140
+ phase = "decode"
141
+
142
+ physical_to_logical_map, logical_to_all_physical_map, expert_count = (
143
+ deepseek_eplb.rebalance_experts(
144
+ tokens_per_expert=logical_count,
145
+ num_physical_experts=num_physical_experts,
146
+ num_local_physical_experts=num_physical_experts // common["ep_size"],
147
+ num_groups=model_config_for_expert_location.num_groups,
148
+ num_nodes=server_args.nnodes,
149
+ phase=phase,
150
+ )
151
+ )
152
+
153
+ return ExpertLocationMetadata._init_raw(
154
+ ep_size=common["ep_size"],
155
+ physical_to_logical_map=physical_to_logical_map,
156
+ logical_to_all_physical_map=logical_to_all_physical_map,
157
+ )
158
+
159
+ @staticmethod
160
+ def _init_common(server_args: ServerArgs, model_config: ModelConfig):
161
+ model_config_for_expert_location = (
162
+ ModelConfigForExpertLocation.from_model_config(model_config)
163
+ )
164
+
165
+ num_physical_experts = (
166
+ model_config_for_expert_location.num_logical_experts
167
+ + server_args.ep_num_redundant_experts
168
+ )
169
+ ep_size = server_args.ep_size
170
+ assert num_physical_experts % ep_size == 0
171
+ num_local_physical_experts = num_physical_experts // ep_size
172
+
173
+ return dict(
174
+ model_config_for_expert_location=model_config_for_expert_location,
175
+ num_physical_experts=num_physical_experts,
176
+ num_local_physical_experts=num_local_physical_experts,
177
+ ep_size=ep_size,
178
+ )
179
+
180
+ @staticmethod
181
+ def _init_raw(
182
+ ep_size: int,
183
+ physical_to_logical_map: torch.Tensor,
184
+ logical_to_all_physical_map: torch.Tensor,
185
+ ):
186
+ _, num_physical_experts = physical_to_logical_map.shape
187
+
188
+ logical_to_all_physical_map_padded = F.pad(
189
+ logical_to_all_physical_map,
190
+ (0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
191
+ value=-1,
192
+ )
193
+
194
+ logical_to_all_physical_map_num_valid = torch.count_nonzero(
195
+ logical_to_all_physical_map != -1, dim=-1
196
+ )
197
+
198
+ return ExpertLocationMetadata(
199
+ physical_to_logical_map=physical_to_logical_map,
200
+ logical_to_all_physical_map=logical_to_all_physical_map_padded,
201
+ logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
202
+ logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
203
+ logical_to_all_physical_map=logical_to_all_physical_map,
204
+ logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
205
+ num_gpus=ep_size,
206
+ num_physical_experts=num_physical_experts,
207
+ ep_rank=torch.distributed.get_rank(),
208
+ ),
209
+ )
210
+
211
+ # -------------------------------- mutation ------------------------------------
212
+
213
+ def update(
214
+ self,
215
+ other: "ExpertLocationMetadata",
216
+ ):
217
+ for field in [
218
+ "ep_size",
219
+ ]:
220
+ assert getattr(self, field) == getattr(other, field)
221
+
222
+ for field in [
223
+ "physical_to_logical_map",
224
+ "logical_to_all_physical_map",
225
+ "logical_to_all_physical_map_num_valid",
226
+ "logical_to_rank_dispatch_physical_map",
227
+ ]:
228
+ dst = getattr(self, field)
229
+ dst[...] = getattr(other, field)
230
+
231
+ # -------------------------------- usage ------------------------------------
232
+
233
+ def logical_to_all_physical(
234
+ self, layer_id: int, logical_expert_id: int
235
+ ) -> List[int]:
236
+ return [
237
+ physical_expert_id
238
+ for physical_expert_id in self.logical_to_all_physical_map[
239
+ layer_id, logical_expert_id
240
+ ].tolist()
241
+ if physical_expert_id != -1
242
+ ]
243
+
244
+
245
+ _global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
246
+
247
+
248
+ def get_global_expert_location_metadata():
249
+ return _global_expert_location_metadata
250
+
251
+
252
+ def set_global_expert_location_metadata(value):
253
+ global _global_expert_location_metadata
254
+ assert _global_expert_location_metadata is None
255
+ _global_expert_location_metadata = value
256
+
257
+
258
+ def _compute_logical_to_all_physical_map(
259
+ physical_to_logical_map: torch.Tensor, num_logical_experts: int
260
+ ):
261
+ # This is rarely called, so we use for loops for maximum clarity
262
+
263
+ num_layers, num_physical_experts = physical_to_logical_map.shape
264
+
265
+ logical_to_all_physical_map = [
266
+ [[] for _ in range(num_logical_experts)] for _ in range(num_layers)
267
+ ]
268
+ for layer_id in range(num_layers):
269
+ for physical_expert_id in range(num_physical_experts):
270
+ logical_expert_id = physical_to_logical_map[
271
+ layer_id, physical_expert_id
272
+ ].item()
273
+ logical_to_all_physical_map[layer_id][logical_expert_id].append(
274
+ physical_expert_id
275
+ )
276
+
277
+ logical_to_all_physical_map = _pad_nested_array(
278
+ logical_to_all_physical_map, pad_value=-1
279
+ )
280
+
281
+ return torch.tensor(
282
+ logical_to_all_physical_map, device=physical_to_logical_map.device
283
+ )
284
+
285
+
286
+ def _pad_nested_array(arr, pad_value):
287
+ max_len = max(len(inner) for outer in arr for inner in outer)
288
+ padded = [
289
+ [inner + [pad_value] * (max_len - len(inner)) for inner in outer]
290
+ for outer in arr
291
+ ]
292
+ return padded
293
+
294
+
295
+ # TODO use more sophisticated approaches
296
+ def compute_logical_to_rank_dispatch_physical_map(
297
+ logical_to_all_physical_map: torch.Tensor,
298
+ logical_to_all_physical_map_num_valid: torch.Tensor,
299
+ num_gpus: int,
300
+ num_physical_experts: int,
301
+ ep_rank: int,
302
+ base_seed: int = 42,
303
+ ):
304
+ device = logical_to_all_physical_map.device
305
+
306
+ num_local_physical_experts = num_physical_experts // num_gpus
307
+ num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
308
+
309
+ g = torch.Generator(device=device)
310
+ g.manual_seed(base_seed + ep_rank)
311
+
312
+ output_shape = (num_layers, num_logical_experts)
313
+ chosen_index = (
314
+ torch.randint(
315
+ 0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
316
+ )
317
+ % logical_to_all_physical_map_num_valid
318
+ )
319
+ logical_to_rank_dispatch_physical_map = torch.gather(
320
+ logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
321
+ ).squeeze(-1)
322
+ assert logical_to_rank_dispatch_physical_map.shape == output_shape
323
+
324
+ for index in range(logical_to_all_physical_map_num_valid.max().item()):
325
+ partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
326
+ is_valid = partial_logical_to_all_physical_map != -1
327
+ is_same_gpu = (
328
+ partial_logical_to_all_physical_map // num_local_physical_experts
329
+ ) == ep_rank
330
+ logical_to_rank_dispatch_physical_map = torch.where(
331
+ is_valid & is_same_gpu,
332
+ partial_logical_to_all_physical_map,
333
+ logical_to_rank_dispatch_physical_map,
334
+ )
335
+
336
+ assert torch.all(logical_to_rank_dispatch_physical_map != -1)
337
+ return logical_to_rank_dispatch_physical_map
338
+
339
+
340
+ @dataclass
341
+ class ModelConfigForExpertLocation:
342
+ num_layers: int
343
+ num_logical_experts: int
344
+ num_groups: Optional[int] = None
345
+
346
+ @staticmethod
347
+ def init_dummy():
348
+ return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
349
+
350
+ @staticmethod
351
+ def from_model_config(model_config: ModelConfig):
352
+ model_class, _ = get_model_architecture(model_config)
353
+ if hasattr(model_class, "get_model_config_for_expert_location"):
354
+ return model_class.get_model_config_for_expert_location(
355
+ model_config.hf_config
356
+ )
357
+ else:
358
+ return ModelConfigForExpertLocation.init_dummy()
359
+
360
+
361
+ def compute_initial_expert_location_metadata(
362
+ server_args: ServerArgs, model_config: ModelConfig
363
+ ) -> ExpertLocationMetadata:
364
+ data = server_args.init_expert_location
365
+ if data == "trivial":
366
+ logger.info("init_expert_location from trivial")
367
+ return ExpertLocationMetadata.init_trivial(server_args, model_config)
368
+
369
+ # TODO unify with the utils function
370
+ if data.endswith(".pt"):
371
+ data_dict = torch.load(data, weights_only=True)
372
+ elif data.endswith(".json"):
373
+ data_dict = json.loads(Path(data).read_text())
374
+ else:
375
+ data_dict = json.loads(data)
376
+
377
+ if "physical_to_logical_map" in data_dict:
378
+ logger.info(
379
+ "init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
380
+ )
381
+ return ExpertLocationMetadata.init_by_mapping(
382
+ server_args, model_config, **data_dict
383
+ )
384
+ elif "logical_count" in data_dict:
385
+ logger.info(
386
+ "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
387
+ )
388
+ return ExpertLocationMetadata.init_by_eplb(
389
+ server_args, model_config, logical_count=data_dict["logical_count"]
390
+ )
391
+ else:
392
+ raise NotImplementedError(
393
+ f"Unknown init_expert_location format ({list(data_dict.keys())=})"
394
+ )
@@ -0,0 +1,91 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Literal, Optional
17
+
18
+ import torch
19
+
20
+ from sglang.srt.managers.expert_location import get_global_expert_location_metadata
21
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
22
+
23
+
24
+ @dataclass
25
+ class ExpertLocationDispatchInfo:
26
+ ep_dispatch_algorithm: Literal["static", "random"]
27
+ # (num_logical_experts,)
28
+ partial_logical_to_rank_dispatch_physical_map: torch.Tensor
29
+ # (num_logical_experts, X)
30
+ partial_logical_to_all_physical_map: torch.Tensor
31
+ # (num_logical_experts,)
32
+ partial_logical_to_all_physical_map_num_valid: torch.Tensor
33
+ num_physical_experts: int
34
+
35
+ @classmethod
36
+ def init_new(cls, layer_id: int):
37
+ ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
38
+ expert_location_metadata = get_global_expert_location_metadata()
39
+
40
+ if ep_dispatch_algorithm is None:
41
+ return None
42
+
43
+ return cls(
44
+ ep_dispatch_algorithm=ep_dispatch_algorithm,
45
+ partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
46
+ layer_id, :
47
+ ],
48
+ partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
49
+ layer_id, :
50
+ ],
51
+ partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[
52
+ layer_id, :
53
+ ],
54
+ num_physical_experts=expert_location_metadata.num_physical_experts,
55
+ )
56
+
57
+
58
+ def topk_ids_logical_to_physical(
59
+ topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
60
+ ) -> torch.Tensor:
61
+ if info is None:
62
+ return topk_ids
63
+
64
+ if info.ep_dispatch_algorithm == "static":
65
+ return _topk_ids_logical_to_physical_static(topk_ids, info)
66
+ if info.ep_dispatch_algorithm == "dynamic":
67
+ return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
68
+ raise NotImplementedError
69
+
70
+
71
+ def _topk_ids_logical_to_physical_static(
72
+ topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
73
+ ) -> torch.Tensor:
74
+ return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]
75
+
76
+
77
+ def _topk_ids_logical_to_physical_dynamic(
78
+ topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
79
+ ) -> torch.Tensor:
80
+ topk_ids_original_shape = topk_ids.shape
81
+ device = topk_ids.device
82
+ topk_ids = topk_ids.flatten()
83
+
84
+ chosen_dispatch_index = (
85
+ torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)
86
+ % info.partial_logical_to_all_physical_map_num_valid[topk_ids]
87
+ )
88
+ topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]
89
+
90
+ topk_ids = topk_ids.view(topk_ids_original_shape)
91
+ return topk_ids
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """
15
- The definition of objects transfered between different
15
+ The definition of objects transferred between different
16
16
  processes (TokenizerManager, DetokenizerManager, Controller).
17
17
  """
18
18
 
@@ -22,13 +22,15 @@ from dataclasses import dataclass, field
22
22
  from enum import Enum
23
23
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
24
24
 
25
+ from sglang.srt.mm_utils import has_valid_data
26
+
25
27
  # handle serialization of Image for pydantic
26
28
  if TYPE_CHECKING:
27
29
  from PIL.Image import Image
28
30
  else:
29
31
  Image = Any
30
32
 
31
- from sglang.srt.managers.schedule_batch import BaseFinishReason
33
+ from sglang.srt.managers.schedule_batch import BaseFinishReason, flatten_nested_list
32
34
  from sglang.srt.sampling.sampling_params import SamplingParams
33
35
 
34
36
 
@@ -40,6 +42,10 @@ class SessionParams:
40
42
  replace: Optional[bool] = None
41
43
 
42
44
 
45
+ AudioDataItem = Union[str, Dict]
46
+ ImageDataItem = Union[Image, str, Dict]
47
+
48
+
43
49
  @dataclass
44
50
  class GenerateReqInput:
45
51
  # The input prompt. It can be a single prompt or a batch of prompts.
@@ -55,10 +61,10 @@ class GenerateReqInput:
55
61
  # - List of lists of images (multiple images per request)
56
62
  # See also python/sglang/srt/utils.py:load_image for more details.
57
63
  image_data: Optional[
58
- Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
64
+ Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
59
65
  ] = None
60
66
  # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
61
- audio_data: Optional[Union[List[str], str]] = None
67
+ audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
62
68
  # The sampling_params. See descriptions below.
63
69
  sampling_params: Optional[Union[List[Dict], Dict]] = None
64
70
  # The request id.
@@ -100,6 +106,9 @@ class GenerateReqInput:
100
106
  bootstrap_port: Optional[Union[List[int], int]] = None
101
107
  bootstrap_room: Optional[Union[List[int], int]] = None
102
108
 
109
+ def contains_mm_input(self) -> bool:
110
+ return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
111
+
103
112
  def normalize_batch_and_arguments(self):
104
113
  """
105
114
  Normalize the batch size and arguments for the request.
@@ -398,6 +407,7 @@ class GenerateReqInput:
398
407
  else None
399
408
  ),
400
409
  return_hidden_states=self.return_hidden_states,
410
+ # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
401
411
  bootstrap_host=(
402
412
  self.bootstrap_host[i] if self.bootstrap_host is not None else None
403
413
  ),
@@ -483,6 +493,9 @@ class EmbeddingReqInput:
483
493
  # The modalities of the image data [image, multi-images, video]
484
494
  modalities: Optional[List[str]] = None
485
495
 
496
+ def contains_mm_input(self) -> bool:
497
+ return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
498
+
486
499
  def normalize_batch_and_arguments(self):
487
500
  # at least one of text, input_ids, or image should be provided
488
501
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -836,6 +849,8 @@ class ProfileReqInput:
836
849
  # the caller doesn't need to run stop_profile.
837
850
  num_steps: Optional[int] = None
838
851
  activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
852
+ with_stack: Optional[bool] = None
853
+ record_shapes: Optional[bool] = None
839
854
 
840
855
 
841
856
  class ProfileReqType(Enum):