sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,451 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from enum import Enum, auto
|
17
|
+
from typing import Dict, Optional, Tuple
|
18
|
+
|
19
|
+
import torch.distributed
|
20
|
+
|
21
|
+
from sglang.srt.distributed import (
|
22
|
+
get_tensor_model_parallel_world_size,
|
23
|
+
tensor_model_parallel_all_reduce,
|
24
|
+
)
|
25
|
+
from sglang.srt.layers.dp_attention import (
|
26
|
+
attn_tp_all_gather,
|
27
|
+
attn_tp_reduce_scatter,
|
28
|
+
dp_gather_partial,
|
29
|
+
dp_scatter,
|
30
|
+
get_attention_tp_rank,
|
31
|
+
get_attention_tp_size,
|
32
|
+
get_local_attention_dp_size,
|
33
|
+
)
|
34
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
36
|
+
|
37
|
+
|
38
|
+
class ScatterMode(Enum):
|
39
|
+
SCATTERED = auto()
|
40
|
+
TP_ATTN_FULL = auto()
|
41
|
+
FULL = auto()
|
42
|
+
|
43
|
+
|
44
|
+
@dataclass
|
45
|
+
class _LayerModeComputationContext:
|
46
|
+
num_layers: int
|
47
|
+
layer_id: int
|
48
|
+
is_layer_sparse: bool
|
49
|
+
is_previous_layer_sparse: Optional[bool]
|
50
|
+
|
51
|
+
def previous_layer(self):
|
52
|
+
assert self.is_previous_layer_sparse is not None
|
53
|
+
return _LayerModeComputationContext(
|
54
|
+
layer_id=self.layer_id - 1,
|
55
|
+
is_layer_sparse=self.is_previous_layer_sparse,
|
56
|
+
is_previous_layer_sparse=None,
|
57
|
+
num_layers=self.num_layers,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
@dataclass
|
62
|
+
class LayerScatterModes:
|
63
|
+
layer_input_mode: ScatterMode
|
64
|
+
attn_mode: ScatterMode
|
65
|
+
# Can be further split into e.g. mlp_input_mode and mlp_output_mode if needed
|
66
|
+
mlp_mode: ScatterMode
|
67
|
+
middle_residual_mode: ScatterMode
|
68
|
+
layer_output_mode: ScatterMode
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def init_new(cls, **kwargs):
|
72
|
+
context = _LayerModeComputationContext(**kwargs)
|
73
|
+
return cls(
|
74
|
+
layer_input_mode=cls._compute_layer_input_mode(context),
|
75
|
+
attn_mode=ScatterMode.TP_ATTN_FULL,
|
76
|
+
mlp_mode=cls._compute_mlp_mode(context),
|
77
|
+
middle_residual_mode=cls._compute_middle_residual_mode(context),
|
78
|
+
layer_output_mode=cls._compute_layer_output_mode(context),
|
79
|
+
)
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
|
83
|
+
if context.layer_id == 0:
|
84
|
+
return ScatterMode.TP_ATTN_FULL
|
85
|
+
return cls._compute_layer_output_mode(context.previous_layer())
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
|
89
|
+
if context.is_layer_sparse:
|
90
|
+
return (
|
91
|
+
ScatterMode.SCATTERED
|
92
|
+
if global_server_args_dict["enable_deepep_moe"]
|
93
|
+
else ScatterMode.FULL
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
return (
|
97
|
+
ScatterMode.SCATTERED
|
98
|
+
if enable_moe_dense_fully_dp()
|
99
|
+
else ScatterMode.FULL
|
100
|
+
)
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def _compute_middle_residual_mode(cls, context: _LayerModeComputationContext):
|
104
|
+
mlp_mode = cls._compute_mlp_mode(context)
|
105
|
+
if mlp_mode == ScatterMode.SCATTERED:
|
106
|
+
return ScatterMode.SCATTERED
|
107
|
+
if mlp_mode == ScatterMode.FULL:
|
108
|
+
return ScatterMode.TP_ATTN_FULL
|
109
|
+
raise NotImplementedError
|
110
|
+
|
111
|
+
@classmethod
|
112
|
+
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
|
113
|
+
mlp_mode = cls._compute_mlp_mode(context)
|
114
|
+
if context.layer_id == context.num_layers - 1:
|
115
|
+
return ScatterMode.TP_ATTN_FULL
|
116
|
+
if mlp_mode == ScatterMode.SCATTERED:
|
117
|
+
return ScatterMode.SCATTERED
|
118
|
+
if mlp_mode == ScatterMode.FULL:
|
119
|
+
return ScatterMode.TP_ATTN_FULL
|
120
|
+
raise NotImplementedError
|
121
|
+
|
122
|
+
|
123
|
+
def enable_moe_dense_fully_dp():
|
124
|
+
return global_server_args_dict["moe_dense_tp_size"] == 1
|
125
|
+
|
126
|
+
|
127
|
+
class LayerCommunicator:
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
layer_scatter_modes: LayerScatterModes,
|
131
|
+
input_layernorm: torch.nn.Module,
|
132
|
+
post_attention_layernorm: torch.nn.Module,
|
133
|
+
):
|
134
|
+
self.layer_scatter_modes = layer_scatter_modes
|
135
|
+
self.input_layernorm = input_layernorm
|
136
|
+
self.post_attention_layernorm = post_attention_layernorm
|
137
|
+
|
138
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
139
|
+
self.attn_tp_size = get_attention_tp_size()
|
140
|
+
self.local_attn_dp_size = get_local_attention_dp_size()
|
141
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
142
|
+
self.process_group_sizes = {
|
143
|
+
ScatterMode.SCATTERED: 1,
|
144
|
+
ScatterMode.TP_ATTN_FULL: self.attn_tp_size,
|
145
|
+
ScatterMode.FULL: self.tp_size,
|
146
|
+
}
|
147
|
+
|
148
|
+
def prepare_attn(
|
149
|
+
self,
|
150
|
+
hidden_states: torch.Tensor,
|
151
|
+
residual: torch.Tensor,
|
152
|
+
forward_batch: ForwardBatch,
|
153
|
+
):
|
154
|
+
if hidden_states.shape[0] == 0:
|
155
|
+
residual = hidden_states
|
156
|
+
else:
|
157
|
+
if residual is None:
|
158
|
+
residual = hidden_states
|
159
|
+
hidden_states = self.input_layernorm(hidden_states)
|
160
|
+
else:
|
161
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
162
|
+
|
163
|
+
hidden_states = _communicate_simple(
|
164
|
+
hidden_states=hidden_states,
|
165
|
+
forward_batch=forward_batch,
|
166
|
+
input_mode=self.layer_scatter_modes.layer_input_mode,
|
167
|
+
output_mode=self.layer_scatter_modes.attn_mode,
|
168
|
+
context=self._compute_context(forward_batch),
|
169
|
+
)
|
170
|
+
|
171
|
+
return hidden_states, residual
|
172
|
+
|
173
|
+
def prepare_mlp(
|
174
|
+
self,
|
175
|
+
hidden_states: torch.Tensor,
|
176
|
+
residual: torch.Tensor,
|
177
|
+
forward_batch: ForwardBatch,
|
178
|
+
):
|
179
|
+
return _communicate_with_all_reduce_and_layer_norm(
|
180
|
+
hidden_states=hidden_states,
|
181
|
+
residual=residual,
|
182
|
+
forward_batch=forward_batch,
|
183
|
+
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
|
184
|
+
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
|
185
|
+
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
|
186
|
+
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
|
187
|
+
layernorm=self.post_attention_layernorm,
|
188
|
+
context=self._compute_context(forward_batch),
|
189
|
+
)
|
190
|
+
|
191
|
+
def postprocess_layer(
|
192
|
+
self,
|
193
|
+
hidden_states: torch.Tensor,
|
194
|
+
residual: torch.Tensor,
|
195
|
+
forward_batch: ForwardBatch,
|
196
|
+
):
|
197
|
+
return _communicate_summable_tensor_pair(
|
198
|
+
hidden_states=hidden_states,
|
199
|
+
residual=residual,
|
200
|
+
forward_batch=forward_batch,
|
201
|
+
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
|
202
|
+
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
|
203
|
+
output_mode=self.layer_scatter_modes.layer_output_mode,
|
204
|
+
context=self._compute_context(forward_batch),
|
205
|
+
)
|
206
|
+
|
207
|
+
def _compute_context(self, forward_batch: ForwardBatch):
|
208
|
+
return _Context(
|
209
|
+
num_tokens_of_mode=_compute_num_tokens_of_mode(
|
210
|
+
forward_batch,
|
211
|
+
attn_tp_rank=self.attn_tp_rank,
|
212
|
+
attn_tp_size=self.attn_tp_size,
|
213
|
+
),
|
214
|
+
process_group_sizes=self.process_group_sizes,
|
215
|
+
attn_tp_rank=self.attn_tp_rank,
|
216
|
+
attn_tp_size=self.attn_tp_size,
|
217
|
+
local_attn_dp_size=self.local_attn_dp_size,
|
218
|
+
tp_size=self.tp_size,
|
219
|
+
)
|
220
|
+
|
221
|
+
|
222
|
+
def _compute_num_tokens_of_mode(
|
223
|
+
forward_batch: ForwardBatch, attn_tp_rank: int, attn_tp_size: int
|
224
|
+
):
|
225
|
+
tp_attn_full_num_tokens = forward_batch.input_ids.shape[0]
|
226
|
+
return {
|
227
|
+
ScatterMode.SCATTERED: _torch_tensor_split_len(
|
228
|
+
tp_attn_full_num_tokens, attn_tp_size, attn_tp_rank
|
229
|
+
),
|
230
|
+
ScatterMode.TP_ATTN_FULL: tp_attn_full_num_tokens,
|
231
|
+
ScatterMode.FULL: (
|
232
|
+
forward_batch.gathered_buffer.shape[0]
|
233
|
+
if global_server_args_dict["enable_dp_attention"]
|
234
|
+
else forward_batch.input_ids.shape[0]
|
235
|
+
),
|
236
|
+
}
|
237
|
+
|
238
|
+
|
239
|
+
def _torch_tensor_split_len(tensor_len: int, n: int, output_index: int):
|
240
|
+
if output_index < int(tensor_len % n):
|
241
|
+
return int(tensor_len / n) + 1
|
242
|
+
else:
|
243
|
+
return int(tensor_len / n)
|
244
|
+
|
245
|
+
|
246
|
+
@dataclass
|
247
|
+
class _Context:
|
248
|
+
num_tokens_of_mode: Dict["ScatterMode", int]
|
249
|
+
process_group_sizes: Dict["ScatterMode", int]
|
250
|
+
attn_tp_rank: int
|
251
|
+
attn_tp_size: int
|
252
|
+
local_attn_dp_size: int
|
253
|
+
tp_size: int
|
254
|
+
|
255
|
+
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
|
256
|
+
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
257
|
+
|
258
|
+
def check_shape(self, x: torch.Tensor, mode: ScatterMode):
|
259
|
+
if x is None:
|
260
|
+
return
|
261
|
+
|
262
|
+
actual_num_tokens = x.shape[0]
|
263
|
+
expect_num_tokens = self.num_tokens_of_mode[mode]
|
264
|
+
assert (
|
265
|
+
actual_num_tokens == expect_num_tokens
|
266
|
+
), f"{actual_num_tokens=} {expect_num_tokens=} {mode=} {x.shape=} {self.num_tokens_of_mode=} {self.process_group_sizes=}"
|
267
|
+
return x
|
268
|
+
|
269
|
+
def check_shapes(
|
270
|
+
self, xs: Tuple[torch.Tensor, ...], modes: Tuple[ScatterMode, ...]
|
271
|
+
) -> Tuple[torch.Tensor, ...]:
|
272
|
+
return tuple(
|
273
|
+
[self.check_shape(x, mode) for x, mode in zip(xs, modes, strict=True)]
|
274
|
+
)
|
275
|
+
|
276
|
+
|
277
|
+
def _communicate_simple(
|
278
|
+
hidden_states: torch.Tensor,
|
279
|
+
forward_batch: ForwardBatch,
|
280
|
+
input_mode: ScatterMode,
|
281
|
+
output_mode: ScatterMode,
|
282
|
+
context: _Context,
|
283
|
+
) -> torch.Tensor:
|
284
|
+
def _inner():
|
285
|
+
nonlocal hidden_states
|
286
|
+
|
287
|
+
if context.is_same_group_size(input_mode, output_mode):
|
288
|
+
return hidden_states
|
289
|
+
|
290
|
+
if (input_mode == ScatterMode.SCATTERED) and (
|
291
|
+
output_mode == ScatterMode.TP_ATTN_FULL
|
292
|
+
):
|
293
|
+
hidden_states, local_hidden_states = (
|
294
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
295
|
+
hidden_states,
|
296
|
+
)
|
297
|
+
attn_tp_all_gather(
|
298
|
+
list(hidden_states.tensor_split(context.attn_tp_size)),
|
299
|
+
local_hidden_states,
|
300
|
+
)
|
301
|
+
return hidden_states
|
302
|
+
|
303
|
+
raise NotImplementedError(f"{input_mode=} {output_mode=}")
|
304
|
+
|
305
|
+
context.check_shape(hidden_states, input_mode)
|
306
|
+
return context.check_shape(_inner(), output_mode)
|
307
|
+
|
308
|
+
|
309
|
+
def _communicate_with_all_reduce_and_layer_norm(
|
310
|
+
hidden_states: torch.Tensor,
|
311
|
+
residual: torch.Tensor,
|
312
|
+
hidden_states_input_mode: ScatterMode,
|
313
|
+
residual_input_mode: ScatterMode,
|
314
|
+
hidden_states_output_mode: ScatterMode,
|
315
|
+
residual_output_mode: ScatterMode,
|
316
|
+
forward_batch: ForwardBatch,
|
317
|
+
layernorm: torch.nn.Module,
|
318
|
+
context: _Context,
|
319
|
+
):
|
320
|
+
"""Besides communication, needs to
|
321
|
+
1. All reduce in tp_attn_group on hidden_states
|
322
|
+
2. Apply layer norm
|
323
|
+
"""
|
324
|
+
|
325
|
+
def _inner():
|
326
|
+
nonlocal hidden_states, residual
|
327
|
+
|
328
|
+
if (
|
329
|
+
context.is_same_group_size(
|
330
|
+
hidden_states_input_mode, hidden_states_output_mode
|
331
|
+
)
|
332
|
+
and context.is_same_group_size(residual_input_mode, residual_output_mode)
|
333
|
+
and context.attn_tp_size == 1
|
334
|
+
):
|
335
|
+
# TODO move these `if shape != 0` into LayerNorm itself
|
336
|
+
if hidden_states.shape[0] != 0:
|
337
|
+
hidden_states, residual = layernorm(hidden_states, residual)
|
338
|
+
return hidden_states, residual
|
339
|
+
|
340
|
+
if (
|
341
|
+
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
342
|
+
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
343
|
+
and (hidden_states_output_mode == ScatterMode.FULL)
|
344
|
+
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
|
345
|
+
):
|
346
|
+
if context.local_attn_dp_size != 1:
|
347
|
+
if context.attn_tp_rank == 0:
|
348
|
+
hidden_states += residual
|
349
|
+
hidden_states, local_hidden_states = (
|
350
|
+
forward_batch.gathered_buffer,
|
351
|
+
hidden_states,
|
352
|
+
)
|
353
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
354
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
355
|
+
if hidden_states.shape[0] != 0:
|
356
|
+
hidden_states = layernorm(hidden_states)
|
357
|
+
else:
|
358
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
359
|
+
hidden_states, residual = layernorm(hidden_states, residual)
|
360
|
+
return hidden_states, residual
|
361
|
+
|
362
|
+
if (
|
363
|
+
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
364
|
+
and (
|
365
|
+
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
|
366
|
+
)
|
367
|
+
and (hidden_states_output_mode == ScatterMode.SCATTERED)
|
368
|
+
and (residual_output_mode == ScatterMode.SCATTERED)
|
369
|
+
):
|
370
|
+
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
|
371
|
+
hidden_states = tensor_list[context.attn_tp_rank]
|
372
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
373
|
+
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
|
374
|
+
residual = residual.tensor_split(context.attn_tp_size)[
|
375
|
+
context.attn_tp_rank
|
376
|
+
]
|
377
|
+
if hidden_states.shape[0] != 0:
|
378
|
+
hidden_states, residual = layernorm(hidden_states, residual)
|
379
|
+
return hidden_states, residual
|
380
|
+
|
381
|
+
raise NotImplementedError(
|
382
|
+
f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
|
383
|
+
)
|
384
|
+
|
385
|
+
context.check_shapes(
|
386
|
+
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
|
387
|
+
)
|
388
|
+
return context.check_shapes(
|
389
|
+
_inner(), (hidden_states_output_mode, residual_output_mode)
|
390
|
+
)
|
391
|
+
|
392
|
+
|
393
|
+
def _communicate_summable_tensor_pair(
|
394
|
+
hidden_states: torch.Tensor,
|
395
|
+
residual: torch.Tensor,
|
396
|
+
forward_batch: ForwardBatch,
|
397
|
+
hidden_states_input_mode: ScatterMode,
|
398
|
+
residual_input_mode: ScatterMode,
|
399
|
+
output_mode: ScatterMode,
|
400
|
+
context: _Context,
|
401
|
+
):
|
402
|
+
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
|
403
|
+
|
404
|
+
def _inner():
|
405
|
+
nonlocal hidden_states, residual
|
406
|
+
|
407
|
+
if context.is_same_group_size(
|
408
|
+
hidden_states_input_mode, output_mode
|
409
|
+
) and context.is_same_group_size(residual_input_mode, output_mode):
|
410
|
+
return hidden_states, residual
|
411
|
+
|
412
|
+
if (
|
413
|
+
(hidden_states_input_mode == ScatterMode.FULL)
|
414
|
+
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
415
|
+
and (output_mode == ScatterMode.TP_ATTN_FULL)
|
416
|
+
):
|
417
|
+
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
418
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
419
|
+
# be careful about this!
|
420
|
+
hidden_states, global_hidden_states = (
|
421
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
422
|
+
hidden_states,
|
423
|
+
)
|
424
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
425
|
+
return hidden_states, residual
|
426
|
+
|
427
|
+
if (
|
428
|
+
(hidden_states_input_mode == ScatterMode.SCATTERED)
|
429
|
+
and (residual_input_mode == ScatterMode.SCATTERED)
|
430
|
+
and (output_mode == ScatterMode.TP_ATTN_FULL)
|
431
|
+
):
|
432
|
+
hidden_states += residual
|
433
|
+
residual = None
|
434
|
+
hidden_states, local_hidden_states = (
|
435
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
436
|
+
hidden_states,
|
437
|
+
)
|
438
|
+
attn_tp_all_gather(
|
439
|
+
list(hidden_states.tensor_split(context.attn_tp_size)),
|
440
|
+
local_hidden_states,
|
441
|
+
)
|
442
|
+
return hidden_states, residual
|
443
|
+
|
444
|
+
raise NotImplementedError(
|
445
|
+
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
|
446
|
+
)
|
447
|
+
|
448
|
+
context.check_shapes(
|
449
|
+
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
|
450
|
+
)
|
451
|
+
return context.check_shapes(_inner(), (output_mode, output_mode))
|
@@ -142,16 +142,6 @@ def get_local_attention_dp_size():
|
|
142
142
|
return _LOCAL_ATTN_DP_SIZE
|
143
143
|
|
144
144
|
|
145
|
-
def get_local_attention_dp_rank():
|
146
|
-
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
147
|
-
return _LOCAL_ATTN_DP_RANK
|
148
|
-
|
149
|
-
|
150
|
-
def get_local_attention_dp_size():
|
151
|
-
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
152
|
-
return _LOCAL_ATTN_DP_SIZE
|
153
|
-
|
154
|
-
|
155
145
|
@contextmanager
|
156
146
|
def disable_dp_size():
|
157
147
|
"""Patch the tp group temporarily until this function ends.
|
@@ -0,0 +1,207 @@
|
|
1
|
+
"""Cutlass MoE kernel."""
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
8
|
+
|
9
|
+
import torch
|
10
|
+
|
11
|
+
from sglang.srt.utils import is_cuda
|
12
|
+
|
13
|
+
_is_cuda = is_cuda()
|
14
|
+
if _is_cuda:
|
15
|
+
import sgl_kernel
|
16
|
+
from sgl_kernel import (
|
17
|
+
fp8_blockwise_scaled_grouped_mm,
|
18
|
+
prepare_moe_input,
|
19
|
+
silu_and_mul,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
def cutlass_fused_experts(
|
24
|
+
a: torch.Tensor,
|
25
|
+
w1_q: torch.Tensor,
|
26
|
+
w2_q: torch.Tensor,
|
27
|
+
w1_scale: torch.Tensor,
|
28
|
+
w2_scale: torch.Tensor,
|
29
|
+
topk_weights: torch.Tensor,
|
30
|
+
topk_ids: torch.Tensor,
|
31
|
+
a1_strides: torch.Tensor,
|
32
|
+
c1_strides: torch.Tensor,
|
33
|
+
a2_strides: torch.Tensor,
|
34
|
+
c2_strides: torch.Tensor,
|
35
|
+
workspace: torch.Tensor,
|
36
|
+
a_ptrs: torch.Tensor,
|
37
|
+
b_ptrs: torch.Tensor,
|
38
|
+
out_ptrs: torch.Tensor,
|
39
|
+
a_scales_ptrs: torch.Tensor,
|
40
|
+
b_scales_ptrs: torch.Tensor,
|
41
|
+
expert_offsets: torch.Tensor,
|
42
|
+
problem_sizes1: torch.Tensor,
|
43
|
+
problem_sizes2: torch.Tensor,
|
44
|
+
use_fp8_blockscale: bool = True,
|
45
|
+
) -> torch.Tensor:
|
46
|
+
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
|
47
|
+
|
48
|
+
This function implements a Mixture of Experts (MoE) layer with a SwiGLU/SiLU
|
49
|
+
activation, leveraging custom kernels likely derived from CUTLASS principles
|
50
|
+
for grouped matrix multiplication (`fp8_blockwise_scaled_grouped_mm`) and
|
51
|
+
data preparation (`prepare_moe_input`, `silu_and_mul`).
|
52
|
+
|
53
|
+
It handles per-token routing, quantizes input activations to FP8 with
|
54
|
+
per-token scales, performs the expert computations using FP8 GEMMs with
|
55
|
+
pre-quantized FP8 weights (per-block scales), applies the SiLU activation,
|
56
|
+
and combines the results weighted by the router scores.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
a (torch.Tensor): Input activations. Shape: `(m, k)`, where `m` is the total
|
60
|
+
number of tokens and `k` is the hidden size. Expected dtype: `torch.half`
|
61
|
+
or `torch.bfloat16`.
|
62
|
+
w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM
|
63
|
+
(up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where
|
64
|
+
`E` is the number of experts, `k` is the hidden size, and `n*2` is the
|
65
|
+
intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`.
|
66
|
+
Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size).
|
67
|
+
w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM
|
68
|
+
(down-projection). Expected shape: `(E, n, k)`, where `n` is half the
|
69
|
+
intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`.
|
70
|
+
Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size).
|
71
|
+
w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales).
|
72
|
+
Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`.
|
73
|
+
w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales).
|
74
|
+
Shape: `(E, num_blocks_k, num_blocks_n)`. Dtype: `torch.float32`.
|
75
|
+
topk_weights (torch.Tensor): Router weights for the selected top-k experts
|
76
|
+
for each token. Shape: `(m, topk)`. Dtype should ideally match `a`.
|
77
|
+
topk_ids (torch.Tensor): Indices of the selected top-k experts for each token.
|
78
|
+
Shape: `(m, topk)`. Dtype: `torch.int32`.
|
79
|
+
a1_strides (torch.Tensor): Stride information for the first GEMM's 'a' input.
|
80
|
+
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
|
81
|
+
Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
|
82
|
+
as it's passed as both a_stride and b_stride in the first call.
|
83
|
+
c1_strides (torch.Tensor): Stride information for the first GEMM's 'c' output.
|
84
|
+
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
|
85
|
+
a2_strides (torch.Tensor): Stride information for the second GEMM's 'a' input.
|
86
|
+
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
|
87
|
+
Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
|
88
|
+
as it's passed as both a_stride and b_stride in the second call.
|
89
|
+
c2_strides (torch.Tensor): Stride information for the second GEMM's 'c' output.
|
90
|
+
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
|
91
|
+
workspace (torch.Tensor): Reusable workspace for the underlying kernel.
|
92
|
+
a_ptrs (torch.Tensor): Pointers container for calculating offsets of the input activations for each expert.
|
93
|
+
b_ptrs (torch.Tensor): Pointers container for calculating offsets of the input weights for each expert.
|
94
|
+
out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert.
|
95
|
+
a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
|
96
|
+
b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
|
97
|
+
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
|
98
|
+
block scaling. Currently, only `True` is supported. Defaults to `True`.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.
|
102
|
+
|
103
|
+
Raises:
|
104
|
+
AssertionError: If input shapes, dtypes, or flags are inconsistent or unsupported.
|
105
|
+
NotImplementedError: If CUDA is not available or `sgl_kernel` is not properly installed.
|
106
|
+
"""
|
107
|
+
assert use_fp8_blockscale, "Only support fp8 blockscale for now"
|
108
|
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
109
|
+
assert w1_q.dtype == torch.float8_e4m3fn
|
110
|
+
assert w2_q.dtype == torch.float8_e4m3fn
|
111
|
+
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
112
|
+
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
113
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
114
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
115
|
+
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
116
|
+
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
117
|
+
assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
118
|
+
|
119
|
+
if is_cuda:
|
120
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
121
|
+
sglang_per_token_group_quant_fp8,
|
122
|
+
)
|
123
|
+
|
124
|
+
out_dtype = a.dtype
|
125
|
+
num_experts = w1_q.size(0)
|
126
|
+
m = a.size(0)
|
127
|
+
k = w1_q.size(1)
|
128
|
+
n = w2_q.size(1)
|
129
|
+
|
130
|
+
topk = topk_ids.size(1)
|
131
|
+
|
132
|
+
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
133
|
+
device = a_q.device
|
134
|
+
|
135
|
+
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
136
|
+
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
137
|
+
|
138
|
+
prepare_moe_input(
|
139
|
+
topk_ids,
|
140
|
+
expert_offsets,
|
141
|
+
problem_sizes1,
|
142
|
+
problem_sizes2,
|
143
|
+
a_map,
|
144
|
+
c_map,
|
145
|
+
num_experts,
|
146
|
+
n,
|
147
|
+
k,
|
148
|
+
)
|
149
|
+
|
150
|
+
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
151
|
+
rep_a1_scales = a1_scale[a_map]
|
152
|
+
|
153
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
154
|
+
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
155
|
+
|
156
|
+
a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
|
157
|
+
w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
|
158
|
+
|
159
|
+
fp8_blockwise_scaled_grouped_mm(
|
160
|
+
c1,
|
161
|
+
a_ptrs,
|
162
|
+
b_ptrs,
|
163
|
+
out_ptrs,
|
164
|
+
a_scales_ptrs,
|
165
|
+
b_scales_ptrs,
|
166
|
+
rep_a_q,
|
167
|
+
w1_q,
|
168
|
+
rep_a1_scales,
|
169
|
+
w1_scale,
|
170
|
+
a1_strides,
|
171
|
+
a1_strides,
|
172
|
+
c1_strides,
|
173
|
+
a_sf_layout,
|
174
|
+
w_sf_layout,
|
175
|
+
problem_sizes1,
|
176
|
+
expert_offsets[:-1],
|
177
|
+
workspace,
|
178
|
+
)
|
179
|
+
|
180
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
181
|
+
silu_and_mul(c1, intermediate)
|
182
|
+
|
183
|
+
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
184
|
+
|
185
|
+
fp8_blockwise_scaled_grouped_mm(
|
186
|
+
c2,
|
187
|
+
a_ptrs,
|
188
|
+
b_ptrs,
|
189
|
+
out_ptrs,
|
190
|
+
a_scales_ptrs,
|
191
|
+
b_scales_ptrs,
|
192
|
+
intemediate_q,
|
193
|
+
w2_q,
|
194
|
+
a2_scale,
|
195
|
+
w2_scale,
|
196
|
+
a2_strides,
|
197
|
+
a2_strides,
|
198
|
+
c2_strides,
|
199
|
+
a_sf_layout,
|
200
|
+
w_sf_layout,
|
201
|
+
problem_sizes2,
|
202
|
+
expert_offsets[:-1],
|
203
|
+
workspace,
|
204
|
+
)
|
205
|
+
return (
|
206
|
+
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
207
|
+
).sum(dim=1)
|