sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,96 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def merge_state_kernel(
10
+ output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
11
+ output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
12
+ prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
13
+ prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
14
+ suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
15
+ suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
16
+ HEAD_SIZE: tl.constexpr,
17
+ PADDED_HEAD_SIZE: tl.constexpr,
18
+ OUTPUT_LSE: tl.constexpr,
19
+ ):
20
+ token_idx = tl.program_id(0)
21
+ num_tokens = tl.num_programs(0)
22
+ head_idx = tl.program_id(1)
23
+ num_heads = tl.num_programs(1)
24
+
25
+ p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
26
+ s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
27
+ p_lse = float("-inf") if p_lse == float("inf") else p_lse
28
+ s_lse = float("-inf") if s_lse == float("inf") else s_lse
29
+
30
+ max_lse = tl.maximum(p_lse, s_lse)
31
+ p_lse = p_lse - max_lse
32
+ s_lse = s_lse - max_lse
33
+ out_se = tl.exp(p_lse) + tl.exp(s_lse)
34
+
35
+ if OUTPUT_LSE:
36
+ out_lse = tl.log(out_se) + max_lse
37
+ tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
38
+
39
+ head_arange = tl.arange(0, PADDED_HEAD_SIZE)
40
+ head_mask = head_arange < HEAD_SIZE
41
+ p_out = tl.load(
42
+ prefix_output
43
+ + token_idx * num_heads * HEAD_SIZE
44
+ + head_idx * HEAD_SIZE
45
+ + head_arange,
46
+ mask=head_mask,
47
+ )
48
+ s_out = tl.load(
49
+ suffix_output
50
+ + token_idx * num_heads * HEAD_SIZE
51
+ + head_idx * HEAD_SIZE
52
+ + head_arange,
53
+ mask=head_mask,
54
+ )
55
+
56
+ p_scale = tl.exp(p_lse) / out_se
57
+ s_scale = tl.exp(s_lse) / out_se
58
+ out = p_out * p_scale + s_out * s_scale
59
+ tl.store(
60
+ output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
61
+ out,
62
+ mask=head_mask,
63
+ )
64
+
65
+
66
+ def merge_state_triton(
67
+ prefix_output: torch.Tensor,
68
+ prefix_lse: torch.Tensor,
69
+ suffix_output: torch.Tensor,
70
+ suffix_lse: torch.Tensor,
71
+ output: Optional[torch.Tensor] = None,
72
+ output_lse: Optional[torch.Tensor] = None,
73
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
74
+ # Avoid creating new tensors if they are already provided
75
+ if output is None:
76
+ output = torch.empty_like(prefix_output)
77
+ if output_lse is None:
78
+ output_lse = torch.empty_like(prefix_lse)
79
+
80
+ num_tokens = output.shape[0]
81
+ num_query_heads = output.shape[1]
82
+ head_size = output.shape[2]
83
+ padded_head_size = triton.next_power_of_2(head_size)
84
+
85
+ merge_state_kernel[(num_tokens, num_query_heads)](
86
+ output,
87
+ output_lse,
88
+ prefix_output,
89
+ prefix_lse,
90
+ suffix_output,
91
+ suffix_lse,
92
+ head_size,
93
+ padded_head_size,
94
+ output_lse is not None,
95
+ )
96
+ return output, output_lse
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
28
28
 
29
29
  num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
30
30
  for i in range(num_loop):
31
- offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
31
+ # index into req_to_token_ptr needs to be int64
32
+ offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
32
33
  mask = offset < kv_end - kv_start
33
34
  data = tl.load(
34
35
  req_to_token_ptr
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
70
71
  num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
71
72
 
72
73
  for i in range(num_pages_loop):
74
+ # index into req_to_token_ptr needs to be int64
73
75
  paged_offset = (
74
- tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
76
+ tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
75
77
  ) * PAGED_SIZE
76
78
  paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
77
79
 
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import lru_cache
3
+ import math
4
+ from functools import lru_cache, wraps
4
5
  from typing import Optional, Tuple
5
6
 
6
7
  import torch
@@ -8,6 +9,13 @@ import torch.nn as nn
8
9
  import torch.nn.functional as F
9
10
  from einops import rearrange
10
11
 
12
+ from sglang.srt.utils import is_cuda
13
+
14
+ _is_cuda = is_cuda()
15
+
16
+ if _is_cuda:
17
+ from sgl_kernel.flash_attn import flash_attn_varlen_func
18
+
11
19
  from sglang.srt.distributed import parallel_state
12
20
  from sglang.srt.distributed import utils as dist_utils
13
21
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
@@ -19,166 +27,31 @@ from sglang.srt.layers.linear import (
19
27
  RowParallelLinear,
20
28
  )
21
29
  from sglang.srt.layers.quantization import QuantizationConfig
22
- from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
23
- from sglang.srt.utils import add_prefix
30
+ from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
31
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
32
+ from sglang.srt.utils import add_prefix, logger
24
33
 
25
-
26
- class VisionAttention(nn.Module):
27
- r"""
28
- Multi-headed attention without any cache, mostly used for ViT.
34
+ ROTARY_EMBED_CLASSES = {
35
+ "normal": apply_rotary_pos_emb,
36
+ }
29
37
 
30
38
 
31
- Args:
32
- use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
33
- use_context_forward (bool, default to True):
34
- if ``True``, a flash_attn style attention will be applied
35
- Otherwise, a full-sequence attention will be applied.
36
- softmax_in_single_precision (bool, default to False):
37
- if ``True``, the softmax will be performed in single-precision
38
- Otherwise, it will be performed in half-precision
39
+ def execute_once(func):
40
+ has_run = None
39
41
 
40
- """
42
+ @wraps(func)
43
+ def wrapper(*args, **kwargs):
44
+ nonlocal has_run
45
+ if not has_run:
46
+ func(*args, **kwargs)
47
+ has_run = True
41
48
 
42
- def __init__(
43
- self,
44
- embed_dim: int,
45
- num_heads: int,
46
- projection_size: int,
47
- use_qkv_parallel: bool,
48
- quant_config: Optional[QuantizationConfig] = None,
49
- dropout: float = 0.0,
50
- use_context_forward: bool = True,
51
- softmax_in_single_precision: bool = False,
52
- flatten_batch: bool = False,
53
- prefix: str = "",
54
- ):
55
- super().__init__()
56
- self.use_context_forward = use_context_forward
57
- world_size = parallel_state.get_tensor_model_parallel_world_size()
58
- self.dropout = dropout
59
- self.head_size = embed_dim // num_heads
60
- self.hidden_size_per_attention_head = dist_utils.divide(
61
- projection_size, num_heads
62
- )
63
- self.num_attention_heads_per_partition = dist_utils.divide(
64
- num_heads, world_size
65
- )
49
+ return wrapper
66
50
 
67
- if self.use_context_forward:
68
- self.qkv_backend = VisionTritonAttention()
69
- else:
70
- self.qkv_backend = VisionSdpaAttention(
71
- head_size=self.head_size,
72
- dropout=dropout,
73
- flatten_batch=flatten_batch,
74
- softmax_in_single_precision=softmax_in_single_precision,
75
- )
76
51
 
77
- self.use_qkv_parallel = use_qkv_parallel
78
- if use_qkv_parallel:
79
- self.qkv_proj = QKVParallelLinear(
80
- hidden_size=embed_dim,
81
- head_size=self.head_size,
82
- total_num_heads=num_heads,
83
- quant_config=quant_config,
84
- prefix=add_prefix("qkv_proj", prefix),
85
- )
86
- else:
87
- self.qkv_proj = ColumnParallelLinear(
88
- input_size=embed_dim,
89
- output_size=3 * projection_size,
90
- quant_config=quant_config,
91
- prefix=add_prefix("qkv_proj", prefix),
92
- )
93
- self.proj = RowParallelLinear(
94
- input_size=embed_dim,
95
- output_size=embed_dim,
96
- quant_config=quant_config,
97
- prefix=add_prefix("proj", prefix),
98
- )
99
-
100
- def forward(
101
- self,
102
- x: torch.Tensor,
103
- cu_seqlens: Optional[torch.Tensor] = None,
104
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
105
- attention_mask: Optional[torch.Tensor] = None,
106
- ) -> torch.Tensor:
107
- r"""
108
- Args:
109
- x: [b, s, embed_dim]
110
- cu_seqlens: [b]
111
- Returns:
112
- [s, b, head * head_size]
113
- """
114
- bsz, s, _ = x.shape
115
- head = self.num_attention_heads_per_partition
116
- if self.use_qkv_parallel:
117
- # [b, s, embed_dim] --> [b, s, embed_dim]
118
- qkv, _ = self.qkv_proj(x)
119
- q, k, v = qkv.chunk(3, dim=-1)
120
-
121
- # [b, s, embed_dim] --> [b * s, head, head_size]
122
- q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
123
- else:
124
- # [b, s, embed_dim] --> [s, b, embed_dim]
125
- x = rearrange(x, "b s ... -> s b ...")
126
- # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
127
- qkv, _ = self.qkv_proj(x)
128
- # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
129
- new_x_shape = qkv.size()[:-1] + (
130
- head,
131
- 3 * self.hidden_size_per_attention_head,
132
- )
133
- qkv = qkv.view(*new_x_shape)
134
-
135
- # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
136
- q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
137
-
138
- # [s, b, head, head_size] --> [b, s, head, head_size]
139
- q, k, v = [
140
- rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
141
- ]
142
-
143
- if position_embeddings is not None:
144
- cos, sin = position_embeddings
145
- original_shape = q.shape
146
- # [total_tokens, head, head_size]
147
- q = q.view(-1, head, self.head_size)
148
- k = k.view(-1, head, self.head_size)
149
-
150
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
151
-
152
- q = q.view(original_shape)
153
- k = k.view(original_shape)
154
-
155
- if self.use_qkv_parallel:
156
- pass
157
- else:
158
- # [b, s, head, head_size] --> [b * s, head, head_size]
159
- q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
160
-
161
- output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
162
-
163
- if self.use_qkv_parallel:
164
- # [b * s, h, head_size] --> [b, s, h * head_size]
165
- output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
166
-
167
- # [b, s, h * head_size] --> [b, s, h * head_size]
168
- output, _ = self.proj(output)
169
- else:
170
- # [b * s, h, head_size] --> [s, b, h * head_size]
171
- context_layer = rearrange(
172
- output, "(b s) h d -> s b (h d)", b=bsz, s=s
173
- ).contiguous()
174
-
175
- # [s, b, h * head_size] --> [s, b, h * head_size]
176
- output, _ = self.proj(context_layer)
177
-
178
- # [s, b, h * head_size] --> [b, s, h * head_size]
179
- output = output.view(bsz, s, -1)
180
-
181
- return output
52
+ @execute_once
53
+ def info_once(message: str):
54
+ logger.info(message)
182
55
 
183
56
 
184
57
  class VisionSdpaAttention(nn.Module):
@@ -189,16 +62,22 @@ class VisionSdpaAttention(nn.Module):
189
62
 
190
63
  def __init__(
191
64
  self,
192
- head_size: int,
65
+ head_dim: int,
66
+ num_heads: int,
67
+ num_kv_heads: int,
193
68
  dropout: float = 0.0,
194
69
  flatten_batch: bool = False,
195
70
  softmax_in_single_precision: bool = False,
71
+ **kwargs,
196
72
  ):
197
73
  super().__init__()
198
- self.head_size = head_size
74
+ self.head_size = head_dim
75
+ self.num_heads = num_heads
76
+ self.num_kv_heads = num_kv_heads
199
77
  self.flatten_batch = flatten_batch
200
78
  self.softmax_in_single_precision = softmax_in_single_precision
201
79
  self.dropout = dropout
80
+ self.scale = 1.0 / math.sqrt(self.head_size)
202
81
 
203
82
  @staticmethod
204
83
  @lru_cache(maxsize=128)
@@ -212,7 +91,7 @@ class VisionSdpaAttention(nn.Module):
212
91
  flatten_batch: whether to flatten batch dimension
213
92
  cu_seqlens: tuple of cumulative sequence lengths
214
93
  Returns:
215
- attention mask tensor
94
+ attention mask tensor of shape [b, 1, s, s] or [1, s, s]
216
95
  """
217
96
  if flatten_batch:
218
97
  mask = torch.zeros([1, s, s], dtype=torch.bool)
@@ -241,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
241
120
  flatten_batch: bool = False,
242
121
  ) -> Optional[torch.Tensor]:
243
122
  r"""
244
- Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
123
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
245
124
  Args:
246
125
  s: sequence length
247
126
  cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
@@ -264,6 +143,7 @@ class VisionSdpaAttention(nn.Module):
264
143
  bsz: int,
265
144
  cu_seqlens: Optional[torch.Tensor] = None,
266
145
  attention_mask: Optional[torch.Tensor] = None,
146
+ **kwargs,
267
147
  ) -> torch.Tensor:
268
148
  r"""
269
149
  Args:
@@ -274,6 +154,8 @@ class VisionSdpaAttention(nn.Module):
274
154
  if self.flatten_batch:
275
155
  assert bsz == 1, "flatten_batch is True, bsz must be 1"
276
156
 
157
+ assert q.dim() == 3, q.shape
158
+
277
159
  s = q.shape[0] // bsz
278
160
 
279
161
  # [b, 1, s, s]
@@ -291,10 +173,10 @@ class VisionSdpaAttention(nn.Module):
291
173
  q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
292
174
 
293
175
  if self.softmax_in_single_precision:
294
- scale = self.head_size**-0.5
295
- k_transposed = rearrange(k, "b h s d -> b h d s")
296
- attn_weights = torch.matmul(q, k_transposed) * scale
297
- del k, k_transposed
176
+ k = rearrange(k, "b h s d -> b h d s")
177
+ attn_weights = torch.matmul(q, k) * self.scale
178
+ del k
179
+ # masking
298
180
  attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
299
181
  attn_weights = attn_weights + attention_mask
300
182
  del attention_mask
@@ -332,6 +214,7 @@ class VisionTritonAttention(nn.Module):
332
214
 
333
215
  def __init__(
334
216
  self,
217
+ **kwargs,
335
218
  ):
336
219
  super().__init__()
337
220
 
@@ -340,8 +223,8 @@ class VisionTritonAttention(nn.Module):
340
223
  q: torch.Tensor,
341
224
  k: torch.Tensor,
342
225
  v: torch.Tensor,
343
- _bsz: int,
344
226
  cu_seqlens: Optional[torch.Tensor],
227
+ **kwargs,
345
228
  ) -> torch.Tensor:
346
229
  r"""
347
230
  Args:
@@ -366,3 +249,247 @@ class VisionTritonAttention(nn.Module):
366
249
  )
