sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,394 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
import json
|
15
|
+
import logging
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from pathlib import Path
|
18
|
+
from typing import List, Optional
|
19
|
+
|
20
|
+
import torch
|
21
|
+
import torch.distributed
|
22
|
+
import torch.nn.functional as F
|
23
|
+
|
24
|
+
from sglang.srt.configs.model_config import ModelConfig
|
25
|
+
from sglang.srt.managers import deepseek_eplb
|
26
|
+
from sglang.srt.model_loader import get_model_architecture
|
27
|
+
from sglang.srt.server_args import ServerArgs
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class ExpertLocationMetadata:
|
34
|
+
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
35
|
+
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
36
|
+
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
37
|
+
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
|
38
|
+
|
39
|
+
# -------------------------------- properties ------------------------------------
|
40
|
+
|
41
|
+
@property
|
42
|
+
def num_layers(self) -> int:
|
43
|
+
return self.physical_to_logical_map.shape[0]
|
44
|
+
|
45
|
+
@property
|
46
|
+
def num_physical_experts(self) -> int:
|
47
|
+
return self.physical_to_logical_map.shape[1]
|
48
|
+
|
49
|
+
@property
|
50
|
+
def num_local_physical_experts(self) -> int:
|
51
|
+
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
|
52
|
+
assert remainder == 0
|
53
|
+
return ans
|
54
|
+
|
55
|
+
@property
|
56
|
+
def num_logical_experts(self) -> int:
|
57
|
+
return self.logical_to_all_physical_map.shape[1]
|
58
|
+
|
59
|
+
@property
|
60
|
+
def ep_size(self):
|
61
|
+
# TODO change when EP size != world size
|
62
|
+
return torch.distributed.get_world_size()
|
63
|
+
|
64
|
+
def __post_init__(self):
|
65
|
+
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
|
66
|
+
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
|
67
|
+
self.logical_to_all_physical_map.shape
|
68
|
+
)
|
69
|
+
num_layers_2, num_logical_experts_1 = (
|
70
|
+
self.logical_to_all_physical_map_num_valid.shape
|
71
|
+
)
|
72
|
+
num_layers_3, num_logical_experts_2 = (
|
73
|
+
self.logical_to_rank_dispatch_physical_map.shape
|
74
|
+
)
|
75
|
+
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
|
76
|
+
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
|
77
|
+
assert num_physical_experts_0 == num_physical_experts_1
|
78
|
+
|
79
|
+
# -------------------------------- construction ------------------------------------
|
80
|
+
|
81
|
+
@staticmethod
|
82
|
+
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
|
83
|
+
"""Trivial location - logical expert i corresponds to physical expert i"""
|
84
|
+
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
85
|
+
num_physical_experts = common["num_physical_experts"]
|
86
|
+
model_config_for_expert_location = common["model_config_for_expert_location"]
|
87
|
+
num_layers = model_config_for_expert_location.num_layers
|
88
|
+
num_logical_experts = model_config_for_expert_location.num_logical_experts
|
89
|
+
|
90
|
+
physical_to_logical_map = (
|
91
|
+
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
|
92
|
+
% num_logical_experts
|
93
|
+
)
|
94
|
+
|
95
|
+
return ExpertLocationMetadata.init_by_mapping(
|
96
|
+
server_args,
|
97
|
+
model_config,
|
98
|
+
physical_to_logical_map=physical_to_logical_map,
|
99
|
+
)
|
100
|
+
|
101
|
+
@staticmethod
|
102
|
+
def init_by_mapping(
|
103
|
+
server_args: ServerArgs,
|
104
|
+
model_config: ModelConfig,
|
105
|
+
physical_to_logical_map,
|
106
|
+
):
|
107
|
+
if not isinstance(physical_to_logical_map, torch.Tensor):
|
108
|
+
physical_to_logical_map = torch.tensor(physical_to_logical_map)
|
109
|
+
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
|
110
|
+
|
111
|
+
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
112
|
+
model_config_for_expert_location = common["model_config_for_expert_location"]
|
113
|
+
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
|
114
|
+
physical_to_logical_map,
|
115
|
+
num_logical_experts=model_config_for_expert_location.num_logical_experts,
|
116
|
+
)
|
117
|
+
|
118
|
+
return ExpertLocationMetadata._init_raw(
|
119
|
+
ep_size=common["ep_size"],
|
120
|
+
physical_to_logical_map=physical_to_logical_map,
|
121
|
+
logical_to_all_physical_map=logical_to_all_physical_map,
|
122
|
+
)
|
123
|
+
|
124
|
+
@staticmethod
|
125
|
+
def init_by_eplb(
|
126
|
+
server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor
|
127
|
+
):
|
128
|
+
if not isinstance(logical_count, torch.Tensor):
|
129
|
+
logical_count = torch.tensor(logical_count)
|
130
|
+
if len(logical_count.shape) == 2:
|
131
|
+
logical_count = logical_count.unsqueeze(0)
|
132
|
+
logical_count = logical_count.to(server_args.device)
|
133
|
+
|
134
|
+
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
135
|
+
model_config_for_expert_location = common["model_config_for_expert_location"]
|
136
|
+
num_physical_experts = common["num_physical_experts"]
|
137
|
+
|
138
|
+
phase = server_args.disaggregation_mode
|
139
|
+
if phase == "null":
|
140
|
+
phase = "decode"
|
141
|
+
|
142
|
+
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
|
143
|
+
deepseek_eplb.rebalance_experts(
|
144
|
+
tokens_per_expert=logical_count,
|
145
|
+
num_physical_experts=num_physical_experts,
|
146
|
+
num_local_physical_experts=num_physical_experts // common["ep_size"],
|
147
|
+
num_groups=model_config_for_expert_location.num_groups,
|
148
|
+
num_nodes=server_args.nnodes,
|
149
|
+
phase=phase,
|
150
|
+
)
|
151
|
+
)
|
152
|
+
|
153
|
+
return ExpertLocationMetadata._init_raw(
|
154
|
+
ep_size=common["ep_size"],
|
155
|
+
physical_to_logical_map=physical_to_logical_map,
|
156
|
+
logical_to_all_physical_map=logical_to_all_physical_map,
|
157
|
+
)
|
158
|
+
|
159
|
+
@staticmethod
|
160
|
+
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
|
161
|
+
model_config_for_expert_location = (
|
162
|
+
ModelConfigForExpertLocation.from_model_config(model_config)
|
163
|
+
)
|
164
|
+
|
165
|
+
num_physical_experts = (
|
166
|
+
model_config_for_expert_location.num_logical_experts
|
167
|
+
+ server_args.ep_num_redundant_experts
|
168
|
+
)
|
169
|
+
ep_size = server_args.ep_size
|
170
|
+
assert num_physical_experts % ep_size == 0
|
171
|
+
num_local_physical_experts = num_physical_experts // ep_size
|
172
|
+
|
173
|
+
return dict(
|
174
|
+
model_config_for_expert_location=model_config_for_expert_location,
|
175
|
+
num_physical_experts=num_physical_experts,
|
176
|
+
num_local_physical_experts=num_local_physical_experts,
|
177
|
+
ep_size=ep_size,
|
178
|
+
)
|
179
|
+
|
180
|
+
@staticmethod
|
181
|
+
def _init_raw(
|
182
|
+
ep_size: int,
|
183
|
+
physical_to_logical_map: torch.Tensor,
|
184
|
+
logical_to_all_physical_map: torch.Tensor,
|
185
|
+
):
|
186
|
+
_, num_physical_experts = physical_to_logical_map.shape
|
187
|
+
|
188
|
+
logical_to_all_physical_map_padded = F.pad(
|
189
|
+
logical_to_all_physical_map,
|
190
|
+
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
|
191
|
+
value=-1,
|
192
|
+
)
|
193
|
+
|
194
|
+
logical_to_all_physical_map_num_valid = torch.count_nonzero(
|
195
|
+
logical_to_all_physical_map != -1, dim=-1
|
196
|
+
)
|
197
|
+
|
198
|
+
return ExpertLocationMetadata(
|
199
|
+
physical_to_logical_map=physical_to_logical_map,
|
200
|
+
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
201
|
+
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
202
|
+
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
|
203
|
+
logical_to_all_physical_map=logical_to_all_physical_map,
|
204
|
+
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
205
|
+
num_gpus=ep_size,
|
206
|
+
num_physical_experts=num_physical_experts,
|
207
|
+
ep_rank=torch.distributed.get_rank(),
|
208
|
+
),
|
209
|
+
)
|
210
|
+
|
211
|
+
# -------------------------------- mutation ------------------------------------
|
212
|
+
|
213
|
+
def update(
|
214
|
+
self,
|
215
|
+
other: "ExpertLocationMetadata",
|
216
|
+
):
|
217
|
+
for field in [
|
218
|
+
"ep_size",
|
219
|
+
]:
|
220
|
+
assert getattr(self, field) == getattr(other, field)
|
221
|
+
|
222
|
+
for field in [
|
223
|
+
"physical_to_logical_map",
|
224
|
+
"logical_to_all_physical_map",
|
225
|
+
"logical_to_all_physical_map_num_valid",
|
226
|
+
"logical_to_rank_dispatch_physical_map",
|
227
|
+
]:
|
228
|
+
dst = getattr(self, field)
|
229
|
+
dst[...] = getattr(other, field)
|
230
|
+
|
231
|
+
# -------------------------------- usage ------------------------------------
|
232
|
+
|
233
|
+
def logical_to_all_physical(
|
234
|
+
self, layer_id: int, logical_expert_id: int
|
235
|
+
) -> List[int]:
|
236
|
+
return [
|
237
|
+
physical_expert_id
|
238
|
+
for physical_expert_id in self.logical_to_all_physical_map[
|
239
|
+
layer_id, logical_expert_id
|
240
|
+
].tolist()
|
241
|
+
if physical_expert_id != -1
|
242
|
+
]
|
243
|
+
|
244
|
+
|
245
|
+
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
|
246
|
+
|
247
|
+
|
248
|
+
def get_global_expert_location_metadata():
|
249
|
+
return _global_expert_location_metadata
|
250
|
+
|
251
|
+
|
252
|
+
def set_global_expert_location_metadata(value):
|
253
|
+
global _global_expert_location_metadata
|
254
|
+
assert _global_expert_location_metadata is None
|
255
|
+
_global_expert_location_metadata = value
|
256
|
+
|
257
|
+
|
258
|
+
def _compute_logical_to_all_physical_map(
|
259
|
+
physical_to_logical_map: torch.Tensor, num_logical_experts: int
|
260
|
+
):
|
261
|
+
# This is rarely called, so we use for loops for maximum clarity
|
262
|
+
|
263
|
+
num_layers, num_physical_experts = physical_to_logical_map.shape
|
264
|
+
|
265
|
+
logical_to_all_physical_map = [
|
266
|
+
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
|
267
|
+
]
|
268
|
+
for layer_id in range(num_layers):
|
269
|
+
for physical_expert_id in range(num_physical_experts):
|
270
|
+
logical_expert_id = physical_to_logical_map[
|
271
|
+
layer_id, physical_expert_id
|
272
|
+
].item()
|
273
|
+
logical_to_all_physical_map[layer_id][logical_expert_id].append(
|
274
|
+
physical_expert_id
|
275
|
+
)
|
276
|
+
|
277
|
+
logical_to_all_physical_map = _pad_nested_array(
|
278
|
+
logical_to_all_physical_map, pad_value=-1
|
279
|
+
)
|
280
|
+
|
281
|
+
return torch.tensor(
|
282
|
+
logical_to_all_physical_map, device=physical_to_logical_map.device
|
283
|
+
)
|
284
|
+
|
285
|
+
|
286
|
+
def _pad_nested_array(arr, pad_value):
|
287
|
+
max_len = max(len(inner) for outer in arr for inner in outer)
|
288
|
+
padded = [
|
289
|
+
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
|
290
|
+
for outer in arr
|
291
|
+
]
|
292
|
+
return padded
|
293
|
+
|
294
|
+
|
295
|
+
# TODO use more sophisticated approaches
|
296
|
+
def compute_logical_to_rank_dispatch_physical_map(
|
297
|
+
logical_to_all_physical_map: torch.Tensor,
|
298
|
+
logical_to_all_physical_map_num_valid: torch.Tensor,
|
299
|
+
num_gpus: int,
|
300
|
+
num_physical_experts: int,
|
301
|
+
ep_rank: int,
|
302
|
+
base_seed: int = 42,
|
303
|
+
):
|
304
|
+
device = logical_to_all_physical_map.device
|
305
|
+
|
306
|
+
num_local_physical_experts = num_physical_experts // num_gpus
|
307
|
+
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
308
|
+
|
309
|
+
g = torch.Generator(device=device)
|
310
|
+
g.manual_seed(base_seed + ep_rank)
|
311
|
+
|
312
|
+
output_shape = (num_layers, num_logical_experts)
|
313
|
+
chosen_index = (
|
314
|
+
torch.randint(
|
315
|
+
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
|
316
|
+
)
|
317
|
+
% logical_to_all_physical_map_num_valid
|
318
|
+
)
|
319
|
+
logical_to_rank_dispatch_physical_map = torch.gather(
|
320
|
+
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
|
321
|
+
).squeeze(-1)
|
322
|
+
assert logical_to_rank_dispatch_physical_map.shape == output_shape
|
323
|
+
|
324
|
+
for index in range(logical_to_all_physical_map_num_valid.max().item()):
|
325
|
+
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
|
326
|
+
is_valid = partial_logical_to_all_physical_map != -1
|
327
|
+
is_same_gpu = (
|
328
|
+
partial_logical_to_all_physical_map // num_local_physical_experts
|
329
|
+
) == ep_rank
|
330
|
+
logical_to_rank_dispatch_physical_map = torch.where(
|
331
|
+
is_valid & is_same_gpu,
|
332
|
+
partial_logical_to_all_physical_map,
|
333
|
+
logical_to_rank_dispatch_physical_map,
|
334
|
+
)
|
335
|
+
|
336
|
+
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
|
337
|
+
return logical_to_rank_dispatch_physical_map
|
338
|
+
|
339
|
+
|
340
|
+
@dataclass
|
341
|
+
class ModelConfigForExpertLocation:
|
342
|
+
num_layers: int
|
343
|
+
num_logical_experts: int
|
344
|
+
num_groups: Optional[int] = None
|
345
|
+
|
346
|
+
@staticmethod
|
347
|
+
def init_dummy():
|
348
|
+
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
|
349
|
+
|
350
|
+
@staticmethod
|
351
|
+
def from_model_config(model_config: ModelConfig):
|
352
|
+
model_class, _ = get_model_architecture(model_config)
|
353
|
+
if hasattr(model_class, "get_model_config_for_expert_location"):
|
354
|
+
return model_class.get_model_config_for_expert_location(
|
355
|
+
model_config.hf_config
|
356
|
+
)
|
357
|
+
else:
|
358
|
+
return ModelConfigForExpertLocation.init_dummy()
|
359
|
+
|
360
|
+
|
361
|
+
def compute_initial_expert_location_metadata(
|
362
|
+
server_args: ServerArgs, model_config: ModelConfig
|
363
|
+
) -> ExpertLocationMetadata:
|
364
|
+
data = server_args.init_expert_location
|
365
|
+
if data == "trivial":
|
366
|
+
logger.info("init_expert_location from trivial")
|
367
|
+
return ExpertLocationMetadata.init_trivial(server_args, model_config)
|
368
|
+
|
369
|
+
# TODO unify with the utils function
|
370
|
+
if data.endswith(".pt"):
|
371
|
+
data_dict = torch.load(data, weights_only=True)
|
372
|
+
elif data.endswith(".json"):
|
373
|
+
data_dict = json.loads(Path(data).read_text())
|
374
|
+
else:
|
375
|
+
data_dict = json.loads(data)
|
376
|
+
|
377
|
+
if "physical_to_logical_map" in data_dict:
|
378
|
+
logger.info(
|
379
|
+
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
|
380
|
+
)
|
381
|
+
return ExpertLocationMetadata.init_by_mapping(
|
382
|
+
server_args, model_config, **data_dict
|
383
|
+
)
|
384
|
+
elif "logical_count" in data_dict:
|
385
|
+
logger.info(
|
386
|
+
"init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
|
387
|
+
)
|
388
|
+
return ExpertLocationMetadata.init_by_eplb(
|
389
|
+
server_args, model_config, logical_count=data_dict["logical_count"]
|
390
|
+
)
|
391
|
+
else:
|
392
|
+
raise NotImplementedError(
|
393
|
+
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
|
394
|
+
)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Literal, Optional
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
21
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class ExpertLocationDispatchInfo:
|
26
|
+
ep_dispatch_algorithm: Literal["static", "random"]
|
27
|
+
# (num_logical_experts,)
|
28
|
+
partial_logical_to_rank_dispatch_physical_map: torch.Tensor
|
29
|
+
# (num_logical_experts, X)
|
30
|
+
partial_logical_to_all_physical_map: torch.Tensor
|
31
|
+
# (num_logical_experts,)
|
32
|
+
partial_logical_to_all_physical_map_num_valid: torch.Tensor
|
33
|
+
num_physical_experts: int
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def init_new(cls, layer_id: int):
|
37
|
+
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
|
38
|
+
expert_location_metadata = get_global_expert_location_metadata()
|
39
|
+
|
40
|
+
if ep_dispatch_algorithm is None:
|
41
|
+
return None
|
42
|
+
|
43
|
+
return cls(
|
44
|
+
ep_dispatch_algorithm=ep_dispatch_algorithm,
|
45
|
+
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
|
46
|
+
layer_id, :
|
47
|
+
],
|
48
|
+
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
|
49
|
+
layer_id, :
|
50
|
+
],
|
51
|
+
partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[
|
52
|
+
layer_id, :
|
53
|
+
],
|
54
|
+
num_physical_experts=expert_location_metadata.num_physical_experts,
|
55
|
+
)
|
56
|
+
|
57
|
+
|
58
|
+
def topk_ids_logical_to_physical(
|
59
|
+
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
60
|
+
) -> torch.Tensor:
|
61
|
+
if info is None:
|
62
|
+
return topk_ids
|
63
|
+
|
64
|
+
if info.ep_dispatch_algorithm == "static":
|
65
|
+
return _topk_ids_logical_to_physical_static(topk_ids, info)
|
66
|
+
if info.ep_dispatch_algorithm == "dynamic":
|
67
|
+
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
|
68
|
+
raise NotImplementedError
|
69
|
+
|
70
|
+
|
71
|
+
def _topk_ids_logical_to_physical_static(
|
72
|
+
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
73
|
+
) -> torch.Tensor:
|
74
|
+
return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]
|
75
|
+
|
76
|
+
|
77
|
+
def _topk_ids_logical_to_physical_dynamic(
|
78
|
+
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
79
|
+
) -> torch.Tensor:
|
80
|
+
topk_ids_original_shape = topk_ids.shape
|
81
|
+
device = topk_ids.device
|
82
|
+
topk_ids = topk_ids.flatten()
|
83
|
+
|
84
|
+
chosen_dispatch_index = (
|
85
|
+
torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)
|
86
|
+
% info.partial_logical_to_all_physical_map_num_valid[topk_ids]
|
87
|
+
)
|
88
|
+
topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]
|
89
|
+
|
90
|
+
topk_ids = topk_ids.view(topk_ids_original_shape)
|
91
|
+
return topk_ids
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -22,13 +22,15 @@ from dataclasses import dataclass, field
|
|
22
22
|
from enum import Enum
|
23
23
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
24
24
|
|
25
|
+
from sglang.srt.mm_utils import has_valid_data
|
26
|
+
|
25
27
|
# handle serialization of Image for pydantic
|
26
28
|
if TYPE_CHECKING:
|
27
29
|
from PIL.Image import Image
|
28
30
|
else:
|
29
31
|
Image = Any
|
30
32
|
|
31
|
-
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
33
|
+
from sglang.srt.managers.schedule_batch import BaseFinishReason, flatten_nested_list
|
32
34
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
33
35
|
|
34
36
|
|
@@ -40,6 +42,10 @@ class SessionParams:
|
|
40
42
|
replace: Optional[bool] = None
|
41
43
|
|
42
44
|
|
45
|
+
AudioDataItem = Union[str, Dict]
|
46
|
+
ImageDataItem = Union[Image, str, Dict]
|
47
|
+
|
48
|
+
|
43
49
|
@dataclass
|
44
50
|
class GenerateReqInput:
|
45
51
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
@@ -55,10 +61,10 @@ class GenerateReqInput:
|
|
55
61
|
# - List of lists of images (multiple images per request)
|
56
62
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
57
63
|
image_data: Optional[
|
58
|
-
Union[List[List[
|
64
|
+
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
|
59
65
|
] = None
|
60
66
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
61
|
-
audio_data: Optional[Union[List[
|
67
|
+
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
62
68
|
# The sampling_params. See descriptions below.
|
63
69
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
64
70
|
# The request id.
|
@@ -100,6 +106,9 @@ class GenerateReqInput:
|
|
100
106
|
bootstrap_port: Optional[Union[List[int], int]] = None
|
101
107
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
102
108
|
|
109
|
+
def contains_mm_input(self) -> bool:
|
110
|
+
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
111
|
+
|
103
112
|
def normalize_batch_and_arguments(self):
|
104
113
|
"""
|
105
114
|
Normalize the batch size and arguments for the request.
|
@@ -398,6 +407,7 @@ class GenerateReqInput:
|
|
398
407
|
else None
|
399
408
|
),
|
400
409
|
return_hidden_states=self.return_hidden_states,
|
410
|
+
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
401
411
|
bootstrap_host=(
|
402
412
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
403
413
|
),
|
@@ -483,6 +493,9 @@ class EmbeddingReqInput:
|
|
483
493
|
# The modalities of the image data [image, multi-images, video]
|
484
494
|
modalities: Optional[List[str]] = None
|
485
495
|
|
496
|
+
def contains_mm_input(self) -> bool:
|
497
|
+
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
498
|
+
|
486
499
|
def normalize_batch_and_arguments(self):
|
487
500
|
# at least one of text, input_ids, or image should be provided
|
488
501
|
if self.text is None and self.input_ids is None and self.image_data is None:
|