sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
|
|
8
8
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
9
9
|
from sglang.srt.function_call.core_types import (
|
10
10
|
StreamingParseResult,
|
11
|
-
StructureInfo,
|
12
11
|
ToolCallItem,
|
13
12
|
_GetInfoFunc,
|
14
13
|
)
|
@@ -19,10 +18,17 @@ logger = logging.getLogger(__name__)
|
|
19
18
|
|
20
19
|
class PythonicDetector(BaseFormatDetector):
|
21
20
|
"""
|
22
|
-
Detector for Llama-
|
23
|
-
|
24
|
-
|
25
|
-
|
21
|
+
Detector for Llama-4 models with Pythonic tool call format.
|
22
|
+
|
23
|
+
The Pythonic format uses Python function call syntax within square brackets,
|
24
|
+
with arguments as Python literals rather than JSON.
|
25
|
+
|
26
|
+
Format Structure:
|
27
|
+
```
|
28
|
+
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
29
|
+
```
|
30
|
+
|
31
|
+
Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default
|
26
32
|
"""
|
27
33
|
|
28
34
|
def __init__(self):
|
@@ -75,11 +81,7 @@ class PythonicDetector(BaseFormatDetector):
|
|
75
81
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
76
82
|
|
77
83
|
calls = []
|
78
|
-
tool_indices =
|
79
|
-
tool.function.name: i
|
80
|
-
for i, tool in enumerate(tools)
|
81
|
-
if tool.function.name
|
82
|
-
}
|
84
|
+
tool_indices = self._get_tool_indices(tools)
|
83
85
|
for call_index, call in enumerate(parsed.elts):
|
84
86
|
if not isinstance(call.func, ast.Name):
|
85
87
|
continue
|
@@ -213,11 +215,11 @@ class PythonicDetector(BaseFormatDetector):
|
|
213
215
|
else:
|
214
216
|
raise ValueError("Tool call arguments must be literals")
|
215
217
|
|
216
|
-
def
|
217
|
-
|
218
|
-
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
|
218
|
+
def supports_structural_tag(self) -> bool:
|
219
|
+
return False
|
219
220
|
|
220
|
-
|
221
|
+
def structure_info(self) -> _GetInfoFunc:
|
222
|
+
raise NotImplementedError
|
221
223
|
|
222
224
|
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
223
225
|
return EBNFComposer.build_ebnf(
|
@@ -17,9 +17,18 @@ logger = logging.getLogger(__name__)
|
|
17
17
|
|
18
18
|
class Qwen25Detector(BaseFormatDetector):
|
19
19
|
"""
|
20
|
-
Detector for Qwen 2.5
|
21
|
-
|
22
|
-
|
20
|
+
Detector for Qwen 2.5 and Qwen 3 model function call format.
|
21
|
+
|
22
|
+
Format Structure:
|
23
|
+
```
|
24
|
+
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
|
25
|
+
```
|
26
|
+
|
27
|
+
Key Components:
|
28
|
+
- Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
|
29
|
+
- Function Call Object: JSON object with "name" and "arguments" fields
|
30
|
+
|
31
|
+
Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
|
23
32
|
"""
|
24
33
|
|
25
34
|
def __init__(self):
|
@@ -9,7 +9,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
|
|
9
9
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
10
10
|
from sglang.srt.function_call.core_types import (
|
11
11
|
StreamingParseResult,
|
12
|
-
StructureInfo,
|
13
12
|
ToolCallItem,
|
14
13
|
_GetInfoFunc,
|
15
14
|
)
|
@@ -29,7 +28,7 @@ def _safe_val(raw: str) -> Any:
|
|
29
28
|
return raw
|
30
29
|
|
31
30
|
|
32
|
-
class
|
31
|
+
class Qwen3CoderDetector(BaseFormatDetector):
|
33
32
|
"""
|
34
33
|
Detector for Qwen 3 models.
|
35
34
|
Assumes function call format:
|
@@ -127,24 +126,27 @@ class Qwen3XMLDetector(BaseFormatDetector):
|
|
127
126
|
params[pname] = _safe_val(pval)
|
128
127
|
raw = {"name": fname, "arguments": params}
|
129
128
|
try:
|
129
|
+
# TODO: fix idx in function call, the index for a function
|
130
|
+
# call will always be -1 in parse_base_json
|
130
131
|
res.extend(self.parse_base_json(raw, tools))
|
131
132
|
except Exception:
|
132
133
|
logger.warning("invalid tool call for %s dropped", fname)
|
133
134
|
return res
|
134
135
|
|
136
|
+
def supports_structural_tag(self) -> bool:
|
137
|
+
return False
|
138
|
+
|
135
139
|
def structure_info(self) -> _GetInfoFunc:
|
136
|
-
|
137
|
-
begin=f"{self.tool_call_start_token}\n<function={n}>",
|
138
|
-
end=f"</function>\n{self.tool_call_end_token}",
|
139
|
-
trigger=self.tool_call_start_token,
|
140
|
-
)
|
140
|
+
raise NotImplementedError
|
141
141
|
|
142
|
-
# TODO: fake ebnf for xml + outlines backend
|
143
142
|
def build_ebnf(self, tools: List[Tool]):
|
144
143
|
return EBNFComposer.build_ebnf(
|
145
144
|
tools,
|
146
145
|
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
|
147
146
|
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
|
148
147
|
tool_call_separator="\\n",
|
149
|
-
function_format="
|
148
|
+
function_format="xml",
|
149
|
+
call_rule_fmt='"<function={name}>\\n" {arguments_rule} "\\n</function>"',
|
150
|
+
key_value_rule_fmt='"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
|
151
|
+
key_value_separator="\\n",
|
150
152
|
)
|
sglang/srt/layers/activation.py
CHANGED
@@ -33,6 +33,7 @@ from sglang.srt.utils import (
|
|
33
33
|
cpu_has_amx_support,
|
34
34
|
is_cpu,
|
35
35
|
is_cuda,
|
36
|
+
is_hip,
|
36
37
|
is_npu,
|
37
38
|
set_weight_attrs,
|
38
39
|
)
|
@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
|
|
42
43
|
_is_npu = is_npu()
|
43
44
|
_is_cpu_amx_available = cpu_has_amx_support()
|
44
45
|
_is_cpu = is_cpu()
|
46
|
+
_is_hip = is_hip()
|
45
47
|
|
46
48
|
if _is_cuda:
|
47
49
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
50
|
+
elif _is_hip:
|
51
|
+
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
|
48
52
|
|
49
53
|
if is_npu():
|
50
54
|
import torch_npu
|
@@ -126,9 +130,13 @@ class QuickGELU(CustomOp):
|
|
126
130
|
return x * torch.sigmoid(1.702 * x)
|
127
131
|
|
128
132
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
129
|
-
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
|
130
133
|
return self.forward_native(x)
|
131
134
|
|
135
|
+
def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
|
136
|
+
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
137
|
+
gelu_quick(x, out)
|
138
|
+
return out
|
139
|
+
|
132
140
|
|
133
141
|
class ScaledActivation(nn.Module):
|
134
142
|
"""An activation function with post-scale parameters.
|
@@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
|
222
230
|
return nn.Identity()
|
223
231
|
|
224
232
|
|
225
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
233
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
226
234
|
logger.info(
|
227
|
-
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
235
|
+
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
228
236
|
)
|
229
237
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
|
|
65
65
|
**kwargs,
|
66
66
|
):
|
67
67
|
"""Run forward on an attention layer."""
|
68
|
-
if forward_batch.forward_mode.
|
68
|
+
if forward_batch.forward_mode.is_idle():
|
69
|
+
return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
70
|
+
elif forward_batch.forward_mode.is_decode():
|
69
71
|
return self.forward_decode(
|
70
72
|
q,
|
71
73
|
k,
|
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import TYPE_CHECKING, Optional, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
6
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
7
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
8
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
9
|
+
|
10
|
+
|
11
|
+
class HybridAttnBackend(AttentionBackend):
|
12
|
+
"""Support different backends for prefill and decode."""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
|
16
|
+
):
|
17
|
+
self.prefill_backend = prefill_backend
|
18
|
+
self.decode_backend = decode_backend
|
19
|
+
|
20
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
21
|
+
if forward_batch.forward_mode.is_decode():
|
22
|
+
self.decode_backend.init_forward_metadata(forward_batch)
|
23
|
+
else:
|
24
|
+
self.prefill_backend.init_forward_metadata(forward_batch)
|
25
|
+
|
26
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
27
|
+
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
28
|
+
|
29
|
+
def init_forward_metadata_capture_cuda_graph(
|
30
|
+
self,
|
31
|
+
bs: int,
|
32
|
+
num_tokens: int,
|
33
|
+
req_pool_indices: torch.Tensor,
|
34
|
+
seq_lens: torch.Tensor,
|
35
|
+
encoder_lens: Optional[torch.Tensor],
|
36
|
+
forward_mode: ForwardMode,
|
37
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
38
|
+
):
|
39
|
+
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
40
|
+
bs,
|
41
|
+
num_tokens,
|
42
|
+
req_pool_indices,
|
43
|
+
seq_lens,
|
44
|
+
encoder_lens,
|
45
|
+
forward_mode,
|
46
|
+
spec_info,
|
47
|
+
)
|
48
|
+
|
49
|
+
def init_forward_metadata_replay_cuda_graph(
|
50
|
+
self,
|
51
|
+
bs: int,
|
52
|
+
req_pool_indices: torch.Tensor,
|
53
|
+
seq_lens: torch.Tensor,
|
54
|
+
seq_lens_sum: int,
|
55
|
+
encoder_lens: Optional[torch.Tensor],
|
56
|
+
forward_mode: ForwardMode,
|
57
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
58
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
59
|
+
):
|
60
|
+
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
61
|
+
bs,
|
62
|
+
req_pool_indices,
|
63
|
+
seq_lens,
|
64
|
+
seq_lens_sum,
|
65
|
+
encoder_lens,
|
66
|
+
forward_mode,
|
67
|
+
spec_info,
|
68
|
+
seq_lens_cpu,
|
69
|
+
)
|
70
|
+
|
71
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
72
|
+
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
73
|
+
|
74
|
+
def forward_decode(
|
75
|
+
self,
|
76
|
+
q: torch.Tensor,
|
77
|
+
k: torch.Tensor,
|
78
|
+
v: torch.Tensor,
|
79
|
+
layer: RadixAttention,
|
80
|
+
forward_batch: ForwardBatch,
|
81
|
+
save_kv_cache: bool = True,
|
82
|
+
**kwargs,
|
83
|
+
):
|
84
|
+
return self.decode_backend.forward_decode(
|
85
|
+
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
86
|
+
)
|
87
|
+
|
88
|
+
def forward_extend(
|
89
|
+
self,
|
90
|
+
q: torch.Tensor,
|
91
|
+
k: torch.Tensor,
|
92
|
+
v: torch.Tensor,
|
93
|
+
layer: RadixAttention,
|
94
|
+
forward_batch: ForwardBatch,
|
95
|
+
save_kv_cache: bool = True,
|
96
|
+
**kwargs,
|
97
|
+
):
|
98
|
+
return self.prefill_backend.forward_extend(
|
99
|
+
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
100
|
+
)
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import dataclasses
|
4
4
|
import functools
|
5
5
|
import math
|
6
|
-
from functools import lru_cache
|
6
|
+
from functools import lru_cache, partial
|
7
7
|
from typing import Any, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import torch
|
@@ -18,11 +18,16 @@ _is_cuda = is_cuda()
|
|
18
18
|
if _is_cuda:
|
19
19
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
20
20
|
|
21
|
-
from sglang.srt.distributed import
|
21
|
+
from sglang.srt.distributed import (
|
22
|
+
parallel_state,
|
23
|
+
split_tensor_along_last_dim,
|
24
|
+
tensor_model_parallel_all_gather,
|
25
|
+
)
|
22
26
|
from sglang.srt.distributed import utils as dist_utils
|
23
27
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
24
28
|
context_attention_fwd,
|
25
29
|
)
|
30
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
26
31
|
from sglang.srt.layers.linear import (
|
27
32
|
ColumnParallelLinear,
|
28
33
|
QKVParallelLinear,
|
@@ -349,25 +354,44 @@ class VisionAttention(nn.Module):
|
|
349
354
|
flatten_batch: bool = False,
|
350
355
|
prefix: str = "",
|
351
356
|
proj_bias: bool = True,
|
357
|
+
num_dummy_heads: int = 0,
|
358
|
+
qkv_bias: bool = True,
|
359
|
+
qk_normalization: bool = False,
|
360
|
+
layer_norm_eps: float = 1e-06,
|
352
361
|
**kwargs,
|
353
362
|
):
|
354
363
|
super().__init__()
|
355
364
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
365
|
+
self.tp_size = world_size
|
366
|
+
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
356
367
|
self.dropout = dropout
|
357
368
|
self.head_size = embed_dim // num_heads
|
358
369
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
359
370
|
projection_size, num_heads
|
360
371
|
)
|
361
372
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
362
|
-
num_heads, world_size
|
373
|
+
num_dummy_heads + num_heads, world_size
|
363
374
|
)
|
364
375
|
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
365
|
-
num_heads, world_size
|
376
|
+
num_dummy_heads + num_heads, world_size
|
366
377
|
)
|
367
378
|
|
368
379
|
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
369
380
|
self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
|
370
381
|
|
382
|
+
self.qk_normalization = qk_normalization
|
383
|
+
|
384
|
+
# Additional dummy heads are used to enable TP for common GPU counts.
|
385
|
+
self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size
|
386
|
+
|
387
|
+
if self.qk_normalization:
|
388
|
+
self.q_norm = RMSNorm(
|
389
|
+
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
390
|
+
)
|
391
|
+
self.k_norm = RMSNorm(
|
392
|
+
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
393
|
+
)
|
394
|
+
|
371
395
|
if global_server_args_dict["mm_attention_backend"] is None:
|
372
396
|
if qkv_backend is None:
|
373
397
|
qkv_backend = "sdpa"
|
@@ -391,26 +415,46 @@ class VisionAttention(nn.Module):
|
|
391
415
|
self.qkv_proj = QKVParallelLinear(
|
392
416
|
hidden_size=embed_dim,
|
393
417
|
head_size=self.head_size,
|
394
|
-
total_num_heads=num_heads,
|
395
|
-
total_num_kv_heads=num_heads,
|
418
|
+
total_num_heads=num_dummy_heads + num_heads,
|
419
|
+
total_num_kv_heads=num_dummy_heads + num_heads,
|
420
|
+
bias=qkv_bias,
|
396
421
|
quant_config=quant_config,
|
397
422
|
prefix=add_prefix("qkv_proj", prefix),
|
398
423
|
)
|
399
424
|
else:
|
400
425
|
self.qkv_proj = ColumnParallelLinear(
|
401
426
|
input_size=embed_dim,
|
402
|
-
output_size=3 *
|
427
|
+
output_size=3 * self.dummy_dim,
|
428
|
+
bias=qkv_bias,
|
403
429
|
quant_config=quant_config,
|
404
430
|
prefix=add_prefix("qkv_proj", prefix),
|
405
431
|
)
|
406
432
|
self.proj = RowParallelLinear(
|
407
|
-
input_size=
|
433
|
+
input_size=self.dummy_dim,
|
408
434
|
output_size=embed_dim,
|
409
435
|
bias=proj_bias,
|
410
436
|
quant_config=quant_config,
|
411
437
|
prefix=add_prefix("proj", prefix),
|
412
438
|
)
|
413
439
|
|
440
|
+
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
441
|
+
"""apply qk norm for internvl vit attn"""
|
442
|
+
q = q.flatten(1, 2)
|
443
|
+
k = k.flatten(1, 2)
|
444
|
+
|
445
|
+
if self.tp_size > 1:
|
446
|
+
q = tensor_model_parallel_all_gather(q.contiguous())
|
447
|
+
k = tensor_model_parallel_all_gather(k.contiguous())
|
448
|
+
q = self.q_norm(q)
|
449
|
+
k = self.k_norm(k)
|
450
|
+
if self.tp_size > 1:
|
451
|
+
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
|
452
|
+
q = splitter(q)[self.tp_rank]
|
453
|
+
k = splitter(k)[self.tp_rank]
|
454
|
+
q = q.unflatten(-1, (-1, self.head_size))
|
455
|
+
k = k.unflatten(-1, (-1, self.head_size))
|
456
|
+
return q, k
|
457
|
+
|
414
458
|
def forward(
|
415
459
|
self,
|
416
460
|
x: torch.Tensor,
|
@@ -489,6 +533,10 @@ class VisionAttention(nn.Module):
|
|
489
533
|
assert k.dim() == 3, k.dim()
|
490
534
|
assert v.dim() == 3, v.dim()
|
491
535
|
|
536
|
+
# internvl
|
537
|
+
if self.qk_normalization:
|
538
|
+
q, k = self._apply_qk_norm(q, k)
|
539
|
+
|
492
540
|
output = self.qkv_backend.forward(
|
493
541
|
q=q,
|
494
542
|
k=k,
|
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
|
|
24
24
|
tensor_model_parallel_all_reduce,
|
25
25
|
)
|
26
26
|
from sglang.srt.layers.dp_attention import (
|
27
|
-
|
28
|
-
|
27
|
+
attn_tp_all_gather_into_tensor,
|
28
|
+
attn_tp_reduce_scatter_tensor,
|
29
29
|
dp_gather_partial,
|
30
30
|
dp_scatter,
|
31
31
|
get_attention_dp_size,
|
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
|
|
309
309
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
310
310
|
hidden_states,
|
311
311
|
)
|
312
|
-
|
313
|
-
|
312
|
+
attn_tp_all_gather_into_tensor(
|
313
|
+
hidden_states,
|
314
314
|
local_hidden_states,
|
315
315
|
)
|
316
316
|
return hidden_states
|
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
400
400
|
].clone(),
|
401
401
|
residual,
|
402
402
|
)
|
403
|
-
|
404
|
-
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
405
|
-
)
|
403
|
+
attn_tp_all_gather_into_tensor(residual, local_residual)
|
406
404
|
if context.attn_dp_size != 1:
|
407
405
|
if context.attn_tp_rank == 0:
|
408
406
|
hidden_states += residual
|
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
442
440
|
*,
|
443
441
|
residual_input_mode,
|
444
442
|
):
|
445
|
-
|
446
|
-
hidden_states =
|
447
|
-
|
443
|
+
input_hidden_states = hidden_states
|
444
|
+
hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
|
445
|
+
context.attn_tp_rank
|
446
|
+
]
|
447
|
+
attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
|
448
448
|
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
|
449
449
|
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
|
450
450
|
if hidden_states.shape[0] != 0:
|
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
|
|
547
547
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
548
548
|
hidden_states,
|
549
549
|
)
|
550
|
-
|
551
|
-
|
550
|
+
attn_tp_all_gather_into_tensor(
|
551
|
+
hidden_states,
|
552
552
|
local_hidden_states,
|
553
553
|
)
|
554
554
|
return hidden_states, residual
|
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
import functools
|
4
4
|
import logging
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from
|
6
|
+
from enum import IntEnum, auto
|
7
|
+
from typing import TYPE_CHECKING, List, Tuple
|
7
8
|
|
8
9
|
import torch
|
9
10
|
import triton
|
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
|
|
30
31
|
_LOCAL_ATTN_DP_RANK = None
|
31
32
|
|
32
33
|
|
34
|
+
class DPPaddingMode(IntEnum):
|
35
|
+
|
36
|
+
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
37
|
+
MAX_LEN = auto()
|
38
|
+
# Padding tokens to sum length and then gather tokens using `all_reduce`
|
39
|
+
SUM_LEN = auto()
|
40
|
+
|
41
|
+
def is_max_len(self):
|
42
|
+
return self == DPPaddingMode.MAX_LEN
|
43
|
+
|
44
|
+
def is_sum_len(self):
|
45
|
+
return self == DPPaddingMode.SUM_LEN
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
|
49
|
+
# we choose the mode that minimizes the communication cost
|
50
|
+
max_len = max(global_num_tokens)
|
51
|
+
sum_len = sum(global_num_tokens)
|
52
|
+
if sum_len * 2 > max_len * get_attention_dp_size():
|
53
|
+
return cls.MAX_LEN
|
54
|
+
else:
|
55
|
+
return cls.SUM_LEN
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
|
59
|
+
return cls.MAX_LEN
|
60
|
+
|
61
|
+
|
33
62
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
34
63
|
if not enable_dp_attention:
|
35
64
|
return tp_rank, tp_size, 0
|
@@ -162,7 +191,7 @@ def disable_dp_size():
|
|
162
191
|
_ATTN_DP_SIZE = old_dp_size
|
163
192
|
|
164
193
|
|
165
|
-
def get_dp_local_info(forward_batch: ForwardBatch):
|
194
|
+
def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
166
195
|
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
167
196
|
dp_rank = get_attention_dp_rank()
|
168
197
|
|
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
|
221
250
|
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
222
251
|
|
223
252
|
|
224
|
-
def
|
253
|
+
def _dp_gather_via_all_reduce(
|
225
254
|
global_tokens: torch.Tensor,
|
226
255
|
local_tokens: torch.Tensor,
|
227
256
|
forward_batch: ForwardBatch,
|
@@ -238,13 +267,6 @@ def _dp_gather(
|
|
238
267
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
239
268
|
), "aliasing between global_tokens and local_tokens not allowed"
|
240
269
|
|
241
|
-
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
242
|
-
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
243
|
-
# actual size of the accepted tokens.
|
244
|
-
if forward_batch.forward_mode.is_draft_extend():
|
245
|
-
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
246
|
-
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
247
|
-
|
248
270
|
memcpy_triton(
|
249
271
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
250
272
|
)
|
@@ -263,6 +285,38 @@ def _dp_gather(
|
|
263
285
|
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
264
286
|
|
265
287
|
|
288
|
+
def _dp_gather_via_all_gather(
|
289
|
+
global_tokens: torch.Tensor,
|
290
|
+
local_tokens: torch.Tensor,
|
291
|
+
forward_batch: ForwardBatch,
|
292
|
+
is_partial: bool,
|
293
|
+
):
|
294
|
+
if not is_partial:
|
295
|
+
if get_attention_tp_rank() != 0:
|
296
|
+
local_tokens.fill_(0)
|
297
|
+
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
|
298
|
+
get_attention_tp_rank()
|
299
|
+
]
|
300
|
+
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
|
301
|
+
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
|
302
|
+
|
303
|
+
|
304
|
+
def _dp_gather(
|
305
|
+
global_tokens: torch.Tensor,
|
306
|
+
local_tokens: torch.Tensor,
|
307
|
+
forward_batch: ForwardBatch,
|
308
|
+
is_partial: bool,
|
309
|
+
):
|
310
|
+
if forward_batch.dp_padding_mode.is_max_len():
|
311
|
+
_dp_gather_via_all_gather(
|
312
|
+
global_tokens, local_tokens, forward_batch, is_partial
|
313
|
+
)
|
314
|
+
else:
|
315
|
+
_dp_gather_via_all_reduce(
|
316
|
+
global_tokens, local_tokens, forward_batch, is_partial
|
317
|
+
)
|
318
|
+
|
319
|
+
|
266
320
|
def dp_gather_partial(
|
267
321
|
global_tokens: torch.Tensor,
|
268
322
|
local_tokens: torch.Tensor,
|
@@ -296,24 +350,18 @@ def dp_scatter(
|
|
296
350
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
297
351
|
), "aliasing between local_tokens and global_tokens not allowed"
|
298
352
|
|
299
|
-
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
300
|
-
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
301
|
-
# actual size of the accepted tokens.
|
302
|
-
if forward_batch.forward_mode.is_draft_extend():
|
303
|
-
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
304
|
-
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
305
|
-
|
306
353
|
memcpy_triton(
|
307
354
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
308
355
|
)
|
309
356
|
|
310
357
|
|
311
|
-
def
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
358
|
+
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
359
|
+
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
360
|
+
|
361
|
+
|
362
|
+
def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
|
363
|
+
return get_attention_tp_group().all_gather_into_tensor(output, input)
|
316
364
|
|
317
365
|
|
318
|
-
def attn_tp_all_gather(output_list: List[torch.Tensor],
|
319
|
-
return get_attention_tp_group().all_gather(
|
366
|
+
def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
|
367
|
+
return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -61,10 +61,15 @@ class RMSNorm(CustomOp):
|
|
61
61
|
self,
|
62
62
|
hidden_size: int,
|
63
63
|
eps: float = 1e-6,
|
64
|
+
var_hidden_size: Optional[int] = None,
|
64
65
|
) -> None:
|
65
66
|
super().__init__()
|
66
67
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
67
68
|
self.variance_epsilon = eps
|
69
|
+
self.hidden_size = hidden_size
|
70
|
+
self.variance_size_override = (
|
71
|
+
None if var_hidden_size == hidden_size else var_hidden_size
|
72
|
+
)
|
68
73
|
if _use_aiter:
|
69
74
|
self._forward_method = self.forward_aiter
|
70
75
|
|
@@ -73,6 +78,8 @@ class RMSNorm(CustomOp):
|
|
73
78
|
x: torch.Tensor,
|
74
79
|
residual: Optional[torch.Tensor] = None,
|
75
80
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
81
|
+
if self.variance_size_override is not None:
|
82
|
+
return self.forward_native(x, residual)
|
76
83
|
if residual is not None:
|
77
84
|
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
78
85
|
return x, residual
|
@@ -138,7 +145,25 @@ class RMSNorm(CustomOp):
|
|
138
145
|
x = x + residual.to(torch.float32)
|
139
146
|
residual = x.to(orig_dtype)
|
140
147
|
|
141
|
-
|
148
|
+
hidden_size = x.shape[-1]
|
149
|
+
if hidden_size != self.hidden_size:
|
150
|
+
raise ValueError(
|
151
|
+
"Expected hidden_size to be "
|
152
|
+
f"{self.hidden_size}, but found: {hidden_size}"
|
153
|
+
)
|
154
|
+
|
155
|
+
if self.variance_size_override is None:
|
156
|
+
x_var = x
|
157
|
+
else:
|
158
|
+
if hidden_size < self.variance_size_override:
|
159
|
+
raise ValueError(
|
160
|
+
"Expected hidden_size to be at least "
|
161
|
+
f"{self.variance_size_override}, but found: {hidden_size}"
|
162
|
+
)
|
163
|
+
|
164
|
+
x_var = x[..., : self.variance_size_override]
|
165
|
+
|
166
|
+
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
142
167
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
143
168
|
x = (x * self.weight).to(orig_dtype)
|
144
169
|
if residual is None:
|