sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -31,11 +31,6 @@ _is_hip = is_hip()
|
|
31
31
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
|
-
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
|
35
|
-
logger.warning(
|
36
|
-
"The following error message 'operation scheduled before its operands' can be ignored."
|
37
|
-
)
|
38
|
-
|
39
34
|
|
40
35
|
_MIN_BLOCK_KV = 32
|
41
36
|
|
@@ -713,7 +708,7 @@ def decode_attention_fwd(
|
|
713
708
|
num_kv_splits,
|
714
709
|
max_kv_splits,
|
715
710
|
sm_scale,
|
716
|
-
logit_cap,
|
711
|
+
logit_cap=logit_cap,
|
717
712
|
)
|
718
713
|
else:
|
719
714
|
# GQA/MQA/MLA
|
@@ -729,5 +724,5 @@ def decode_attention_fwd(
|
|
729
724
|
num_kv_splits,
|
730
725
|
max_kv_splits,
|
731
726
|
sm_scale,
|
732
|
-
logit_cap,
|
727
|
+
logit_cap=logit_cap,
|
733
728
|
)
|
@@ -1,15 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import dataclasses
|
4
|
+
import functools
|
3
5
|
import math
|
4
|
-
from functools import lru_cache
|
5
|
-
from typing import Optional, Tuple
|
6
|
+
from functools import lru_cache
|
7
|
+
from typing import Any, Optional, Tuple, Union
|
6
8
|
|
7
9
|
import torch
|
8
10
|
import torch.nn as nn
|
9
11
|
import torch.nn.functional as F
|
10
12
|
from einops import rearrange
|
11
13
|
|
12
|
-
from sglang.srt.utils import is_cuda
|
14
|
+
from sglang.srt.utils import is_cuda, print_info_once
|
13
15
|
|
14
16
|
_is_cuda = is_cuda()
|
15
17
|
|
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
|
|
29
31
|
from sglang.srt.layers.quantization import QuantizationConfig
|
30
32
|
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
31
33
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
32
|
-
from sglang.srt.utils import add_prefix
|
34
|
+
from sglang.srt.utils import add_prefix
|
33
35
|
|
34
36
|
ROTARY_EMBED_CLASSES = {
|
35
37
|
"normal": apply_rotary_pos_emb,
|
36
38
|
}
|
37
39
|
|
38
40
|
|
39
|
-
|
40
|
-
|
41
|
+
@dataclasses.dataclass
|
42
|
+
class SingletonCache:
|
43
|
+
data: Any = None
|
41
44
|
|
42
|
-
|
43
|
-
|
44
|
-
nonlocal has_run
|
45
|
-
if not has_run:
|
46
|
-
func(*args, **kwargs)
|
47
|
-
has_run = True
|
45
|
+
def set_data(self, value: Any) -> None:
|
46
|
+
self.data = value
|
48
47
|
|
49
|
-
|
48
|
+
def get_data(self) -> Optional[Any]:
|
49
|
+
return self.data
|
50
50
|
|
51
|
+
def empty(self) -> bool:
|
52
|
+
return self.get_data() is None
|
51
53
|
|
52
|
-
|
53
|
-
|
54
|
-
|
54
|
+
|
55
|
+
# TODO: requires real seqlens from images
|
56
|
+
@functools.lru_cache(maxsize=128)
|
57
|
+
def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
|
58
|
+
"""
|
59
|
+
Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
|
60
|
+
Caches the result based on these parameters.
|
61
|
+
"""
|
62
|
+
cu_seqlens = torch.arange(
|
63
|
+
0,
|
64
|
+
(batch_size + 1) * seqlen,
|
65
|
+
step=seqlen,
|
66
|
+
dtype=torch.int32,
|
67
|
+
device=device,
|
68
|
+
)
|
69
|
+
return cu_seqlens
|
55
70
|
|
56
71
|
|
57
72
|
class VisionSdpaAttention(nn.Module):
|
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
|
|
265
280
|
q: torch.Tensor,
|
266
281
|
k: torch.Tensor,
|
267
282
|
v: torch.Tensor,
|
268
|
-
cu_seqlens: Optional[torch.Tensor],
|
269
|
-
|
283
|
+
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
284
|
+
bsz: int,
|
285
|
+
seq_len: int,
|
270
286
|
**kwargs,
|
271
287
|
) -> torch.Tensor:
|
272
288
|
r"""
|
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
|
|
275
291
|
Returns:
|
276
292
|
[b * s, h, head_size]
|
277
293
|
"""
|
278
|
-
cu_seqlens
|
294
|
+
if cu_seqlens is None:
|
295
|
+
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
296
|
+
elif isinstance(cu_seqlens, SingletonCache):
|
297
|
+
if cu_seqlens.empty():
|
298
|
+
cu_seqlens.set_data(
|
299
|
+
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
300
|
+
)
|
301
|
+
cu_seqlens = cu_seqlens.get_data()
|
302
|
+
|
303
|
+
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
|
279
304
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
280
305
|
max_seqlen = seq_lens.max().item()
|
281
306
|
output = flash_attn_varlen_func(
|
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
|
|
346
371
|
if global_server_args_dict["mm_attention_backend"] is None:
|
347
372
|
if qkv_backend is None:
|
348
373
|
qkv_backend = "sdpa"
|
349
|
-
|
374
|
+
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
350
375
|
else:
|
351
376
|
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
352
377
|
|
353
|
-
|
378
|
+
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
354
379
|
|
355
380
|
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
356
381
|
head_dim=self.head_size,
|
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
|
|
423
448
|
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
424
449
|
qkv, _ = self.qkv_proj(x)
|
425
450
|
|
426
|
-
# [s, b, head
|
451
|
+
# [s, b, head, head_dim_sum]
|
427
452
|
new_x_shape = qkv.size()[:-1] + (
|
428
453
|
head,
|
429
|
-
|
454
|
+
self.q_size + 2 * self.kv_size,
|
430
455
|
)
|
431
456
|
qkv = qkv.view(*new_x_shape)
|
432
457
|
|
433
458
|
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
434
|
-
q, k, v =
|
459
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
460
|
+
|
435
461
|
# [s, b, head, head_size] --> [b, s, head, head_size]
|
436
462
|
q, k, v = [
|
437
463
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
|
|
468
494
|
k=k,
|
469
495
|
v=v,
|
470
496
|
bsz=bsz,
|
497
|
+
seq_len=s,
|
471
498
|
cu_seqlens=cu_seqlens,
|
472
499
|
attention_mask=attention_mask,
|
473
500
|
)
|
@@ -226,13 +226,13 @@ class LayerCommunicator:
|
|
226
226
|
|
227
227
|
@dataclass
|
228
228
|
class CommunicateContext:
|
229
|
-
process_group_sizes: Dict[
|
229
|
+
process_group_sizes: Dict[ScatterMode, int]
|
230
230
|
attn_tp_rank: int
|
231
231
|
attn_tp_size: int
|
232
232
|
local_attn_dp_size: int
|
233
233
|
tp_size: int
|
234
234
|
|
235
|
-
def is_same_group_size(self, a:
|
235
|
+
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
236
236
|
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
237
237
|
|
238
238
|
@classmethod
|
@@ -244,6 +244,7 @@ class CommunicateContext:
|
|
244
244
|
process_group_sizes = {
|
245
245
|
ScatterMode.SCATTERED: 1,
|
246
246
|
ScatterMode.TP_ATTN_FULL: attn_tp_size,
|
247
|
+
# TODO: support --moe-dense-tp-size > 1
|
247
248
|
ScatterMode.FULL: tp_size,
|
248
249
|
}
|
249
250
|
return cls(
|
@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
323
324
|
|
324
325
|
if (
|
325
326
|
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
326
|
-
and (
|
327
|
+
and (
|
328
|
+
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
|
329
|
+
)
|
327
330
|
and (hidden_states_output_mode == ScatterMode.FULL)
|
328
331
|
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
|
329
332
|
):
|
330
|
-
return
|
333
|
+
return partial(
|
334
|
+
CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
|
335
|
+
residual_input_mode=residual_input_mode,
|
336
|
+
)
|
331
337
|
|
332
338
|
if (
|
333
339
|
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
@@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
360
366
|
return hidden_states, residual
|
361
367
|
|
362
368
|
@staticmethod
|
363
|
-
def
|
369
|
+
def _gather_hidden_states_and_residual(
|
364
370
|
hidden_states: torch.Tensor,
|
365
371
|
residual: torch.Tensor,
|
366
372
|
forward_batch: ForwardBatch,
|
367
373
|
layernorm: torch.nn.Module,
|
368
374
|
context: CommunicateContext,
|
375
|
+
*,
|
376
|
+
residual_input_mode,
|
369
377
|
):
|
378
|
+
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
379
|
+
residual, local_residual = (
|
380
|
+
forward_batch.gathered_buffer[
|
381
|
+
: forward_batch.input_ids.shape[0]
|
382
|
+
].clone(),
|
383
|
+
residual,
|
384
|
+
)
|
385
|
+
attn_tp_all_gather(
|
386
|
+
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
387
|
+
)
|
370
388
|
if context.local_attn_dp_size != 1:
|
371
389
|
if context.attn_tp_rank == 0:
|
372
390
|
hidden_states += residual
|
sglang/srt/layers/linear.py
CHANGED
@@ -546,8 +546,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
546
546
|
param.shard_id.append(loaded_shard_id)
|
547
547
|
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
548
548
|
param.data_container.append(loaded_weight)
|
549
|
-
if len(param.data_container) == 2:
|
550
|
-
self.qweight = param.materialize_nested()
|
551
549
|
return
|
552
550
|
|
553
551
|
param_data = param.data
|
@@ -961,8 +959,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
961
959
|
param.shard_id.append(loaded_shard_id)
|
962
960
|
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
963
961
|
param.data_container.append(loaded_weight)
|
964
|
-
if len(param.data_container) == 3:
|
965
|
-
self.qweight = param.materialize_nested()
|
966
962
|
return
|
967
963
|
|
968
964
|
param_data = param.data
|
@@ -47,18 +47,6 @@ from sglang.srt.utils import dump_to_file
|
|
47
47
|
logger = logging.getLogger(__name__)
|
48
48
|
|
49
49
|
|
50
|
-
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
51
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
52
|
-
from sglang.srt.model_executor.forward_batch_info import (
|
53
|
-
CaptureHiddenMode,
|
54
|
-
ForwardBatch,
|
55
|
-
ForwardMode,
|
56
|
-
)
|
57
|
-
from sglang.srt.utils import dump_to_file
|
58
|
-
|
59
|
-
logger = logging.getLogger(__name__)
|
60
|
-
|
61
|
-
|
62
50
|
@dataclasses.dataclass
|
63
51
|
class LogitsProcessorOutput:
|
64
52
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
@@ -4,6 +4,7 @@ from typing import List, Optional
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
6
|
|
7
|
+
from sglang.math_utils import ceil_div
|
7
8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
8
9
|
from sglang.srt.utils import dispose_tensor, is_cuda
|
9
10
|
|
@@ -15,11 +16,6 @@ if _is_cuda:
|
|
15
16
|
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
16
17
|
)
|
17
18
|
|
18
|
-
try:
|
19
|
-
from deep_gemm import ceil_div
|
20
|
-
except ImportError:
|
21
|
-
logger.error(f"Failed to import ceil_div from deep_gemm.")
|
22
|
-
|
23
19
|
import triton.language as tl
|
24
20
|
|
25
21
|
|
@@ -278,6 +274,7 @@ def _silu_and_mul_post_quant_kernel(
|
|
278
274
|
fp8_min,
|
279
275
|
BLOCK_N: tl.constexpr,
|
280
276
|
NUM_STAGE: tl.constexpr,
|
277
|
+
SCALE_UE8M0: tl.constexpr,
|
281
278
|
):
|
282
279
|
expert_id = tl.program_id(2)
|
283
280
|
token_id = tl.program_id(1)
|
@@ -319,6 +316,8 @@ def _silu_and_mul_post_quant_kernel(
|
|
319
316
|
gate_up = up * gate
|
320
317
|
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
321
318
|
output_s = _absmax / fp8_max
|
319
|
+
if SCALE_UE8M0:
|
320
|
+
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
322
321
|
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
|
323
322
|
output_ptr.dtype.element_ty
|
324
323
|
)
|
@@ -339,6 +338,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
|
339
338
|
output_scale: torch.Tensor,
|
340
339
|
quant_group_size: int,
|
341
340
|
masked_m: torch.Tensor,
|
341
|
+
scale_ue8m0: bool = False,
|
342
342
|
):
|
343
343
|
"""
|
344
344
|
input shape [expert_num, token_num_padded, hidden_dim]
|
@@ -395,6 +395,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
|
395
395
|
BLOCK_N=BLOCK_N,
|
396
396
|
NUM_STAGE=NUM_STAGES,
|
397
397
|
num_warps=num_warps,
|
398
|
+
SCALE_UE8M0=scale_ue8m0,
|
398
399
|
)
|
399
400
|
return
|
400
401
|
|
@@ -1,30 +1,11 @@
|
|
1
1
|
import logging
|
2
2
|
from typing import Callable, List, Optional, Tuple
|
3
3
|
|
4
|
+
import einops
|
4
5
|
import torch
|
6
|
+
from sgl_kernel import silu_and_mul
|
5
7
|
from torch.nn import Module
|
6
8
|
|
7
|
-
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
8
|
-
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
9
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
10
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
11
|
-
|
12
|
-
try:
|
13
|
-
from deep_gemm import (
|
14
|
-
get_col_major_tma_aligned_tensor,
|
15
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
16
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
17
|
-
)
|
18
|
-
from sgl_kernel import silu_and_mul
|
19
|
-
|
20
|
-
from sglang.srt.layers.quantization.fp8_kernel import (
|
21
|
-
sglang_per_token_group_quant_fp8,
|
22
|
-
)
|
23
|
-
|
24
|
-
use_deep_gemm = True
|
25
|
-
except ImportError:
|
26
|
-
use_deep_gemm = False
|
27
|
-
|
28
9
|
from sglang.srt.custom_op import CustomOp
|
29
10
|
from sglang.srt.distributed import (
|
30
11
|
get_tensor_model_parallel_rank,
|
@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
45
26
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
46
27
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
47
28
|
from sglang.srt.layers.moe.topk import select_experts
|
29
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
48
30
|
from sglang.srt.layers.quantization.base_config import (
|
49
31
|
QuantizationConfig,
|
50
32
|
QuantizeMethodBase,
|
@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
|
|
52
34
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
53
35
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
54
36
|
scaled_fp8_quant,
|
37
|
+
sglang_per_token_group_quant_fp8,
|
55
38
|
sglang_per_token_quant_fp8,
|
56
39
|
)
|
40
|
+
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
41
|
+
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
42
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
57
43
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
58
|
-
from sglang.srt.utils import
|
44
|
+
from sglang.srt.utils import (
|
45
|
+
DeepEPMode,
|
46
|
+
dispose_tensor,
|
47
|
+
get_bool_env_var,
|
48
|
+
is_hip,
|
49
|
+
set_weight_attrs,
|
50
|
+
)
|
59
51
|
|
60
52
|
_is_hip = is_hip()
|
61
53
|
|
@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
680
672
|
params_dtype: torch.dtype,
|
681
673
|
**extra_weight_attrs,
|
682
674
|
):
|
683
|
-
|
684
675
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
685
676
|
params_dtype = torch.float8_e4m3fn
|
686
677
|
|
@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
|
|
920
911
|
)
|
921
912
|
self.deepep_mode = deepep_mode
|
922
913
|
if self.deepep_mode.enable_low_latency():
|
923
|
-
assert
|
914
|
+
assert (
|
915
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
916
|
+
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
924
917
|
self.w13_weight_fp8 = (
|
925
918
|
self.w13_weight,
|
926
919
|
(
|
@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
|
|
948
941
|
):
|
949
942
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
950
943
|
if resolved_deepep_mode == DeepEPMode.normal:
|
951
|
-
if
|
944
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
952
945
|
return self.forward_deepgemm_contiguous(
|
953
946
|
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
954
947
|
)
|
@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
|
|
1145
1138
|
dtype=torch.bfloat16,
|
1146
1139
|
)
|
1147
1140
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1148
|
-
|
1141
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1149
1142
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1150
1143
|
)
|
1151
1144
|
del input_tensor
|
@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
|
|
1169
1162
|
)
|
1170
1163
|
del down_input
|
1171
1164
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
1172
|
-
|
1165
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1173
1166
|
(down_input_fp8, down_input_scale),
|
1174
1167
|
self.w2_weight_fp8,
|
1175
1168
|
down_output,
|
@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
|
|
1202
1195
|
gateup_output = torch.empty(
|
1203
1196
|
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
1204
1197
|
)
|
1205
|
-
|
1206
|
-
hidden_states_fp8,
|
1198
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1199
|
+
hidden_states_fp8,
|
1200
|
+
self.w13_weight_fp8,
|
1201
|
+
gateup_output,
|
1202
|
+
masked_m,
|
1203
|
+
expected_m,
|
1204
|
+
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
1207
1205
|
)
|
1208
1206
|
dispose_tensor(hidden_states_fp8[0])
|
1209
1207
|
|
@@ -1233,6 +1231,7 @@ class DeepEPMoE(EPMoE):
|
|
1233
1231
|
down_input_scale,
|
1234
1232
|
scale_block_size,
|
1235
1233
|
masked_m,
|
1234
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1236
1235
|
)
|
1237
1236
|
del gateup_output
|
1238
1237
|
|
@@ -1240,13 +1239,24 @@ class DeepEPMoE(EPMoE):
|
|
1240
1239
|
n = self.w2_weight.size(1)
|
1241
1240
|
down_input_fp8 = (
|
1242
1241
|
down_input,
|
1243
|
-
|
1242
|
+
(
|
1243
|
+
down_input_scale
|
1244
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1245
|
+
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
1246
|
+
down_input_scale
|
1247
|
+
)
|
1248
|
+
),
|
1244
1249
|
)
|
1245
1250
|
down_output = torch.empty(
|
1246
1251
|
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
1247
1252
|
)
|
1248
|
-
|
1249
|
-
down_input_fp8,
|
1253
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1254
|
+
down_input_fp8,
|
1255
|
+
self.w2_weight_fp8,
|
1256
|
+
down_output,
|
1257
|
+
masked_m,
|
1258
|
+
expected_m,
|
1259
|
+
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
1250
1260
|
)
|
1251
1261
|
|
1252
1262
|
return down_output
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from dataclasses import dataclass
|
3
3
|
|
4
|
-
from sglang.srt.layers.quantization
|
4
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
5
5
|
from sglang.srt.managers.expert_distribution import (
|
6
6
|
get_global_expert_distribution_recorder,
|
7
7
|
)
|
@@ -107,6 +107,8 @@ class DeepEPBuffer:
|
|
107
107
|
num_rdma_bytes,
|
108
108
|
low_latency_mode=deepep_mode.enable_low_latency(),
|
109
109
|
num_qps_per_rank=num_qps_per_rank,
|
110
|
+
# TODO can be false when unneeded
|
111
|
+
allow_mnnvl=True,
|
110
112
|
)
|
111
113
|
return cls._buffer
|
112
114
|
|
@@ -234,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
234
236
|
topk_weights: torch.Tensor,
|
235
237
|
):
|
236
238
|
topk_idx = topk_idx.to(torch.int64)
|
237
|
-
if
|
239
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
238
240
|
# TODO hard code 128 block quant,use fp8 communication
|
239
241
|
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
|
240
242
|
previous_event = Buffer.capture() if self.async_finish else None
|
241
243
|
return hidden_states, topk_idx, topk_weights, previous_event
|
242
244
|
|
243
245
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
244
|
-
if
|
246
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
245
247
|
(
|
246
248
|
hidden_states,
|
247
249
|
topk_idx,
|
@@ -343,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
343
345
|
previous_event=previous_event,
|
344
346
|
async_finish=self.async_finish,
|
345
347
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
346
|
-
expert_alignment=128 if
|
348
|
+
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
|
347
349
|
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
348
350
|
)
|
349
351
|
|
@@ -407,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
407
409
|
topk_idx: torch.Tensor,
|
408
410
|
topk_weights: torch.Tensor,
|
409
411
|
):
|
410
|
-
if
|
412
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
411
413
|
output = hidden_states
|
412
414
|
else:
|
413
415
|
if hidden_states.shape[0] > 0:
|
@@ -540,38 +542,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
540
542
|
topk_idx: torch.Tensor,
|
541
543
|
use_fp8: bool = False,
|
542
544
|
):
|
543
|
-
"""
|
544
|
-
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
|
545
|
-
# Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
|
546
|
-
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
547
|
-
|
548
|
-
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
549
|
-
index 76ae2e2..8ecd08f 100644
|
550
|
-
--- a/csrc/kernels/internode_ll.cu
|
551
|
-
+++ b/csrc/kernels/internode_ll.cu
|
552
|
-
@@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
553
|
-
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
554
|
-
void* workspace, cudaStream_t stream, int phases) {
|
555
|
-
constexpr int kNumMaxTopK = 9;
|
556
|
-
- constexpr int kNumWarpsPerGroup = 10;
|
557
|
-
- constexpr int kNumWarpGroups = 3;
|
558
|
-
+ constexpr int kNumWarpsPerGroup = 8;
|
559
|
-
+ constexpr int kNumWarpGroups = 4;
|
560
|
-
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
561
|
-
|
562
|
-
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
563
|
-
@@ -501,8 +501,8 @@ void combine(void* combined_x,
|
564
|
-
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
565
|
-
int num_topk, int num_experts, int rank, int num_ranks,
|
566
|
-
void* workspace, cudaStream_t stream, int phases) {
|
567
|
-
- constexpr int kNumWarpsPerGroup = 10;
|
568
|
-
- constexpr int kNumWarpGroups = 3;
|
569
|
-
+ constexpr int kNumWarpsPerGroup = 8;
|
570
|
-
+ constexpr int kNumWarpGroups = 4;
|
571
|
-
constexpr int kNumMaxTopk = 9;
|
572
|
-
|
573
|
-
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
574
|
-
"""
|
575
545
|
buffer = self._get_buffer()
|
576
546
|
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
577
547
|
buffer.low_latency_dispatch(
|
@@ -582,6 +552,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
582
552
|
use_fp8=use_fp8,
|
583
553
|
async_finish=not self.return_recv_hook,
|
584
554
|
return_recv_hook=self.return_recv_hook,
|
555
|
+
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
556
|
+
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
557
|
+
use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
558
|
+
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
585
559
|
)
|
586
560
|
)
|
587
561
|
return packed_recv_hidden, packed_recv_count, event, hook
|
@@ -12,6 +12,7 @@ import torch
|
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
|
+
from sglang.math_utils import ceil_div
|
15
16
|
from sglang.srt.layers.moe.topk import select_experts
|
16
17
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
18
|
per_token_group_quant_fp8,
|
@@ -518,10 +519,6 @@ def fused_moe_kernel(
|
|
518
519
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
519
520
|
|
520
521
|
|
521
|
-
def ceil_div(a, b):
|
522
|
-
return (a + b - 1) // b
|
523
|
-
|
524
|
-
|
525
522
|
@triton.jit
|
526
523
|
def moe_align_block_size_stage1(
|
527
524
|
topk_ids_ptr,
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region(
|
|
249
249
|
topk_ids[indices >= num_token_non_padded, :] = -1
|
250
250
|
|
251
251
|
|
252
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
253
|
+
def _biased_grouped_topk_postprocess(
|
254
|
+
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
255
|
+
):
|
256
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
257
|
+
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
258
|
+
return topk_ids
|
259
|
+
|
260
|
+
|
252
261
|
def biased_grouped_topk(
|
253
262
|
hidden_states: torch.Tensor,
|
254
263
|
gating_output: torch.Tensor,
|
@@ -282,14 +291,13 @@ def biased_grouped_topk(
|
|
282
291
|
num_fused_shared_experts,
|
283
292
|
routed_scaling_factor,
|
284
293
|
)
|
285
|
-
# TODO merge into kernel
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
)(topk_ids, num_token_non_padded)
|
294
|
+
# TODO merge into kernel
|
295
|
+
if (expert_location_dispatch_info is not None) or (
|
296
|
+
num_token_non_padded is not None
|
297
|
+
):
|
298
|
+
topk_ids = _biased_grouped_topk_postprocess(
|
299
|
+
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
300
|
+
)
|
293
301
|
return topk_weights, topk_ids
|
294
302
|
else:
|
295
303
|
biased_grouped_topk_fn = (
|