sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
9
|
-
from triton_kernels.matmul_ogs import
|
9
|
+
from triton_kernels.matmul_ogs import (
|
10
|
+
FlexCtx,
|
11
|
+
FnSpecs,
|
12
|
+
FusedActivation,
|
13
|
+
PrecisionConfig,
|
14
|
+
matmul_ogs,
|
15
|
+
)
|
16
|
+
from triton_kernels.numerics import InFlexData
|
10
17
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
11
|
-
|
12
|
-
from sglang.srt.utils import direct_register_custom_op
|
18
|
+
from triton_kernels.swiglu import swiglu_fn
|
13
19
|
|
14
20
|
if TYPE_CHECKING:
|
15
21
|
from sglang.srt.layers.moe.topk import TopKOutput
|
16
22
|
|
17
23
|
|
24
|
+
def quantize(w, dtype, dev, **opt):
|
25
|
+
if dtype == "bf16":
|
26
|
+
return w.to(torch.bfloat16), InFlexData()
|
27
|
+
elif dtype == "fp8":
|
28
|
+
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
|
29
|
+
return (
|
30
|
+
wq,
|
31
|
+
InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
|
32
|
+
MicroscalingCtx(),
|
33
|
+
)
|
34
|
+
else:
|
35
|
+
assert dtype == "mx4", f"{dtype=}"
|
36
|
+
swizzle_mx_scale = opt["swizzle_mx_scale"]
|
37
|
+
swizzle_axis = 2 if swizzle_mx_scale else None
|
38
|
+
w = w.to(torch.bfloat16)
|
39
|
+
w, mx_scales, weight_scale_shape = downcast_to_mxfp(
|
40
|
+
w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
|
41
|
+
)
|
42
|
+
return (
|
43
|
+
w,
|
44
|
+
InFlexData(),
|
45
|
+
MicroscalingCtx(
|
46
|
+
weight_scale=mx_scales,
|
47
|
+
swizzle_mx=swizzle_mx_scale,
|
48
|
+
actual_weight_scale_shape=weight_scale_shape,
|
49
|
+
),
|
50
|
+
)
|
51
|
+
|
52
|
+
|
18
53
|
def triton_kernel_moe_forward(
|
19
54
|
hidden_states: torch.Tensor,
|
20
55
|
w1: torch.Tensor,
|
@@ -146,3 +181,153 @@ def triton_kernel_fused_experts(
|
|
146
181
|
)
|
147
182
|
|
148
183
|
return intermediate_cache3
|
184
|
+
|
185
|
+
|
186
|
+
def triton_kernel_moe_with_bias_forward(
|
187
|
+
hidden_states: torch.Tensor,
|
188
|
+
w1: torch.Tensor,
|
189
|
+
w1_pcg,
|
190
|
+
b1: torch.Tensor,
|
191
|
+
w2: torch.Tensor,
|
192
|
+
w2_pcg,
|
193
|
+
b2: torch.Tensor,
|
194
|
+
topk_output: TopKOutput,
|
195
|
+
inplace: bool = False,
|
196
|
+
activation: str = "silu",
|
197
|
+
use_fp8_w8a8: bool = False,
|
198
|
+
per_channel_quant: bool = False,
|
199
|
+
global_num_experts: int = -1,
|
200
|
+
expert_map: Optional[torch.Tensor] = None,
|
201
|
+
w1_scale: Optional[torch.Tensor] = None,
|
202
|
+
w2_scale: Optional[torch.Tensor] = None,
|
203
|
+
a1_scale: Optional[torch.Tensor] = None,
|
204
|
+
a2_scale: Optional[torch.Tensor] = None,
|
205
|
+
block_shape: Optional[list[int]] = None,
|
206
|
+
activation_alpha: Optional[float] = None,
|
207
|
+
swiglu_limit: Optional[int] = None,
|
208
|
+
) -> torch.Tensor:
|
209
|
+
assert topk_output.format.is_triton_kernel()
|
210
|
+
routing_data, gather_idx, scatter_idx = topk_output
|
211
|
+
|
212
|
+
return triton_kernel_fused_experts_with_bias(
|
213
|
+
hidden_states,
|
214
|
+
w1=w1,
|
215
|
+
w1_pcg=w1_pcg,
|
216
|
+
b1=b1,
|
217
|
+
w2=w2,
|
218
|
+
w2_pcg=w2_pcg,
|
219
|
+
b2=b2,
|
220
|
+
routing_data=routing_data,
|
221
|
+
gather_indx=gather_idx,
|
222
|
+
scatter_indx=scatter_idx,
|
223
|
+
inplace=inplace,
|
224
|
+
activation=activation,
|
225
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
226
|
+
per_channel_quant=per_channel_quant,
|
227
|
+
global_num_experts=global_num_experts,
|
228
|
+
expert_map=expert_map,
|
229
|
+
w1_scale=w1_scale,
|
230
|
+
w2_scale=w2_scale,
|
231
|
+
a1_scale=a1_scale,
|
232
|
+
a2_scale=a2_scale,
|
233
|
+
block_shape=block_shape,
|
234
|
+
activation_alpha=activation_alpha,
|
235
|
+
swiglu_limit=swiglu_limit,
|
236
|
+
)
|
237
|
+
|
238
|
+
|
239
|
+
def triton_kernel_fused_experts_with_bias(
|
240
|
+
hidden_states: torch.Tensor,
|
241
|
+
w1: torch.Tensor,
|
242
|
+
w1_pcg,
|
243
|
+
b1: torch.Tensor,
|
244
|
+
w2: torch.Tensor,
|
245
|
+
w2_pcg,
|
246
|
+
b2: torch.Tensor,
|
247
|
+
routing_data: RoutingData,
|
248
|
+
gather_indx: GatherIndx,
|
249
|
+
scatter_indx: ScatterIndx,
|
250
|
+
inplace: bool = False,
|
251
|
+
activation: str = "silu",
|
252
|
+
use_fp8_w8a8: bool = False,
|
253
|
+
per_channel_quant: bool = False,
|
254
|
+
global_num_experts: int = -1,
|
255
|
+
expert_map: Optional[torch.Tensor] = None,
|
256
|
+
w1_scale: Optional[torch.Tensor] = None,
|
257
|
+
w2_scale: Optional[torch.Tensor] = None,
|
258
|
+
a1_scale: Optional[torch.Tensor] = None,
|
259
|
+
a2_scale: Optional[torch.Tensor] = None,
|
260
|
+
block_shape: Optional[list[int]] = None,
|
261
|
+
activation_alpha: Optional[float] = None,
|
262
|
+
swiglu_limit: Optional[int] = None,
|
263
|
+
) -> torch.Tensor:
|
264
|
+
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
|
265
|
+
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
266
|
+
assert per_channel_quant == False, "per_channel_quant is not supported"
|
267
|
+
assert expert_map == None, "expert_map is not supported"
|
268
|
+
assert w1_scale == None, "w1_scale is not supported"
|
269
|
+
assert w2_scale == None, "w2_scale is not supported"
|
270
|
+
assert a1_scale == None, "a1_scale is not supported"
|
271
|
+
assert a2_scale == None, "a2_scale is not supported"
|
272
|
+
assert block_shape == None, "block_shape is not supported"
|
273
|
+
|
274
|
+
# type check
|
275
|
+
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
276
|
+
for w in (w1, w2):
|
277
|
+
# TODO assert bf16 or mxfp4
|
278
|
+
# assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}"
|
279
|
+
pass
|
280
|
+
|
281
|
+
# Shape check
|
282
|
+
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
283
|
+
assert (
|
284
|
+
hidden_states.shape[-1] == w1.shape[-2]
|
285
|
+
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
286
|
+
assert (
|
287
|
+
w2.shape[-1] == w1.shape[1]
|
288
|
+
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
289
|
+
|
290
|
+
# feature check
|
291
|
+
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
292
|
+
|
293
|
+
E, _, _ = w1.shape
|
294
|
+
|
295
|
+
if global_num_experts == -1:
|
296
|
+
global_num_experts = E
|
297
|
+
|
298
|
+
# TODO maybe completely remove this branch
|
299
|
+
if w1.dtype == torch.bfloat16:
|
300
|
+
device = "cuda"
|
301
|
+
optg = dict()
|
302
|
+
w1, w1_flex = quantize(w1, "bf16", device, **optg)
|
303
|
+
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
|
304
|
+
|
305
|
+
w2, w2_flex = quantize(w2, "bf16", device, **optg)
|
306
|
+
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
|
307
|
+
|
308
|
+
act = FusedActivation(
|
309
|
+
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
310
|
+
(activation_alpha, swiglu_limit),
|
311
|
+
2,
|
312
|
+
)
|
313
|
+
|
314
|
+
intermediate_cache = matmul_ogs(
|
315
|
+
hidden_states,
|
316
|
+
w1,
|
317
|
+
b1,
|
318
|
+
routing_data,
|
319
|
+
gather_indx=gather_indx,
|
320
|
+
precision_config=w1_pcg,
|
321
|
+
gammas=None,
|
322
|
+
fused_activation=act,
|
323
|
+
)
|
324
|
+
|
325
|
+
return matmul_ogs(
|
326
|
+
intermediate_cache,
|
327
|
+
w2,
|
328
|
+
b2,
|
329
|
+
routing_data,
|
330
|
+
scatter_indx=scatter_indx,
|
331
|
+
precision_config=w2_pcg,
|
332
|
+
gammas=routing_data.gate_scal,
|
333
|
+
)
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -185,8 +185,9 @@ class TopK(CustomOp):
|
|
185
185
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
186
186
|
) -> TopKOutput:
|
187
187
|
if self.use_triton_kernels:
|
188
|
+
# renormalize=True is equivalent to sm_first=False
|
188
189
|
routing_data, gather_idx, scatter_idx = routing(
|
189
|
-
router_logits, self.top_k, self.renormalize
|
190
|
+
router_logits, self.top_k, sm_first=not self.renormalize
|
190
191
|
)
|
191
192
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
192
193
|
else:
|
@@ -397,8 +398,12 @@ def grouped_topk_gpu(
|
|
397
398
|
.reshape(num_token, -1)
|
398
399
|
) # [n, e]
|
399
400
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
401
|
+
# TODO: NPU can't support directly evaluating a comparison for now
|
400
402
|
topk_weights, topk_ids = torch.topk(
|
401
|
-
tmp_scores,
|
403
|
+
tmp_scores,
|
404
|
+
k=topk,
|
405
|
+
dim=-1,
|
406
|
+
sorted=(True if num_fused_shared_experts > 0 else False),
|
402
407
|
)
|
403
408
|
if num_fused_shared_experts:
|
404
409
|
topk_ids[:, -1] = torch.randint(
|
@@ -488,8 +493,12 @@ def biased_grouped_topk_impl(
|
|
488
493
|
tmp_scores = scores_for_choice.masked_fill(
|
489
494
|
~score_mask.bool(), float("-inf")
|
490
495
|
) # [n, e]
|
496
|
+
# TODO: NPU can't support directly evaluating a comparison for now
|
491
497
|
_, topk_ids = torch.topk(
|
492
|
-
tmp_scores,
|
498
|
+
tmp_scores,
|
499
|
+
k=topk,
|
500
|
+
dim=-1,
|
501
|
+
sorted=(True if num_fused_shared_experts > 0 else False),
|
493
502
|
)
|
494
503
|
topk_weights = scores.gather(1, topk_ids)
|
495
504
|
|
sglang/srt/layers/moe/utils.py
CHANGED
@@ -1,4 +1,20 @@
|
|
1
|
+
import importlib.util
|
1
2
|
from enum import Enum
|
3
|
+
from functools import lru_cache
|
4
|
+
|
5
|
+
from packaging import version as pkg_version
|
6
|
+
|
7
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
8
|
+
|
9
|
+
|
10
|
+
@lru_cache(maxsize=1)
|
11
|
+
def should_use_flashinfer_trtllm_moe():
|
12
|
+
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
13
|
+
not importlib.util.find_spec("flashinfer")
|
14
|
+
or pkg_version.parse(__import__("flashinfer").__version__)
|
15
|
+
>= pkg_version.parse("0.2.9rc1")
|
16
|
+
)
|
17
|
+
return result
|
2
18
|
|
3
19
|
|
4
20
|
class MoeA2ABackend(Enum):
|
@@ -47,6 +47,12 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|
47
47
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
48
48
|
CompressedTensorsConfig,
|
49
49
|
)
|
50
|
+
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
51
|
+
|
52
|
+
is_mxfp_supported = mxfp_supported()
|
53
|
+
if is_mxfp_supported:
|
54
|
+
from sglang.srt.layers.quantization.fp4 import MxFp4Config
|
55
|
+
|
50
56
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
51
57
|
from sglang.srt.layers.quantization.gptq import (
|
52
58
|
GPTQConfig,
|
@@ -60,6 +66,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|
60
66
|
ModelOptFp8Config,
|
61
67
|
)
|
62
68
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
69
|
+
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
|
63
70
|
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
64
71
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
65
72
|
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
@@ -85,6 +92,21 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
85
92
|
"petit_nvfp4": PetitNvFp4Config,
|
86
93
|
}
|
87
94
|
|
95
|
+
|
96
|
+
if is_cuda():
|
97
|
+
BASE_QUANTIZATION_METHODS.update(
|
98
|
+
{
|
99
|
+
"quark": Mxfp4Config,
|
100
|
+
"mxfp4": Mxfp4Config,
|
101
|
+
}
|
102
|
+
)
|
103
|
+
elif is_mxfp_supported and is_hip():
|
104
|
+
BASE_QUANTIZATION_METHODS.update(
|
105
|
+
{
|
106
|
+
"quark": MxFp4Config,
|
107
|
+
"mxfp4": MxFp4Config,
|
108
|
+
}
|
109
|
+
)
|
88
110
|
# VLLM-dependent quantization methods
|
89
111
|
VLLM_QUANTIZATION_METHODS = {
|
90
112
|
"aqlm": AQLMConfig,
|