sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,96 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def merge_state_kernel(
|
10
|
+
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
|
11
|
+
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
|
12
|
+
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
|
13
|
+
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
|
14
|
+
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
|
15
|
+
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
|
16
|
+
HEAD_SIZE: tl.constexpr,
|
17
|
+
PADDED_HEAD_SIZE: tl.constexpr,
|
18
|
+
OUTPUT_LSE: tl.constexpr,
|
19
|
+
):
|
20
|
+
token_idx = tl.program_id(0)
|
21
|
+
num_tokens = tl.num_programs(0)
|
22
|
+
head_idx = tl.program_id(1)
|
23
|
+
num_heads = tl.num_programs(1)
|
24
|
+
|
25
|
+
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
|
26
|
+
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
|
27
|
+
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
28
|
+
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
29
|
+
|
30
|
+
max_lse = tl.maximum(p_lse, s_lse)
|
31
|
+
p_lse = p_lse - max_lse
|
32
|
+
s_lse = s_lse - max_lse
|
33
|
+
out_se = tl.exp(p_lse) + tl.exp(s_lse)
|
34
|
+
|
35
|
+
if OUTPUT_LSE:
|
36
|
+
out_lse = tl.log(out_se) + max_lse
|
37
|
+
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
|
38
|
+
|
39
|
+
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
40
|
+
head_mask = head_arange < HEAD_SIZE
|
41
|
+
p_out = tl.load(
|
42
|
+
prefix_output
|
43
|
+
+ token_idx * num_heads * HEAD_SIZE
|
44
|
+
+ head_idx * HEAD_SIZE
|
45
|
+
+ head_arange,
|
46
|
+
mask=head_mask,
|
47
|
+
)
|
48
|
+
s_out = tl.load(
|
49
|
+
suffix_output
|
50
|
+
+ token_idx * num_heads * HEAD_SIZE
|
51
|
+
+ head_idx * HEAD_SIZE
|
52
|
+
+ head_arange,
|
53
|
+
mask=head_mask,
|
54
|
+
)
|
55
|
+
|
56
|
+
p_scale = tl.exp(p_lse) / out_se
|
57
|
+
s_scale = tl.exp(s_lse) / out_se
|
58
|
+
out = p_out * p_scale + s_out * s_scale
|
59
|
+
tl.store(
|
60
|
+
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
61
|
+
out,
|
62
|
+
mask=head_mask,
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
def merge_state_triton(
|
67
|
+
prefix_output: torch.Tensor,
|
68
|
+
prefix_lse: torch.Tensor,
|
69
|
+
suffix_output: torch.Tensor,
|
70
|
+
suffix_lse: torch.Tensor,
|
71
|
+
output: Optional[torch.Tensor] = None,
|
72
|
+
output_lse: Optional[torch.Tensor] = None,
|
73
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
74
|
+
# Avoid creating new tensors if they are already provided
|
75
|
+
if output is None:
|
76
|
+
output = torch.empty_like(prefix_output)
|
77
|
+
if output_lse is None:
|
78
|
+
output_lse = torch.empty_like(prefix_lse)
|
79
|
+
|
80
|
+
num_tokens = output.shape[0]
|
81
|
+
num_query_heads = output.shape[1]
|
82
|
+
head_size = output.shape[2]
|
83
|
+
padded_head_size = triton.next_power_of_2(head_size)
|
84
|
+
|
85
|
+
merge_state_kernel[(num_tokens, num_query_heads)](
|
86
|
+
output,
|
87
|
+
output_lse,
|
88
|
+
prefix_output,
|
89
|
+
prefix_lse,
|
90
|
+
suffix_output,
|
91
|
+
suffix_lse,
|
92
|
+
head_size,
|
93
|
+
padded_head_size,
|
94
|
+
output_lse is not None,
|
95
|
+
)
|
96
|
+
return output, output_lse
|
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
|
|
28
28
|
|
29
29
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
30
30
|
for i in range(num_loop):
|
31
|
-
|
31
|
+
# index into req_to_token_ptr needs to be int64
|
32
|
+
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
|
32
33
|
mask = offset < kv_end - kv_start
|
33
34
|
data = tl.load(
|
34
35
|
req_to_token_ptr
|
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
|
|
70
71
|
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
71
72
|
|
72
73
|
for i in range(num_pages_loop):
|
74
|
+
# index into req_to_token_ptr needs to be int64
|
73
75
|
paged_offset = (
|
74
|
-
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
76
|
+
tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
|
75
77
|
) * PAGED_SIZE
|
76
78
|
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
77
79
|
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
|
3
|
+
import math
|
4
|
+
from functools import lru_cache, wraps
|
4
5
|
from typing import Optional, Tuple
|
5
6
|
|
6
7
|
import torch
|
@@ -8,6 +9,13 @@ import torch.nn as nn
|
|
8
9
|
import torch.nn.functional as F
|
9
10
|
from einops import rearrange
|
10
11
|
|
12
|
+
from sglang.srt.utils import is_cuda
|
13
|
+
|
14
|
+
_is_cuda = is_cuda()
|
15
|
+
|
16
|
+
if _is_cuda:
|
17
|
+
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
18
|
+
|
11
19
|
from sglang.srt.distributed import parallel_state
|
12
20
|
from sglang.srt.distributed import utils as dist_utils
|
13
21
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
@@ -19,166 +27,31 @@ from sglang.srt.layers.linear import (
|
|
19
27
|
RowParallelLinear,
|
20
28
|
)
|
21
29
|
from sglang.srt.layers.quantization import QuantizationConfig
|
22
|
-
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
23
|
-
from sglang.srt.
|
30
|
+
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
31
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
32
|
+
from sglang.srt.utils import add_prefix, logger
|
24
33
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
Multi-headed attention without any cache, mostly used for ViT.
|
34
|
+
ROTARY_EMBED_CLASSES = {
|
35
|
+
"normal": apply_rotary_pos_emb,
|
36
|
+
}
|
29
37
|
|
30
38
|
|
31
|
-
|
32
|
-
|
33
|
-
use_context_forward (bool, default to True):
|
34
|
-
if ``True``, a flash_attn style attention will be applied
|
35
|
-
Otherwise, a full-sequence attention will be applied.
|
36
|
-
softmax_in_single_precision (bool, default to False):
|
37
|
-
if ``True``, the softmax will be performed in single-precision
|
38
|
-
Otherwise, it will be performed in half-precision
|
39
|
+
def execute_once(func):
|
40
|
+
has_run = None
|
39
41
|
|
40
|
-
|
42
|
+
@wraps(func)
|
43
|
+
def wrapper(*args, **kwargs):
|
44
|
+
nonlocal has_run
|
45
|
+
if not has_run:
|
46
|
+
func(*args, **kwargs)
|
47
|
+
has_run = True
|
41
48
|
|
42
|
-
|
43
|
-
self,
|
44
|
-
embed_dim: int,
|
45
|
-
num_heads: int,
|
46
|
-
projection_size: int,
|
47
|
-
use_qkv_parallel: bool,
|
48
|
-
quant_config: Optional[QuantizationConfig] = None,
|
49
|
-
dropout: float = 0.0,
|
50
|
-
use_context_forward: bool = True,
|
51
|
-
softmax_in_single_precision: bool = False,
|
52
|
-
flatten_batch: bool = False,
|
53
|
-
prefix: str = "",
|
54
|
-
):
|
55
|
-
super().__init__()
|
56
|
-
self.use_context_forward = use_context_forward
|
57
|
-
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
58
|
-
self.dropout = dropout
|
59
|
-
self.head_size = embed_dim // num_heads
|
60
|
-
self.hidden_size_per_attention_head = dist_utils.divide(
|
61
|
-
projection_size, num_heads
|
62
|
-
)
|
63
|
-
self.num_attention_heads_per_partition = dist_utils.divide(
|
64
|
-
num_heads, world_size
|
65
|
-
)
|
49
|
+
return wrapper
|
66
50
|
|
67
|
-
if self.use_context_forward:
|
68
|
-
self.qkv_backend = VisionTritonAttention()
|
69
|
-
else:
|
70
|
-
self.qkv_backend = VisionSdpaAttention(
|
71
|
-
head_size=self.head_size,
|
72
|
-
dropout=dropout,
|
73
|
-
flatten_batch=flatten_batch,
|
74
|
-
softmax_in_single_precision=softmax_in_single_precision,
|
75
|
-
)
|
76
51
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
hidden_size=embed_dim,
|
81
|
-
head_size=self.head_size,
|
82
|
-
total_num_heads=num_heads,
|
83
|
-
quant_config=quant_config,
|
84
|
-
prefix=add_prefix("qkv_proj", prefix),
|
85
|
-
)
|
86
|
-
else:
|
87
|
-
self.qkv_proj = ColumnParallelLinear(
|
88
|
-
input_size=embed_dim,
|
89
|
-
output_size=3 * projection_size,
|
90
|
-
quant_config=quant_config,
|
91
|
-
prefix=add_prefix("qkv_proj", prefix),
|
92
|
-
)
|
93
|
-
self.proj = RowParallelLinear(
|
94
|
-
input_size=embed_dim,
|
95
|
-
output_size=embed_dim,
|
96
|
-
quant_config=quant_config,
|
97
|
-
prefix=add_prefix("proj", prefix),
|
98
|
-
)
|
99
|
-
|
100
|
-
def forward(
|
101
|
-
self,
|
102
|
-
x: torch.Tensor,
|
103
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
104
|
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
105
|
-
attention_mask: Optional[torch.Tensor] = None,
|
106
|
-
) -> torch.Tensor:
|
107
|
-
r"""
|
108
|
-
Args:
|
109
|
-
x: [b, s, embed_dim]
|
110
|
-
cu_seqlens: [b]
|
111
|
-
Returns:
|
112
|
-
[s, b, head * head_size]
|
113
|
-
"""
|
114
|
-
bsz, s, _ = x.shape
|
115
|
-
head = self.num_attention_heads_per_partition
|
116
|
-
if self.use_qkv_parallel:
|
117
|
-
# [b, s, embed_dim] --> [b, s, embed_dim]
|
118
|
-
qkv, _ = self.qkv_proj(x)
|
119
|
-
q, k, v = qkv.chunk(3, dim=-1)
|
120
|
-
|
121
|
-
# [b, s, embed_dim] --> [b * s, head, head_size]
|
122
|
-
q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
|
123
|
-
else:
|
124
|
-
# [b, s, embed_dim] --> [s, b, embed_dim]
|
125
|
-
x = rearrange(x, "b s ... -> s b ...")
|
126
|
-
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
127
|
-
qkv, _ = self.qkv_proj(x)
|
128
|
-
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
129
|
-
new_x_shape = qkv.size()[:-1] + (
|
130
|
-
head,
|
131
|
-
3 * self.hidden_size_per_attention_head,
|
132
|
-
)
|
133
|
-
qkv = qkv.view(*new_x_shape)
|
134
|
-
|
135
|
-
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
136
|
-
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
137
|
-
|
138
|
-
# [s, b, head, head_size] --> [b, s, head, head_size]
|
139
|
-
q, k, v = [
|
140
|
-
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
141
|
-
]
|
142
|
-
|
143
|
-
if position_embeddings is not None:
|
144
|
-
cos, sin = position_embeddings
|
145
|
-
original_shape = q.shape
|
146
|
-
# [total_tokens, head, head_size]
|
147
|
-
q = q.view(-1, head, self.head_size)
|
148
|
-
k = k.view(-1, head, self.head_size)
|
149
|
-
|
150
|
-
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
151
|
-
|
152
|
-
q = q.view(original_shape)
|
153
|
-
k = k.view(original_shape)
|
154
|
-
|
155
|
-
if self.use_qkv_parallel:
|
156
|
-
pass
|
157
|
-
else:
|
158
|
-
# [b, s, head, head_size] --> [b * s, head, head_size]
|
159
|
-
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
160
|
-
|
161
|
-
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
|
162
|
-
|
163
|
-
if self.use_qkv_parallel:
|
164
|
-
# [b * s, h, head_size] --> [b, s, h * head_size]
|
165
|
-
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
166
|
-
|
167
|
-
# [b, s, h * head_size] --> [b, s, h * head_size]
|
168
|
-
output, _ = self.proj(output)
|
169
|
-
else:
|
170
|
-
# [b * s, h, head_size] --> [s, b, h * head_size]
|
171
|
-
context_layer = rearrange(
|
172
|
-
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
173
|
-
).contiguous()
|
174
|
-
|
175
|
-
# [s, b, h * head_size] --> [s, b, h * head_size]
|
176
|
-
output, _ = self.proj(context_layer)
|
177
|
-
|
178
|
-
# [s, b, h * head_size] --> [b, s, h * head_size]
|
179
|
-
output = output.view(bsz, s, -1)
|
180
|
-
|
181
|
-
return output
|
52
|
+
@execute_once
|
53
|
+
def info_once(message: str):
|
54
|
+
logger.info(message)
|
182
55
|
|
183
56
|
|
184
57
|
class VisionSdpaAttention(nn.Module):
|
@@ -189,16 +62,22 @@ class VisionSdpaAttention(nn.Module):
|
|
189
62
|
|
190
63
|
def __init__(
|
191
64
|
self,
|
192
|
-
|
65
|
+
head_dim: int,
|
66
|
+
num_heads: int,
|
67
|
+
num_kv_heads: int,
|
193
68
|
dropout: float = 0.0,
|
194
69
|
flatten_batch: bool = False,
|
195
70
|
softmax_in_single_precision: bool = False,
|
71
|
+
**kwargs,
|
196
72
|
):
|
197
73
|
super().__init__()
|
198
|
-
self.head_size =
|
74
|
+
self.head_size = head_dim
|
75
|
+
self.num_heads = num_heads
|
76
|
+
self.num_kv_heads = num_kv_heads
|
199
77
|
self.flatten_batch = flatten_batch
|
200
78
|
self.softmax_in_single_precision = softmax_in_single_precision
|
201
79
|
self.dropout = dropout
|
80
|
+
self.scale = 1.0 / math.sqrt(self.head_size)
|
202
81
|
|
203
82
|
@staticmethod
|
204
83
|
@lru_cache(maxsize=128)
|
@@ -212,7 +91,7 @@ class VisionSdpaAttention(nn.Module):
|
|
212
91
|
flatten_batch: whether to flatten batch dimension
|
213
92
|
cu_seqlens: tuple of cumulative sequence lengths
|
214
93
|
Returns:
|
215
|
-
attention mask tensor
|
94
|
+
attention mask tensor of shape [b, 1, s, s] or [1, s, s]
|
216
95
|
"""
|
217
96
|
if flatten_batch:
|
218
97
|
mask = torch.zeros([1, s, s], dtype=torch.bool)
|
@@ -241,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
|
|
241
120
|
flatten_batch: bool = False,
|
242
121
|
) -> Optional[torch.Tensor]:
|
243
122
|
r"""
|
244
|
-
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1,
|
123
|
+
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
|
245
124
|
Args:
|
246
125
|
s: sequence length
|
247
126
|
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
@@ -264,6 +143,7 @@ class VisionSdpaAttention(nn.Module):
|
|
264
143
|
bsz: int,
|
265
144
|
cu_seqlens: Optional[torch.Tensor] = None,
|
266
145
|
attention_mask: Optional[torch.Tensor] = None,
|
146
|
+
**kwargs,
|
267
147
|
) -> torch.Tensor:
|
268
148
|
r"""
|
269
149
|
Args:
|
@@ -274,6 +154,8 @@ class VisionSdpaAttention(nn.Module):
|
|
274
154
|
if self.flatten_batch:
|
275
155
|
assert bsz == 1, "flatten_batch is True, bsz must be 1"
|
276
156
|
|
157
|
+
assert q.dim() == 3, q.shape
|
158
|
+
|
277
159
|
s = q.shape[0] // bsz
|
278
160
|
|
279
161
|
# [b, 1, s, s]
|
@@ -291,10 +173,10 @@ class VisionSdpaAttention(nn.Module):
|
|
291
173
|
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
292
174
|
|
293
175
|
if self.softmax_in_single_precision:
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
176
|
+
k = rearrange(k, "b h s d -> b h d s")
|
177
|
+
attn_weights = torch.matmul(q, k) * self.scale
|
178
|
+
del k
|
179
|
+
# masking
|
298
180
|
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
|
299
181
|
attn_weights = attn_weights + attention_mask
|
300
182
|
del attention_mask
|
@@ -332,6 +214,7 @@ class VisionTritonAttention(nn.Module):
|
|
332
214
|
|
333
215
|
def __init__(
|
334
216
|
self,
|
217
|
+
**kwargs,
|
335
218
|
):
|
336
219
|
super().__init__()
|
337
220
|
|
@@ -340,8 +223,8 @@ class VisionTritonAttention(nn.Module):
|
|
340
223
|
q: torch.Tensor,
|
341
224
|
k: torch.Tensor,
|
342
225
|
v: torch.Tensor,
|
343
|
-
_bsz: int,
|
344
226
|
cu_seqlens: Optional[torch.Tensor],
|
227
|
+
**kwargs,
|
345
228
|
) -> torch.Tensor:
|
346
229
|
r"""
|
347
230
|
Args:
|
@@ -366,3 +249,247 @@ class VisionTritonAttention(nn.Module):
|
|
366
249
|
)
|
367
250
|
|
368
251
|
return output
|
252
|
+
|
253
|
+
|
254
|
+
class VisionFlash3Attention(nn.Module):
|
255
|
+
def __init__(
|
256
|
+
self,
|
257
|
+
**kwargs,
|
258
|
+
):
|
259
|
+
if not _is_cuda:
|
260
|
+
raise Exception("VisionFlash3Attention is only available for cuda")
|
261
|
+
super().__init__()
|
262
|
+
|
263
|
+
def forward(
|
264
|
+
self,
|
265
|
+
q: torch.Tensor,
|
266
|
+
k: torch.Tensor,
|
267
|
+
v: torch.Tensor,
|
268
|
+
cu_seqlens: Optional[torch.Tensor],
|
269
|
+
attention_mask: Optional[torch.Tensor] = None,
|
270
|
+
**kwargs,
|
271
|
+
) -> torch.Tensor:
|
272
|
+
r"""
|
273
|
+
Args:
|
274
|
+
cu_seqlens: [b]
|
275
|
+
Returns:
|
276
|
+
[b * s, h, head_size]
|
277
|
+
"""
|
278
|
+
cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
|
279
|
+
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
280
|
+
max_seqlen = seq_lens.max().item()
|
281
|
+
output = flash_attn_varlen_func(
|
282
|
+
q,
|
283
|
+
k,
|
284
|
+
v,
|
285
|
+
cu_seqlens_q=cu_seqlens,
|
286
|
+
cu_seqlens_k=cu_seqlens,
|
287
|
+
max_seqlen_q=max_seqlen,
|
288
|
+
max_seqlen_k=max_seqlen,
|
289
|
+
)
|
290
|
+
|
291
|
+
return output
|
292
|
+
|
293
|
+
|
294
|
+
QKV_BACKEND_IMPL = {
|
295
|
+
"triton_attn": VisionTritonAttention,
|
296
|
+
"sdpa": VisionSdpaAttention,
|
297
|
+
"fa3": VisionFlash3Attention,
|
298
|
+
}
|
299
|
+
|
300
|
+
|
301
|
+
class VisionAttention(nn.Module):
|
302
|
+
r"""
|
303
|
+
Multi-headed attention without any cache, mostly used for multimodal transformers.
|
304
|
+
|
305
|
+
|
306
|
+
Args:
|
307
|
+
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
308
|
+
softmax_in_single_precision (bool, default to False):
|
309
|
+
if ``True``, the softmax will be performed in single-precision
|
310
|
+
Otherwise, it will be performed in half-precision
|
311
|
+
|
312
|
+
"""
|
313
|
+
|
314
|
+
def __init__(
|
315
|
+
self,
|
316
|
+
embed_dim: int,
|
317
|
+
num_heads: int,
|
318
|
+
projection_size: int,
|
319
|
+
use_qkv_parallel: bool,
|
320
|
+
qkv_backend: Optional[str] = None,
|
321
|
+
quant_config: Optional[QuantizationConfig] = None,
|
322
|
+
dropout: float = 0.0,
|
323
|
+
softmax_in_single_precision: bool = False,
|
324
|
+
flatten_batch: bool = False,
|
325
|
+
prefix: str = "",
|
326
|
+
proj_bias: bool = True,
|
327
|
+
**kwargs,
|
328
|
+
):
|
329
|
+
super().__init__()
|
330
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
331
|
+
self.dropout = dropout
|
332
|
+
self.head_size = embed_dim // num_heads
|
333
|
+
self.hidden_size_per_attention_head = dist_utils.divide(
|
334
|
+
projection_size, num_heads
|
335
|
+
)
|
336
|
+
self.num_attention_heads_per_partition = dist_utils.divide(
|
337
|
+
num_heads, world_size
|
338
|
+
)
|
339
|
+
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
340
|
+
num_heads, world_size
|
341
|
+
)
|
342
|
+
|
343
|
+
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
344
|
+
self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
|
345
|
+
|
346
|
+
if global_server_args_dict["mm_attention_backend"] is None:
|
347
|
+
if qkv_backend is None:
|
348
|
+
qkv_backend = "sdpa"
|
349
|
+
info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
350
|
+
else:
|
351
|
+
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
352
|
+
|
353
|
+
info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
354
|
+
|
355
|
+
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
356
|
+
head_dim=self.head_size,
|
357
|
+
num_heads=self.num_attention_heads_per_partition,
|
358
|
+
num_kv_heads=self.num_attention_kv_heads_per_partition,
|
359
|
+
dropout=dropout,
|
360
|
+
flatten_batch=flatten_batch,
|
361
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
362
|
+
)
|
363
|
+
|
364
|
+
self.use_qkv_parallel = use_qkv_parallel
|
365
|
+
if use_qkv_parallel:
|
366
|
+
self.qkv_proj = QKVParallelLinear(
|
367
|
+
hidden_size=embed_dim,
|
368
|
+
head_size=self.head_size,
|
369
|
+
total_num_heads=num_heads,
|
370
|
+
total_num_kv_heads=num_heads,
|
371
|
+
quant_config=quant_config,
|
372
|
+
prefix=add_prefix("qkv_proj", prefix),
|
373
|
+
)
|
374
|
+
else:
|
375
|
+
self.qkv_proj = ColumnParallelLinear(
|
376
|
+
input_size=embed_dim,
|
377
|
+
output_size=3 * projection_size,
|
378
|
+
quant_config=quant_config,
|
379
|
+
prefix=add_prefix("qkv_proj", prefix),
|
380
|
+
)
|
381
|
+
self.proj = RowParallelLinear(
|
382
|
+
input_size=embed_dim,
|
383
|
+
output_size=embed_dim,
|
384
|
+
bias=proj_bias,
|
385
|
+
quant_config=quant_config,
|
386
|
+
prefix=add_prefix("proj", prefix),
|
387
|
+
)
|
388
|
+
|
389
|
+
def forward(
|
390
|
+
self,
|
391
|
+
x: torch.Tensor,
|
392
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
393
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
394
|
+
attention_mask: Optional[torch.Tensor] = None,
|
395
|
+
**kwargs,
|
396
|
+
) -> torch.Tensor:
|
397
|
+
r"""
|
398
|
+
Args:
|
399
|
+
x: [b, s, embed_dim]
|
400
|
+
cu_seqlens: [b]
|
401
|
+
Returns:
|
402
|
+
[s, b, head * head_size]
|
403
|
+
"""
|
404
|
+
if x.dim() == 2:
|
405
|
+
x = x.unsqueeze(0)
|
406
|
+
assert x.dim() == 3, x.shape
|
407
|
+
bsz, s, _ = x.shape
|
408
|
+
head = self.num_attention_heads_per_partition
|
409
|
+
kv_head = self.num_attention_kv_heads_per_partition
|
410
|
+
if self.use_qkv_parallel:
|
411
|
+
# [b, s, embed_dim] --> [b, s, embed_dim]
|
412
|
+
qkv, _ = self.qkv_proj(x)
|
413
|
+
|
414
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
415
|
+
|
416
|
+
# [b, s, embed_dim] --> [b * s, head, head_size]
|
417
|
+
q = q.reshape(bsz * s, head, -1).contiguous()
|
418
|
+
k = k.reshape(bsz * s, kv_head, -1).contiguous()
|
419
|
+
v = v.reshape(bsz * s, kv_head, -1).contiguous()
|
420
|
+
else:
|
421
|
+
# [b, s, embed_dim] --> [s, b, embed_dim]
|
422
|
+
x = rearrange(x, "b s ... -> s b ...")
|
423
|
+
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
424
|
+
qkv, _ = self.qkv_proj(x)
|
425
|
+
|
426
|
+
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
427
|
+
new_x_shape = qkv.size()[:-1] + (
|
428
|
+
head,
|
429
|
+
3 * self.hidden_size_per_attention_head,
|
430
|
+
)
|
431
|
+
qkv = qkv.view(*new_x_shape)
|
432
|
+
|
433
|
+
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
434
|
+
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
435
|
+
# [s, b, head, head_size] --> [b, s, head, head_size]
|
436
|
+
q, k, v = [
|
437
|
+
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
438
|
+
]
|
439
|
+
|
440
|
+
if position_embeddings is not None:
|
441
|
+
cos, sin = position_embeddings
|
442
|
+
original_shape = q.shape
|
443
|
+
# [total_tokens, head, head_size]
|
444
|
+
q = q.view(-1, head, self.head_size)
|
445
|
+
k = k.view(-1, head, self.head_size)
|
446
|
+
|
447
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
448
|
+
|
449
|
+
q = q.view(original_shape)
|
450
|
+
k = k.view(original_shape)
|
451
|
+
|
452
|
+
if q.dim() == 4:
|
453
|
+
# [b, s, head, head_size] --> [b * s, head, head_size]
|
454
|
+
q = rearrange(q, "b s ... -> (b s) ...")
|
455
|
+
if k.dim() == 4:
|
456
|
+
# [b, s, head, head_size] --> [b * s, head, head_size]
|
457
|
+
k = rearrange(k, "b s ... -> (b s) ...")
|
458
|
+
if v.dim() == 4:
|
459
|
+
# [b, s, head, head_size] --> [b * s, head, head_size]
|
460
|
+
v = rearrange(v, "b s ... -> (b s) ...")
|
461
|
+
|
462
|
+
assert q.dim() == 3, q.dim()
|
463
|
+
assert k.dim() == 3, k.dim()
|
464
|
+
assert v.dim() == 3, v.dim()
|
465
|
+
|
466
|
+
output = self.qkv_backend.forward(
|
467
|
+
q=q,
|
468
|
+
k=k,
|
469
|
+
v=v,
|
470
|
+
bsz=bsz,
|
471
|
+
cu_seqlens=cu_seqlens,
|
472
|
+
attention_mask=attention_mask,
|
473
|
+
)
|
474
|
+
|
475
|
+
assert output.dim() == 3, output.shape
|
476
|
+
|
477
|
+
if self.use_qkv_parallel:
|
478
|
+
# [b * s, h, head_size] --> [b, s, h * head_size]
|
479
|
+
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
480
|
+
|
481
|
+
# [b, s, h * head_size] --> [b, s, h * head_size]
|
482
|
+
output, _ = self.proj(output)
|
483
|
+
else:
|
484
|
+
# [b * s, h, head_size] --> [s, b, h * head_size]
|
485
|
+
context_layer = rearrange(
|
486
|
+
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
487
|
+
).contiguous()
|
488
|
+
|
489
|
+
# [s, b, h * head_size] --> [s, b, h * head_size]
|
490
|
+
output, _ = self.proj(context_layer)
|
491
|
+
|
492
|
+
# [s, b, h * head_size] --> [b, s, h * head_size]
|
493
|
+
output = output.view(bsz, s, -1)
|
494
|
+
|
495
|
+
return output
|