sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -9,18 +9,89 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
|
9
9
|
|
10
10
|
@triton.jit
|
11
11
|
def create_flashinfer_kv_indices_triton(
|
12
|
-
req_to_token_ptr,
|
12
|
+
req_to_token_ptr,
|
13
13
|
req_pool_indices_ptr,
|
14
14
|
page_kernel_lens_ptr,
|
15
15
|
kv_indptr,
|
16
16
|
kv_start_idx,
|
17
17
|
kv_indices_ptr,
|
18
18
|
req_to_token_ptr_stride: tl.constexpr,
|
19
|
+
PAGE_SIZE: tl.constexpr = 1,
|
19
20
|
):
|
21
|
+
"""
|
22
|
+
Create KV indices for FlashInfer attention backend.
|
23
|
+
|
24
|
+
This Triton kernel builds a lookup table that maps from logical request/token
|
25
|
+
coordinates to physical token locations in the global KV cache pool. It's used
|
26
|
+
by FlashInfer attention backends to efficiently access scattered KV cache data.
|
27
|
+
|
28
|
+
The kernel processes each request in parallel and converts the req_to_token
|
29
|
+
lookup table into a flat list of token indices that can be used by attention kernels.
|
30
|
+
|
31
|
+
general idea:
|
32
|
+
blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
|
33
|
+
fixed number of pages)]
|
34
|
+
max_pages = max_context_len / PAGED_SIZE
|
35
|
+
kv_indices_ptr will store the flat list of the pages used by each request
|
36
|
+
Args:
|
37
|
+
Inputs Arguments (non mutable):
|
38
|
+
|
39
|
+
req_to_token_ptr: Request to token location look up table
|
40
|
+
Shape: [max_batch, max_context_len]
|
41
|
+
req_pool_indices_ptr: Request to pool index look up table. Each request uses
|
42
|
+
one pool.
|
43
|
+
Shape: [batch_size]
|
44
|
+
page_kernel_lens_ptr: sequence lengths per request
|
45
|
+
Shape: [batch_size]
|
46
|
+
kv_indptr: Should be computed based on number of pages used by each request.
|
47
|
+
It is used by flashinfer attention kernels to index into the kv_indices_ptr.
|
48
|
+
per request.
|
49
|
+
Shape: [batch_size + 1]
|
50
|
+
kv_indptr[i] = start index in kv_indices for request i
|
51
|
+
kv_start_idx: Pointer to array containing start offsets for each request in SGL.
|
52
|
+
Can be None. If provided, adds offset to token positions.
|
53
|
+
|
54
|
+
req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
|
55
|
+
Equal to max_context_len.
|
56
|
+
|
57
|
+
PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
|
58
|
+
|
59
|
+
Outputs:
|
60
|
+
kv_indices_ptr: Pointer to output array where KV indices will be stored.
|
61
|
+
Shape:[total-num-pages],
|
62
|
+
where total_num_pages = sum(seq_lens // PAGED_SIZE)
|
63
|
+
|
64
|
+
Example:
|
65
|
+
If we have:
|
66
|
+
- req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
|
67
|
+
- page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
|
68
|
+
- req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
|
69
|
+
in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
|
70
|
+
|
71
|
+
The kernel will output:
|
72
|
+
If PAGE_SIZE = 1:
|
73
|
+
packed
|
74
|
+
- kv_indptr (passed in as input arg): [0,3,5]
|
75
|
+
- kv_indices = [10, 11, 12, 20, 21]
|
76
|
+
padded - max_pages is 10 tokens per req
|
77
|
+
- kv_indptr (passed in as input arg): [0,10, 20]
|
78
|
+
- kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
|
79
|
+
20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
|
80
|
+
|
81
|
+
If PAGE_SIZE = 2
|
82
|
+
packed:
|
83
|
+
- kv_indptr (passed in as input arg): [0,3,4]
|
84
|
+
- kv_indices = [5,6,10]
|
85
|
+
padded: max_pages is 4
|
86
|
+
- kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
|
87
|
+
- kv_indices = [5, 6, -1, -1,
|
88
|
+
10, -1, -1, -1]
|
89
|
+
This allows attention kernels to directly access the correct KV cache
|
90
|
+
entries for each request's tokens.
|
91
|
+
"""
|
20
92
|
BLOCK_SIZE: tl.constexpr = 512
|
93
|
+
NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
|
21
94
|
pid = tl.program_id(axis=0)
|
22
|
-
|
23
|
-
# find the req pool idx, this is for batch to token
|
24
95
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
25
96
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
26
97
|
|
@@ -31,19 +102,27 @@ def create_flashinfer_kv_indices_triton(
|
|
31
102
|
kv_end = kv_start
|
32
103
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
33
104
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
+
|
43
|
-
|
44
|
-
|
105
|
+
kv_range = kv_end - kv_start
|
106
|
+
num_pages = tl.cdiv(kv_range, PAGE_SIZE)
|
107
|
+
num_loops = tl.cdiv(kv_range, BLOCK_SIZE)
|
108
|
+
req_to_token_block_start = (
|
109
|
+
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start
|
110
|
+
)
|
111
|
+
for i in range(num_loops):
|
112
|
+
token_offsets_in_block = (
|
113
|
+
tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK
|
114
|
+
) * PAGE_SIZE
|
115
|
+
page_offsets_in_block = token_offsets_in_block // PAGE_SIZE
|
116
|
+
valid_tokens = token_offsets_in_block < kv_range
|
117
|
+
valid_pages = page_offsets_in_block < num_pages
|
118
|
+
token_numbers = tl.load(
|
119
|
+
req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
|
120
|
+
)
|
121
|
+
tl.store(
|
122
|
+
kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
|
123
|
+
token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
|
124
|
+
mask=valid_pages,
|
45
125
|
)
|
46
|
-
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
47
126
|
|
48
127
|
|
49
128
|
@triton.jit
|
@@ -12,7 +12,12 @@ import torch.nn.functional as F
|
|
12
12
|
from einops import rearrange
|
13
13
|
|
14
14
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
15
|
-
from sglang.srt.utils import
|
15
|
+
from sglang.srt.utils import (
|
16
|
+
get_device_capability,
|
17
|
+
is_blackwell,
|
18
|
+
is_cuda,
|
19
|
+
print_info_once,
|
20
|
+
)
|
16
21
|
|
17
22
|
_is_cuda = is_cuda()
|
18
23
|
|
@@ -20,7 +25,6 @@ if _is_cuda:
|
|
20
25
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
21
26
|
|
22
27
|
from sglang.srt.distributed import (
|
23
|
-
parallel_state,
|
24
28
|
split_tensor_along_last_dim,
|
25
29
|
tensor_model_parallel_all_gather,
|
26
30
|
)
|
@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
|
|
402
406
|
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
403
407
|
)
|
404
408
|
|
405
|
-
#
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
qkv_backend = "sdpa"
|
409
|
+
# Select attention backend via a unified method
|
410
|
+
_passed_backend = qkv_backend
|
411
|
+
qkv_backend = self._determine_attention_backend(_passed_backend)
|
412
|
+
if (
|
413
|
+
global_server_args_dict["mm_attention_backend"] is None
|
414
|
+
and _passed_backend is None
|
415
|
+
):
|
413
416
|
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
414
|
-
else:
|
415
|
-
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
416
|
-
|
417
417
|
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
418
418
|
|
419
419
|
self.customized_position_embedding_applier = (
|
@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
|
|
461
461
|
prefix=add_prefix("proj", prefix),
|
462
462
|
)
|
463
463
|
|
464
|
+
def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:
|
465
|
+
"""Decide the multimodal attention backend string.
|
466
|
+
|
467
|
+
Priority: server args override > constructor arg > platform default.
|
468
|
+
|
469
|
+
Platform defaults:
|
470
|
+
- CUDA: "triton_attn"
|
471
|
+
- Non-CUDA: "sdpa"
|
472
|
+
"""
|
473
|
+
override_backend = global_server_args_dict["mm_attention_backend"]
|
474
|
+
if override_backend is not None:
|
475
|
+
backend = override_backend
|
476
|
+
elif passed_backend is not None:
|
477
|
+
backend = passed_backend
|
478
|
+
elif is_cuda():
|
479
|
+
major, minor = get_device_capability()
|
480
|
+
if major == 9:
|
481
|
+
backend = "fa3"
|
482
|
+
else:
|
483
|
+
backend = "triton_attn"
|
484
|
+
else:
|
485
|
+
backend = "sdpa"
|
486
|
+
if backend == "fa3" and is_blackwell():
|
487
|
+
raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
|
488
|
+
|
489
|
+
return backend
|
490
|
+
|
464
491
|
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
465
492
|
"""apply qk norm for internvl vit attn"""
|
466
493
|
q = q.flatten(1, 2)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
"""Utility functions for vision attention layers."""
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
6
|
+
|
7
|
+
|
8
|
+
def update_vit_attn_dummy_heads_config(config):
|
9
|
+
"""Update HF config to ensure vision attention num_attention_heads is divisible by tp_size"""
|
10
|
+
tp_size = get_attention_tp_size()
|
11
|
+
num_heads = getattr(
|
12
|
+
config.vision_config,
|
13
|
+
"num_heads",
|
14
|
+
getattr(config.vision_config, "num_attention_heads", None),
|
15
|
+
)
|
16
|
+
head_dim = config.vision_config.hidden_size // num_heads
|
17
|
+
num_dummy_heads = 0
|
18
|
+
|
19
|
+
if num_heads % tp_size != 0:
|
20
|
+
num_dummy_heads = ((num_heads + tp_size - 1) // tp_size) * tp_size - num_heads
|
21
|
+
|
22
|
+
setattr(config.vision_config, "head_dim", head_dim)
|
23
|
+
setattr(config.vision_config, "num_dummy_heads", num_dummy_heads)
|
24
|
+
|
25
|
+
|
26
|
+
def pad_vit_attn_dummy_heads(config, name: str, loaded_weight: torch.Tensor):
|
27
|
+
"""Pad attention qkv weights for dummy heads"""
|
28
|
+
num_dummy_heads = config.vision_config.num_dummy_heads
|
29
|
+
if num_dummy_heads == 0:
|
30
|
+
return loaded_weight
|
31
|
+
head_dim = config.vision_config.head_dim
|
32
|
+
|
33
|
+
if "attn.qkv_proj" in name:
|
34
|
+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
35
|
+
if name.endswith(".weight"):
|
36
|
+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
37
|
+
elif name.endswith(".bias"):
|
38
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
39
|
+
else:
|
40
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
41
|
+
pad_func = lambda x: torch.cat(
|
42
|
+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
43
|
+
).flatten(0, 1)
|
44
|
+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
45
|
+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
46
|
+
elif any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
|
47
|
+
if name.endswith(".weight"):
|
48
|
+
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
|
49
|
+
elif name.endswith(".bias"):
|
50
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
51
|
+
else:
|
52
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
53
|
+
padded_weight = loaded_weight.new_zeros(dummy_shape)
|
54
|
+
loaded_weight = torch.cat(
|
55
|
+
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
|
56
|
+
).flatten(0, 1)
|
57
|
+
elif "attn.proj.weight" in name:
|
58
|
+
padded_weight = loaded_weight.new_zeros(
|
59
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
60
|
+
)
|
61
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
62
|
+
elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
63
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
64
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
65
|
+
return loaded_weight
|
@@ -17,7 +17,7 @@ from enum import Enum, auto
|
|
17
17
|
from functools import partial
|
18
18
|
from typing import Dict, Optional
|
19
19
|
|
20
|
-
import torch
|
20
|
+
import torch
|
21
21
|
|
22
22
|
from sglang.srt.distributed import (
|
23
23
|
get_tensor_model_parallel_world_size,
|
@@ -34,6 +34,11 @@ from sglang.srt.layers.dp_attention import (
|
|
34
34
|
get_attention_tp_size,
|
35
35
|
get_global_dp_buffer,
|
36
36
|
get_local_dp_buffer,
|
37
|
+
is_dp_attention_enabled,
|
38
|
+
)
|
39
|
+
from sglang.srt.layers.moe import (
|
40
|
+
get_moe_a2a_backend,
|
41
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
37
42
|
)
|
38
43
|
from sglang.srt.layers.utils import is_sm100_supported
|
39
44
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -43,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
|
|
43
48
|
_is_flashinfer_available = is_flashinfer_available()
|
44
49
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
45
50
|
|
51
|
+
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
52
|
+
|
46
53
|
|
47
54
|
class ScatterMode(Enum):
|
48
55
|
"""
|
@@ -111,7 +118,11 @@ class LayerScatterModes:
|
|
111
118
|
if context.is_layer_sparse:
|
112
119
|
return (
|
113
120
|
ScatterMode.SCATTERED
|
114
|
-
if
|
121
|
+
if (
|
122
|
+
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
|
123
|
+
not get_moe_a2a_backend().is_none()
|
124
|
+
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
125
|
+
)
|
115
126
|
else ScatterMode.FULL
|
116
127
|
)
|
117
128
|
else:
|
@@ -154,11 +165,13 @@ class LayerCommunicator:
|
|
154
165
|
post_attention_layernorm: torch.nn.Module,
|
155
166
|
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
156
167
|
allow_reduce_scatter: bool = False,
|
168
|
+
is_last_layer: bool = False,
|
157
169
|
):
|
158
170
|
self.layer_scatter_modes = layer_scatter_modes
|
159
171
|
self.input_layernorm = input_layernorm
|
160
172
|
self.post_attention_layernorm = post_attention_layernorm
|
161
173
|
self.allow_reduce_scatter = allow_reduce_scatter
|
174
|
+
self.is_last_layer = is_last_layer
|
162
175
|
|
163
176
|
self._context = CommunicateContext.init_new()
|
164
177
|
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
@@ -256,6 +269,41 @@ class LayerCommunicator:
|
|
256
269
|
and forward_batch.dp_padding_mode.is_max_len()
|
257
270
|
)
|
258
271
|
|
272
|
+
def should_fuse_mlp_allreduce_with_next_layer(
|
273
|
+
self, forward_batch: ForwardBatch
|
274
|
+
) -> bool:
|
275
|
+
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
|
276
|
+
if (
|
277
|
+
is_dp_attention_enabled()
|
278
|
+
and speculative_algo is not None
|
279
|
+
and speculative_algo.is_eagle()
|
280
|
+
):
|
281
|
+
return False
|
282
|
+
|
283
|
+
batch_size = (
|
284
|
+
forward_batch.input_ids.shape[0]
|
285
|
+
if hasattr(forward_batch, "input_ids")
|
286
|
+
else 0
|
287
|
+
)
|
288
|
+
if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
|
289
|
+
return False
|
290
|
+
|
291
|
+
static_conditions_met = (
|
292
|
+
(not self.is_last_layer)
|
293
|
+
and (self._context.tp_size > 1)
|
294
|
+
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
295
|
+
and _is_flashinfer_available
|
296
|
+
)
|
297
|
+
|
298
|
+
if not static_conditions_met:
|
299
|
+
return False
|
300
|
+
|
301
|
+
return (
|
302
|
+
batch_size > 0
|
303
|
+
and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
|
304
|
+
and (not self.is_last_layer)
|
305
|
+
)
|
306
|
+
|
259
307
|
|
260
308
|
@dataclass
|
261
309
|
class CommunicateContext:
|
@@ -382,7 +430,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
382
430
|
)
|
383
431
|
|
384
432
|
raise NotImplementedError(
|
385
|
-
f"{hidden_states_input_mode=} {residual_input_mode=} {
|
433
|
+
f"{hidden_states_input_mode=} {residual_input_mode=} {hidden_states_output_mode=} {residual_output_mode=}"
|
386
434
|
)
|
387
435
|
|
388
436
|
@staticmethod
|
@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
|
|
72
72
|
_device: torch.device
|
73
73
|
_global_dp_buffer_len: int
|
74
74
|
_local_dp_buffer_len: int
|
75
|
+
_global_num_tokens: Optional[List[int]]
|
75
76
|
|
76
77
|
@classmethod
|
77
78
|
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
|
|
80
81
|
cls._device = device
|
81
82
|
|
82
83
|
@classmethod
|
83
|
-
def set_dp_buffer_len(
|
84
|
+
def set_dp_buffer_len(
|
85
|
+
cls,
|
86
|
+
global_dp_buffer_len: int,
|
87
|
+
local_dp_buffer_len: int,
|
88
|
+
global_num_tokens: Optional[List[int]] = None,
|
89
|
+
):
|
84
90
|
cls._global_dp_buffer_len = global_dp_buffer_len
|
85
91
|
cls._local_dp_buffer_len = local_dp_buffer_len
|
92
|
+
cls._global_num_tokens = global_num_tokens
|
86
93
|
|
87
94
|
@classmethod
|
88
95
|
def get_global_dp_buffer(cls) -> torch.Tensor:
|
@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
|
|
108
115
|
def get_local_dp_buffer_len(cls) -> int:
|
109
116
|
return cls._local_dp_buffer_len
|
110
117
|
|
118
|
+
@classmethod
|
119
|
+
def get_dp_global_num_tokens(cls) -> List[int]:
|
120
|
+
return cls._global_num_tokens
|
121
|
+
|
111
122
|
|
112
|
-
def set_dp_buffer_len(
|
123
|
+
def set_dp_buffer_len(
|
124
|
+
global_dp_buffer_len: int,
|
125
|
+
local_dp_buffer_len: int,
|
126
|
+
global_num_tokens: Optional[List[int]] = None,
|
127
|
+
):
|
113
128
|
_DpGatheredBufferWrapper.set_dp_buffer_len(
|
114
|
-
global_dp_buffer_len, local_dp_buffer_len
|
129
|
+
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
|
115
130
|
)
|
116
131
|
|
117
132
|
|
@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
|
|
131
146
|
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
|
132
147
|
|
133
148
|
|
149
|
+
def get_dp_global_num_tokens() -> List[int]:
|
150
|
+
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
|
151
|
+
|
152
|
+
|
134
153
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
135
154
|
if not enable_dp_attention:
|
136
155
|
return tp_rank, tp_size, 0
|
@@ -215,7 +234,7 @@ def initialize_dp_attention(
|
|
215
234
|
_DpGatheredBufferWrapper.set_metadata(
|
216
235
|
hidden_size=model_config.hidden_size,
|
217
236
|
dtype=model_config.dtype,
|
218
|
-
device=torch.device(
|
237
|
+
device=torch.device(server_args.device),
|
219
238
|
)
|
220
239
|
|
221
240
|
|
sglang/srt/layers/elementwise.py
CHANGED
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
|
|
486
486
|
return out_hidden_states, out_scales
|
487
487
|
else:
|
488
488
|
return out_hidden_states, None
|
489
|
+
|
490
|
+
|
491
|
+
# silu on first half of vector
|
492
|
+
@triton.jit
|
493
|
+
def silu_and_mul_kernel(
|
494
|
+
out_hidden_states_ptr, # (bs, hidden_dim)
|
495
|
+
out_scales_ptr, # (bs,)
|
496
|
+
hidden_states_ptr, # (bs, hidden_dim * 2)
|
497
|
+
quant_max: tl.constexpr,
|
498
|
+
static_scale: tl.constexpr,
|
499
|
+
hidden_dim: tl.constexpr, # the output hidden_dim
|
500
|
+
BLOCK_SIZE: tl.constexpr,
|
501
|
+
):
|
502
|
+
pid = tl.program_id(axis=0)
|
503
|
+
|
504
|
+
input_start = pid * hidden_dim * 2
|
505
|
+
output_start = pid * hidden_dim
|
506
|
+
|
507
|
+
input1_offs = tl.arange(0, BLOCK_SIZE)
|
508
|
+
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
|
509
|
+
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
|
510
|
+
output_offs = tl.arange(0, BLOCK_SIZE)
|
511
|
+
|
512
|
+
x1 = tl.load(
|
513
|
+
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
|
514
|
+
).to(tl.float32)
|
515
|
+
x3 = tl.load(
|
516
|
+
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
|
517
|
+
).to(tl.float32)
|
518
|
+
|
519
|
+
# silu
|
520
|
+
# cast down before mul to better match training?
|
521
|
+
silu_x1 = x1 * tl.sigmoid(x1)
|
522
|
+
out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)
|
523
|
+
|
524
|
+
if quant_max is not None:
|
525
|
+
raise NotImplementedError()
|
526
|
+
|
527
|
+
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
|
528
|
+
|
529
|
+
|
530
|
+
def silu_and_mul_triton(
|
531
|
+
hidden_states,
|
532
|
+
scales=None,
|
533
|
+
quantize=None, # dtype to quantize to
|
534
|
+
out=None,
|
535
|
+
):
|
536
|
+
bs, in_hidden_dim = hidden_states.shape
|
537
|
+
hidden_dim = in_hidden_dim // 2
|
538
|
+
|
539
|
+
if out is None:
|
540
|
+
out_hidden_states = torch.empty(
|
541
|
+
(bs, hidden_dim),
|
542
|
+
dtype=quantize or hidden_states.dtype,
|
543
|
+
device=hidden_states.device,
|
544
|
+
)
|
545
|
+
else:
|
546
|
+
assert out.shape == (bs, hidden_dim)
|
547
|
+
assert out.dtype == (quantize or hidden_states.dtype)
|
548
|
+
out_hidden_states = out
|
549
|
+
out_scales = None
|
550
|
+
static_scale = False
|
551
|
+
if quantize is not None:
|
552
|
+
if scales is None:
|
553
|
+
out_scales = torch.empty(
|
554
|
+
(bs,), dtype=torch.float32, device=hidden_states.device
|
555
|
+
)
|
556
|
+
else:
|
557
|
+
out_scales = scales
|
558
|
+
static_scale = True
|
559
|
+
|
560
|
+
max_warps = 16 if _is_hip else 32
|
561
|
+
config = {
|
562
|
+
# 8 ele per thread (not tuned)
|
563
|
+
"num_warps": max(
|
564
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
|
565
|
+
),
|
566
|
+
}
|
567
|
+
|
568
|
+
silu_and_mul_kernel[(bs,)](
|
569
|
+
out_hidden_states,
|
570
|
+
out_scales,
|
571
|
+
hidden_states,
|
572
|
+
quant_max=torch.finfo(quantize).max if quantize is not None else None,
|
573
|
+
static_scale=static_scale,
|
574
|
+
hidden_dim=hidden_dim,
|
575
|
+
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
|
576
|
+
**config,
|
577
|
+
)
|
578
|
+
|
579
|
+
if quantize is not None:
|
580
|
+
return out_hidden_states, out_scales
|
581
|
+
else:
|
582
|
+
return out_hidden_states, None
|
@@ -5,7 +5,11 @@ import torch
|
|
5
5
|
import torch.distributed as dist
|
6
6
|
|
7
7
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
8
|
-
from sglang.srt.utils import
|
8
|
+
from sglang.srt.utils import (
|
9
|
+
direct_register_custom_op,
|
10
|
+
is_flashinfer_available,
|
11
|
+
supports_custom_op,
|
12
|
+
)
|
9
13
|
|
10
14
|
logger = logging.getLogger(__name__)
|
11
15
|
|
@@ -196,6 +200,30 @@ def flashinfer_allreduce_residual_rmsnorm(
|
|
196
200
|
return norm_out, residual_out
|
197
201
|
|
198
202
|
|
203
|
+
def fake_flashinfer_allreduce_residual_rmsnorm(
|
204
|
+
input_tensor: torch.Tensor,
|
205
|
+
residual: torch.Tensor,
|
206
|
+
weight: torch.Tensor,
|
207
|
+
eps: float = 1e-6,
|
208
|
+
max_token_num: int = 2048,
|
209
|
+
use_oneshot: Optional[bool] = None,
|
210
|
+
trigger_completion_at_end: bool = False,
|
211
|
+
fp32_acc: bool = False,
|
212
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
213
|
+
residual_out = torch.empty_like(residual)
|
214
|
+
norm_out = torch.empty_like(input_tensor)
|
215
|
+
return norm_out, residual_out
|
216
|
+
|
217
|
+
|
218
|
+
if supports_custom_op():
|
219
|
+
direct_register_custom_op(
|
220
|
+
"flashinfer_allreduce_residual_rmsnorm",
|
221
|
+
flashinfer_allreduce_residual_rmsnorm,
|
222
|
+
mutates_args=["input_tensor", "residual", "weight"],
|
223
|
+
fake_impl=fake_flashinfer_allreduce_residual_rmsnorm,
|
224
|
+
)
|
225
|
+
|
226
|
+
|
199
227
|
def cleanup_flashinfer_workspace():
|
200
228
|
global _workspace_manager
|
201
229
|
if _workspace_manager is not None:
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -27,6 +27,7 @@ from sglang.srt.utils import (
|
|
27
27
|
is_cuda,
|
28
28
|
is_hip,
|
29
29
|
is_npu,
|
30
|
+
supports_custom_op,
|
30
31
|
)
|
31
32
|
|
32
33
|
_is_cuda = is_cuda()
|
@@ -202,8 +203,14 @@ class RMSNorm(CustomOp):
|
|
202
203
|
flashinfer_allreduce_residual_rmsnorm,
|
203
204
|
)
|
204
205
|
|
206
|
+
fused_op = (
|
207
|
+
torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
|
208
|
+
if supports_custom_op()
|
209
|
+
else flashinfer_allreduce_residual_rmsnorm
|
210
|
+
)
|
211
|
+
|
205
212
|
if get_tensor_model_parallel_world_size() > 1:
|
206
|
-
fused_result =
|
213
|
+
fused_result = fused_op(
|
207
214
|
input_tensor=x,
|
208
215
|
residual=residual,
|
209
216
|
weight=self.weight,
|
sglang/srt/layers/linear.py
CHANGED
@@ -110,6 +110,20 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|
110
110
|
return param[shard_id], loaded_weight
|
111
111
|
|
112
112
|
|
113
|
+
def adjust_shard_offsets(shard_offsets, loaded_weight, dim):
|
114
|
+
actual_weight_size = loaded_weight.size(dim)
|
115
|
+
target_weight_size = shard_offsets[-1][-1] + shard_offsets[-1][-2]
|
116
|
+
if actual_weight_size != target_weight_size:
|
117
|
+
new_shard_offsets = []
|
118
|
+
new_offset = 0
|
119
|
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
120
|
+
actual_shard_size = actual_weight_size * shard_size // target_weight_size
|
121
|
+
new_shard_offsets.append((shard_id, new_offset, actual_shard_size))
|
122
|
+
new_offset += actual_shard_size
|
123
|
+
return new_shard_offsets
|
124
|
+
return shard_offsets
|
125
|
+
|
126
|
+
|
113
127
|
class LinearBase(torch.nn.Module):
|
114
128
|
"""Base linear layer.
|
115
129
|
|
@@ -535,6 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
535
549
|
packed_dim = getattr(param, "packed_dim", None)
|
536
550
|
|
537
551
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
552
|
+
if _is_cpu:
|
553
|
+
shard_offsets = adjust_shard_offsets(
|
554
|
+
shard_offsets, loaded_weight, output_dim
|
555
|
+
)
|
556
|
+
|
538
557
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
539
558
|
# Special case for Quantization.
|
540
559
|
# If quantized, we need to adjust the offset and size to account
|
@@ -977,6 +996,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
977
996
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
978
997
|
|
979
998
|
packed_dim = getattr(param, "packed_dim", None)
|
999
|
+
if _is_cpu:
|
1000
|
+
shard_offsets = adjust_shard_offsets(
|
1001
|
+
shard_offsets, loaded_weight, output_dim
|
1002
|
+
)
|
1003
|
+
|
980
1004
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
981
1005
|
# Special case for Quantized Weights.
|
982
1006
|
# If quantized, we need to adjust the offset and size to account
|
@@ -191,7 +191,11 @@ class LogitsMetadata:
|
|
191
191
|
else:
|
192
192
|
self.global_dp_buffer_len = self.global_dp_buffer_len
|
193
193
|
|
194
|
-
set_dp_buffer_len(
|
194
|
+
set_dp_buffer_len(
|
195
|
+
self.global_dp_buffer_len,
|
196
|
+
self.dp_local_num_tokens,
|
197
|
+
self.global_num_tokens_for_logprob_cpu,
|
198
|
+
)
|
195
199
|
|
196
200
|
|
197
201
|
class LogitsProcessor(nn.Module):
|