sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import functools
|
|
4
4
|
import logging
|
5
5
|
from contextlib import contextmanager
|
6
6
|
from enum import IntEnum, auto
|
7
|
-
from typing import TYPE_CHECKING, List, Tuple
|
7
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
8
8
|
|
9
9
|
import torch
|
10
10
|
import triton
|
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
|
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
20
|
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from sglang.srt.configs.model_config import ModelConfig
|
23
|
+
from sglang.srt.server_args import ServerArgs
|
24
|
+
|
21
25
|
logger = logging.getLogger(__name__)
|
22
26
|
|
23
27
|
if TYPE_CHECKING:
|
24
28
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
25
29
|
|
26
|
-
_ATTN_TP_GROUP = None
|
27
|
-
_ATTN_TP_RANK = None
|
28
|
-
_ATTN_TP_SIZE = None
|
29
|
-
_ATTN_DP_RANK = None
|
30
|
-
_ATTN_DP_SIZE = None
|
31
|
-
_LOCAL_ATTN_DP_SIZE = None
|
32
|
-
_LOCAL_ATTN_DP_RANK = None
|
30
|
+
_ATTN_TP_GROUP: Optional[GroupCoordinator] = None
|
31
|
+
_ATTN_TP_RANK: Optional[int] = None
|
32
|
+
_ATTN_TP_SIZE: Optional[int] = None
|
33
|
+
_ATTN_DP_RANK: Optional[int] = None
|
34
|
+
_ATTN_DP_SIZE: Optional[int] = None
|
35
|
+
_LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
36
|
+
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
37
|
+
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
33
38
|
|
34
39
|
|
35
|
-
class
|
40
|
+
class DpPaddingMode(IntEnum):
|
36
41
|
|
37
42
|
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
38
43
|
MAX_LEN = auto()
|
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
|
|
40
45
|
SUM_LEN = auto()
|
41
46
|
|
42
47
|
def is_max_len(self):
|
43
|
-
return self ==
|
48
|
+
return self == DpPaddingMode.MAX_LEN
|
44
49
|
|
45
50
|
def is_sum_len(self):
|
46
|
-
return self ==
|
51
|
+
return self == DpPaddingMode.SUM_LEN
|
47
52
|
|
48
53
|
@classmethod
|
49
|
-
def get_dp_padding_mode(cls, global_num_tokens: List[int]) ->
|
54
|
+
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
|
50
55
|
# we choose the mode that minimizes the communication cost
|
51
56
|
max_len = max(global_num_tokens)
|
52
57
|
sum_len = sum(global_num_tokens)
|
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
|
|
56
61
|
return cls.SUM_LEN
|
57
62
|
|
58
63
|
@classmethod
|
59
|
-
def get_default_mode_in_cuda_graph(cls) ->
|
64
|
+
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
60
65
|
return cls.MAX_LEN
|
61
66
|
|
62
67
|
|
68
|
+
class _DpGatheredBufferWrapper:
|
69
|
+
|
70
|
+
_hidden_size: int
|
71
|
+
_dtype: torch.dtype
|
72
|
+
_device: torch.device
|
73
|
+
_global_dp_buffer_len: int
|
74
|
+
_local_dp_buffer_len: int
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
78
|
+
cls._hidden_size = hidden_size
|
79
|
+
cls._dtype = dtype
|
80
|
+
cls._device = device
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
|
84
|
+
cls._global_dp_buffer_len = global_dp_buffer_len
|
85
|
+
cls._local_dp_buffer_len = local_dp_buffer_len
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def get_global_dp_buffer(cls) -> torch.Tensor:
|
89
|
+
return torch.empty(
|
90
|
+
(cls._global_dp_buffer_len, cls._hidden_size),
|
91
|
+
dtype=cls._dtype,
|
92
|
+
device=cls._device,
|
93
|
+
)
|
94
|
+
|
95
|
+
@classmethod
|
96
|
+
def get_local_dp_buffer(cls) -> torch.Tensor:
|
97
|
+
return torch.empty(
|
98
|
+
(cls._local_dp_buffer_len, cls._hidden_size),
|
99
|
+
dtype=cls._dtype,
|
100
|
+
device=cls._device,
|
101
|
+
)
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
def get_global_dp_buffer_len(cls) -> int:
|
105
|
+
return cls._global_dp_buffer_len
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def get_local_dp_buffer_len(cls) -> int:
|
109
|
+
return cls._local_dp_buffer_len
|
110
|
+
|
111
|
+
|
112
|
+
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
|
113
|
+
_DpGatheredBufferWrapper.set_dp_buffer_len(
|
114
|
+
global_dp_buffer_len, local_dp_buffer_len
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
def get_global_dp_buffer() -> torch.Tensor:
|
119
|
+
return _DpGatheredBufferWrapper.get_global_dp_buffer()
|
120
|
+
|
121
|
+
|
122
|
+
def get_local_dp_buffer() -> torch.Tensor:
|
123
|
+
return _DpGatheredBufferWrapper.get_local_dp_buffer()
|
124
|
+
|
125
|
+
|
126
|
+
def get_global_dp_buffer_len() -> int:
|
127
|
+
return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
|
128
|
+
|
129
|
+
|
130
|
+
def get_local_dp_buffer_len() -> int:
|
131
|
+
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
|
132
|
+
|
133
|
+
|
63
134
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
64
135
|
if not enable_dp_attention:
|
65
136
|
return tp_rank, tp_size, 0
|
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
|
|
89
160
|
|
90
161
|
|
91
162
|
def initialize_dp_attention(
|
92
|
-
|
93
|
-
|
94
|
-
tp_size: int,
|
95
|
-
dp_size: int,
|
96
|
-
moe_dense_tp_size: int,
|
97
|
-
pp_size: int,
|
163
|
+
server_args: ServerArgs,
|
164
|
+
model_config: ModelConfig,
|
98
165
|
):
|
99
166
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
100
|
-
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
167
|
+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
|
101
168
|
|
102
169
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
103
170
|
|
171
|
+
enable_dp_attention = server_args.enable_dp_attention
|
172
|
+
tp_size = server_args.tp_size
|
173
|
+
dp_size = server_args.dp_size
|
174
|
+
moe_dense_tp_size = server_args.moe_dense_tp_size
|
175
|
+
pp_size = server_args.pp_size
|
176
|
+
|
177
|
+
tp_rank = get_tensor_model_parallel_rank()
|
178
|
+
|
179
|
+
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
|
180
|
+
|
104
181
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
105
182
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
106
183
|
)
|
@@ -135,38 +212,48 @@ def initialize_dp_attention(
|
|
135
212
|
group_name="attention_tp",
|
136
213
|
)
|
137
214
|
|
215
|
+
_DpGatheredBufferWrapper.set_metadata(
|
216
|
+
hidden_size=model_config.hidden_size,
|
217
|
+
dtype=model_config.dtype,
|
218
|
+
device=torch.device("cuda"),
|
219
|
+
)
|
138
220
|
|
139
|
-
|
221
|
+
|
222
|
+
def is_dp_attention_enabled() -> bool:
|
223
|
+
return _ENABLE_DP_ATTENTION_FLAG
|
224
|
+
|
225
|
+
|
226
|
+
def get_attention_tp_group() -> GroupCoordinator:
|
140
227
|
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
141
228
|
return _ATTN_TP_GROUP
|
142
229
|
|
143
230
|
|
144
|
-
def get_attention_tp_rank():
|
231
|
+
def get_attention_tp_rank() -> int:
|
145
232
|
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
146
233
|
return _ATTN_TP_RANK
|
147
234
|
|
148
235
|
|
149
|
-
def get_attention_tp_size():
|
236
|
+
def get_attention_tp_size() -> int:
|
150
237
|
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
151
238
|
return _ATTN_TP_SIZE
|
152
239
|
|
153
240
|
|
154
|
-
def get_attention_dp_rank():
|
241
|
+
def get_attention_dp_rank() -> int:
|
155
242
|
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
156
243
|
return _ATTN_DP_RANK
|
157
244
|
|
158
245
|
|
159
|
-
def get_attention_dp_size():
|
246
|
+
def get_attention_dp_size() -> int:
|
160
247
|
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
161
248
|
return _ATTN_DP_SIZE
|
162
249
|
|
163
250
|
|
164
|
-
def get_local_attention_dp_rank():
|
251
|
+
def get_local_attention_dp_rank() -> int:
|
165
252
|
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
166
253
|
return _LOCAL_ATTN_DP_RANK
|
167
254
|
|
168
255
|
|
169
|
-
def get_local_attention_dp_size():
|
256
|
+
def get_local_attention_dp_size() -> int:
|
170
257
|
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
171
258
|
return _LOCAL_ATTN_DP_SIZE
|
172
259
|
|
@@ -292,6 +379,10 @@ def _dp_gather_via_all_gather(
|
|
292
379
|
forward_batch: ForwardBatch,
|
293
380
|
is_partial: bool,
|
294
381
|
):
|
382
|
+
if get_attention_tp_size() == 1:
|
383
|
+
get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
|
384
|
+
return
|
385
|
+
|
295
386
|
if not is_partial:
|
296
387
|
if get_attention_tp_rank() != 0:
|
297
388
|
local_tokens.fill_(0)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Tuple
|
2
|
+
from typing import Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
|
|
92
92
|
|
93
93
|
|
94
94
|
def ensure_workspace_initialized(
|
95
|
-
max_token_num: int =
|
95
|
+
max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
96
96
|
):
|
97
97
|
"""Ensure workspace is initialized"""
|
98
98
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
|
|
124
124
|
residual: torch.Tensor,
|
125
125
|
weight: torch.Tensor,
|
126
126
|
eps: float = 1e-6,
|
127
|
-
max_token_num: int =
|
128
|
-
use_oneshot: bool =
|
127
|
+
max_token_num: int = 2048,
|
128
|
+
use_oneshot: Optional[bool] = None,
|
129
129
|
trigger_completion_at_end: bool = False,
|
130
130
|
fp32_acc: bool = False,
|
131
131
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
sglang/srt/layers/linear.py
CHANGED
@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
|
|
1294
1294
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
1295
1295
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1296
1296
|
sm.tag(output_parallel)
|
1297
|
+
|
1297
1298
|
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
1298
1299
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1299
1300
|
else:
|
@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
|
|
27
27
|
tensor_model_parallel_all_gather,
|
28
28
|
)
|
29
29
|
from sglang.srt.layers.dp_attention import (
|
30
|
-
|
30
|
+
DpPaddingMode,
|
31
31
|
attn_tp_all_gather,
|
32
32
|
attn_tp_all_gather_into_tensor,
|
33
33
|
dp_gather_replicate,
|
@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
|
|
35
35
|
get_attention_dp_rank,
|
36
36
|
get_attention_dp_size,
|
37
37
|
get_attention_tp_size,
|
38
|
+
get_global_dp_buffer,
|
38
39
|
get_local_attention_dp_size,
|
40
|
+
set_dp_buffer_len,
|
39
41
|
)
|
40
42
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
41
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -108,14 +110,12 @@ class LogitsMetadata:
|
|
108
110
|
# The start position of local hidden states.
|
109
111
|
dp_local_start_pos: Optional[torch.Tensor] = None
|
110
112
|
dp_local_num_tokens: Optional[torch.Tensor] = None
|
111
|
-
|
112
|
-
# Buffer to gather logits from all ranks.
|
113
|
-
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
|
113
|
+
global_dp_buffer_len: Optional[int] = None
|
114
114
|
# Number of tokens to sample per DP rank
|
115
115
|
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
116
116
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
117
117
|
# The gather mode for DP attention
|
118
|
-
dp_padding_mode: Optional[
|
118
|
+
dp_padding_mode: Optional[DpPaddingMode] = None
|
119
119
|
# for padding
|
120
120
|
padded_static_len: int = -1
|
121
121
|
|
@@ -164,11 +164,10 @@ class LogitsMetadata:
|
|
164
164
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
165
165
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
166
166
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
167
|
-
|
168
|
-
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
|
167
|
+
global_dp_buffer_len=forward_batch.global_dp_buffer_len,
|
169
168
|
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
|
170
169
|
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
171
|
-
dp_padding_mode=
|
170
|
+
dp_padding_mode=DpPaddingMode.SUM_LEN,
|
172
171
|
)
|
173
172
|
|
174
173
|
def compute_dp_attention_metadata(self):
|
@@ -188,16 +187,11 @@ class LogitsMetadata:
|
|
188
187
|
|
189
188
|
if self.global_num_tokens_for_logprob_cpu is not None:
|
190
189
|
# create a smaller buffer to reduce peak memory usage
|
191
|
-
self.
|
192
|
-
(
|
193
|
-
sum(self.global_num_tokens_for_logprob_cpu),
|
194
|
-
self.gathered_buffer.shape[1],
|
195
|
-
),
|
196
|
-
dtype=self.gathered_buffer.dtype,
|
197
|
-
device=self.gathered_buffer.device,
|
198
|
-
)
|
190
|
+
self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
|
199
191
|
else:
|
200
|
-
self.
|
192
|
+
self.global_dp_buffer_len = self.global_dp_buffer_len
|
193
|
+
|
194
|
+
set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
|
201
195
|
|
202
196
|
|
203
197
|
class LogitsProcessor(nn.Module):
|
@@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module):
|
|
443
437
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
444
438
|
logits_metadata.compute_dp_attention_metadata()
|
445
439
|
hidden_states, local_hidden_states = (
|
446
|
-
|
440
|
+
get_global_dp_buffer(),
|
447
441
|
hidden_states,
|
448
442
|
)
|
449
443
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
12
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
12
|
+
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
13
13
|
from sglang.srt.utils import is_cuda
|
14
14
|
|
15
15
|
_is_cuda = is_cuda()
|
@@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8(
|
|
124
124
|
|
125
125
|
if is_cuda:
|
126
126
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
127
|
+
per_group_transpose,
|
127
128
|
per_token_group_quant_fp8_hopper_moe_mn_major,
|
128
129
|
sglang_per_token_group_quant_fp8,
|
129
130
|
)
|
@@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8(
|
|
152
153
|
k,
|
153
154
|
)
|
154
155
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
|
162
|
-
rep_a, expert_offsets, problem_sizes1, 128
|
163
|
-
)
|
156
|
+
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
157
|
+
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
158
|
+
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
159
|
+
|
160
|
+
if not is_sm100_supported():
|
161
|
+
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
164
162
|
w1_scale = w1_scale.contiguous()
|
165
163
|
|
166
164
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
@@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8(
|
|
193
191
|
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
194
192
|
silu_and_mul(c1, intermediate)
|
195
193
|
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
|
200
|
-
intermediate, expert_offsets, problem_sizes2, 128
|
201
|
-
)
|
194
|
+
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
195
|
+
if not is_sm100_supported():
|
196
|
+
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
202
197
|
w2_scale = w2_scale.contiguous()
|
203
198
|
|
204
199
|
fp8_blockwise_scaled_grouped_mm(
|
@@ -11,7 +11,7 @@ from sgl_kernel import (
|
|
11
11
|
)
|
12
12
|
|
13
13
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
|
-
|
14
|
+
post_reorder_triton_kernel_for_cutlass_moe,
|
15
15
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
16
16
|
run_cutlass_moe_ep_preproess,
|
17
17
|
)
|
@@ -199,14 +199,13 @@ def cutlass_w4a8_moe(
|
|
199
199
|
)
|
200
200
|
|
201
201
|
output = torch.empty_like(a)
|
202
|
-
|
202
|
+
post_reorder_triton_kernel_for_cutlass_moe[(m,)](
|
203
203
|
c2,
|
204
204
|
output,
|
205
205
|
src2dst,
|
206
|
-
|
206
|
+
local_topk_ids,
|
207
207
|
topk_weights,
|
208
|
-
|
209
|
-
end_expert_id,
|
208
|
+
num_experts,
|
210
209
|
topk,
|
211
210
|
k,
|
212
211
|
0,
|
@@ -581,6 +581,49 @@ def post_reorder_triton_kernel(
|
|
581
581
|
)
|
582
582
|
|
583
583
|
|
584
|
+
@triton.jit
|
585
|
+
def post_reorder_triton_kernel_for_cutlass_moe(
|
586
|
+
down_output_ptr,
|
587
|
+
output_ptr,
|
588
|
+
src2dst_ptr,
|
589
|
+
topk_ids_ptr,
|
590
|
+
topk_weights_ptr,
|
591
|
+
num_experts,
|
592
|
+
topk,
|
593
|
+
hidden_size,
|
594
|
+
dst_start,
|
595
|
+
BLOCK_SIZE: tl.constexpr,
|
596
|
+
):
|
597
|
+
InDtype = down_output_ptr.dtype.element_ty
|
598
|
+
|
599
|
+
src_idx_int32 = tl.program_id(0)
|
600
|
+
src_idx = src_idx_int32.to(tl.int64)
|
601
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
602
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
603
|
+
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
604
|
+
|
605
|
+
store_ptr = output_ptr + src_idx * hidden_size
|
606
|
+
|
607
|
+
vec = tl.arange(0, BLOCK_SIZE)
|
608
|
+
|
609
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
610
|
+
offset = start_offset + vec
|
611
|
+
mask = offset < hidden_size
|
612
|
+
|
613
|
+
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
614
|
+
for idx in range(topk):
|
615
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
616
|
+
if expert_id != num_experts:
|
617
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
618
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
619
|
+
dst_idx = dst_idx - dst_start
|
620
|
+
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
621
|
+
load_ptr = down_output_ptr + dst_idx * hidden_size
|
622
|
+
in_data = tl.load(load_ptr + offset, mask=mask)
|
623
|
+
sum_vec += in_data * weigh_scale
|
624
|
+
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
625
|
+
|
626
|
+
|
584
627
|
@triton.jit
|
585
628
|
def compute_m_range(
|
586
629
|
pid,
|
@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
|
|
34
34
|
|
35
35
|
if TYPE_CHECKING:
|
36
36
|
from sglang.srt.layers.moe.token_dispatcher import (
|
37
|
+
AscendDeepEPLLOutput,
|
37
38
|
DeepEPLLOutput,
|
38
39
|
DeepEPNormalOutput,
|
39
40
|
DispatchOutput,
|
@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE):
|
|
387
388
|
return_recv_hook=True,
|
388
389
|
)
|
389
390
|
|
390
|
-
if self.deepep_mode.enable_low_latency():
|
391
|
+
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
392
|
+
# NPU supports low_latency deepep without deepgemm
|
391
393
|
assert (
|
392
394
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
393
395
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE):
|
|
404
406
|
)
|
405
407
|
# the last one is invalid rank_id
|
406
408
|
self.expert_mask[:-1] = 1
|
407
|
-
|
409
|
+
elif not _is_npu:
|
408
410
|
self.w13_weight_fp8 = (
|
409
411
|
self.w13_weight,
|
410
412
|
(
|
@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE):
|
|
459
461
|
if _use_aiter:
|
460
462
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
461
463
|
return self.forward_aiter(dispatch_output)
|
464
|
+
if _is_npu:
|
465
|
+
return self.forward_npu(dispatch_output)
|
462
466
|
if dispatch_output.format.is_deepep_normal():
|
463
467
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
464
468
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE):
|
|
723
727
|
|
724
728
|
return down_output
|
725
729
|
|
730
|
+
def forward_npu(
|
731
|
+
self,
|
732
|
+
dispatch_output: DeepEPLLOutput,
|
733
|
+
):
|
734
|
+
if TYPE_CHECKING:
|
735
|
+
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
736
|
+
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
737
|
+
assert self.quant_method is not None
|
738
|
+
assert self.activation == "silu"
|
739
|
+
|
740
|
+
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
741
|
+
output_dtype = torch.bfloat16
|
742
|
+
|
743
|
+
pertoken_scale = hidden_states[1]
|
744
|
+
hidden_states = hidden_states[0]
|
745
|
+
|
746
|
+
group_list_type = 1
|
747
|
+
seg_indptr = seg_indptr.to(torch.int64)
|
748
|
+
|
749
|
+
import torch_npu
|
750
|
+
|
751
|
+
# gmm1: gate_up_proj
|
752
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
753
|
+
x=[hidden_states],
|
754
|
+
weight=[self.w13_weight],
|
755
|
+
scale=[self.w13_weight_scale.to(output_dtype)],
|
756
|
+
per_token_scale=[pertoken_scale],
|
757
|
+
split_item=2,
|
758
|
+
group_list_type=group_list_type,
|
759
|
+
group_type=0,
|
760
|
+
group_list=seg_indptr,
|
761
|
+
output_dtype=output_dtype,
|
762
|
+
)[0]
|
763
|
+
|
764
|
+
# act_fn: swiglu
|
765
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
766
|
+
|
767
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
768
|
+
|
769
|
+
# gmm2: down_proj
|
770
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
771
|
+
x=[hidden_states],
|
772
|
+
weight=[self.w2_weight],
|
773
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
774
|
+
per_token_scale=[swiglu_out_scale],
|
775
|
+
split_item=2,
|
776
|
+
group_list_type=group_list_type,
|
777
|
+
group_type=0,
|
778
|
+
group_list=seg_indptr,
|
779
|
+
output_dtype=output_dtype,
|
780
|
+
)[0]
|
781
|
+
|
782
|
+
return hidden_states
|
783
|
+
|
726
784
|
|
727
785
|
def get_moe_impl_class():
|
728
786
|
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|