367
250
 
368
251
  return output
252
+
253
+
254
+ class VisionFlash3Attention(nn.Module):
255
+ def __init__(
256
+ self,
257
+ **kwargs,
258
+ ):
259
+ if not _is_cuda:
260
+ raise Exception("VisionFlash3Attention is only available for cuda")
261
+ super().__init__()
262
+
263
+ def forward(
264
+ self,
265
+ q: torch.Tensor,
266
+ k: torch.Tensor,
267
+ v: torch.Tensor,
268
+ cu_seqlens: Optional[torch.Tensor],
269
+ attention_mask: Optional[torch.Tensor] = None,
270
+ **kwargs,
271
+ ) -> torch.Tensor:
272
+ r"""
273
+ Args:
274
+ cu_seqlens: [b]
275
+ Returns:
276
+ [b * s, h, head_size]
277
+ """
278
+ cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
279
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
280
+ max_seqlen = seq_lens.max().item()
281
+ output = flash_attn_varlen_func(
282
+ q,
283
+ k,
284
+ v,
285
+ cu_seqlens_q=cu_seqlens,
286
+ cu_seqlens_k=cu_seqlens,
287
+ max_seqlen_q=max_seqlen,
288
+ max_seqlen_k=max_seqlen,
289
+ )
290
+
291
+ return output
292
+
293
+
294
+ QKV_BACKEND_IMPL = {
295
+ "triton_attn": VisionTritonAttention,
296
+ "sdpa": VisionSdpaAttention,
297
+ "fa3": VisionFlash3Attention,
298
+ }
299
+
300
+
301
+ class VisionAttention(nn.Module):
302
+ r"""
303
+ Multi-headed attention without any cache, mostly used for multimodal transformers.
304
+
305
+
306
+ Args:
307
+ use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
308
+ softmax_in_single_precision (bool, default to False):
309
+ if ``True``, the softmax will be performed in single-precision
310
+ Otherwise, it will be performed in half-precision
311
+
312
+ """
313
+
314
+ def __init__(
315
+ self,
316
+ embed_dim: int,
317
+ num_heads: int,
318
+ projection_size: int,
319
+ use_qkv_parallel: bool,
320
+ qkv_backend: Optional[str] = None,
321
+ quant_config: Optional[QuantizationConfig] = None,
322
+ dropout: float = 0.0,
323
+ softmax_in_single_precision: bool = False,
324
+ flatten_batch: bool = False,
325
+ prefix: str = "",
326
+ proj_bias: bool = True,
327
+ **kwargs,
328
+ ):
329
+ super().__init__()
330
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
331
+ self.dropout = dropout
332
+ self.head_size = embed_dim // num_heads
333
+ self.hidden_size_per_attention_head = dist_utils.divide(
334
+ projection_size, num_heads
335
+ )
336
+ self.num_attention_heads_per_partition = dist_utils.divide(
337
+ num_heads, world_size
338
+ )
339
+ self.num_attention_kv_heads_per_partition = dist_utils.divide(
340
+ num_heads, world_size
341
+ )
342
+
343
+ self.q_size = self.num_attention_heads_per_partition * self.head_size
344
+ self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
345
+
346
+ if global_server_args_dict["mm_attention_backend"] is None:
347
+ if qkv_backend is None:
348
+ qkv_backend = "sdpa"
349
+ info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
350
+ else:
351
+ qkv_backend = global_server_args_dict["mm_attention_backend"]
352
+
353
+ info_once(f"Using {qkv_backend} as multimodal attention backend.")
354
+
355
+ self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
356
+ head_dim=self.head_size,
357
+ num_heads=self.num_attention_heads_per_partition,
358
+ num_kv_heads=self.num_attention_kv_heads_per_partition,
359
+ dropout=dropout,
360
+ flatten_batch=flatten_batch,
361
+ softmax_in_single_precision=softmax_in_single_precision,
362
+ )
363
+
364
+ self.use_qkv_parallel = use_qkv_parallel
365
+ if use_qkv_parallel:
366
+ self.qkv_proj = QKVParallelLinear(
367
+ hidden_size=embed_dim,
368
+ head_size=self.head_size,
369
+ total_num_heads=num_heads,
370
+ total_num_kv_heads=num_heads,
371
+ quant_config=quant_config,
372
+ prefix=add_prefix("qkv_proj", prefix),
373
+ )
374
+ else:
375
+ self.qkv_proj = ColumnParallelLinear(
376
+ input_size=embed_dim,
377
+ output_size=3 * projection_size,
378
+ quant_config=quant_config,
379
+ prefix=add_prefix("qkv_proj", prefix),
380
+ )
381
+ self.proj = RowParallelLinear(
382
+ input_size=embed_dim,
383
+ output_size=embed_dim,
384
+ bias=proj_bias,
385
+ quant_config=quant_config,
386
+ prefix=add_prefix("proj", prefix),
387
+ )
388
+
389
+ def forward(
390
+ self,
391
+ x: torch.Tensor,
392
+ cu_seqlens: Optional[torch.Tensor] = None,
393
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
394
+ attention_mask: Optional[torch.Tensor] = None,
395
+ **kwargs,
396
+ ) -> torch.Tensor:
397
+ r"""
398
+ Args:
399
+ x: [b, s, embed_dim]
400
+ cu_seqlens: [b]
401
+ Returns:
402
+ [s, b, head * head_size]
403
+ """
404
+ if x.dim() == 2:
405
+ x = x.unsqueeze(0)
406
+ assert x.dim() == 3, x.shape
407
+ bsz, s, _ = x.shape
408
+ head = self.num_attention_heads_per_partition
409
+ kv_head = self.num_attention_kv_heads_per_partition
410
+ if self.use_qkv_parallel:
411
+ # [b, s, embed_dim] --> [b, s, embed_dim]
412
+ qkv, _ = self.qkv_proj(x)
413
+
414
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
415
+
416
+ # [b, s, embed_dim] --> [b * s, head, head_size]
417
+ q = q.reshape(bsz * s, head, -1).contiguous()
418
+ k = k.reshape(bsz * s, kv_head, -1).contiguous()
419
+ v = v.reshape(bsz * s, kv_head, -1).contiguous()
420
+ else:
421
+ # [b, s, embed_dim] --> [s, b, embed_dim]
422
+ x = rearrange(x, "b s ... -> s b ...")
423
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
424
+ qkv, _ = self.qkv_proj(x)
425
+
426
+ # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
427
+ new_x_shape = qkv.size()[:-1] + (
428
+ head,
429
+ 3 * self.hidden_size_per_attention_head,
430
+ )
431
+ qkv = qkv.view(*new_x_shape)
432
+
433
+ # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
434
+ q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
435
+ # [s, b, head, head_size] --> [b, s, head, head_size]
436
+ q, k, v = [
437
+ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
438
+ ]
439
+
440
+ if position_embeddings is not None:
441
+ cos, sin = position_embeddings
442
+ original_shape = q.shape
443
+ # [total_tokens, head, head_size]
444
+ q = q.view(-1, head, self.head_size)
445
+ k = k.view(-1, head, self.head_size)
446
+
447
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
448
+
449
+ q = q.view(original_shape)
450
+ k = k.view(original_shape)
451
+
452
+ if q.dim() == 4:
453
+ # [b, s, head, head_size] --> [b * s, head, head_size]
454
+ q = rearrange(q, "b s ... -> (b s) ...")
455
+ if k.dim() == 4:
456
+ # [b, s, head, head_size] --> [b * s, head, head_size]
457
+ k = rearrange(k, "b s ... -> (b s) ...")
458
+ if v.dim() == 4:
459
+ # [b, s, head, head_size] --> [b * s, head, head_size]
460
+ v = rearrange(v, "b s ... -> (b s) ...")
461
+
462
+ assert q.dim() == 3, q.dim()
463
+ assert k.dim() == 3, k.dim()
464
+ assert v.dim() == 3, v.dim()
465
+
466
+ output = self.qkv_backend.forward(
467
+ q=q,
468
+ k=k,
469
+ v=v,
470
+ bsz=bsz,
471
+ cu_seqlens=cu_seqlens,
472
+ attention_mask=attention_mask,
473
+ )
474
+
475
+ assert output.dim() == 3, output.shape
476
+
477
+ if self.use_qkv_parallel:
478
+ # [b * s, h, head_size] --> [b, s, h * head_size]
479
+ output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
480
+
481
+ # [b, s, h * head_size] --> [b, s, h * head_size]
482
+ output, _ = self.proj(output)
483
+ else:
484
+ # [b * s, h, head_size] --> [s, b, h * head_size]
485
+ context_layer = rearrange(
486
+ output, "(b s) h d -> s b (h d)", b=bsz, s=s
487
+ ).contiguous()
488
+
489
+ # [s, b, h * head_size] --> [s, b, h * head_size]
490
+ output, _ = self.proj(context_layer)
491
+
492
+ # [s, b, h * head_size] --> [b, s, h * head_size]
493
+ output = output.view(bsz, s, -1)
494
+
495
+ return output