sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +0 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +62 -6
- sglang/srt/disaggregation/mini_lb.py +5 -1
- sglang/srt/disaggregation/mooncake/conn.py +32 -62
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/prefill.py +40 -4
- sglang/srt/disaggregation/utils.py +15 -0
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +114 -71
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -57
- sglang/srt/layers/quantization/fp8_utils.py +187 -262
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +3 -2
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +2 -4
- sglang/srt/managers/scheduler.py +12 -71
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +7 -2
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +20 -27
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +289 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +29 -201
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +34 -32
- sglang/srt/speculative/eagle_worker.py +4 -7
- sglang/srt/utils.py +16 -1
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -142,6 +142,16 @@ def make_local_attention_virtual_batches(
|
|
142
142
|
seqlens_k_local: Key sequence lengths for local attention
|
143
143
|
block_table_local: Block table for local attention
|
144
144
|
"""
|
145
|
+
# Adjust attention_chunk_size based on the actual sequence length
|
146
|
+
# to avoid index out of bounds errors
|
147
|
+
max_seq_len = seq_lens_np.max()
|
148
|
+
effective_chunk_size = min(attn_chunk_size, max_seq_len)
|
149
|
+
# Make sure effective_chunk_size is divisible by page_size
|
150
|
+
effective_chunk_size = (effective_chunk_size // page_size) * page_size
|
151
|
+
if effective_chunk_size < page_size:
|
152
|
+
effective_chunk_size = page_size
|
153
|
+
attn_chunk_size = effective_chunk_size
|
154
|
+
|
145
155
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
146
156
|
actual_batch_size = seq_lens_np.shape[0]
|
147
157
|
|
@@ -299,9 +309,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
299
309
|
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
300
310
|
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
301
311
|
self.page_size = model_runner.page_size
|
302
|
-
self.use_mla =
|
303
|
-
model_runner.model_config.attention_arch == AttentionArch.MLA
|
304
|
-
) and (not global_server_args_dict["disable_mla"])
|
312
|
+
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
305
313
|
self.skip_prefill = skip_prefill
|
306
314
|
|
307
315
|
self.topk = topk
|
@@ -346,6 +354,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
346
354
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
347
355
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
348
356
|
]
|
357
|
+
|
358
|
+
self._init_local_attn_metadata(metadata, device)
|
349
359
|
else:
|
350
360
|
# Normal Decode
|
351
361
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
@@ -359,6 +369,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
359
369
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
360
370
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
361
371
|
]
|
372
|
+
|
373
|
+
self._init_local_attn_metadata(metadata, device)
|
362
374
|
elif forward_batch.forward_mode.is_target_verify():
|
363
375
|
metadata.cache_seqlens_int32 = (
|
364
376
|
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
@@ -407,49 +419,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
407
419
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
408
420
|
|
409
421
|
# Setup local attention if enabled
|
410
|
-
if
|
411
|
-
self.
|
412
|
-
and forward_batch.forward_mode == ForwardMode.EXTEND
|
413
|
-
):
|
414
|
-
# Convert tensors to numpy for local attention processing
|
415
|
-
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
416
|
-
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
417
|
-
|
418
|
-
# Adjust attention_chunk_size based on the actual sequence length
|
419
|
-
# to avoid index out of bounds errors
|
420
|
-
max_seq_len = seq_lens_np.max()
|
421
|
-
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
422
|
-
# Make sure effective_chunk_size is divisible by page_size
|
423
|
-
effective_chunk_size = (
|
424
|
-
effective_chunk_size // self.page_size
|
425
|
-
) * self.page_size
|
426
|
-
if effective_chunk_size < self.page_size:
|
427
|
-
effective_chunk_size = self.page_size
|
428
|
-
|
429
|
-
# Create local attention metadata
|
430
|
-
(
|
431
|
-
seqlens_q_local_np,
|
432
|
-
cu_seqlens_q_local_np,
|
433
|
-
seqlens_k_local_np,
|
434
|
-
block_table_local,
|
435
|
-
) = make_local_attention_virtual_batches(
|
436
|
-
effective_chunk_size,
|
437
|
-
cu_seqlens_q_np,
|
438
|
-
seq_lens_np,
|
439
|
-
metadata.page_table,
|
440
|
-
self.page_size,
|
441
|
-
)
|
442
|
-
|
443
|
-
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
444
|
-
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
445
|
-
device
|
446
|
-
),
|
447
|
-
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
448
|
-
local_block_table=block_table_local,
|
449
|
-
local_max_query_len=seqlens_q_local_np.max(),
|
450
|
-
local_max_seq_len=seqlens_k_local_np.max(),
|
451
|
-
)
|
452
|
-
metadata.local_attn_metadata = local_metadata
|
422
|
+
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
423
|
+
self._init_local_attn_metadata(metadata, device)
|
453
424
|
|
454
425
|
# Encoder metadata for cross attention
|
455
426
|
if forward_batch.encoder_lens is not None:
|
@@ -706,6 +677,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
706
677
|
|
707
678
|
# Use precomputed metadata across all layers
|
708
679
|
metadata = self.forward_metadata
|
680
|
+
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
681
|
+
use_local_attention = (
|
682
|
+
self.attention_chunk_size is not None and local_attn_metadata is not None
|
683
|
+
)
|
709
684
|
|
710
685
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
711
686
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
@@ -740,33 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
|
|
740
715
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
741
716
|
)
|
742
717
|
|
743
|
-
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
744
718
|
if layer.is_cross_attention:
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
719
|
+
# Always use non-chunked logic for cross-attention
|
720
|
+
o = flash_attn_with_kvcache(
|
721
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
722
|
+
k_cache=key_cache,
|
723
|
+
v_cache=value_cache,
|
724
|
+
page_table=metadata.encoder_page_table,
|
725
|
+
cache_seqlens=metadata.encoder_lens_int32,
|
726
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
727
|
+
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
|
728
|
+
max_seqlen_q=1,
|
729
|
+
softmax_scale=layer.scaling,
|
730
|
+
causal=False,
|
731
|
+
window_size=(-1, -1),
|
732
|
+
softcap=layer.logit_cap,
|
733
|
+
k_descale=k_descale,
|
734
|
+
v_descale=v_descale,
|
735
|
+
)
|
736
|
+
elif use_local_attention:
|
737
|
+
# Use chunked (local) attention batching for self-attention
|
738
|
+
o = flash_attn_with_kvcache(
|
739
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
740
|
+
k_cache=key_cache,
|
741
|
+
v_cache=value_cache,
|
742
|
+
page_table=local_attn_metadata.local_block_table,
|
743
|
+
cache_seqlens=local_attn_metadata.local_seqused_k,
|
744
|
+
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
745
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
746
|
+
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
747
|
+
softmax_scale=layer.scaling,
|
748
|
+
causal=True,
|
749
|
+
window_size=(-1, -1),
|
750
|
+
softcap=layer.logit_cap,
|
751
|
+
k_descale=k_descale,
|
752
|
+
v_descale=v_descale,
|
753
|
+
)
|
749
754
|
else:
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
k_descale=k_descale,
|
768
|
-
v_descale=v_descale,
|
769
|
-
)
|
755
|
+
# Default: single-token self-attention
|
756
|
+
o = flash_attn_with_kvcache(
|
757
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
758
|
+
k_cache=key_cache,
|
759
|
+
v_cache=value_cache,
|
760
|
+
page_table=metadata.page_table,
|
761
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
762
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
763
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
764
|
+
max_seqlen_q=1,
|
765
|
+
softmax_scale=layer.scaling,
|
766
|
+
causal=True,
|
767
|
+
window_size=window_size,
|
768
|
+
softcap=layer.logit_cap,
|
769
|
+
k_descale=k_descale,
|
770
|
+
v_descale=v_descale,
|
771
|
+
)
|
770
772
|
else:
|
771
773
|
# Do absorbed multi-latent attention
|
772
774
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
@@ -988,6 +990,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
988
990
|
seq_lens = seq_lens[:bs]
|
989
991
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
990
992
|
req_pool_indices = req_pool_indices[:bs]
|
993
|
+
device = seq_lens.device
|
994
|
+
|
991
995
|
if forward_mode.is_decode_or_idle():
|
992
996
|
metadata = self.decode_cuda_graph_metadata[bs]
|
993
997
|
|
@@ -1014,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1014
1018
|
]
|
1015
1019
|
|
1016
1020
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1021
|
+
|
1022
|
+
self._init_local_attn_metadata(metadata, device)
|
1017
1023
|
else:
|
1018
1024
|
# Normal Decode
|
1019
1025
|
max_len = seq_lens_cpu.max().item()
|
@@ -1037,6 +1043,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1037
1043
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1038
1044
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
1039
1045
|
|
1046
|
+
self._init_local_attn_metadata(metadata, device)
|
1040
1047
|
elif forward_mode.is_target_verify():
|
1041
1048
|
metadata = self.target_verify_metadata[bs]
|
1042
1049
|
metadata.cache_seqlens_int32.copy_(
|
@@ -1087,6 +1094,42 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1087
1094
|
"""Get the fill value for sequence length in CUDA graph."""
|
1088
1095
|
return 0
|
1089
1096
|
|
1097
|
+
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
1098
|
+
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
1099
|
+
if self.attention_chunk_size is None:
|
1100
|
+
metadata.local_attn_metadata = None
|
1101
|
+
return
|
1102
|
+
|
1103
|
+
cu_seqlens_q = metadata.cu_seqlens_q
|
1104
|
+
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
1105
|
+
page_table = metadata.page_table
|
1106
|
+
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
1107
|
+
metadata.local_attn_metadata = None
|
1108
|
+
return
|
1109
|
+
|
1110
|
+
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
1111
|
+
seq_lens_np = cache_seqlens_int32.cpu().numpy()
|
1112
|
+
(
|
1113
|
+
seqlens_q_local_np,
|
1114
|
+
cu_seqlens_q_local_np,
|
1115
|
+
seqlens_k_local_np,
|
1116
|
+
block_table_local,
|
1117
|
+
) = make_local_attention_virtual_batches(
|
1118
|
+
self.attention_chunk_size,
|
1119
|
+
cu_seqlens_q_np,
|
1120
|
+
seq_lens_np,
|
1121
|
+
page_table,
|
1122
|
+
self.page_size,
|
1123
|
+
)
|
1124
|
+
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1125
|
+
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
1126
|
+
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
1127
|
+
local_block_table=block_table_local.to(device),
|
1128
|
+
local_max_query_len=int(seqlens_q_local_np.max()),
|
1129
|
+
local_max_seq_len=int(seqlens_k_local_np.max()),
|
1130
|
+
)
|
1131
|
+
metadata.local_attn_metadata = local_metadata
|
1132
|
+
|
1090
1133
|
|
1091
1134
|
class FlashAttentionMultiStepBackend:
|
1092
1135
|
|
@@ -100,8 +100,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
100
100
|
self.num_wrappers = 1
|
101
101
|
self.dispatch_reason = None
|
102
102
|
|
103
|
-
# Qwen2 models require higher flashinfer workspace size
|
104
|
-
if
|
103
|
+
# Qwen2/Qwen3 models require higher flashinfer workspace size
|
104
|
+
if (
|
105
|
+
"Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
|
106
|
+
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
|
107
|
+
):
|
105
108
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
106
109
|
|
107
110
|
# Allocate buffers
|
@@ -6,6 +6,7 @@ import torch
|
|
6
6
|
from torch.nn.functional import scaled_dot_product_attention
|
7
7
|
|
8
8
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
9
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
10
11
|
|
11
12
|
if TYPE_CHECKING:
|
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
202
203
|
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
203
204
|
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
204
205
|
|
206
|
+
causal = True
|
207
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
208
|
+
causal = False
|
209
|
+
|
205
210
|
self._run_sdpa_forward_extend(
|
206
211
|
q_,
|
207
212
|
o_,
|
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
214
219
|
forward_batch.extend_seq_lens,
|
215
220
|
scaling=layer.scaling,
|
216
221
|
enable_gqa=use_gqa,
|
217
|
-
causal=
|
222
|
+
causal=causal,
|
218
223
|
)
|
219
224
|
return o
|
220
225
|
|
@@ -10,6 +10,7 @@ import triton.language as tl
|
|
10
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
11
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
12
12
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
13
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
14
15
|
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
15
16
|
|
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
528
529
|
layer, forward_batch.out_cache_loc, k, v
|
529
530
|
)
|
530
531
|
|
532
|
+
causal = True
|
533
|
+
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
534
|
+
causal = False
|
535
|
+
|
531
536
|
self.extend_attention_fwd(
|
532
537
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
533
538
|
k.contiguous(),
|
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
539
544
|
self.forward_metadata.kv_indptr,
|
540
545
|
self.forward_metadata.kv_indices,
|
541
546
|
self.forward_metadata.custom_mask,
|
547
|
+
causal,
|
542
548
|
self.forward_metadata.mask_indptr,
|
543
549
|
self.forward_metadata.max_extend_len,
|
544
550
|
layer.scaling,
|
@@ -74,6 +74,7 @@ def _fwd_kernel(
|
|
74
74
|
BLOCK_M: tl.constexpr,
|
75
75
|
BLOCK_N: tl.constexpr,
|
76
76
|
USE_CUSTOM_MASK: tl.constexpr,
|
77
|
+
IS_CAUSAL: tl.constexpr,
|
77
78
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
78
79
|
STORE_TRANSPOSE: tl.constexpr,
|
79
80
|
):
|
@@ -129,6 +130,7 @@ def _fwd_kernel(
|
|
129
130
|
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
130
131
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
131
132
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
133
|
+
|
132
134
|
offs_kv_loc = tl.load(
|
133
135
|
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
134
136
|
)
|
@@ -196,7 +198,11 @@ def _fwd_kernel(
|
|
196
198
|
|
197
199
|
# stage 2: compute the triangle part
|
198
200
|
|
199
|
-
cur_block_m_end =
|
201
|
+
cur_block_m_end = (
|
202
|
+
cur_seq_len_extend
|
203
|
+
if not IS_CAUSAL
|
204
|
+
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
205
|
+
)
|
200
206
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
201
207
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
202
208
|
mask_n = (start_n + offs_n) < cur_block_m_end
|
@@ -243,12 +249,15 @@ def _fwd_kernel(
|
|
243
249
|
)
|
244
250
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
245
251
|
qk = tl.where(custom_mask, qk, float("-inf"))
|
246
|
-
|
252
|
+
elif IS_CAUSAL:
|
247
253
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
248
254
|
start_n + offs_n[None, :]
|
249
255
|
)
|
250
256
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
251
257
|
qk = tl.where(mask_causual, qk, float("-inf"))
|
258
|
+
else:
|
259
|
+
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
260
|
+
qk = tl.where(mask_non_causal, qk, float("-inf"))
|
252
261
|
|
253
262
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
254
263
|
re_scale = tl.exp(e_max - n_e_max)
|
@@ -299,6 +308,7 @@ def extend_attention_fwd(
|
|
299
308
|
kv_indptr,
|
300
309
|
kv_indices,
|
301
310
|
custom_mask,
|
311
|
+
is_causal,
|
302
312
|
mask_indptr,
|
303
313
|
max_len_extend,
|
304
314
|
sm_scale=None,
|
@@ -411,6 +421,7 @@ def extend_attention_fwd(
|
|
411
421
|
Lq=Lq,
|
412
422
|
Lv=Lv,
|
413
423
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
424
|
+
IS_CAUSAL=is_causal,
|
414
425
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
415
426
|
STORE_TRANSPOSE=_is_hip,
|
416
427
|
num_warps=num_warps,
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
|
+
from sglang.srt.custom_op import CustomOp
|
22
23
|
from sglang.srt.utils import is_cuda_available
|
23
24
|
|
24
25
|
_is_cuda = is_cuda_available()
|
@@ -31,7 +32,6 @@ if _is_cuda:
|
|
31
32
|
rmsnorm,
|
32
33
|
)
|
33
34
|
|
34
|
-
from sglang.srt.custom_op import CustomOp
|
35
35
|
|
36
36
|
logger = logging.getLogger(__name__)
|
37
37
|
|
sglang/srt/layers/linear.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
2
2
|
|
3
|
+
import itertools
|
3
4
|
import logging
|
4
5
|
from abc import abstractmethod
|
5
6
|
from typing import Dict, List, Optional, Tuple
|
@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
|
61
62
|
|
62
63
|
|
63
64
|
def adjust_bitsandbytes_4bit_shard(
|
64
|
-
param: Parameter,
|
65
|
+
param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
65
66
|
) -> Tuple[int, int]:
|
66
67
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
67
68
|
|
68
|
-
total, _ =
|
69
|
-
orig_offset, orig_size =
|
69
|
+
total, _ = shard_offsets["total"]
|
70
|
+
orig_offset, orig_size = shard_offsets[loaded_shard_id]
|
70
71
|
|
71
72
|
quantized_total = param.data.shape[0]
|
72
73
|
quantized_offset = orig_offset * quantized_total // total
|
@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
573
574
|
shard_offsets.append((i, current_shard_offset, output_size))
|
574
575
|
current_shard_offset += output_size
|
575
576
|
packed_dim = getattr(param, "packed_dim", None)
|
577
|
+
|
578
|
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
576
579
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
577
580
|
# Special case for Quantization.
|
578
581
|
# If quantized, we need to adjust the offset and size to account
|
@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
585
588
|
param, shard_size, shard_offset
|
586
589
|
)
|
587
590
|
|
591
|
+
if use_bitsandbytes_4bit:
|
592
|
+
index = list(itertools.accumulate([0] + self.output_sizes))
|
593
|
+
orig_offsets = {
|
594
|
+
str(i): (index[i], size)
|
595
|
+
for i, size in enumerate(self.output_sizes)
|
596
|
+
}
|
597
|
+
orig_offsets["total"] = (self.output_size, 0)
|
598
|
+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
599
|
+
param, orig_offsets, str(shard_id)
|
600
|
+
)
|
601
|
+
|
588
602
|
loaded_weight_shard = loaded_weight.narrow(
|
589
603
|
output_dim, shard_offset, shard_size
|
590
604
|
)
|
@@ -2,6 +2,7 @@ import logging
|
|
2
2
|
from typing import Callable, List, Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
|
+
from torch.nn import Module
|
5
6
|
|
6
7
|
try:
|
7
8
|
from deep_gemm import (
|
@@ -13,8 +14,6 @@ try:
|
|
13
14
|
except ImportError:
|
14
15
|
use_deep_gemm = False
|
15
16
|
|
16
|
-
from torch.nn import Module
|
17
|
-
|
18
17
|
from sglang.srt.custom_op import CustomOp
|
19
18
|
from sglang.srt.distributed import (
|
20
19
|
get_tensor_model_parallel_rank,
|
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
|
|
37
36
|
QuantizeMethodBase,
|
38
37
|
)
|
39
38
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
39
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
41
|
-
from sglang.srt.utils import DeepEPMode,
|
41
|
+
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
42
42
|
|
43
|
-
|
43
|
+
_is_hip = is_hip()
|
44
44
|
|
45
|
-
if
|
46
|
-
from
|
47
|
-
else:
|
48
|
-
from vllm import _custom_ops as vllm_ops
|
45
|
+
if _is_hip:
|
46
|
+
from vllm._custom_ops import scaled_fp8_quant
|
49
47
|
|
50
48
|
logger = logging.getLogger(__name__)
|
51
49
|
|
52
|
-
_is_hip = is_hip()
|
53
|
-
|
54
|
-
_buffer = None
|
55
|
-
|
56
50
|
|
57
51
|
class GroupedGemmRunner(torch.nn.Module):
|
58
52
|
flashinfer_gemm_warpper = None
|
@@ -142,6 +136,7 @@ class EPMoE(torch.nn.Module):
|
|
142
136
|
correction_bias: Optional[torch.Tensor] = None,
|
143
137
|
custom_routing_function: Optional[Callable] = None,
|
144
138
|
activation: str = "silu",
|
139
|
+
routed_scaling_factor: Optional[float] = None,
|
145
140
|
):
|
146
141
|
super().__init__()
|
147
142
|
|
@@ -170,6 +165,7 @@ class EPMoE(torch.nn.Module):
|
|
170
165
|
self.correction_bias = correction_bias
|
171
166
|
self.custom_routing_function = custom_routing_function
|
172
167
|
self.activation = activation
|
168
|
+
self.routed_scaling_factor = routed_scaling_factor
|
173
169
|
|
174
170
|
if quant_config is None:
|
175
171
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
@@ -221,6 +217,7 @@ class EPMoE(torch.nn.Module):
|
|
221
217
|
num_expert_group=self.num_expert_group,
|
222
218
|
correction_bias=self.correction_bias,
|
223
219
|
custom_routing_function=self.custom_routing_function,
|
220
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
224
221
|
)
|
225
222
|
|
226
223
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -740,20 +737,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
740
737
|
)
|
741
738
|
|
742
739
|
for expert in range(layer.num_experts_per_partition):
|
743
|
-
|
744
|
-
w13_weight[expert, :, :]
|
745
|
-
|
746
|
-
|
747
|
-
w2_weight[expert, :, :]
|
748
|
-
|
749
|
-
)
|
750
|
-
else:
|
751
|
-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
752
|
-
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
753
|
-
)
|
754
|
-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
755
|
-
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
756
|
-
)
|
740
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
741
|
+
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
742
|
+
)
|
743
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
744
|
+
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
745
|
+
)
|
757
746
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
758
747
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
759
748
|
return
|
@@ -986,9 +975,6 @@ class DeepEPMoE(EPMoE):
|
|
986
975
|
):
|
987
976
|
assert self.quant_method is not None
|
988
977
|
assert self.activation == "silu"
|
989
|
-
assert (
|
990
|
-
hidden_states_fp8[0].size(0) % 4 == 0
|
991
|
-
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
|
992
978
|
|
993
979
|
# GroupGemm-0
|
994
980
|
num_groups, m, k = hidden_states_fp8[0].size()
|
@@ -26,6 +26,7 @@ def fused_moe_forward_native(
|
|
26
26
|
apply_router_weight_on_input: bool = False,
|
27
27
|
inplace: bool = True,
|
28
28
|
no_combine: bool = False,
|
29
|
+
routed_scaling_factor: Optional[float] = None,
|
29
30
|
) -> torch.Tensor:
|
30
31
|
|
31
32
|
if apply_router_weight_on_input:
|
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
|
|
41
42
|
num_expert_group=num_expert_group,
|
42
43
|
custom_routing_function=custom_routing_function,
|
43
44
|
correction_bias=correction_bias,
|
45
|
+
routed_scaling_factor=routed_scaling_factor,
|
44
46
|
torch_native=True,
|
45
47
|
)
|
46
48
|
|
@@ -71,6 +73,7 @@ def moe_forward_native(
|
|
71
73
|
custom_routing_function: Optional[Callable] = None,
|
72
74
|
correction_bias: Optional[torch.Tensor] = None,
|
73
75
|
activation: str = "silu",
|
76
|
+
routed_scaling_factor: Optional[float] = None,
|
74
77
|
) -> torch.Tensor:
|
75
78
|
|
76
79
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
@@ -86,6 +89,7 @@ def moe_forward_native(
|
|
86
89
|
custom_routing_function=custom_routing_function,
|
87
90
|
correction_bias=correction_bias,
|
88
91
|
torch_native=True,
|
92
|
+
routed_scaling_factor=routed_scaling_factor,
|
89
93
|
)
|
90
94
|
|
91
95
|
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
@@ -13,6 +13,7 @@ import triton
|
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
15
|
from sglang.srt.layers.moe.topk import select_experts
|
16
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
16
17
|
from sglang.srt.utils import (
|
17
18
|
direct_register_custom_op,
|
18
19
|
get_bool_env_var,
|
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
|
|
22
23
|
)
|
23
24
|
|
24
25
|
_is_hip = is_hip()
|
25
|
-
|
26
|
-
|
27
|
-
logger = logging.getLogger(__name__)
|
28
|
-
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
29
|
-
|
30
|
-
enable_moe_align_block_size_triton = bool(
|
31
|
-
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
32
|
-
)
|
33
|
-
|
34
26
|
_is_cuda = is_cuda()
|
35
27
|
|
36
28
|
if _is_cuda:
|
37
29
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
38
|
-
|
39
|
-
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
40
30
|
else:
|
41
31
|
from vllm import _custom_ops as vllm_ops
|
32
|
+
from vllm._custom_ops import scaled_fp8_quant
|
42
33
|
|
43
34
|
if _is_cuda or _is_hip:
|
44
35
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
45
36
|
|
46
37
|
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
40
|
+
enable_moe_align_block_size_triton = bool(
|
41
|
+
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
42
|
+
)
|
43
|
+
|
44
|
+
|
47
45
|
@triton.jit
|
48
46
|
def write_zeros_to_output(
|
49
47
|
c_ptr,
|
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
|
|
770
768
|
# activation tensor-wise fp8 quantization, dynamic or static
|
771
769
|
padded_size = padding_size
|
772
770
|
# activations apply per-token quantization when weights apply per-channel quantization by default
|
773
|
-
|
774
|
-
A, A_scale =
|
775
|
-
|
776
|
-
)
|
777
|
-
else:
|
778
|
-
A, A_scale = vllm_ops.scaled_fp8_quant(
|
779
|
-
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
780
|
-
)
|
771
|
+
A, A_scale = scaled_fp8_quant(
|
772
|
+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
773
|
+
)
|
781
774
|
else:
|
782
775
|
# activation block-wise fp8 quantization
|
783
776
|
assert len(block_shape) == 2
|
@@ -1554,6 +1547,7 @@ def fused_moe(
|
|
1554
1547
|
a2_scale: Optional[torch.Tensor] = None,
|
1555
1548
|
block_shape: Optional[List[int]] = None,
|
1556
1549
|
no_combine: bool = False,
|
1550
|
+
routed_scaling_factor: Optional[float] = None,
|
1557
1551
|
) -> torch.Tensor:
|
1558
1552
|
"""
|
1559
1553
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -1608,6 +1602,7 @@ def fused_moe(
|
|
1608
1602
|
topk_group=topk_group,
|
1609
1603
|
num_expert_group=num_expert_group,
|
1610
1604
|
custom_routing_function=custom_routing_function,
|
1605
|
+
routed_scaling_factor=routed_scaling_factor,
|
1611
1606
|
)
|
1612
1607
|
|
1613
1608
|
return fused_experts(
|