sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -9,18 +9,89 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
|
9
9
|
|
10
10
|
@triton.jit
|
11
11
|
def create_flashinfer_kv_indices_triton(
|
12
|
-
req_to_token_ptr,
|
12
|
+
req_to_token_ptr,
|
13
13
|
req_pool_indices_ptr,
|
14
14
|
page_kernel_lens_ptr,
|
15
15
|
kv_indptr,
|
16
16
|
kv_start_idx,
|
17
17
|
kv_indices_ptr,
|
18
18
|
req_to_token_ptr_stride: tl.constexpr,
|
19
|
+
PAGE_SIZE: tl.constexpr = 1,
|
19
20
|
):
|
21
|
+
"""
|
22
|
+
Create KV indices for FlashInfer attention backend.
|
23
|
+
|
24
|
+
This Triton kernel builds a lookup table that maps from logical request/token
|
25
|
+
coordinates to physical token locations in the global KV cache pool. It's used
|
26
|
+
by FlashInfer attention backends to efficiently access scattered KV cache data.
|
27
|
+
|
28
|
+
The kernel processes each request in parallel and converts the req_to_token
|
29
|
+
lookup table into a flat list of token indices that can be used by attention kernels.
|
30
|
+
|
31
|
+
general idea:
|
32
|
+
blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
|
33
|
+
fixed number of pages)]
|
34
|
+
max_pages = max_context_len / PAGED_SIZE
|
35
|
+
kv_indices_ptr will store the flat list of the pages used by each request
|
36
|
+
Args:
|
37
|
+
Inputs Arguments (non mutable):
|
38
|
+
|
39
|
+
req_to_token_ptr: Request to token location look up table
|
40
|
+
Shape: [max_batch, max_context_len]
|
41
|
+
req_pool_indices_ptr: Request to pool index look up table. Each request uses
|
42
|
+
one pool.
|
43
|
+
Shape: [batch_size]
|
44
|
+
page_kernel_lens_ptr: sequence lengths per request
|
45
|
+
Shape: [batch_size]
|
46
|
+
kv_indptr: Should be computed based on number of pages used by each request.
|
47
|
+
It is used by flashinfer attention kernels to index into the kv_indices_ptr.
|
48
|
+
per request.
|
49
|
+
Shape: [batch_size + 1]
|
50
|
+
kv_indptr[i] = start index in kv_indices for request i
|
51
|
+
kv_start_idx: Pointer to array containing start offsets for each request in SGL.
|
52
|
+
Can be None. If provided, adds offset to token positions.
|
53
|
+
|
54
|
+
req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
|
55
|
+
Equal to max_context_len.
|
56
|
+
|
57
|
+
PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
|
58
|
+
|
59
|
+
Outputs:
|
60
|
+
kv_indices_ptr: Pointer to output array where KV indices will be stored.
|
61
|
+
Shape:[total-num-pages],
|
62
|
+
where total_num_pages = sum(seq_lens // PAGED_SIZE)
|
63
|
+
|
64
|
+
Example:
|
65
|
+
If we have:
|
66
|
+
- req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
|
67
|
+
- page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
|
68
|
+
- req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
|
69
|
+
in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
|
70
|
+
|
71
|
+
The kernel will output:
|
72
|
+
If PAGE_SIZE = 1:
|
73
|
+
packed
|
74
|
+
- kv_indptr (passed in as input arg): [0,3,5]
|
75
|
+
- kv_indices = [10, 11, 12, 20, 21]
|
76
|
+
padded - max_pages is 10 tokens per req
|
77
|
+
- kv_indptr (passed in as input arg): [0,10, 20]
|
78
|
+
- kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
|
79
|
+
20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
|
80
|
+
|
81
|
+
If PAGE_SIZE = 2
|
82
|
+
packed:
|
83
|
+
- kv_indptr (passed in as input arg): [0,3,4]
|
84
|
+
- kv_indices = [5,6,10]
|
85
|
+
padded: max_pages is 4
|
86
|
+
- kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
|
87
|
+
- kv_indices = [5, 6, -1, -1,
|
88
|
+
10, -1, -1, -1]
|
89
|
+
This allows attention kernels to directly access the correct KV cache
|
90
|
+
entries for each request's tokens.
|
91
|
+
"""
|
20
92
|
BLOCK_SIZE: tl.constexpr = 512
|
93
|
+
NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
|
21
94
|
pid = tl.program_id(axis=0)
|
22
|
-
|
23
|
-
# find the req pool idx, this is for batch to token
|
24
95
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
25
96
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
26
97
|
|
@@ -31,19 +102,27 @@ def create_flashinfer_kv_indices_triton(
|
|
31
102
|
kv_end = kv_start
|
32
103
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
33
104
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
+
|
43
|
-
|
44
|
-
|
105
|
+
kv_range = kv_end - kv_start
|
106
|
+
num_pages = tl.cdiv(kv_range, PAGE_SIZE)
|
107
|
+
num_loops = tl.cdiv(kv_range, BLOCK_SIZE)
|
108
|
+
req_to_token_block_start = (
|
109
|
+
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start
|
110
|
+
)
|
111
|
+
for i in range(num_loops):
|
112
|
+
token_offsets_in_block = (
|
113
|
+
tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK
|
114
|
+
) * PAGE_SIZE
|
115
|
+
page_offsets_in_block = token_offsets_in_block // PAGE_SIZE
|
116
|
+
valid_tokens = token_offsets_in_block < kv_range
|
117
|
+
valid_pages = page_offsets_in_block < num_pages
|
118
|
+
token_numbers = tl.load(
|
119
|
+
req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
|
120
|
+
)
|
121
|
+
tl.store(
|
122
|
+
kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
|
123
|
+
token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
|
124
|
+
mask=valid_pages,
|
45
125
|
)
|
46
|
-
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
47
126
|
|
48
127
|
|
49
128
|
@triton.jit
|
@@ -12,7 +12,12 @@ import torch.nn.functional as F
|
|
12
12
|
from einops import rearrange
|
13
13
|
|
14
14
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
15
|
-
from sglang.srt.utils import
|
15
|
+
from sglang.srt.utils import (
|
16
|
+
get_device_capability,
|
17
|
+
is_blackwell,
|
18
|
+
is_cuda,
|
19
|
+
print_info_once,
|
20
|
+
)
|
16
21
|
|
17
22
|
_is_cuda = is_cuda()
|
18
23
|
|
@@ -20,7 +25,6 @@ if _is_cuda:
|
|
20
25
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
21
26
|
|
22
27
|
from sglang.srt.distributed import (
|
23
|
-
parallel_state,
|
24
28
|
split_tensor_along_last_dim,
|
25
29
|
tensor_model_parallel_all_gather,
|
26
30
|
)
|
@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
|
|
402
406
|
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
403
407
|
)
|
404
408
|
|
405
|
-
#
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
qkv_backend = "sdpa"
|
409
|
+
# Select attention backend via a unified method
|
410
|
+
_passed_backend = qkv_backend
|
411
|
+
qkv_backend = self._determine_attention_backend(_passed_backend)
|
412
|
+
if (
|
413
|
+
global_server_args_dict["mm_attention_backend"] is None
|
414
|
+
and _passed_backend is None
|
415
|
+
):
|
413
416
|
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
414
|
-
else:
|
415
|
-
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
416
|
-
|
417
417
|
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
418
418
|
|
419
419
|
self.customized_position_embedding_applier = (
|
@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
|
|
461
461
|
prefix=add_prefix("proj", prefix),
|
462
462
|
)
|
463
463
|
|
464
|
+
def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:
|
465
|
+
"""Decide the multimodal attention backend string.
|
466
|
+
|
467
|
+
Priority: server args override > constructor arg > platform default.
|
468
|
+
|
469
|
+
Platform defaults:
|
470
|
+
- CUDA: "triton_attn"
|
471
|
+
- Non-CUDA: "sdpa"
|
472
|
+
"""
|
473
|
+
override_backend = global_server_args_dict["mm_attention_backend"]
|
474
|
+
if override_backend is not None:
|
475
|
+
backend = override_backend
|
476
|
+
elif passed_backend is not None:
|
477
|
+
backend = passed_backend
|
478
|
+
elif is_cuda():
|
479
|
+
major, minor = get_device_capability()
|
480
|
+
if major == 9:
|
481
|
+
backend = "fa3"
|
482
|
+
else:
|
483
|
+
backend = "triton_attn"
|
484
|
+
else:
|
485
|
+
backend = "sdpa"
|
486
|
+
if backend == "fa3" and is_blackwell():
|
487
|
+
raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
|
488
|
+
|
489
|
+
return backend
|
490
|
+
|
464
491
|
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
465
492
|
"""apply qk norm for internvl vit attn"""
|
466
493
|
q = q.flatten(1, 2)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
"""Utility functions for vision attention layers."""
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
6
|
+
|
7
|
+
|
8
|
+
def update_vit_attn_dummy_heads_config(config):
|
9
|
+
"""Update HF config to ensure vision attention num_attention_heads is divisible by tp_size"""
|
10
|
+
tp_size = get_attention_tp_size()
|
11
|
+
num_heads = getattr(
|
12
|
+
config.vision_config,
|
13
|
+
"num_heads",
|
14
|
+
getattr(config.vision_config, "num_attention_heads", None),
|
15
|
+
)
|
16
|
+
head_dim = config.vision_config.hidden_size // num_heads
|
17
|
+
num_dummy_heads = 0
|
18
|
+
|
19
|
+
if num_heads % tp_size != 0:
|
20
|
+
num_dummy_heads = ((num_heads + tp_size - 1) // tp_size) * tp_size - num_heads
|
21
|
+
|
22
|
+
setattr(config.vision_config, "head_dim", head_dim)
|
23
|
+
setattr(config.vision_config, "num_dummy_heads", num_dummy_heads)
|
24
|
+
|
25
|
+
|
26
|
+
def pad_vit_attn_dummy_heads(config, name: str, loaded_weight: torch.Tensor):
|
27
|
+
"""Pad attention qkv weights for dummy heads"""
|
28
|
+
num_dummy_heads = config.vision_config.num_dummy_heads
|
29
|
+
if num_dummy_heads == 0:
|
30
|
+
return loaded_weight
|
31
|
+
head_dim = config.vision_config.head_dim
|
32
|
+
|
33
|
+
if "attn.qkv_proj" in name:
|
34
|
+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
35
|
+
if name.endswith(".weight"):
|
36
|
+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
37
|
+
elif name.endswith(".bias"):
|
38
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
39
|
+
else:
|
40
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
41
|
+
pad_func = lambda x: torch.cat(
|
42
|
+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
43
|
+
).flatten(0, 1)
|
44
|
+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
45
|
+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
46
|
+
elif any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
|
47
|
+
if name.endswith(".weight"):
|
48
|
+
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
|
49
|
+
elif name.endswith(".bias"):
|
50
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
51
|
+
else:
|
52
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
53
|
+
padded_weight = loaded_weight.new_zeros(dummy_shape)
|
54
|
+
loaded_weight = torch.cat(
|
55
|
+
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
|
56
|
+
).flatten(0, 1)
|
57
|
+
elif "attn.proj.weight" in name:
|
58
|
+
padded_weight = loaded_weight.new_zeros(
|
59
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
60
|
+
)
|
61
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
62
|
+
elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
63
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
64
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
65
|
+
return loaded_weight
|
@@ -17,7 +17,7 @@ from enum import Enum, auto
|
|
17
17
|
from functools import partial
|
18
18
|
from typing import Dict, Optional
|
19
19
|
|
20
|
-
import torch
|
20
|
+
import torch
|
21
21
|
|
22
22
|
from sglang.srt.distributed import (
|
23
23
|
get_tensor_model_parallel_world_size,
|
@@ -32,6 +32,13 @@ from sglang.srt.layers.dp_attention import (
|
|
32
32
|
get_attention_dp_size,
|
33
33
|
get_attention_tp_rank,
|
34
34
|
get_attention_tp_size,
|
35
|
+
get_global_dp_buffer,
|
36
|
+
get_local_dp_buffer,
|
37
|
+
is_dp_attention_enabled,
|
38
|
+
)
|
39
|
+
from sglang.srt.layers.moe import (
|
40
|
+
get_moe_a2a_backend,
|
41
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
35
42
|
)
|
36
43
|
from sglang.srt.layers.utils import is_sm100_supported
|
37
44
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -41,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
|
|
41
48
|
_is_flashinfer_available = is_flashinfer_available()
|
42
49
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
43
50
|
|
51
|
+
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
52
|
+
|
44
53
|
|
45
54
|
class ScatterMode(Enum):
|
46
55
|
"""
|
@@ -109,7 +118,11 @@ class LayerScatterModes:
|
|
109
118
|
if context.is_layer_sparse:
|
110
119
|
return (
|
111
120
|
ScatterMode.SCATTERED
|
112
|
-
if
|
121
|
+
if (
|
122
|
+
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
|
123
|
+
not get_moe_a2a_backend().is_none()
|
124
|
+
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
125
|
+
)
|
113
126
|
else ScatterMode.FULL
|
114
127
|
)
|
115
128
|
else:
|
@@ -152,11 +165,13 @@ class LayerCommunicator:
|
|
152
165
|
post_attention_layernorm: torch.nn.Module,
|
153
166
|
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
154
167
|
allow_reduce_scatter: bool = False,
|
168
|
+
is_last_layer: bool = False,
|
155
169
|
):
|
156
170
|
self.layer_scatter_modes = layer_scatter_modes
|
157
171
|
self.input_layernorm = input_layernorm
|
158
172
|
self.post_attention_layernorm = post_attention_layernorm
|
159
173
|
self.allow_reduce_scatter = allow_reduce_scatter
|
174
|
+
self.is_last_layer = is_last_layer
|
160
175
|
|
161
176
|
self._context = CommunicateContext.init_new()
|
162
177
|
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
@@ -254,6 +269,41 @@ class LayerCommunicator:
|
|
254
269
|
and forward_batch.dp_padding_mode.is_max_len()
|
255
270
|
)
|
256
271
|
|
272
|
+
def should_fuse_mlp_allreduce_with_next_layer(
|
273
|
+
self, forward_batch: ForwardBatch
|
274
|
+
) -> bool:
|
275
|
+
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
|
276
|
+
if (
|
277
|
+
is_dp_attention_enabled()
|
278
|
+
and speculative_algo is not None
|
279
|
+
and speculative_algo.is_eagle()
|
280
|
+
):
|
281
|
+
return False
|
282
|
+
|
283
|
+
batch_size = (
|
284
|
+
forward_batch.input_ids.shape[0]
|
285
|
+
if hasattr(forward_batch, "input_ids")
|
286
|
+
else 0
|
287
|
+
)
|
288
|
+
if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
|
289
|
+
return False
|
290
|
+
|
291
|
+
static_conditions_met = (
|
292
|
+
(not self.is_last_layer)
|
293
|
+
and (self._context.tp_size > 1)
|
294
|
+
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
295
|
+
and _is_flashinfer_available
|
296
|
+
)
|
297
|
+
|
298
|
+
if not static_conditions_met:
|
299
|
+
return False
|
300
|
+
|
301
|
+
return (
|
302
|
+
batch_size > 0
|
303
|
+
and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
|
304
|
+
and (not self.is_last_layer)
|
305
|
+
)
|
306
|
+
|
257
307
|
|
258
308
|
@dataclass
|
259
309
|
class CommunicateContext:
|
@@ -319,7 +369,7 @@ class CommunicateSimpleFn:
|
|
319
369
|
context: CommunicateContext,
|
320
370
|
) -> torch.Tensor:
|
321
371
|
hidden_states, local_hidden_states = (
|
322
|
-
|
372
|
+
get_local_dp_buffer(),
|
323
373
|
hidden_states,
|
324
374
|
)
|
325
375
|
attn_tp_all_gather_into_tensor(
|
@@ -380,7 +430,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
380
430
|
)
|
381
431
|
|
382
432
|
raise NotImplementedError(
|
383
|
-
f"{hidden_states_input_mode=} {residual_input_mode=} {
|
433
|
+
f"{hidden_states_input_mode=} {residual_input_mode=} {hidden_states_output_mode=} {residual_output_mode=}"
|
384
434
|
)
|
385
435
|
|
386
436
|
@staticmethod
|
@@ -408,9 +458,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
408
458
|
):
|
409
459
|
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
410
460
|
residual, local_residual = (
|
411
|
-
|
412
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
|
413
|
-
),
|
461
|
+
get_local_dp_buffer(),
|
414
462
|
residual,
|
415
463
|
)
|
416
464
|
attn_tp_all_gather_into_tensor(residual, local_residual)
|
@@ -424,7 +472,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
424
472
|
residual = hidden_states
|
425
473
|
hidden_states = layernorm(hidden_states)
|
426
474
|
hidden_states, local_hidden_states = (
|
427
|
-
|
475
|
+
get_global_dp_buffer(),
|
428
476
|
hidden_states,
|
429
477
|
)
|
430
478
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
@@ -548,7 +596,7 @@ class CommunicateSummableTensorPairFn:
|
|
548
596
|
allow_reduce_scatter: bool = False,
|
549
597
|
):
|
550
598
|
hidden_states, global_hidden_states = (
|
551
|
-
|
599
|
+
get_local_dp_buffer(),
|
552
600
|
hidden_states,
|
553
601
|
)
|
554
602
|
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
@@ -569,7 +617,7 @@ class CommunicateSummableTensorPairFn:
|
|
569
617
|
hidden_states += residual
|
570
618
|
residual = None
|
571
619
|
hidden_states, local_hidden_states = (
|
572
|
-
|
620
|
+
get_local_dp_buffer(),
|
573
621
|
hidden_states,
|
574
622
|
)
|
575
623
|
attn_tp_all_gather_into_tensor(
|
@@ -4,7 +4,7 @@ import functools
|
|
4
4
|
import logging
|
5
5
|
from contextlib import contextmanager
|
6
6
|
from enum import IntEnum, auto
|
7
|
-
from typing import TYPE_CHECKING, List, Tuple
|
7
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
8
8
|
|
9
9
|
import torch
|
10
10
|
import triton
|
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
|
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
20
|
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from sglang.srt.configs.model_config import ModelConfig
|
23
|
+
from sglang.srt.server_args import ServerArgs
|
24
|
+
|
21
25
|
logger = logging.getLogger(__name__)
|
22
26
|
|
23
27
|
if TYPE_CHECKING:
|
24
28
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
25
29
|
|
26
|
-
_ATTN_TP_GROUP = None
|
27
|
-
_ATTN_TP_RANK = None
|
28
|
-
_ATTN_TP_SIZE = None
|
29
|
-
_ATTN_DP_RANK = None
|
30
|
-
_ATTN_DP_SIZE = None
|
31
|
-
_LOCAL_ATTN_DP_SIZE = None
|
32
|
-
_LOCAL_ATTN_DP_RANK = None
|
30
|
+
_ATTN_TP_GROUP: Optional[GroupCoordinator] = None
|
31
|
+
_ATTN_TP_RANK: Optional[int] = None
|
32
|
+
_ATTN_TP_SIZE: Optional[int] = None
|
33
|
+
_ATTN_DP_RANK: Optional[int] = None
|
34
|
+
_ATTN_DP_SIZE: Optional[int] = None
|
35
|
+
_LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
36
|
+
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
37
|
+
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
33
38
|
|
34
39
|
|
35
|
-
class
|
40
|
+
class DpPaddingMode(IntEnum):
|
36
41
|
|
37
42
|
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
38
43
|
MAX_LEN = auto()
|
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
|
|
40
45
|
SUM_LEN = auto()
|
41
46
|
|
42
47
|
def is_max_len(self):
|
43
|
-
return self ==
|
48
|
+
return self == DpPaddingMode.MAX_LEN
|
44
49
|
|
45
50
|
def is_sum_len(self):
|
46
|
-
return self ==
|
51
|
+
return self == DpPaddingMode.SUM_LEN
|
47
52
|
|
48
53
|
@classmethod
|
49
|
-
def get_dp_padding_mode(cls, global_num_tokens: List[int]) ->
|
54
|
+
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
|
50
55
|
# we choose the mode that minimizes the communication cost
|
51
56
|
max_len = max(global_num_tokens)
|
52
57
|
sum_len = sum(global_num_tokens)
|
@@ -56,10 +61,95 @@ class DPPaddingMode(IntEnum):
|
|
56
61
|
return cls.SUM_LEN
|
57
62
|
|
58
63
|
@classmethod
|
59
|
-
def get_default_mode_in_cuda_graph(cls) ->
|
64
|
+
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
60
65
|
return cls.MAX_LEN
|
61
66
|
|
62
67
|
|
68
|
+
class _DpGatheredBufferWrapper:
|
69
|
+
|
70
|
+
_hidden_size: int
|
71
|
+
_dtype: torch.dtype
|
72
|
+
_device: torch.device
|
73
|
+
_global_dp_buffer_len: int
|
74
|
+
_local_dp_buffer_len: int
|
75
|
+
_global_num_tokens: Optional[List[int]]
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
79
|
+
cls._hidden_size = hidden_size
|
80
|
+
cls._dtype = dtype
|
81
|
+
cls._device = device
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def set_dp_buffer_len(
|
85
|
+
cls,
|
86
|
+
global_dp_buffer_len: int,
|
87
|
+
local_dp_buffer_len: int,
|
88
|
+
global_num_tokens: Optional[List[int]] = None,
|
89
|
+
):
|
90
|
+
cls._global_dp_buffer_len = global_dp_buffer_len
|
91
|
+
cls._local_dp_buffer_len = local_dp_buffer_len
|
92
|
+
cls._global_num_tokens = global_num_tokens
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def get_global_dp_buffer(cls) -> torch.Tensor:
|
96
|
+
return torch.empty(
|
97
|
+
(cls._global_dp_buffer_len, cls._hidden_size),
|
98
|
+
dtype=cls._dtype,
|
99
|
+
device=cls._device,
|
100
|
+
)
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def get_local_dp_buffer(cls) -> torch.Tensor:
|
104
|
+
return torch.empty(
|
105
|
+
(cls._local_dp_buffer_len, cls._hidden_size),
|
106
|
+
dtype=cls._dtype,
|
107
|
+
device=cls._device,
|
108
|
+
)
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def get_global_dp_buffer_len(cls) -> int:
|
112
|
+
return cls._global_dp_buffer_len
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
def get_local_dp_buffer_len(cls) -> int:
|
116
|
+
return cls._local_dp_buffer_len
|
117
|
+
|
118
|
+
@classmethod
|
119
|
+
def get_dp_global_num_tokens(cls) -> List[int]:
|
120
|
+
return cls._global_num_tokens
|
121
|
+
|
122
|
+
|
123
|
+
def set_dp_buffer_len(
|
124
|
+
global_dp_buffer_len: int,
|
125
|
+
local_dp_buffer_len: int,
|
126
|
+
global_num_tokens: Optional[List[int]] = None,
|
127
|
+
):
|
128
|
+
_DpGatheredBufferWrapper.set_dp_buffer_len(
|
129
|
+
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
|
130
|
+
)
|
131
|
+
|
132
|
+
|
133
|
+
def get_global_dp_buffer() -> torch.Tensor:
|
134
|
+
return _DpGatheredBufferWrapper.get_global_dp_buffer()
|
135
|
+
|
136
|
+
|
137
|
+
def get_local_dp_buffer() -> torch.Tensor:
|
138
|
+
return _DpGatheredBufferWrapper.get_local_dp_buffer()
|
139
|
+
|
140
|
+
|
141
|
+
def get_global_dp_buffer_len() -> int:
|
142
|
+
return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
|
143
|
+
|
144
|
+
|
145
|
+
def get_local_dp_buffer_len() -> int:
|
146
|
+
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
|
147
|
+
|
148
|
+
|
149
|
+
def get_dp_global_num_tokens() -> List[int]:
|
150
|
+
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
|
151
|
+
|
152
|
+
|
63
153
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
64
154
|
if not enable_dp_attention:
|
65
155
|
return tp_rank, tp_size, 0
|
@@ -89,18 +179,24 @@ def compute_dp_attention_local_info(
|
|
89
179
|
|
90
180
|
|
91
181
|
def initialize_dp_attention(
|
92
|
-
|
93
|
-
|
94
|
-
tp_size: int,
|
95
|
-
dp_size: int,
|
96
|
-
moe_dense_tp_size: int,
|
97
|
-
pp_size: int,
|
182
|
+
server_args: ServerArgs,
|
183
|
+
model_config: ModelConfig,
|
98
184
|
):
|
99
185
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
100
|
-
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
186
|
+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
|
101
187
|
|
102
188
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
103
189
|
|
190
|
+
enable_dp_attention = server_args.enable_dp_attention
|
191
|
+
tp_size = server_args.tp_size
|
192
|
+
dp_size = server_args.dp_size
|
193
|
+
moe_dense_tp_size = server_args.moe_dense_tp_size
|
194
|
+
pp_size = server_args.pp_size
|
195
|
+
|
196
|
+
tp_rank = get_tensor_model_parallel_rank()
|
197
|
+
|
198
|
+
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
|
199
|
+
|
104
200
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
105
201
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
106
202
|
)
|
@@ -135,38 +231,48 @@ def initialize_dp_attention(
|
|
135
231
|
group_name="attention_tp",
|
136
232
|
)
|
137
233
|
|
234
|
+
_DpGatheredBufferWrapper.set_metadata(
|
235
|
+
hidden_size=model_config.hidden_size,
|
236
|
+
dtype=model_config.dtype,
|
237
|
+
device=torch.device(server_args.device),
|
238
|
+
)
|
239
|
+
|
240
|
+
|
241
|
+
def is_dp_attention_enabled() -> bool:
|
242
|
+
return _ENABLE_DP_ATTENTION_FLAG
|
138
243
|
|
139
|
-
|
244
|
+
|
245
|
+
def get_attention_tp_group() -> GroupCoordinator:
|
140
246
|
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
141
247
|
return _ATTN_TP_GROUP
|
142
248
|
|
143
249
|
|
144
|
-
def get_attention_tp_rank():
|
250
|
+
def get_attention_tp_rank() -> int:
|
145
251
|
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
146
252
|
return _ATTN_TP_RANK
|
147
253
|
|
148
254
|
|
149
|
-
def get_attention_tp_size():
|
255
|
+
def get_attention_tp_size() -> int:
|
150
256
|
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
151
257
|
return _ATTN_TP_SIZE
|
152
258
|
|
153
259
|
|
154
|
-
def get_attention_dp_rank():
|
260
|
+
def get_attention_dp_rank() -> int:
|
155
261
|
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
156
262
|
return _ATTN_DP_RANK
|
157
263
|
|
158
264
|
|
159
|
-
def get_attention_dp_size():
|
265
|
+
def get_attention_dp_size() -> int:
|
160
266
|
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
161
267
|
return _ATTN_DP_SIZE
|
162
268
|
|
163
269
|
|
164
|
-
def get_local_attention_dp_rank():
|
270
|
+
def get_local_attention_dp_rank() -> int:
|
165
271
|
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
166
272
|
return _LOCAL_ATTN_DP_RANK
|
167
273
|
|
168
274
|
|
169
|
-
def get_local_attention_dp_size():
|
275
|
+
def get_local_attention_dp_size() -> int:
|
170
276
|
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
171
277
|
return _LOCAL_ATTN_DP_SIZE
|
172
278
|
|
@@ -292,6 +398,10 @@ def _dp_gather_via_all_gather(
|
|
292
398
|
forward_batch: ForwardBatch,
|
293
399
|
is_partial: bool,
|
294
400
|
):
|
401
|
+
if get_attention_tp_size() == 1:
|
402
|
+
get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
|
403
|
+
return
|
404
|
+
|
295
405
|
if not is_partial:
|
296
406
|
if get_attention_tp_rank() != 0:
|
297
407
|
local_tokens.fill_(0)
|