sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
262
262
|
)
|
263
263
|
|
264
264
|
def init_cuda_graph_state(
|
265
|
-
self,
|
265
|
+
self,
|
266
|
+
max_bs: int,
|
267
|
+
max_num_tokens: int,
|
268
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
266
269
|
):
|
267
270
|
if kv_indices_buf is None:
|
268
271
|
cuda_graph_kv_indices = torch.zeros(
|
269
|
-
(
|
272
|
+
(max_num_tokens * self.max_context_len,),
|
270
273
|
dtype=torch.int32,
|
271
274
|
device="cuda",
|
272
275
|
)
|
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
285
288
|
|
286
289
|
if not self.skip_prefill:
|
287
290
|
self.cuda_graph_custom_mask = torch.zeros(
|
288
|
-
(
|
291
|
+
(max_num_tokens * self.max_context_len),
|
289
292
|
dtype=torch.uint8,
|
290
293
|
device="cuda",
|
291
294
|
)
|
@@ -440,7 +443,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
440
443
|
raise ValueError("Invalid forward mode")
|
441
444
|
|
442
445
|
def get_cuda_graph_seq_len_fill_value(self):
|
443
|
-
return
|
446
|
+
return 1
|
444
447
|
|
445
448
|
def forward_extend(
|
446
449
|
self,
|
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1096
1099
|
|
1097
1100
|
self.common_template(forward_batch, kv_indices, call_fn)
|
1098
1101
|
|
1099
|
-
def init_cuda_graph_state(self, max_bs: int):
|
1102
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
1100
1103
|
self.cuda_graph_kv_indices = torch.zeros(
|
1101
1104
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
1102
1105
|
dtype=torch.int32,
|
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1105
1108
|
|
1106
1109
|
for i in range(self.speculative_num_steps):
|
1107
1110
|
self.attn_backends[i].init_cuda_graph_state(
|
1108
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
1111
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
1109
1112
|
)
|
1110
1113
|
|
1111
1114
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
199
199
|
)
|
200
200
|
|
201
201
|
def init_cuda_graph_state(
|
202
|
-
self,
|
202
|
+
self,
|
203
|
+
max_bs: int,
|
204
|
+
max_num_tokens: int,
|
205
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
203
206
|
):
|
204
207
|
if kv_indices_buf is None:
|
205
208
|
cuda_graph_kv_indices = torch.zeros(
|
@@ -364,7 +367,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
364
367
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
365
368
|
|
366
369
|
def get_cuda_graph_seq_len_fill_value(self):
|
367
|
-
return
|
370
|
+
return 1
|
368
371
|
|
369
372
|
def forward_extend(
|
370
373
|
self,
|
@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
852
855
|
|
853
856
|
self.common_template(forward_batch, kv_indices, call_fn)
|
854
857
|
|
855
|
-
def init_cuda_graph_state(self, max_bs: int):
|
858
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
856
859
|
self.cuda_graph_kv_indices = torch.zeros(
|
857
860
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
858
861
|
dtype=torch.int32,
|
@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
861
864
|
|
862
865
|
for i in range(self.speculative_num_steps):
|
863
866
|
self.attn_backends[i].init_cuda_graph_state(
|
864
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
867
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
865
868
|
)
|
866
869
|
|
867
870
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
148
148
|
def init_cuda_graph_state(
|
149
149
|
self,
|
150
150
|
max_bs: int,
|
151
|
+
max_num_tokens: int,
|
151
152
|
block_kv_indices: Optional[torch.Tensor] = None,
|
152
153
|
):
|
153
154
|
if block_kv_indices is None:
|
@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
|
|
502
503
|
|
503
504
|
self.common_template(forward_batch, call_fn)
|
504
505
|
|
505
|
-
def init_cuda_graph_state(self, max_bs: int):
|
506
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
506
507
|
for i in range(self.speculative_num_steps):
|
507
|
-
self.attn_backends[i].init_cuda_graph_state(
|
508
|
+
self.attn_backends[i].init_cuda_graph_state(
|
509
|
+
max_bs, max_num_tokens, block_kv_indices=None
|
510
|
+
)
|
508
511
|
|
509
512
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
510
513
|
def call_fn(i, forward_batch):
|
@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
|
|
32
32
|
if forward_batch_child.batch_size > 0:
|
33
33
|
child.init_forward_metadata(forward_batch=forward_batch_child)
|
34
34
|
|
35
|
-
def init_cuda_graph_state(self, max_bs: int):
|
36
|
-
self.primary.init_cuda_graph_state(max_bs=max_bs)
|
35
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
36
|
+
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
37
37
|
for item in self.children:
|
38
38
|
# TODO for children, maybe can provide *smaller* max_bs to optimize
|
39
|
-
item.init_cuda_graph_state(max_bs=max_bs)
|
39
|
+
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
40
40
|
|
41
41
|
def init_forward_metadata_capture_cuda_graph(
|
42
42
|
self,
|
@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
261
261
|
num_kv_splits = None
|
262
262
|
attn_logits = None
|
263
263
|
attn_lse = None
|
264
|
+
|
264
265
|
elif forward_batch.forward_mode.is_draft_extend():
|
265
266
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
266
267
|
spec_info.generate_attn_arg_prefill(
|
@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
|
|
335
336
|
)
|
336
337
|
|
337
338
|
def init_cuda_graph_state(
|
338
|
-
self,
|
339
|
+
self,
|
340
|
+
max_bs: int,
|
341
|
+
max_num_tokens: int,
|
342
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
339
343
|
):
|
340
344
|
self.cuda_graph_attn_logits = torch.zeros(
|
341
|
-
(
|
345
|
+
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
|
342
346
|
dtype=torch.float32,
|
343
347
|
device=self.device,
|
344
348
|
)
|
345
349
|
self.cuda_graph_attn_lse = torch.zeros(
|
346
|
-
(
|
350
|
+
(max_num_tokens, self.num_head, self.max_kv_splits),
|
347
351
|
dtype=torch.float32,
|
348
352
|
device=self.device,
|
349
353
|
)
|
350
354
|
self.cuda_graph_num_kv_splits = torch.full(
|
351
|
-
(
|
355
|
+
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
352
356
|
)
|
353
357
|
if kv_indices_buf is None:
|
354
358
|
self.cuda_graph_kv_indices = torch.zeros(
|
355
|
-
(
|
359
|
+
(max_num_tokens * self.max_context_len),
|
356
360
|
dtype=torch.int32,
|
357
361
|
device=self.device,
|
358
362
|
)
|
@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
361
365
|
|
362
366
|
if not self.skip_prefill:
|
363
367
|
self.cuda_graph_custom_mask = torch.zeros(
|
364
|
-
(
|
368
|
+
(max_num_tokens * self.max_context_len),
|
365
369
|
dtype=torch.uint8,
|
366
370
|
device=self.device,
|
367
371
|
)
|
@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
369
373
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
370
374
|
if kv_indices_buf is None:
|
371
375
|
self.cuda_graph_window_kv_indices = torch.zeros(
|
372
|
-
(
|
376
|
+
(max_num_tokens * self.sliding_window_size),
|
373
377
|
dtype=torch.int32,
|
374
378
|
device=self.device,
|
375
379
|
)
|
@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
377
381
|
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
|
378
382
|
|
379
383
|
self.cuda_graph_window_num_kv_splits = torch.full(
|
380
|
-
(
|
384
|
+
(max_num_tokens,),
|
385
|
+
self.max_kv_splits,
|
386
|
+
dtype=torch.int32,
|
387
|
+
device=self.device,
|
381
388
|
)
|
382
389
|
|
383
390
|
def init_forward_metadata_capture_cuda_graph(
|
@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
458
465
|
)
|
459
466
|
|
460
467
|
custom_mask = self.cuda_graph_custom_mask
|
468
|
+
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
461
469
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
462
470
|
mask_indptr = self.mask_indptr[: bs + 1]
|
463
471
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
|
|
821
829
|
|
822
830
|
self.common_template(forward_batch, kv_indices, call_fn)
|
823
831
|
|
824
|
-
def init_cuda_graph_state(self, max_bs: int):
|
832
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
825
833
|
self.cuda_graph_kv_indices = torch.zeros(
|
826
|
-
(self.speculative_num_steps,
|
834
|
+
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
827
835
|
dtype=torch.int32,
|
828
836
|
device=self.device,
|
829
837
|
)
|
830
838
|
for i in range(self.speculative_num_steps):
|
831
839
|
self.attn_backends[i].init_cuda_graph_state(
|
832
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
840
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
833
841
|
)
|
834
842
|
|
835
843
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
|
|
28
28
|
attn_tp_reduce_scatter,
|
29
29
|
dp_gather_partial,
|
30
30
|
dp_scatter,
|
31
|
+
get_attention_dp_size,
|
31
32
|
get_attention_tp_rank,
|
32
33
|
get_attention_tp_size,
|
33
|
-
get_local_attention_dp_size,
|
34
34
|
)
|
35
35
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
36
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -229,7 +229,7 @@ class CommunicateContext:
|
|
229
229
|
process_group_sizes: Dict[ScatterMode, int]
|
230
230
|
attn_tp_rank: int
|
231
231
|
attn_tp_size: int
|
232
|
-
|
232
|
+
attn_dp_size: int
|
233
233
|
tp_size: int
|
234
234
|
|
235
235
|
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
@@ -239,7 +239,7 @@ class CommunicateContext:
|
|
239
239
|
def init_new(cls):
|
240
240
|
attn_tp_rank = get_attention_tp_rank()
|
241
241
|
attn_tp_size = get_attention_tp_size()
|
242
|
-
|
242
|
+
attn_dp_size = get_attention_dp_size()
|
243
243
|
tp_size = get_tensor_model_parallel_world_size()
|
244
244
|
process_group_sizes = {
|
245
245
|
ScatterMode.SCATTERED: 1,
|
@@ -251,7 +251,7 @@ class CommunicateContext:
|
|
251
251
|
process_group_sizes=process_group_sizes,
|
252
252
|
attn_tp_rank=attn_tp_rank,
|
253
253
|
attn_tp_size=attn_tp_size,
|
254
|
-
|
254
|
+
attn_dp_size=attn_dp_size,
|
255
255
|
tp_size=tp_size,
|
256
256
|
)
|
257
257
|
|
@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
385
385
|
attn_tp_all_gather(
|
386
386
|
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
387
387
|
)
|
388
|
-
if context.
|
388
|
+
if context.attn_dp_size != 1:
|
389
389
|
if context.attn_tp_rank == 0:
|
390
390
|
hidden_states += residual
|
391
391
|
hidden_states, local_hidden_states = (
|
@@ -165,7 +165,8 @@ def disable_dp_size():
|
|
165
165
|
|
166
166
|
|
167
167
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
168
|
-
|
168
|
+
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
169
|
+
dp_rank = get_attention_dp_rank()
|
169
170
|
|
170
171
|
if forward_batch.dp_local_start_pos is None:
|
171
172
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
@@ -238,6 +239,10 @@ def _dp_gather(
|
|
238
239
|
assert (
|
239
240
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
240
241
|
), "aliasing between global_tokens and local_tokens not allowed"
|
242
|
+
if forward_batch.forward_mode.is_draft_extend():
|
243
|
+
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
244
|
+
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
245
|
+
|
241
246
|
memcpy_triton(
|
242
247
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
243
248
|
)
|
@@ -288,6 +293,10 @@ def dp_scatter(
|
|
288
293
|
assert (
|
289
294
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
290
295
|
), "aliasing between local_tokens and global_tokens not allowed"
|
296
|
+
if forward_batch.forward_mode.is_draft_extend():
|
297
|
+
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
298
|
+
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
299
|
+
|
291
300
|
memcpy_triton(
|
292
301
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
293
302
|
)
|
@@ -301,4 +310,4 @@ def attn_tp_reduce_scatter(
|
|
301
310
|
|
302
311
|
|
303
312
|
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
304
|
-
return get_attention_tp_group().all_gather(input_,
|
313
|
+
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -20,11 +20,21 @@ import torch
|
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
22
|
from sglang.srt.custom_op import CustomOp
|
23
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
cpu_has_amx_support,
|
25
|
+
get_bool_env_var,
|
26
|
+
is_cpu,
|
27
|
+
is_cuda,
|
28
|
+
is_hip,
|
29
|
+
is_npu,
|
30
|
+
)
|
24
31
|
|
25
32
|
_is_cuda = is_cuda()
|
26
33
|
_is_hip = is_hip()
|
34
|
+
_is_npu = is_npu()
|
27
35
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
36
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
37
|
+
_is_cpu = is_cpu()
|
28
38
|
|
29
39
|
if _is_cuda:
|
30
40
|
from sgl_kernel import (
|
@@ -121,6 +131,23 @@ class RMSNorm(CustomOp):
|
|
121
131
|
else:
|
122
132
|
return x, residual
|
123
133
|
|
134
|
+
def forward_cpu(
|
135
|
+
self,
|
136
|
+
x: torch.Tensor,
|
137
|
+
residual: Optional[torch.Tensor] = None,
|
138
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
139
|
+
if _is_cpu_amx_available:
|
140
|
+
if residual is not None:
|
141
|
+
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
142
|
+
x, residual, self.weight.data, self.variance_epsilon
|
143
|
+
)
|
144
|
+
return x, residual
|
145
|
+
return torch.ops.sgl_kernel.rmsnorm_cpu(
|
146
|
+
x, self.weight.data, self.variance_epsilon
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
return self.forward_native(x, residual)
|
150
|
+
|
124
151
|
|
125
152
|
class GemmaRMSNorm(CustomOp):
|
126
153
|
def __init__(
|
@@ -187,7 +214,7 @@ class Gemma3RMSNorm(nn.Module):
|
|
187
214
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
188
215
|
|
189
216
|
|
190
|
-
if not (_is_cuda or _is_hip):
|
217
|
+
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
191
218
|
logger.info(
|
192
219
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
193
220
|
)
|
@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
|
|
30
30
|
attn_tp_all_gather,
|
31
31
|
dp_gather_replicate,
|
32
32
|
dp_scatter,
|
33
|
+
get_attention_dp_rank,
|
33
34
|
get_attention_dp_size,
|
34
35
|
get_attention_tp_size,
|
35
|
-
get_local_attention_dp_rank,
|
36
36
|
get_local_attention_dp_size,
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
@@ -171,7 +171,7 @@ class LogitsMetadata:
|
|
171
171
|
return
|
172
172
|
|
173
173
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
174
|
-
dp_rank =
|
174
|
+
dp_rank = get_attention_dp_rank()
|
175
175
|
if dp_rank == 0:
|
176
176
|
dp_local_start_pos = torch.zeros_like(
|
177
177
|
self.global_num_tokens_for_logprob_gpu[0]
|
@@ -478,11 +478,13 @@ def post_reorder_triton_kernel(
|
|
478
478
|
end_expert_id,
|
479
479
|
topk,
|
480
480
|
hidden_size,
|
481
|
+
dst_start,
|
481
482
|
BLOCK_SIZE: tl.constexpr,
|
482
483
|
):
|
483
484
|
InDtype = down_output_ptr.dtype.element_ty
|
484
485
|
|
485
|
-
|
486
|
+
src_idx_int32 = tl.program_id(0)
|
487
|
+
src_idx = src_idx_int32.to(tl.int64)
|
486
488
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
487
489
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
488
490
|
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
@@ -501,7 +503,9 @@ def post_reorder_triton_kernel(
|
|
501
503
|
expert_id = tl.load(topk_ids_ptr + idx)
|
502
504
|
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
503
505
|
computed = True
|
504
|
-
|
506
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
507
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
508
|
+
dst_idx = dst_idx - dst_start
|
505
509
|
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
506
510
|
load_ptr = down_output_ptr + dst_idx * hidden_size
|
507
511
|
in_data = tl.load(load_ptr + offset, mask=mask)
|
@@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
|
|
1086
1090
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
1087
1091
|
)
|
1088
1092
|
return output.t()[:m]
|
1093
|
+
|
1094
|
+
|
1095
|
+
@triton.jit
|
1096
|
+
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
|
1097
|
+
expert_id = tl.program_id(0)
|
1098
|
+
start = tl.load(seg_indptr + expert_id)
|
1099
|
+
end = tl.load(seg_indptr + expert_id + 1)
|
1100
|
+
tl.store(masked_m + expert_id, (end - start))
|
1101
|
+
|
1102
|
+
|
1103
|
+
@triton.jit
|
1104
|
+
def deepgemm_compute_src2dst_triton_kernel(
|
1105
|
+
topk_ids,
|
1106
|
+
reorder_ids,
|
1107
|
+
seg_indptr,
|
1108
|
+
src2dst,
|
1109
|
+
m_max,
|
1110
|
+
num_toks,
|
1111
|
+
BLOCK_SIZE: tl.constexpr,
|
1112
|
+
):
|
1113
|
+
pid = tl.program_id(axis=0)
|
1114
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
1115
|
+
mask = dst_id < num_toks
|
1116
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
1117
|
+
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
|
1118
|
+
expert_dst_start = tl.load(seg_indptr + expert_id)
|
1119
|
+
expert_dst_offset = dst_id - expert_dst_start
|
1120
|
+
dst_id = expert_id * m_max + expert_dst_offset
|
1121
|
+
tl.store(src2dst + src_id, dst_id, mask=mask)
|
1122
|
+
|
1123
|
+
|
1124
|
+
@triton.jit
|
1125
|
+
def fill_gateup_input_triton_kernel(
|
1126
|
+
input_ptr,
|
1127
|
+
scale_ptr,
|
1128
|
+
gateup_input_ptr,
|
1129
|
+
gateup_input_scale_ptr,
|
1130
|
+
src2dst_ptr,
|
1131
|
+
topk_ids_ptr,
|
1132
|
+
start_expert_id,
|
1133
|
+
end_expert_id,
|
1134
|
+
topk,
|
1135
|
+
m_max,
|
1136
|
+
hidden_size,
|
1137
|
+
scale_size,
|
1138
|
+
BLOCK_SIZE: tl.constexpr,
|
1139
|
+
):
|
1140
|
+
|
1141
|
+
src_idx_int32 = tl.program_id(0)
|
1142
|
+
src_idx = src_idx_int32.to(tl.int64)
|
1143
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
1144
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
1145
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
1146
|
+
scale_src_ptr = scale_ptr + src_idx * scale_size
|
1147
|
+
|
1148
|
+
vec = tl.arange(0, BLOCK_SIZE)
|
1149
|
+
for idx in range(topk):
|
1150
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
1151
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
1152
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
1153
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
1154
|
+
dst_idx = dst_idx - start_expert_id * m_max
|
1155
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
1156
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
1157
|
+
offset = start_offset + vec
|
1158
|
+
mask = offset < hidden_size
|
1159
|
+
in_data = tl.load(src_ptr + offset, mask=mask)
|
1160
|
+
tl.store(dst_ptr + offset, in_data, mask=mask)
|
1161
|
+
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
|
1162
|
+
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
|
1163
|
+
offset = start_offset + vec
|
1164
|
+
mask = offset < scale_size
|
1165
|
+
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
|
1166
|
+
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
|
1167
|
+
|
1168
|
+
|
1169
|
+
def moe_ep_deepgemm_preprocess(
|
1170
|
+
topk_ids: torch.Tensor,
|
1171
|
+
num_experts: int,
|
1172
|
+
hidden_states: torch.Tensor,
|
1173
|
+
top_k: int,
|
1174
|
+
start_expert_id,
|
1175
|
+
end_expert_id,
|
1176
|
+
block_shape,
|
1177
|
+
output_dtype: torch.dtype = torch.float8_e4m3fn,
|
1178
|
+
):
|
1179
|
+
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
1180
|
+
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
1181
|
+
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
1182
|
+
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
|
1183
|
+
|
1184
|
+
compute_seg_indptr_triton_kernel[(num_experts,)](
|
1185
|
+
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
1186
|
+
)
|
1187
|
+
|
1188
|
+
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
|
1189
|
+
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
|
1190
|
+
|
1191
|
+
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
|
1192
|
+
m_max = (hidden_states.size(0) + 255) // 256 * 256
|
1193
|
+
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
|
1194
|
+
gateup_input = torch.empty(
|
1195
|
+
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
|
1196
|
+
device=hidden_states.device,
|
1197
|
+
dtype=output_dtype,
|
1198
|
+
)
|
1199
|
+
|
1200
|
+
deepgemm_compute_src2dst_triton_kernel[grid](
|
1201
|
+
topk_ids,
|
1202
|
+
reorder_ids,
|
1203
|
+
seg_indptr,
|
1204
|
+
src2dst,
|
1205
|
+
m_max,
|
1206
|
+
topk_ids.numel(),
|
1207
|
+
BLOCK_SIZE=256,
|
1208
|
+
)
|
1209
|
+
|
1210
|
+
if block_shape is None:
|
1211
|
+
block_shape = [128, 128]
|
1212
|
+
assert len(block_shape) == 2
|
1213
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
1214
|
+
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
|
1215
|
+
|
1216
|
+
gateup_input_scale = torch.empty(
|
1217
|
+
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
|
1218
|
+
device=hidden_states.device,
|
1219
|
+
dtype=scale.dtype,
|
1220
|
+
)
|
1221
|
+
|
1222
|
+
fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
|
1223
|
+
hidden_states,
|
1224
|
+
scale,
|
1225
|
+
gateup_input,
|
1226
|
+
gateup_input_scale,
|
1227
|
+
src2dst,
|
1228
|
+
topk_ids,
|
1229
|
+
start_expert_id,
|
1230
|
+
end_expert_id,
|
1231
|
+
top_k,
|
1232
|
+
m_max,
|
1233
|
+
hidden_states.size(1),
|
1234
|
+
scale.size(1),
|
1235
|
+
BLOCK_SIZE=1024,
|
1236
|
+
)
|
1237
|
+
|
1238
|
+
return (
|
1239
|
+
m_max,
|
1240
|
+
masked_m[start_expert_id : (end_expert_id + 1)],
|
1241
|
+
expected_m,
|
1242
|
+
src2dst,
|
1243
|
+
gateup_input,
|
1244
|
+
gateup_input_scale,
|
1245
|
+
)
|