sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -24,8 +24,10 @@ if TYPE_CHECKING:
|
|
24
24
|
_ATTN_TP_GROUP = None
|
25
25
|
_ATTN_TP_RANK = None
|
26
26
|
_ATTN_TP_SIZE = None
|
27
|
-
|
28
|
-
|
27
|
+
_ATTN_DP_RANK = None
|
28
|
+
_ATTN_DP_SIZE = None
|
29
|
+
_LOCAL_ATTN_DP_SIZE = None
|
30
|
+
_LOCAL_ATTN_DP_RANK = None
|
29
31
|
|
30
32
|
|
31
33
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|
33
35
|
return tp_rank, tp_size, 0
|
34
36
|
|
35
37
|
attn_tp_size = tp_size // dp_size
|
36
|
-
|
38
|
+
attn_dp_rank = tp_rank // attn_tp_size
|
37
39
|
attn_tp_rank = tp_rank % attn_tp_size
|
38
|
-
|
40
|
+
|
41
|
+
return attn_tp_rank, attn_tp_size, attn_dp_rank
|
42
|
+
|
43
|
+
|
44
|
+
def compute_dp_attention_local_info(
|
45
|
+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
46
|
+
):
|
47
|
+
if not enable_dp_attention:
|
48
|
+
return tp_rank, tp_size, 0
|
49
|
+
|
50
|
+
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
|
51
|
+
local_tp_rank = tp_rank % local_tp_size
|
52
|
+
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
|
53
|
+
|
54
|
+
local_attn_tp_size = local_tp_size // local_dp_size
|
55
|
+
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
|
56
|
+
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
|
57
|
+
|
58
|
+
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
|
39
59
|
|
40
60
|
|
41
61
|
def initialize_dp_attention(
|
@@ -43,22 +63,32 @@ def initialize_dp_attention(
|
|
43
63
|
tp_rank: int,
|
44
64
|
tp_size: int,
|
45
65
|
dp_size: int,
|
66
|
+
moe_dense_tp_size: int,
|
46
67
|
pp_size: int,
|
47
68
|
):
|
48
|
-
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE,
|
69
|
+
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
70
|
+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
49
71
|
|
50
72
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
51
73
|
|
52
|
-
_ATTN_TP_RANK, _ATTN_TP_SIZE,
|
74
|
+
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
53
75
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
54
76
|
)
|
77
|
+
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
|
78
|
+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
79
|
+
)
|
55
80
|
|
56
81
|
if enable_dp_attention:
|
57
82
|
local_rank = tp_rank % (tp_size // dp_size)
|
58
|
-
|
83
|
+
_ATTN_DP_SIZE = dp_size
|
84
|
+
if moe_dense_tp_size is None:
|
85
|
+
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
86
|
+
else:
|
87
|
+
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
59
88
|
else:
|
60
89
|
local_rank = tp_rank
|
61
|
-
|
90
|
+
_ATTN_DP_SIZE = 1
|
91
|
+
_LOCAL_ATTN_DP_SIZE = 1
|
62
92
|
|
63
93
|
tp_group = get_tp_group()
|
64
94
|
_ATTN_TP_GROUP = GroupCoordinator(
|
@@ -93,13 +123,33 @@ def get_attention_tp_size():
|
|
93
123
|
|
94
124
|
|
95
125
|
def get_attention_dp_rank():
|
96
|
-
assert
|
97
|
-
return
|
126
|
+
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
127
|
+
return _ATTN_DP_RANK
|
98
128
|
|
99
129
|
|
100
130
|
def get_attention_dp_size():
|
101
|
-
assert
|
102
|
-
return
|
131
|
+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
132
|
+
return _ATTN_DP_SIZE
|
133
|
+
|
134
|
+
|
135
|
+
def get_local_attention_dp_rank():
|
136
|
+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
137
|
+
return _LOCAL_ATTN_DP_RANK
|
138
|
+
|
139
|
+
|
140
|
+
def get_local_attention_dp_size():
|
141
|
+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
142
|
+
return _LOCAL_ATTN_DP_SIZE
|
143
|
+
|
144
|
+
|
145
|
+
def get_local_attention_dp_rank():
|
146
|
+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
147
|
+
return _LOCAL_ATTN_DP_RANK
|
148
|
+
|
149
|
+
|
150
|
+
def get_local_attention_dp_size():
|
151
|
+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
152
|
+
return _LOCAL_ATTN_DP_SIZE
|
103
153
|
|
104
154
|
|
105
155
|
@contextmanager
|
@@ -112,19 +162,19 @@ def disable_dp_size():
|
|
112
162
|
Args:
|
113
163
|
tp_group (GroupCoordinator): the tp group coordinator
|
114
164
|
"""
|
115
|
-
global
|
116
|
-
assert
|
165
|
+
global _ATTN_DP_SIZE
|
166
|
+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
117
167
|
|
118
|
-
old_dp_size =
|
119
|
-
|
168
|
+
old_dp_size = _ATTN_DP_SIZE
|
169
|
+
_ATTN_DP_SIZE = 1
|
120
170
|
try:
|
121
171
|
yield
|
122
172
|
finally:
|
123
|
-
|
173
|
+
_ATTN_DP_SIZE = old_dp_size
|
124
174
|
|
125
175
|
|
126
176
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
127
|
-
dp_rank =
|
177
|
+
dp_rank = get_local_attention_dp_rank()
|
128
178
|
|
129
179
|
if forward_batch.dp_local_start_pos is None:
|
130
180
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
@@ -201,7 +251,7 @@ def _dp_gather(
|
|
201
251
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
202
252
|
)
|
203
253
|
|
204
|
-
# Input IDs are in int 32. We should use inplace_all_reduce for local case
|
254
|
+
# Input IDs are in int 32. We should use inplace_all_reduce for local case because of custom all reduce.
|
205
255
|
NUM_GPUS_PER_NODE = 8
|
206
256
|
if (
|
207
257
|
not local_tokens.dtype.is_floating_point
|
@@ -252,12 +302,12 @@ def dp_scatter(
|
|
252
302
|
)
|
253
303
|
|
254
304
|
|
255
|
-
def
|
305
|
+
def attn_tp_reduce_scatter(
|
256
306
|
output: torch.Tensor,
|
257
307
|
input_list: List[torch.Tensor],
|
258
308
|
):
|
259
309
|
return get_attention_tp_group().reduce_scatter(output, input_list)
|
260
310
|
|
261
311
|
|
262
|
-
def
|
312
|
+
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
263
313
|
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -76,7 +76,7 @@ class RMSNorm(CustomOp):
|
|
76
76
|
residual: Optional[torch.Tensor] = None,
|
77
77
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
78
78
|
if not x.is_contiguous():
|
79
|
-
# NOTE:
|
79
|
+
# NOTE: Remove this if aiter kernel supports discontinuous input
|
80
80
|
x = x.contiguous()
|
81
81
|
if residual is not None:
|
82
82
|
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
@@ -23,15 +23,17 @@ import triton.language as tl
|
|
23
23
|
from torch import nn
|
24
24
|
|
25
25
|
from sglang.srt.distributed import (
|
26
|
-
get_tensor_model_parallel_rank,
|
27
26
|
get_tensor_model_parallel_world_size,
|
28
27
|
tensor_model_parallel_all_gather,
|
29
28
|
)
|
30
29
|
from sglang.srt.layers.dp_attention import (
|
30
|
+
attn_tp_all_gather,
|
31
31
|
dp_gather_replicate,
|
32
32
|
dp_scatter,
|
33
|
-
get_attention_dp_rank,
|
34
33
|
get_attention_dp_size,
|
34
|
+
get_attention_tp_size,
|
35
|
+
get_local_attention_dp_rank,
|
36
|
+
get_local_attention_dp_size,
|
35
37
|
)
|
36
38
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
37
39
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -45,6 +47,18 @@ from sglang.srt.utils import dump_to_file
|
|
45
47
|
logger = logging.getLogger(__name__)
|
46
48
|
|
47
49
|
|
50
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
51
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
52
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
53
|
+
CaptureHiddenMode,
|
54
|
+
ForwardBatch,
|
55
|
+
ForwardMode,
|
56
|
+
)
|
57
|
+
from sglang.srt.utils import dump_to_file
|
58
|
+
|
59
|
+
logger = logging.getLogger(__name__)
|
60
|
+
|
61
|
+
|
48
62
|
@dataclasses.dataclass
|
49
63
|
class LogitsProcessorOutput:
|
50
64
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
@@ -169,7 +183,7 @@ class LogitsMetadata:
|
|
169
183
|
return
|
170
184
|
|
171
185
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
172
|
-
dp_rank =
|
186
|
+
dp_rank = get_local_attention_dp_rank()
|
173
187
|
if dp_rank == 0:
|
174
188
|
dp_local_start_pos = torch.zeros_like(
|
175
189
|
self.global_num_tokens_for_logprob_gpu[0]
|
@@ -198,12 +212,20 @@ class LogitsProcessor(nn.Module):
|
|
198
212
|
super().__init__()
|
199
213
|
self.config = config
|
200
214
|
self.logit_scale = logit_scale
|
201
|
-
self.
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
215
|
+
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
216
|
+
if self.use_attn_tp_group:
|
217
|
+
self.attn_tp_size = get_attention_tp_size()
|
218
|
+
self.do_tensor_parallel_all_gather = (
|
219
|
+
not skip_all_gather and self.attn_tp_size > 1
|
220
|
+
)
|
221
|
+
self.do_tensor_parallel_all_gather_dp_attn = False
|
222
|
+
else:
|
223
|
+
self.do_tensor_parallel_all_gather = (
|
224
|
+
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
225
|
+
)
|
226
|
+
self.do_tensor_parallel_all_gather_dp_attn = (
|
227
|
+
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
228
|
+
)
|
207
229
|
self.final_logit_softcapping = getattr(
|
208
230
|
self.config, "final_logit_softcapping", None
|
209
231
|
)
|
@@ -315,7 +337,8 @@ class LogitsProcessor(nn.Module):
|
|
315
337
|
|
316
338
|
if self.debug_tensor_dump_output_folder:
|
317
339
|
assert (
|
318
|
-
not self.do_tensor_parallel_all_gather
|
340
|
+
not self.do_tensor_parallel_all_gather
|
341
|
+
or get_local_attention_dp_size() == 1
|
319
342
|
), "dp attention + sharded lm_head doesn't support full logits"
|
320
343
|
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
321
344
|
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
@@ -442,7 +465,19 @@ class LogitsProcessor(nn.Module):
|
|
442
465
|
logits.mul_(self.logit_scale)
|
443
466
|
|
444
467
|
if self.do_tensor_parallel_all_gather:
|
445
|
-
|
468
|
+
if self.use_attn_tp_group:
|
469
|
+
global_logits = torch.empty(
|
470
|
+
(self.config.vocab_size, logits.shape[0]),
|
471
|
+
device=logits.device,
|
472
|
+
dtype=logits.dtype,
|
473
|
+
)
|
474
|
+
global_logits = global_logits.T
|
475
|
+
attn_tp_all_gather(
|
476
|
+
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
|
477
|
+
)
|
478
|
+
logits = global_logits
|
479
|
+
else:
|
480
|
+
logits = tensor_model_parallel_all_gather(logits)
|
446
481
|
|
447
482
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
448
483
|
logits, global_logits = (
|
@@ -5,16 +5,23 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
9
8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
10
9
|
from sglang.srt.utils import is_cuda
|
11
10
|
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
12
13
|
_is_cuda = is_cuda()
|
13
14
|
if _is_cuda:
|
14
15
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
15
|
-
sglang_per_token_group_quant_fp8,
|
16
|
+
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
16
17
|
)
|
17
|
-
|
18
|
+
|
19
|
+
try:
|
20
|
+
from deep_gemm import ceil_div
|
21
|
+
except ImportError:
|
22
|
+
logger.error(f"Failed to import ceil_div from deep_gemm.")
|
23
|
+
|
24
|
+
import triton.language as tl
|
18
25
|
|
19
26
|
|
20
27
|
@triton.jit
|
@@ -109,7 +116,7 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
|
109
116
|
seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
110
117
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
|
111
118
|
|
112
|
-
# Find
|
119
|
+
# Find offset
|
113
120
|
expert_ids = torch.arange(
|
114
121
|
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
|
115
122
|
)
|
@@ -654,10 +661,7 @@ def grouped_gemm_triton(
|
|
654
661
|
if block_shape is not None:
|
655
662
|
assert len(block_shape) == 2
|
656
663
|
block_n, block_k = block_shape[0], block_shape[1]
|
657
|
-
|
658
|
-
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
|
659
|
-
else:
|
660
|
-
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
664
|
+
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
661
665
|
|
662
666
|
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
663
667
|
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
@@ -707,3 +711,334 @@ def grouped_gemm_triton(
|
|
707
711
|
**config,
|
708
712
|
)
|
709
713
|
return c
|
714
|
+
|
715
|
+
|
716
|
+
@triton.jit
|
717
|
+
def _fwd_kernel_ep_scatter_1(
|
718
|
+
num_recv_tokens_per_expert,
|
719
|
+
expert_start_loc,
|
720
|
+
m_indices,
|
721
|
+
num_experts: tl.constexpr,
|
722
|
+
BLOCK_E: tl.constexpr,
|
723
|
+
BLOCK_EXPERT_NUM: tl.constexpr,
|
724
|
+
):
|
725
|
+
cur_expert = tl.program_id(0)
|
726
|
+
|
727
|
+
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
|
728
|
+
tokens_per_expert = tl.load(
|
729
|
+
num_recv_tokens_per_expert + offset_cumsum,
|
730
|
+
mask=offset_cumsum < num_experts,
|
731
|
+
other=0,
|
732
|
+
)
|
733
|
+
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
|
734
|
+
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
|
735
|
+
|
736
|
+
cur_expert_start = tl.load(expert_start_loc + cur_expert)
|
737
|
+
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
|
738
|
+
|
739
|
+
m_indices_start_ptr = m_indices + cur_expert_start
|
740
|
+
off_expert = tl.arange(0, BLOCK_E)
|
741
|
+
|
742
|
+
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
|
743
|
+
tl.store(
|
744
|
+
m_indices_start_ptr + start_m + off_expert,
|
745
|
+
cur_expert,
|
746
|
+
)
|
747
|
+
|
748
|
+
|
749
|
+
@triton.jit
|
750
|
+
def _fwd_kernel_ep_scatter_2(
|
751
|
+
total_token_num,
|
752
|
+
expert_start_loc,
|
753
|
+
recv_x,
|
754
|
+
recv_x_stride0,
|
755
|
+
recv_x_stride1,
|
756
|
+
recv_x_scale,
|
757
|
+
recv_x_scale_stride0,
|
758
|
+
recv_x_scale_stride1,
|
759
|
+
recv_topk,
|
760
|
+
recv_topk_stride0,
|
761
|
+
recv_topk_stride1,
|
762
|
+
output_tensor,
|
763
|
+
output_tensor_stride0,
|
764
|
+
output_tensor_stride1,
|
765
|
+
output_tensor_scale,
|
766
|
+
output_tensor_scale_stride0,
|
767
|
+
output_tensor_scale_stride1,
|
768
|
+
output_index,
|
769
|
+
output_index_stride0,
|
770
|
+
output_index_stride1,
|
771
|
+
topk_num: tl.constexpr,
|
772
|
+
HIDDEN_SIZE: tl.constexpr,
|
773
|
+
HIDDEN_SIZE_PAD: tl.constexpr,
|
774
|
+
SCALE_HIDDEN_SIZE: tl.constexpr,
|
775
|
+
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
|
776
|
+
):
|
777
|
+
start_token_id = tl.program_id(0)
|
778
|
+
grid_num = tl.num_programs(0)
|
779
|
+
|
780
|
+
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
|
781
|
+
mask = offset_in < HIDDEN_SIZE
|
782
|
+
|
783
|
+
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
784
|
+
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
785
|
+
|
786
|
+
for token_id in range(start_token_id, total_token_num, grid_num):
|
787
|
+
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
788
|
+
to_copy_s = tl.load(
|
789
|
+
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
790
|
+
)
|
791
|
+
|
792
|
+
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
|
793
|
+
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
794
|
+
if expert_id >= 0:
|
795
|
+
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
|
796
|
+
tl.store(
|
797
|
+
output_index + token_id * output_index_stride0 + topk_index,
|
798
|
+
dest_token_index,
|
799
|
+
)
|
800
|
+
output_tensor_ptr = (
|
801
|
+
output_tensor + dest_token_index * output_tensor_stride0
|
802
|
+
)
|
803
|
+
output_tensor_scale_ptr = (
|
804
|
+
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
|
805
|
+
)
|
806
|
+
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
|
807
|
+
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
|
808
|
+
|
809
|
+
|
810
|
+
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
|
811
|
+
@torch.no_grad()
|
812
|
+
def ep_scatter(
|
813
|
+
recv_x: torch.Tensor,
|
814
|
+
recv_x_scale: torch.Tensor,
|
815
|
+
recv_topk: torch.Tensor,
|
816
|
+
num_recv_tokens_per_expert: torch.Tensor,
|
817
|
+
expert_start_loc: torch.Tensor,
|
818
|
+
output_tensor: torch.Tensor,
|
819
|
+
output_tensor_scale: torch.Tensor,
|
820
|
+
m_indices: torch.Tensor,
|
821
|
+
output_index: torch.Tensor,
|
822
|
+
):
|
823
|
+
BLOCK_E = 128 # token num of per expert is aligned to 128
|
824
|
+
BLOCK_D = 128 # block size of quantization
|
825
|
+
num_warps = 8
|
826
|
+
num_experts = num_recv_tokens_per_expert.shape[0]
|
827
|
+
hidden_size = recv_x.shape[1]
|
828
|
+
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
|
829
|
+
grid = num_experts
|
830
|
+
|
831
|
+
assert m_indices.shape[0] % BLOCK_E == 0
|
832
|
+
|
833
|
+
_fwd_kernel_ep_scatter_1[(grid,)](
|
834
|
+
num_recv_tokens_per_expert,
|
835
|
+
expert_start_loc,
|
836
|
+
m_indices,
|
837
|
+
num_experts=num_experts,
|
838
|
+
num_warps=num_warps,
|
839
|
+
BLOCK_E=BLOCK_E,
|
840
|
+
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
|
841
|
+
)
|
842
|
+
|
843
|
+
grid = min(recv_topk.shape[0], 1024 * 8)
|
844
|
+
|
845
|
+
_fwd_kernel_ep_scatter_2[(grid,)](
|
846
|
+
recv_topk.shape[0],
|
847
|
+
expert_start_loc,
|
848
|
+
recv_x,
|
849
|
+
recv_x.stride(0),
|
850
|
+
recv_x.stride(1),
|
851
|
+
recv_x_scale,
|
852
|
+
recv_x_scale.stride(0),
|
853
|
+
recv_x_scale.stride(1),
|
854
|
+
recv_topk,
|
855
|
+
recv_topk.stride(0),
|
856
|
+
recv_topk.stride(1),
|
857
|
+
output_tensor,
|
858
|
+
output_tensor.stride(0),
|
859
|
+
output_tensor.stride(1),
|
860
|
+
output_tensor_scale,
|
861
|
+
output_tensor_scale.stride(0),
|
862
|
+
output_tensor_scale.stride(1),
|
863
|
+
output_index,
|
864
|
+
output_index.stride(0),
|
865
|
+
output_index.stride(1),
|
866
|
+
topk_num=recv_topk.shape[1],
|
867
|
+
num_warps=num_warps,
|
868
|
+
HIDDEN_SIZE=hidden_size,
|
869
|
+
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
|
870
|
+
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
|
871
|
+
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
|
872
|
+
)
|
873
|
+
return
|
874
|
+
|
875
|
+
|
876
|
+
@triton.jit
|
877
|
+
def _fwd_kernel_ep_gather(
|
878
|
+
total_token_num,
|
879
|
+
input_tensor,
|
880
|
+
input_tensor_stride0,
|
881
|
+
input_tensor_stride1,
|
882
|
+
recv_topk_ids,
|
883
|
+
recv_topk_ids_stride0,
|
884
|
+
recv_topk_ids_stride1,
|
885
|
+
recv_topk_weight,
|
886
|
+
recv_topk_weight_stride0,
|
887
|
+
recv_topk_weight_stride1,
|
888
|
+
input_index,
|
889
|
+
input_index_stride0,
|
890
|
+
input_index_stride1,
|
891
|
+
output_tensor,
|
892
|
+
output_tensor_stride0,
|
893
|
+
output_tensor_stride1,
|
894
|
+
topk_num: tl.constexpr,
|
895
|
+
BLOCK_D: tl.constexpr,
|
896
|
+
):
|
897
|
+
cur_block = tl.program_id(0)
|
898
|
+
start_cur_token = tl.program_id(1)
|
899
|
+
grid_num = tl.num_programs(1)
|
900
|
+
|
901
|
+
for cur_token in range(start_cur_token, total_token_num, grid_num):
|
902
|
+
off_d = tl.arange(0, BLOCK_D)
|
903
|
+
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
904
|
+
for topk_index in range(0, topk_num):
|
905
|
+
expert_id = tl.load(
|
906
|
+
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
907
|
+
)
|
908
|
+
if expert_id >= 0:
|
909
|
+
source_token_index = tl.load(
|
910
|
+
input_index + cur_token * input_index_stride0 + topk_index
|
911
|
+
)
|
912
|
+
acc_weight = tl.load(
|
913
|
+
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
914
|
+
)
|
915
|
+
tmp = tl.load(
|
916
|
+
input_tensor
|
917
|
+
+ source_token_index * input_tensor_stride0
|
918
|
+
+ cur_block * BLOCK_D
|
919
|
+
+ off_d
|
920
|
+
)
|
921
|
+
accumulator += tmp.to(tl.float32) * acc_weight
|
922
|
+
|
923
|
+
tl.store(
|
924
|
+
output_tensor
|
925
|
+
+ cur_token * output_tensor_stride0
|
926
|
+
+ cur_block * BLOCK_D
|
927
|
+
+ off_d,
|
928
|
+
accumulator.to(output_tensor.dtype.element_ty),
|
929
|
+
)
|
930
|
+
|
931
|
+
|
932
|
+
@torch.no_grad()
|
933
|
+
def ep_gather(
|
934
|
+
input_tensor: torch.Tensor,
|
935
|
+
recv_topk_ids: torch.Tensor,
|
936
|
+
recv_topk_weight: torch.Tensor,
|
937
|
+
input_index: torch.Tensor,
|
938
|
+
output_tensor: torch.Tensor,
|
939
|
+
):
|
940
|
+
BLOCK_D = 1024 # block size of quantization
|
941
|
+
num_warps = 2
|
942
|
+
num_tokens = output_tensor.shape[0]
|
943
|
+
hidden_size = input_tensor.shape[1]
|
944
|
+
assert hidden_size % BLOCK_D == 0
|
945
|
+
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
|
946
|
+
_fwd_kernel_ep_gather[grid](
|
947
|
+
num_tokens,
|
948
|
+
input_tensor,
|
949
|
+
input_tensor.stride(0),
|
950
|
+
input_tensor.stride(1),
|
951
|
+
recv_topk_ids,
|
952
|
+
recv_topk_ids.stride(0),
|
953
|
+
recv_topk_ids.stride(1),
|
954
|
+
recv_topk_weight,
|
955
|
+
recv_topk_weight.stride(0),
|
956
|
+
recv_topk_weight.stride(1),
|
957
|
+
input_index,
|
958
|
+
input_index.stride(0),
|
959
|
+
input_index.stride(1),
|
960
|
+
output_tensor,
|
961
|
+
output_tensor.stride(0),
|
962
|
+
output_tensor.stride(1),
|
963
|
+
topk_num=recv_topk_ids.shape[1],
|
964
|
+
num_warps=num_warps,
|
965
|
+
BLOCK_D=BLOCK_D,
|
966
|
+
)
|
967
|
+
return
|
968
|
+
|
969
|
+
|
970
|
+
# copy from
|
971
|
+
# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
|
972
|
+
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
973
|
+
"""
|
974
|
+
Global memory address of TMA must be 16-byte aligned.
|
975
|
+
Since we use column-major layout for the LHS scaling tensor,
|
976
|
+
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
977
|
+
|
978
|
+
Arguments:
|
979
|
+
x: original M-axis shape of the LHS scaling tensor.
|
980
|
+
element_size: element size of the LHS scaling tensor.
|
981
|
+
|
982
|
+
Returns:
|
983
|
+
M-axis shape of the LHS scaling tensor after padding.
|
984
|
+
"""
|
985
|
+
tma_alignment_bytes = 16
|
986
|
+
assert tma_alignment_bytes % element_size == 0
|
987
|
+
alignment = tma_alignment_bytes // element_size
|
988
|
+
return ceil_div(x, alignment) * alignment
|
989
|
+
|
990
|
+
|
991
|
+
@triton.jit
|
992
|
+
def _tma_align_input_scale_kernel(
|
993
|
+
input_scale_ptr,
|
994
|
+
output_ptr,
|
995
|
+
m,
|
996
|
+
k_div_block_size,
|
997
|
+
input_scale_stride_m,
|
998
|
+
input_scale_stride_k,
|
999
|
+
output_stride_m,
|
1000
|
+
output_stride_k,
|
1001
|
+
BLOCK_SIZE_K: tl.constexpr,
|
1002
|
+
):
|
1003
|
+
pid_m = tl.program_id(axis=0)
|
1004
|
+
grid_m = tl.num_programs(0)
|
1005
|
+
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
1006
|
+
|
1007
|
+
for m_base in range(pid_m, m, grid_m):
|
1008
|
+
input_offset = (
|
1009
|
+
input_scale_ptr
|
1010
|
+
+ m_base * input_scale_stride_m
|
1011
|
+
+ k_offsets * input_scale_stride_k
|
1012
|
+
)
|
1013
|
+
input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
|
1014
|
+
|
1015
|
+
output_offset = (
|
1016
|
+
output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
|
1017
|
+
)
|
1018
|
+
tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
|
1019
|
+
|
1020
|
+
|
1021
|
+
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
|
1022
|
+
def tma_align_input_scale(input_scale: torch.Tensor):
|
1023
|
+
assert input_scale.dim() == 2
|
1024
|
+
m, k_div_block_size = input_scale.shape
|
1025
|
+
padd_m = get_tma_aligned_size(m, input_scale.element_size())
|
1026
|
+
output = torch.empty(
|
1027
|
+
(k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
grid_m = min(m, 8192)
|
1031
|
+
BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
|
1032
|
+
|
1033
|
+
_tma_align_input_scale_kernel[(grid_m,)](
|
1034
|
+
input_scale_ptr=input_scale,
|
1035
|
+
output_ptr=output,
|
1036
|
+
m=m,
|
1037
|
+
k_div_block_size=k_div_block_size,
|
1038
|
+
input_scale_stride_m=input_scale.stride(0),
|
1039
|
+
input_scale_stride_k=input_scale.stride(1),
|
1040
|
+
output_stride_m=output.stride(1), # Note: these are swapped
|
1041
|
+
output_stride_k=output.stride(0), # for column-major
|
1042
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
1043
|
+
)
|
1044
|
+
return output.t()[:m]
|